From 60ccf54fa632a9ff1547302f5dfe2ca6af5af3f7 Mon Sep 17 00:00:00 2001 From: Luca Boccassi Date: Sun, 3 May 2020 17:29:19 +0100 Subject: [PATCH] Problem: sub/cancel broken with CURVE Solution: handle downgrading sub/cancel messages in CURVE engine --- src/curve_client.cpp | 10 ++++-- src/curve_client.hpp | 4 ++- src/curve_mechanism_base.cpp | 46 +++++++++++++++++++++++---- src/curve_mechanism_base.hpp | 8 +++-- src/curve_server.cpp | 10 ++++-- src/curve_server.hpp | 3 +- src/ws_engine.cpp | 4 +-- src/zmtp_engine.cpp | 14 ++++---- src/zmtp_engine.hpp | 2 +- unittests/unittest_curve_encoding.cpp | 6 ++-- 10 files changed, 78 insertions(+), 29 deletions(-) diff --git a/src/curve_client.cpp b/src/curve_client.cpp index b4b7a380..5d1f64e8 100644 --- a/src/curve_client.cpp +++ b/src/curve_client.cpp @@ -41,10 +41,14 @@ #include "secure_allocator.hpp" zmq::curve_client_t::curve_client_t (session_base_t *session_, - const options_t &options_) : + const options_t &options_, + const bool downgrade_sub_) : mechanism_base_t (session_, options_), - curve_mechanism_base_t ( - session_, options_, "CurveZMQMESSAGEC", "CurveZMQMESSAGES"), + curve_mechanism_base_t (session_, + options_, + "CurveZMQMESSAGEC", + "CurveZMQMESSAGES", + downgrade_sub_), _state (send_hello), _tools (options_.curve_public_key, options_.curve_secret_key, diff --git a/src/curve_client.hpp b/src/curve_client.hpp index c85d1b78..9d51d3a0 100644 --- a/src/curve_client.hpp +++ b/src/curve_client.hpp @@ -44,7 +44,9 @@ class session_base_t; class curve_client_t ZMQ_FINAL : public curve_mechanism_base_t { public: - curve_client_t (session_base_t *session_, const options_t &options_); + curve_client_t (session_base_t *session_, + const options_t &options_, + const bool downgrade_sub_); ~curve_client_t () ZMQ_FINAL; // mechanism implementation diff --git a/src/curve_mechanism_base.cpp b/src/curve_mechanism_base.cpp index c040d899..6f0173e1 100644 --- a/src/curve_mechanism_base.cpp +++ b/src/curve_mechanism_base.cpp @@ -49,9 +49,11 @@ zmq::curve_mechanism_base_t::curve_mechanism_base_t ( session_base_t *session_, const options_t &options_, const char *encode_nonce_prefix_, - const char *decode_nonce_prefix_) : + const char *decode_nonce_prefix_, + const bool downgrade_sub_) : mechanism_base_t (session_, options_), - curve_encoding_t (encode_nonce_prefix_, decode_nonce_prefix_) + curve_encoding_t ( + encode_nonce_prefix_, decode_nonce_prefix_, downgrade_sub_) { } @@ -77,11 +79,13 @@ int zmq::curve_mechanism_base_t::decode (msg_t *msg_) } zmq::curve_encoding_t::curve_encoding_t (const char *encode_nonce_prefix_, - const char *decode_nonce_prefix_) : + const char *decode_nonce_prefix_, + const bool downgrade_sub_) : _encode_nonce_prefix (encode_nonce_prefix_), _decode_nonce_prefix (decode_nonce_prefix_), _cn_nonce (1), - _cn_peer_nonce (1) + _cn_peer_nonce (1), + _downgrade_sub (downgrade_sub_) { } @@ -133,15 +137,26 @@ int zmq::curve_encoding_t::check_validity (msg_t *msg_, int *error_event_code_) int zmq::curve_encoding_t::encode (msg_t *msg_) { + size_t sub_cancel_len = 0; uint8_t message_nonce[crypto_box_NONCEBYTES]; memcpy (message_nonce, _encode_nonce_prefix, nonce_prefix_len); put_uint64 (message_nonce + nonce_prefix_len, get_and_inc_nonce ()); + if (msg_->is_subscribe () || msg_->is_cancel ()) { + if (_downgrade_sub) + sub_cancel_len = 1; + else + sub_cancel_len = msg_->is_cancel () + ? zmq::msg_t::cancel_cmd_name_size + : zmq::msg_t::sub_cmd_name_size; + } + #ifdef ZMQ_HAVE_CRYPTO_BOX_EASY_FNS - const size_t mlen = flags_len + msg_->size (); + const size_t mlen = flags_len + sub_cancel_len + msg_->size (); std::vector message_plaintext (mlen); #else - const size_t mlen = crypto_box_ZEROBYTES + flags_len + msg_->size (); + const size_t mlen = + crypto_box_ZEROBYTES + flags_len + sub_cancel_len + msg_->size (); std::vector message_plaintext_with_zerobytes (mlen); uint8_t *const message_plaintext = &message_plaintext_with_zerobytes[crypto_box_ZEROBYTES]; @@ -153,10 +168,27 @@ int zmq::curve_encoding_t::encode (msg_t *msg_) const uint8_t flags = msg_->flags () & flag_mask; message_plaintext[0] = flags; + + // For backward compatibility subscribe/cancel command messages are not stored with + // the message flags, and are encoded in the encoder, so that messages for < 3.0 peers + // can be encoded in the "old" 0/1 way rather than as commands. + if (sub_cancel_len == 1) + message_plaintext[flags_len] = msg_->is_subscribe () ? 1 : 0; + else if (sub_cancel_len == zmq::msg_t::sub_cmd_name_size) { + message_plaintext[0] |= zmq::msg_t::command; + memcpy (&message_plaintext[flags_len], zmq::sub_cmd_name, + zmq::msg_t::sub_cmd_name_size); + } else if (sub_cancel_len == zmq::msg_t::cancel_cmd_name_size) { + message_plaintext[0] |= zmq::msg_t::command; + memcpy (&message_plaintext[flags_len], zmq::cancel_cmd_name, + zmq::msg_t::cancel_cmd_name_size); + } + // this is copying the data from insecure memory, so there is no point in // using secure_allocator_t for message_plaintext if (msg_->size () > 0) - memcpy (&message_plaintext[flags_len], msg_->data (), msg_->size ()); + memcpy (&message_plaintext[flags_len + sub_cancel_len], msg_->data (), + msg_->size ()); #ifdef ZMQ_HAVE_CRYPTO_BOX_EASY_FNS msg_t msg_box; diff --git a/src/curve_mechanism_base.hpp b/src/curve_mechanism_base.hpp index a341e744..a72965e9 100644 --- a/src/curve_mechanism_base.hpp +++ b/src/curve_mechanism_base.hpp @@ -56,7 +56,8 @@ class curve_encoding_t { public: curve_encoding_t (const char *encode_nonce_prefix_, - const char *decode_nonce_prefix_); + const char *decode_nonce_prefix_, + const bool downgrade_sub_); int encode (msg_t *msg_); int decode (msg_t *msg_, int *error_event_code_); @@ -81,6 +82,8 @@ class curve_encoding_t // Intermediary buffer used to speed up boxing and unboxing. uint8_t _cn_precom[crypto_box_BEFORENMBYTES]; + const bool _downgrade_sub; + ZMQ_NON_COPYABLE_NOR_MOVABLE (curve_encoding_t) }; @@ -91,7 +94,8 @@ class curve_mechanism_base_t : public virtual mechanism_base_t, curve_mechanism_base_t (session_base_t *session_, const options_t &options_, const char *encode_nonce_prefix_, - const char *decode_nonce_prefix_); + const char *decode_nonce_prefix_, + const bool downgrade_sub_); // mechanism implementation int encode (msg_t *msg_) ZMQ_OVERRIDE; diff --git a/src/curve_server.cpp b/src/curve_server.cpp index 91ac85b9..6e34221e 100644 --- a/src/curve_server.cpp +++ b/src/curve_server.cpp @@ -41,12 +41,16 @@ zmq::curve_server_t::curve_server_t (session_base_t *session_, const std::string &peer_address_, - const options_t &options_) : + const options_t &options_, + const bool downgrade_sub_) : mechanism_base_t (session_, options_), zap_client_common_handshake_t ( session_, peer_address_, options_, sending_ready), - curve_mechanism_base_t ( - session_, options_, "CurveZMQMESSAGES", "CurveZMQMESSAGEC") + curve_mechanism_base_t (session_, + options_, + "CurveZMQMESSAGES", + "CurveZMQMESSAGEC", + downgrade_sub_) { int rc; // Fetch our secret key from socket options diff --git a/src/curve_server.hpp b/src/curve_server.hpp index 67b85e67..995efce0 100644 --- a/src/curve_server.hpp +++ b/src/curve_server.hpp @@ -48,7 +48,8 @@ class curve_server_t ZMQ_FINAL : public zap_client_common_handshake_t, public: curve_server_t (session_base_t *session_, const std::string &peer_address_, - const options_t &options_); + const options_t &options_, + const bool downgrade_sub_); ~curve_server_t (); // mechanism implementation diff --git a/src/ws_engine.cpp b/src/ws_engine.cpp index 26404905..20d0f8d4 100644 --- a/src/ws_engine.cpp +++ b/src/ws_engine.cpp @@ -268,10 +268,10 @@ bool zmq::ws_engine_t::select_protocol (const char *protocol_) && strcmp ("ZWS2.0/CURVE", protocol_) == 0) { if (_options.as_server) _mechanism = new (std::nothrow) - curve_server_t (session (), _peer_address, _options); + curve_server_t (session (), _peer_address, _options, false); else _mechanism = - new (std::nothrow) curve_client_t (session (), _options); + new (std::nothrow) curve_client_t (session (), _options, false); alloc_assert (_mechanism); return true; } diff --git a/src/zmtp_engine.cpp b/src/zmtp_engine.cpp index d9a186b1..a8615d4c 100644 --- a/src/zmtp_engine.cpp +++ b/src/zmtp_engine.cpp @@ -347,7 +347,7 @@ bool zmq::zmtp_engine_t::handshake_v2_0 () return true; } -bool zmq::zmtp_engine_t::handshake_v3_x () +bool zmq::zmtp_engine_t::handshake_v3_x (const bool downgrade_sub_) { if (_options.mechanism == ZMQ_NULL && memcmp (_greeting_recv + 12, "NULL\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", @@ -374,11 +374,11 @@ bool zmq::zmtp_engine_t::handshake_v3_x () "CURVE\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) == 0) { if (_options.as_server) - _mechanism = new (std::nothrow) - curve_server_t (session (), _peer_address, _options); + _mechanism = new (std::nothrow) curve_server_t ( + session (), _peer_address, _options, downgrade_sub_); else - _mechanism = - new (std::nothrow) curve_client_t (session (), _options); + _mechanism = new (std::nothrow) + curve_client_t (session (), _options, downgrade_sub_); alloc_assert (_mechanism); } #endif @@ -418,7 +418,7 @@ bool zmq::zmtp_engine_t::handshake_v3_0 () _options.in_batch_size, _options.maxmsgsize, _options.zero_copy); alloc_assert (_decoder); - return zmq::zmtp_engine_t::handshake_v3_x (); + return zmq::zmtp_engine_t::handshake_v3_x (true); } bool zmq::zmtp_engine_t::handshake_v3_1 () @@ -430,7 +430,7 @@ bool zmq::zmtp_engine_t::handshake_v3_1 () _options.in_batch_size, _options.maxmsgsize, _options.zero_copy); alloc_assert (_decoder); - return zmq::zmtp_engine_t::handshake_v3_x (); + return zmq::zmtp_engine_t::handshake_v3_x (false); } int zmq::zmtp_engine_t::routing_id_msg (msg_t *msg_) diff --git a/src/zmtp_engine.hpp b/src/zmtp_engine.hpp index 2d18bcfa..10f4134c 100644 --- a/src/zmtp_engine.hpp +++ b/src/zmtp_engine.hpp @@ -92,7 +92,7 @@ class zmtp_engine_t ZMQ_FINAL : public stream_engine_base_t bool handshake_v1_0_unversioned (); bool handshake_v1_0 (); bool handshake_v2_0 (); - bool handshake_v3_x (); + bool handshake_v3_x (bool downgrade_sub); bool handshake_v3_0 (); bool handshake_v3_1 (); diff --git a/unittests/unittest_curve_encoding.cpp b/unittests/unittest_curve_encoding.cpp index 7cccd06e..a331476e 100644 --- a/unittests/unittest_curve_encoding.cpp +++ b/unittests/unittest_curve_encoding.cpp @@ -48,9 +48,11 @@ void test_roundtrip (zmq::msg_t *msg_) + msg_->size ()); zmq::curve_encoding_t encoding_client ("CurveZMQMESSAGEC", - "CurveZMQMESSAGES"); + "CurveZMQMESSAGES", + false); zmq::curve_encoding_t encoding_server ("CurveZMQMESSAGES", - "CurveZMQMESSAGEC"); + "CurveZMQMESSAGEC", + false); uint8_t client_public[32]; uint8_t client_secret[32];