diff --git a/src/stream_engine.cpp b/src/stream_engine.cpp index c411e708..d24de27e 100644 --- a/src/stream_engine.cpp +++ b/src/stream_engine.cpp @@ -41,6 +41,8 @@ #include "config.hpp" #include "err.hpp" #include "ip.hpp" +#include "likely.hpp" +#include "wire.hpp" zmq::stream_engine_t::stream_engine_t (fd_t fd_, const options_t &options_, const std::string &endpoint_) : s (fd_), @@ -51,6 +53,9 @@ zmq::stream_engine_t::stream_engine_t (fd_t fd_, const options_t &options_, cons outpos (NULL), outsize (0), encoder (out_batch_size), + handshaking (true), + greeting_bytes_read (0), + greeting_size (0), session (NULL), options (options_), endpoint (endpoint_), @@ -112,13 +117,27 @@ void zmq::stream_engine_t::plug (io_thread_t *io_thread_, // Connect to session object. zmq_assert (!session); zmq_assert (session_); - encoder.set_session (session_); - decoder.set_session (session_); session = session_; // Connect to I/O threads poller object. io_object_t::plug (io_thread_); handle = add_fd (s); + + // We need to detect whether our peer is using the versioned + // protocol. The detection is done in two steps. First, we read + // first two bytes and check if the long format of length is in use. + // If so, we receive and check the 'flags' field. If the rightmost bit + // is 1, the peer is using versioned protocol. + greeting_size = 2; + + // Send the 'length' and 'flags' fields of the identity message. + // The 'length' field is encoded in the long format. + outpos = greeting_output_buffer; + outpos [outsize++] = 0xff; + put_uint64 (&outpos [outsize], options.identity_size + 1); + outsize += 8; + outpos [outsize++] = 0x7f; + set_pollin (handle); set_pollout (handle); // Flush all the data that may have been already received downstream. @@ -150,6 +169,11 @@ void zmq::stream_engine_t::terminate () void zmq::stream_engine_t::in_event () { + // If still handshaking, receive and prcess the greeting message. + if (unlikely (handshaking)) + if (!handshake ()) + return; + bool disconnection = false; // If there's no data to process in the buffer... @@ -235,6 +259,12 @@ void zmq::stream_engine_t::out_event () outpos += nbytes; outsize -= nbytes; + + // If we are still handshaking and there are no data + // to send, stop polling for output. + if (unlikely (handshaking)) + if (outsize == 0) + reset_pollout (handle); } void zmq::stream_engine_t::activate_out () @@ -267,6 +297,123 @@ void zmq::stream_engine_t::activate_in () in_event (); } +int zmq::stream_engine_t::receive_greeting () +{ + zmq_assert (greeting_bytes_read < greeting_size); + + while (greeting_bytes_read < greeting_size) { + const int n = read (greeting + greeting_bytes_read, + greeting_size - greeting_bytes_read); + if (n == -1) + return -1; + if (n == 0) + return 0; + + greeting_bytes_read += n; + + if (greeting_bytes_read < greeting_size) + continue; + + if (greeting_size == 2) { + // We have received the first two bytes from the peer. + // If the first byte is not 0xff, we know that the + // peer is using unversioned protocol. + if (greeting [0] != 0xff) + break; + + // This may still be a long identity message (either + // 254 or 255 bytes long). We need to receive 8 more + // bytes so we can inspect the potential 'flags' field. + greeting_size = 10; + } + else + if (greeting_size == 10) { + // Inspect the rightmost bit of the 10th byte (which coincides + // with the 'flags' field if a regular message was sent). + // Zero indicates this is a header of identity message + // (i.e. the peer is using the unversioned protocol). + if (!(greeting [9] & 0x01)) + break; + + // This is truly a handshake and we can now send the rest of + // the greeting message out. + + if (outsize == 0) + set_pollout (handle); + + zmq_assert (outpos != NULL); + + outpos [outsize++] = 1; // Protocol version + outpos [outsize++] = 1; // Remaining length (1 byte for v1) + outpos [outsize++] = options.type; // Socket type + + // Read the 'version' and 'remaining_length' fields. + greeting_size = 12; + } + else + if (greeting_size == 12) { + // We have received the greeting message up to + // the 'remaining_length' field. Receive the remaining + // bytes of the greeting. + greeting_size += greeting [11]; + } + } + + return 0; +} + +bool zmq::stream_engine_t::handshake () +{ + zmq_assert (handshaking); + zmq_assert (greeting_bytes_read < greeting_size); + + int rc = receive_greeting (); + if (rc == -1) { + error (); + return false; + } + + if (greeting_bytes_read < greeting_size) + return false; + + // We have received either a header of identity message + // or the whole greeting. + + encoder.set_session (session); + decoder.set_session (session); + + zmq_assert (greeting [0] != 0xff || greeting_bytes_read >= 10); + + // Is the peer using the unversioned protocol? + // If so, we send and receive rests of identity + // messages. + if (greeting [0] != 0xff || !(greeting [9] & 0x01)) { + // We have already sent the message header. + // Since there is no way to tell the encoder to + // skip the message header, we simply throw that + // header data away. + const size_t header_size = options.identity_size + 1 >= 255 ? 10 : 2; + unsigned char tmp [10], *bufferp = tmp; + size_t buffer_size = header_size; + encoder.get_data (&bufferp, &buffer_size); + zmq_assert (buffer_size == header_size); + + // Make sure the decoder sees the data we have already received. + inpos = greeting; + insize = greeting_bytes_read; + } + + // Start polling for output if necessary. + if (outsize == 0) + set_pollout (handle); + + // Handshaking was successful. + // Switch into the normal message flow. + handshaking = false; + + return true; +} + void zmq::stream_engine_t::error () { zmq_assert (session); diff --git a/src/stream_engine.hpp b/src/stream_engine.hpp index f4d15b6d..90b9221f 100644 --- a/src/stream_engine.hpp +++ b/src/stream_engine.hpp @@ -67,6 +67,12 @@ namespace zmq // Function to handle network disconnections. void error (); + // Receives the greeting message from the peer. + int receive_greeting (); + + // Detects the protocol used by the peer. + bool handshake (); + // Writes data to the socket. Returns the number of bytes actually // written (even zero is to be considered to be a success). In case // of error or orderly shutdown by the other peer -1 is returned. @@ -81,6 +87,16 @@ namespace zmq // Underlying socket. fd_t s; + // Maximum size of a greeting message: + // preamble (10 bytes) + version (1 byte) + remaining_length (1 byte) + + // up to 255 remaining bytes. + const static size_t maximum_greeting_size = 10 + 1 + 1 + 255; + + // Size of v1 greeting message: + // preamble (10 bytes) + version (1 byte) + remaining_length (1 byte) + + // socket_type (1) + const static size_t v1_greeting_size = 10 + 1 + 1 + 1; + handle_t handle; unsigned char *inpos; @@ -92,6 +108,26 @@ namespace zmq size_t outsize; encoder_t encoder; + // When true, we are still trying to determine whether + // the peer is using versioned protocol, and if so, which + // version. When false, normal message flow has started. + bool handshaking; + + // The receive buffer holding the greeting message + // that we are receiving from the peer. + unsigned char greeting [maximum_greeting_size]; + + // The number of bytes of the greeting message that + // we have already received. + unsigned int greeting_bytes_read; + + // The size of the greeting message. + unsigned int greeting_size; + + // The send buffer holding the greeting message + // that we are sending to the peer. + unsigned char greeting_output_buffer [v1_greeting_size]; + // The session this engine is attached to. zmq::session_base_t *session;