diff --git a/src/curve_client.cpp b/src/curve_client.cpp index fae7414d..e0641990 100644 --- a/src/curve_client.cpp +++ b/src/curve_client.cpp @@ -101,6 +101,113 @@ int zmq::curve_client_t::process_handshake_message (msg_t *msg_) return rc; } +int zmq::curve_client_t::encode (msg_t *msg_) +{ + zmq_assert (state == connected); + + uint8_t flags = 0; + if (msg_->flags () & msg_t::more) + flags |= 0x01; + + uint8_t message_nonce [crypto_box_NONCEBYTES]; + memcpy (message_nonce, "CurveZMQMESSAGEC", 16); + memcpy (message_nonce + 16, &cn_nonce, 8); + + const size_t mlen = crypto_box_ZEROBYTES + 1 + msg_->size (); + + uint8_t *message_plaintext = static_cast (malloc (mlen)); + alloc_assert (message_plaintext); + + memset (message_plaintext, 0, crypto_box_ZEROBYTES); + message_plaintext [crypto_box_ZEROBYTES] = flags; + memcpy (message_plaintext + crypto_box_ZEROBYTES + 1, + msg_->data (), msg_->size ()); + + uint8_t *message_box = static_cast (malloc (mlen)); + alloc_assert (message_box); + + int rc = crypto_box_afternm (message_box, message_plaintext, + mlen, message_nonce, cn_precom); + zmq_assert (rc == 0); + + rc = msg_->close (); + zmq_assert (rc == 0); + + rc = msg_->init_size (16 + mlen - crypto_box_BOXZEROBYTES); + zmq_assert (rc == 0); + + uint8_t *message = static_cast (msg_->data ()); + + memcpy (message, "MESSAGE ", 8); + memcpy (message + 8, &cn_nonce, 8); + memcpy (message + 16, message_box + crypto_box_BOXZEROBYTES, + mlen - crypto_box_BOXZEROBYTES); + + free (message_plaintext); + free (message_box); + + cn_nonce++; + + return 0; +} + +int zmq::curve_client_t::decode (msg_t *msg_) +{ + zmq_assert (state == connected); + + if (msg_->size () < 33) { + errno = EPROTO; + return -1; + } + + const uint8_t *message = static_cast (msg_->data ()); + if (memcmp (message, "MESSAGE ", 8)) { + errno = EPROTO; + return -1; + } + + uint8_t message_nonce [crypto_box_NONCEBYTES]; + memcpy (message_nonce, "CurveZMQMESSAGES", 16); + memcpy (message_nonce + 16, message + 8, 8); + + const size_t clen = crypto_box_BOXZEROBYTES + (msg_->size () - 16); + + uint8_t *message_plaintext = static_cast (malloc (clen)); + alloc_assert (message_plaintext); + + uint8_t *message_box = static_cast (malloc (clen)); + alloc_assert (message_box); + + memset (message_box, 0, crypto_box_BOXZEROBYTES); + memcpy (message_box + crypto_box_BOXZEROBYTES, + message + 16, msg_->size () - 16); + + int rc = crypto_box_open_afternm (message_plaintext, message_box, + clen, message_nonce, cn_precom); + if (rc == 0) { + rc = msg_->close (); + zmq_assert (rc == 0); + + rc = msg_->init_size (clen - 1 - crypto_box_ZEROBYTES); + zmq_assert (rc == 0); + + const uint8_t flags = message_plaintext [crypto_box_ZEROBYTES]; + if (flags & 0x01) + msg_->set_flags (msg_t::more); + + memcpy (msg_->data (), + message_plaintext + crypto_box_ZEROBYTES + 1, + msg_->size ()); + } + else + errno = EPROTO; + + free (message_plaintext); + free (message_box); + + return rc; +} + bool zmq::curve_client_t::is_handshake_complete () const { return state == connected; diff --git a/src/curve_client.hpp b/src/curve_client.hpp index 254d94f0..66a071ff 100644 --- a/src/curve_client.hpp +++ b/src/curve_client.hpp @@ -52,6 +52,8 @@ namespace zmq // mechanism implementation virtual int next_handshake_message (msg_t *msg_); virtual int process_handshake_message (msg_t *msg_); + virtual int encode (msg_t *msg_); + virtual int decode (msg_t *msg_); virtual bool is_handshake_complete () const; private: diff --git a/src/curve_server.cpp b/src/curve_server.cpp index 87e56706..708694de 100644 --- a/src/curve_server.cpp +++ b/src/curve_server.cpp @@ -111,6 +111,113 @@ int zmq::curve_server_t::process_handshake_message (msg_t *msg_) return rc; } +int zmq::curve_server_t::encode (msg_t *msg_) +{ + zmq_assert (state == connected); + + const size_t mlen = crypto_box_ZEROBYTES + 1 + msg_->size (); + + uint8_t message_nonce [crypto_box_NONCEBYTES]; + memcpy (message_nonce, "CurveZMQMESSAGES", 16); + memcpy (message_nonce + 16, &cn_nonce, 8); + + uint8_t flags = 0; + if (msg_->flags () & msg_t::more) + flags |= 0x01; + + uint8_t *message_plaintext = static_cast (malloc (mlen)); + alloc_assert (message_plaintext); + + memset (message_plaintext, 0, crypto_box_ZEROBYTES); + message_plaintext [crypto_box_ZEROBYTES] = flags; + memcpy (message_plaintext + crypto_box_ZEROBYTES + 1, + msg_->data (), msg_->size ()); + + uint8_t *message_box = static_cast (malloc (mlen)); + alloc_assert (message_box); + + int rc = crypto_box_afternm (message_box, message_plaintext, + mlen, message_nonce, cn_precom); + zmq_assert (rc == 0); + + rc = msg_->close (); + zmq_assert (rc == 0); + + rc = msg_->init_size (16 + mlen - crypto_box_BOXZEROBYTES); + zmq_assert (rc == 0); + + uint8_t *message = static_cast (msg_->data ()); + + memcpy (message, "MESSAGE ", 8); + memcpy (message + 8, &cn_nonce, 8); + memcpy (message + 16, message_box + crypto_box_BOXZEROBYTES, + mlen - crypto_box_BOXZEROBYTES); + + free (message_plaintext); + free (message_box); + + cn_nonce++; + + return 0; +} + +int zmq::curve_server_t::decode (msg_t *msg_) +{ + zmq_assert (state == connected); + + if (msg_->size () < 33) { + errno = EPROTO; + return -1; + } + + const uint8_t *message = static_cast (msg_->data ()); + if (memcmp (message, "MESSAGE ", 8)) { + errno = EPROTO; + return -1; + } + + uint8_t message_nonce [crypto_box_NONCEBYTES]; + memcpy (message_nonce, "CurveZMQMESSAGEC", 16); + memcpy (message_nonce + 16, message + 8, 8); + + const size_t clen = crypto_box_BOXZEROBYTES + msg_->size () - 16; + + uint8_t *message_plaintext = static_cast (malloc (clen)); + alloc_assert (message_plaintext); + + uint8_t *message_box = static_cast (malloc (clen)); + alloc_assert (message_box); + + memset (message_box, 0, crypto_box_BOXZEROBYTES); + memcpy (message_box + crypto_box_BOXZEROBYTES, + message + 16, msg_->size () - 16); + + int rc = crypto_box_open_afternm (message_plaintext, message_box, + clen, message_nonce, cn_precom); + if (rc == 0) { + rc = msg_->close (); + zmq_assert (rc == 0); + + rc = msg_->init_size (clen - 1 - crypto_box_ZEROBYTES); + zmq_assert (rc == 0); + + const uint8_t flags = message_plaintext [crypto_box_ZEROBYTES]; + if (flags & 0x01) + msg_->set_flags (msg_t::more); + + memcpy (msg_->data (), + message_plaintext + crypto_box_ZEROBYTES + 1, + msg_->size ()); + } + else + errno = EPROTO; + + free (message_plaintext); + free (message_box); + + return rc; +} + int zmq::curve_server_t::zap_msg_available () { if (state != expect_zap_reply) { diff --git a/src/curve_server.hpp b/src/curve_server.hpp index 2d6b6cd0..1eafc3c7 100644 --- a/src/curve_server.hpp +++ b/src/curve_server.hpp @@ -56,6 +56,8 @@ namespace zmq // mechanism implementation virtual int next_handshake_message (msg_t *msg_); virtual int process_handshake_message (msg_t *msg_); + virtual int encode (msg_t *msg_); + virtual int decode (msg_t *msg_); virtual int zap_msg_available (); virtual bool is_handshake_complete () const; diff --git a/src/mechanism.hpp b/src/mechanism.hpp index a9b04172..a1127f30 100644 --- a/src/mechanism.hpp +++ b/src/mechanism.hpp @@ -46,6 +46,10 @@ namespace zmq // Process the handshake message received from the peer. virtual int process_handshake_message (msg_t *msg_) = 0; + virtual int encode (msg_t *msg_) { return 0; } + + virtual int decode (msg_t *msg_) { return 0; } + // Notifies mechanism about availability of ZAP message. virtual int zap_msg_available () { return 0; } diff --git a/src/stream_engine.cpp b/src/stream_engine.cpp index 16404e05..0df3ac0a 100644 --- a/src/stream_engine.cpp +++ b/src/stream_engine.cpp @@ -42,6 +42,8 @@ #include "v2_decoder.hpp" #include "null_mechanism.hpp" #include "plain_mechanism.hpp" +#include "curve_client.hpp" +#include "curve_server.hpp" #include "raw_decoder.hpp" #include "raw_encoder.hpp" #include "config.hpp" @@ -459,10 +461,18 @@ bool zmq::stream_engine_t::handshake () else { outpos [outsize++] = 0; // Minor version number memset (outpos + outsize, 0, 20); + + zmq_assert (options.mechanism == ZMQ_NULL + || options.mechanism == ZMQ_PLAIN + || options.mechanism == ZMQ_CURVE); + if (options.mechanism == ZMQ_NULL) memcpy (outpos + outsize, "NULL", 4); else + if (options.mechanism == ZMQ_PLAIN) memcpy (outpos + outsize, "PLAIN", 5); + else + memcpy (outpos + outsize, "CURVE", 5); outsize += 20; memset (outpos + outsize, 0, 32); outsize += 32; @@ -539,6 +549,16 @@ bool zmq::stream_engine_t::handshake () mechanism = new (std::nothrow) plain_mechanism_t (session, options); alloc_assert (mechanism); } +#ifdef HAVE_LIBSODIUM + else + if (memcmp (greeting_recv + 12, "CURVE\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) == 0) { + if (options.as_server) + mechanism = new (std::nothrow) curve_server_t (session, options); + else + mechanism = new (std::nothrow) curve_client_t (options); + alloc_assert (mechanism); + } +#endif else { error (); return false; @@ -643,8 +663,8 @@ void zmq::stream_engine_t::mechanism_ready () errno_assert (rc == 0); } - read_msg = &stream_engine_t::pull_msg_from_session; - write_msg = &stream_engine_t::push_msg_to_session; + read_msg = &stream_engine_t::pull_and_encode; + write_msg = &stream_engine_t::decode_and_push; } int zmq::stream_engine_t::pull_msg_from_session (msg_t *msg_) @@ -657,6 +677,39 @@ int zmq::stream_engine_t::push_msg_to_session (msg_t *msg_) return session->push_msg (msg_); } +int zmq::stream_engine_t::pull_and_encode (msg_t *msg_) +{ + zmq_assert (mechanism != NULL); + + if (session->pull_msg (msg_) == -1) + return -1; + if (mechanism->encode (msg_) == -1) + return -1; + return 0; +} + +int zmq::stream_engine_t::decode_and_push (msg_t *msg_) +{ + zmq_assert (mechanism != NULL); + + if (mechanism->decode (msg_) == -1) + return -1; + if (session->push_msg (msg_) == -1) { + if (errno == EAGAIN) + write_msg = &stream_engine_t::push_one_then_decode_and_push; + return -1; + } + return 0; +} + +int zmq::stream_engine_t::push_one_then_decode_and_push (msg_t *msg_) +{ + const int rc = session->push_msg (msg_); + if (rc == 0) + write_msg = &stream_engine_t::decode_and_push; + return rc; +} + int zmq::stream_engine_t::write_subscription_msg (msg_t *msg_) { msg_t subscription; diff --git a/src/stream_engine.hpp b/src/stream_engine.hpp index 1f337e02..d4d36b1e 100644 --- a/src/stream_engine.hpp +++ b/src/stream_engine.hpp @@ -102,6 +102,10 @@ namespace zmq int pull_msg_from_session (msg_t *msg_); int push_msg_to_session (msg_t *msg); + int pull_and_encode (msg_t *msg_); + int decode_and_push (msg_t *msg_); + int push_one_then_decode_and_push (msg_t *msg_); + void mechanism_ready (); int write_subscription_msg (msg_t *msg_);