diff --git a/src/stream_engine.cpp b/src/stream_engine.cpp index 16404e05..0df3ac0a 100644 --- a/src/stream_engine.cpp +++ b/src/stream_engine.cpp @@ -42,6 +42,8 @@ #include "v2_decoder.hpp" #include "null_mechanism.hpp" #include "plain_mechanism.hpp" +#include "curve_client.hpp" +#include "curve_server.hpp" #include "raw_decoder.hpp" #include "raw_encoder.hpp" #include "config.hpp" @@ -459,10 +461,18 @@ bool zmq::stream_engine_t::handshake () else { outpos [outsize++] = 0; // Minor version number memset (outpos + outsize, 0, 20); + + zmq_assert (options.mechanism == ZMQ_NULL + || options.mechanism == ZMQ_PLAIN + || options.mechanism == ZMQ_CURVE); + if (options.mechanism == ZMQ_NULL) memcpy (outpos + outsize, "NULL", 4); else + if (options.mechanism == ZMQ_PLAIN) memcpy (outpos + outsize, "PLAIN", 5); + else + memcpy (outpos + outsize, "CURVE", 5); outsize += 20; memset (outpos + outsize, 0, 32); outsize += 32; @@ -539,6 +549,16 @@ bool zmq::stream_engine_t::handshake () mechanism = new (std::nothrow) plain_mechanism_t (session, options); alloc_assert (mechanism); } +#ifdef HAVE_LIBSODIUM + else + if (memcmp (greeting_recv + 12, "CURVE\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0", 20) == 0) { + if (options.as_server) + mechanism = new (std::nothrow) curve_server_t (session, options); + else + mechanism = new (std::nothrow) curve_client_t (options); + alloc_assert (mechanism); + } +#endif else { error (); return false; @@ -643,8 +663,8 @@ void zmq::stream_engine_t::mechanism_ready () errno_assert (rc == 0); } - read_msg = &stream_engine_t::pull_msg_from_session; - write_msg = &stream_engine_t::push_msg_to_session; + read_msg = &stream_engine_t::pull_and_encode; + write_msg = &stream_engine_t::decode_and_push; } int zmq::stream_engine_t::pull_msg_from_session (msg_t *msg_) @@ -657,6 +677,39 @@ int zmq::stream_engine_t::push_msg_to_session (msg_t *msg_) return session->push_msg (msg_); } +int zmq::stream_engine_t::pull_and_encode (msg_t *msg_) +{ + zmq_assert (mechanism != NULL); + + if (session->pull_msg (msg_) == -1) + return -1; + if (mechanism->encode (msg_) == -1) + return -1; + return 0; +} + +int zmq::stream_engine_t::decode_and_push (msg_t *msg_) +{ + zmq_assert (mechanism != NULL); + + if (mechanism->decode (msg_) == -1) + return -1; + if (session->push_msg (msg_) == -1) { + if (errno == EAGAIN) + write_msg = &stream_engine_t::push_one_then_decode_and_push; + return -1; + } + return 0; +} + +int zmq::stream_engine_t::push_one_then_decode_and_push (msg_t *msg_) +{ + const int rc = session->push_msg (msg_); + if (rc == 0) + write_msg = &stream_engine_t::decode_and_push; + return rc; +} + int zmq::stream_engine_t::write_subscription_msg (msg_t *msg_) { msg_t subscription; diff --git a/src/stream_engine.hpp b/src/stream_engine.hpp index 1f337e02..d4d36b1e 100644 --- a/src/stream_engine.hpp +++ b/src/stream_engine.hpp @@ -102,6 +102,10 @@ namespace zmq int pull_msg_from_session (msg_t *msg_); int push_msg_to_session (msg_t *msg); + int pull_and_encode (msg_t *msg_); + int decode_and_push (msg_t *msg_); + int push_one_then_decode_and_push (msg_t *msg_); + void mechanism_ready (); int write_subscription_msg (msg_t *msg_);