libzmq/src/curve_client.cpp
Luca Boccassi da31917f4f Relicense from LGPL3 + exceptions to Mozilla Public License version 2.0
Relicense permission collected from all relevant authors as tallied at:
https://github.com/rlenferink/libzmq-relicense/blob/master/checklist.md
The relicense grants are collected under RELICENSE/ and will be moved
to the above repository in a later commit.

Fixes https://github.com/zeromq/libzmq/issues/2376
2023-06-05 20:31:47 +01:00

261 lines
7.8 KiB
C++

/* SPDX-License-Identifier: MPL-2.0 */
#include "precompiled.hpp"
#include "macros.hpp"
#ifdef ZMQ_HAVE_CURVE
#include "msg.hpp"
#include "session_base.hpp"
#include "err.hpp"
#include "curve_client.hpp"
#include "wire.hpp"
#include "curve_client_tools.hpp"
#include "secure_allocator.hpp"
zmq::curve_client_t::curve_client_t (session_base_t *session_,
const options_t &options_,
const bool downgrade_sub_) :
mechanism_base_t (session_, options_),
curve_mechanism_base_t (session_,
options_,
"CurveZMQMESSAGEC",
"CurveZMQMESSAGES",
downgrade_sub_),
_state (send_hello),
_tools (options_.curve_public_key,
options_.curve_secret_key,
options_.curve_server_key)
{
}
zmq::curve_client_t::~curve_client_t ()
{
}
int zmq::curve_client_t::next_handshake_command (msg_t *msg_)
{
int rc = 0;
switch (_state) {
case send_hello:
rc = produce_hello (msg_);
if (rc == 0)
_state = expect_welcome;
break;
case send_initiate:
rc = produce_initiate (msg_);
if (rc == 0)
_state = expect_ready;
break;
default:
errno = EAGAIN;
rc = -1;
}
return rc;
}
int zmq::curve_client_t::process_handshake_command (msg_t *msg_)
{
const unsigned char *msg_data =
static_cast<unsigned char *> (msg_->data ());
const size_t msg_size = msg_->size ();
int rc = 0;
if (curve_client_tools_t::is_handshake_command_welcome (msg_data, msg_size))
rc = process_welcome (msg_data, msg_size);
else if (curve_client_tools_t::is_handshake_command_ready (msg_data,
msg_size))
rc = process_ready (msg_data, msg_size);
else if (curve_client_tools_t::is_handshake_command_error (msg_data,
msg_size))
rc = process_error (msg_data, msg_size);
else {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND);
errno = EPROTO;
rc = -1;
}
if (rc == 0) {
rc = msg_->close ();
errno_assert (rc == 0);
rc = msg_->init ();
errno_assert (rc == 0);
}
return rc;
}
int zmq::curve_client_t::encode (msg_t *msg_)
{
zmq_assert (_state == connected);
return curve_mechanism_base_t::encode (msg_);
}
int zmq::curve_client_t::decode (msg_t *msg_)
{
zmq_assert (_state == connected);
return curve_mechanism_base_t::decode (msg_);
}
zmq::mechanism_t::status_t zmq::curve_client_t::status () const
{
if (_state == connected)
return mechanism_t::ready;
if (_state == error_received)
return mechanism_t::error;
return mechanism_t::handshaking;
}
int zmq::curve_client_t::produce_hello (msg_t *msg_)
{
int rc = msg_->init_size (200);
errno_assert (rc == 0);
rc = _tools.produce_hello (msg_->data (), get_and_inc_nonce ());
if (rc == -1) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
// TODO this is somewhat inconsistent: we call init_size, but we may
// not close msg_; i.e. we assume that msg_ is initialized but empty
// (if it were non-empty, calling init_size might cause a leak!)
// msg_->close ();
return -1;
}
return 0;
}
int zmq::curve_client_t::process_welcome (const uint8_t *msg_data_,
size_t msg_size_)
{
const int rc = _tools.process_welcome (msg_data_, msg_size_,
get_writable_precom_buffer ());
if (rc == -1) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
errno = EPROTO;
return -1;
}
_state = send_initiate;
return 0;
}
int zmq::curve_client_t::produce_initiate (msg_t *msg_)
{
const size_t metadata_length = basic_properties_len ();
std::vector<unsigned char, secure_allocator_t<unsigned char> >
metadata_plaintext (metadata_length);
add_basic_properties (&metadata_plaintext[0], metadata_length);
const size_t msg_size =
113 + 128 + crypto_box_BOXZEROBYTES + metadata_length;
int rc = msg_->init_size (msg_size);
errno_assert (rc == 0);
rc = _tools.produce_initiate (msg_->data (), msg_size, get_and_inc_nonce (),
&metadata_plaintext[0], metadata_length);
if (-1 == rc) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
// TODO see comment in produce_hello
return -1;
}
return 0;
}
int zmq::curve_client_t::process_ready (const uint8_t *msg_data_,
size_t msg_size_)
{
if (msg_size_ < 30) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (),
ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_READY);
errno = EPROTO;
return -1;
}
const size_t clen = (msg_size_ - 14) + crypto_box_BOXZEROBYTES;
uint8_t ready_nonce[crypto_box_NONCEBYTES];
std::vector<uint8_t, secure_allocator_t<uint8_t> > ready_plaintext (
crypto_box_ZEROBYTES + clen);
std::vector<uint8_t> ready_box (crypto_box_BOXZEROBYTES + 16 + clen);
std::fill (ready_box.begin (), ready_box.begin () + crypto_box_BOXZEROBYTES,
0);
memcpy (&ready_box[crypto_box_BOXZEROBYTES], msg_data_ + 14,
clen - crypto_box_BOXZEROBYTES);
memcpy (ready_nonce, "CurveZMQREADY---", 16);
memcpy (ready_nonce + 16, msg_data_ + 6, 8);
set_peer_nonce (get_uint64 (msg_data_ + 6));
int rc = crypto_box_open_afternm (&ready_plaintext[0], &ready_box[0], clen,
ready_nonce, get_precom_buffer ());
if (rc != 0) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC);
errno = EPROTO;
return -1;
}
rc = parse_metadata (&ready_plaintext[crypto_box_ZEROBYTES],
clen - crypto_box_ZEROBYTES);
if (rc == 0)
_state = connected;
else {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_INVALID_METADATA);
errno = EPROTO;
}
return rc;
}
int zmq::curve_client_t::process_error (const uint8_t *msg_data_,
size_t msg_size_)
{
if (_state != expect_welcome && _state != expect_ready) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (), ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND);
errno = EPROTO;
return -1;
}
if (msg_size_ < 7) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (),
ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR);
errno = EPROTO;
return -1;
}
const size_t error_reason_len = static_cast<size_t> (msg_data_[6]);
if (error_reason_len > msg_size_ - 7) {
session->get_socket ()->event_handshake_failed_protocol (
session->get_endpoint (),
ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_ERROR);
errno = EPROTO;
return -1;
}
const char *error_reason = reinterpret_cast<const char *> (msg_data_) + 7;
handle_error_reason (error_reason, error_reason_len);
_state = error_received;
return 0;
}
#endif