diff --git a/src/curve_client.cpp b/src/curve_client.cpp index 09107654..49e7c471 100644 --- a/src/curve_client.cpp +++ b/src/curve_client.cpp @@ -41,8 +41,7 @@ zmq::curve_client_t::curve_client_t (session_base_t *session_, const options_t &options_) : - mechanism_t (options_), - session (session_), + mechanism_base_t (session_, options_), state (send_hello), tools (options_.curve_public_key, options_.curve_secret_key, @@ -166,20 +165,23 @@ int zmq::curve_client_t::encode (msg_t *msg_) int zmq::curve_client_t::decode (msg_t *msg_) { zmq_assert (state == connected); + int rc = check_basic_command_structure (msg_); + if (rc == -1) + return rc; - if (msg_->size () < 33) { + const uint8_t *message = static_cast (msg_->data ()); + if (msg_->size() < 8 || memcmp (message, "\x07MESSAGE", 8)) { session->get_socket ()->event_handshake_failed_protocol ( session->get_endpoint (), - ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE); // TODO message may not be a MESSAGE at all + ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND); errno = EPROTO; return -1; } - const uint8_t *message = static_cast (msg_->data ()); - if (memcmp (message, "\x07MESSAGE", 8)) { + if (msg_->size () < 33) { session->get_socket ()->event_handshake_failed_protocol ( session->get_endpoint (), - ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND); + ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE); errno = EPROTO; return -1; } @@ -209,8 +211,8 @@ int zmq::curve_client_t::decode (msg_t *msg_) memcpy (message_box + crypto_box_BOXZEROBYTES, message + 16, msg_->size () - 16); - int rc = crypto_box_open_afternm (message_plaintext, message_box, clen, - message_nonce, tools.cn_precom); + rc = crypto_box_open_afternm (message_plaintext, message_box, clen, + message_nonce, tools.cn_precom); if (rc == 0) { rc = msg_->close (); zmq_assert (rc == 0); diff --git a/src/curve_client.hpp b/src/curve_client.hpp index 9b2d2175..f5be67e7 100644 --- a/src/curve_client.hpp +++ b/src/curve_client.hpp @@ -42,7 +42,7 @@ namespace zmq class msg_t; class session_base_t; - class curve_client_t : public mechanism_t + class curve_client_t : public mechanism_base_t { public: @@ -67,8 +67,6 @@ namespace zmq connected }; - session_base_t *session; - // Current FSM state state_t state; diff --git a/src/curve_server.cpp b/src/curve_server.cpp index ec0d1b49..aa9f5425 100644 --- a/src/curve_server.cpp +++ b/src/curve_server.cpp @@ -41,7 +41,7 @@ zmq::curve_server_t::curve_server_t (session_base_t *session_, const std::string &peer_address_, const options_t &options_) : - mechanism_t (options_), + mechanism_base_t (session_, options_), zap_client_common_handshake_t ( session_, peer_address_, options_, sending_ready), cn_nonce (1), diff --git a/src/gssapi_client.cpp b/src/gssapi_client.cpp index 0f73db3c..942c25fa 100644 --- a/src/gssapi_client.cpp +++ b/src/gssapi_client.cpp @@ -40,7 +40,9 @@ #include "gssapi_client.hpp" #include "wire.hpp" -zmq::gssapi_client_t::gssapi_client_t (const options_t &options_) : +zmq::gssapi_client_t::gssapi_client_t (session_base_t *session_, + const options_t &options_) : + mechanism_base_t (session_, options_), gssapi_mechanism_base_t (options_), state (call_next_init), token_ptr (GSS_C_NO_BUFFER), diff --git a/src/gssapi_client.hpp b/src/gssapi_client.hpp index 0a89e149..6bfdf50c 100644 --- a/src/gssapi_client.hpp +++ b/src/gssapi_client.hpp @@ -38,13 +38,12 @@ namespace zmq { class msg_t; + class session_base_t; - class gssapi_client_t : - public gssapi_mechanism_base_t + class gssapi_client_t : public gssapi_mechanism_base_t { - public: - - gssapi_client_t (const options_t &options_); + public: + gssapi_client_t (session_base_t *session_, const options_t &options_); virtual ~gssapi_client_t (); // mechanism implementation diff --git a/src/gssapi_mechanism_base.cpp b/src/gssapi_mechanism_base.cpp index 78106e7b..65d9b124 100644 --- a/src/gssapi_mechanism_base.cpp +++ b/src/gssapi_mechanism_base.cpp @@ -40,8 +40,11 @@ #include "gssapi_mechanism_base.hpp" #include "wire.hpp" -zmq::gssapi_mechanism_base_t::gssapi_mechanism_base_t (const options_t & options_) : - mechanism_t(options_), +zmq::gssapi_mechanism_base_t::gssapi_mechanism_base_t ( + session_base_t *session_, + const std::string &peer_address_, + const options_t &options_) : + mechanism_base_t (session_, options_), send_tok (), recv_tok (), /// FIXME remove? in_buf (), diff --git a/src/gssapi_mechanism_base.hpp b/src/gssapi_mechanism_base.hpp index b8d157b8..03c947bf 100644 --- a/src/gssapi_mechanism_base.hpp +++ b/src/gssapi_mechanism_base.hpp @@ -49,14 +49,15 @@ namespace zmq /// For example, clients and servers both need to produce and /// process context-level GSSAPI tokens (via INITIATE commands) /// and per-message GSSAPI tokens (via MESSAGE commands). - class gssapi_mechanism_base_t: - public virtual mechanism_t + class gssapi_mechanism_base_t : public virtual mechanism_base_t { - public: - gssapi_mechanism_base_t (const options_t &options_); + public: + gssapi_mechanism_base_t (session_base_t *session_, + const std::string &peer_address_, + const options_t &options_); virtual ~gssapi_mechanism_base_t () = 0; - protected: + protected: // Produce a context-level GSSAPI token (INITIATE command) // during security context initialization. int produce_initiate (msg_t *msg_, void *data_, size_t data_len_); diff --git a/src/gssapi_server.cpp b/src/gssapi_server.cpp index d665c7e8..936c4499 100644 --- a/src/gssapi_server.cpp +++ b/src/gssapi_server.cpp @@ -45,7 +45,7 @@ zmq::gssapi_server_t::gssapi_server_t (session_base_t *session_, const std::string &peer_address_, const options_t &options_) : - mechanism_t (options_), + mechanism_base_t (session_, options_), gssapi_mechanism_base_t (options_), zap_client_t (session_, peer_address_, options_), state (recv_next_token), diff --git a/src/gssapi_server.hpp b/src/gssapi_server.hpp index 62a4ff66..5aa272da 100644 --- a/src/gssapi_server.hpp +++ b/src/gssapi_server.hpp @@ -45,19 +45,18 @@ namespace zmq : public gssapi_mechanism_base_t, public zap_client_t { public: + gssapi_server_t (session_base_t *session_, + const std::string &peer_address, + const options_t &options_); + virtual ~gssapi_server_t (); - gssapi_server_t (session_base_t *session_, - const std::string &peer_address, - const options_t &options_); - virtual ~gssapi_server_t (); - - // mechanism implementation - virtual int next_handshake_command (msg_t *msg_); - virtual int process_handshake_command (msg_t *msg_); - virtual int encode (msg_t *msg_); - virtual int decode (msg_t *msg_); - virtual int zap_msg_available (); - virtual status_t status () const; + // mechanism implementation + virtual int next_handshake_command (msg_t *msg_); + virtual int process_handshake_command (msg_t *msg_); + virtual int encode (msg_t *msg_); + virtual int decode (msg_t *msg_); + virtual int zap_msg_available (); + virtual status_t status () const; private: diff --git a/src/mechanism.cpp b/src/mechanism.cpp index a7b17d92..39117359 100644 --- a/src/mechanism.cpp +++ b/src/mechanism.cpp @@ -35,6 +35,7 @@ #include "msg.hpp" #include "err.hpp" #include "wire.hpp" +#include "session_base.hpp" zmq::mechanism_t::mechanism_t (const options_t &options_) : options (options_) @@ -281,3 +282,24 @@ bool zmq::mechanism_t::check_socket_type (const std::string& type_) const } return false; } + +zmq::mechanism_base_t::mechanism_base_t (session_base_t *const session_, + const options_t &options_) : + mechanism_t (options_), + session (session_) +{ + +} + +int zmq::mechanism_base_t::check_basic_command_structure (msg_t *msg_) +{ + if (msg_->size () <= 1 || msg_->size () <= ((uint8_t *) msg_->data ())[0]) { + session->get_socket ()->event_handshake_failed_protocol ( + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_UNSPECIFIED); + errno = EPROTO; + return -1; + } + return 0; +} + diff --git a/src/mechanism.hpp b/src/mechanism.hpp index 8b47d963..e07ce567 100644 --- a/src/mechanism.hpp +++ b/src/mechanism.hpp @@ -38,11 +38,12 @@ namespace zmq { + class msg_t; + class session_base_t; + // Abstract class representing security mechanism. // Different mechanism extends this class. - class msg_t; - class mechanism_t { public: @@ -146,6 +147,16 @@ namespace zmq bool check_socket_type (const std::string& type_) const; }; -} + class mechanism_base_t : public mechanism_t + { + protected: + mechanism_base_t (session_base_t *const session_, + const options_t &options_); + + session_base_t *const session; + + int check_basic_command_structure (msg_t *msg_); + }; + } #endif diff --git a/src/null_mechanism.cpp b/src/null_mechanism.cpp index c765bdee..8f181920 100644 --- a/src/null_mechanism.cpp +++ b/src/null_mechanism.cpp @@ -42,7 +42,7 @@ zmq::null_mechanism_t::null_mechanism_t (session_base_t *session_, const std::string &peer_address_, const options_t &options_) : - mechanism_t (options_), + mechanism_base_t (session_, options_), zap_client_t (session_, peer_address_, options_), ready_command_sent (false), error_command_sent (false), diff --git a/src/plain_server.cpp b/src/plain_server.cpp index f205c1ac..5eaf58aa 100644 --- a/src/plain_server.cpp +++ b/src/plain_server.cpp @@ -40,7 +40,7 @@ zmq::plain_server_t::plain_server_t (session_base_t *session_, const std::string &peer_address_, const options_t &options_) : - mechanism_t (options_), + mechanism_base_t (session_, options_), zap_client_common_handshake_t ( session_, peer_address_, options_, sending_welcome) { diff --git a/src/stream_engine.cpp b/src/stream_engine.cpp index d1731bbe..80f0e487 100644 --- a/src/stream_engine.cpp +++ b/src/stream_engine.cpp @@ -698,7 +698,8 @@ bool zmq::stream_engine_t::handshake () mechanism = new (std::nothrow) gssapi_server_t (session, peer_address, options); else - mechanism = new (std::nothrow) gssapi_client_t (options); + mechanism = + new (std::nothrow) gssapi_client_t (session, options); alloc_assert (mechanism); } #endif diff --git a/src/zap_client.cpp b/src/zap_client.cpp index 9a5b7145..1d820d37 100644 --- a/src/zap_client.cpp +++ b/src/zap_client.cpp @@ -38,8 +38,7 @@ namespace zmq zap_client_t::zap_client_t (session_base_t *const session_, const std::string &peer_address_, const options_t &options_) : - mechanism_t (options_), - session (session_), + mechanism_base_t (session_, options_), peer_address (peer_address_) { } @@ -248,24 +247,12 @@ void zap_client_t::handle_zap_status_code () session->get_endpoint (), status_code_numeric); } -int zap_client_t::check_basic_command_structure (msg_t *msg_) -{ - if (msg_->size () <= 1 || msg_->size () <= ((uint8_t *) msg_->data ())[0]) { - session->get_socket ()->event_handshake_failed_protocol ( - session->get_endpoint (), - ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_UNSPECIFIED); - errno = EPROTO; - return -1; - } - return 0; -} - zap_client_common_handshake_t::zap_client_common_handshake_t ( session_base_t *const session_, const std::string &peer_address_, const options_t &options_, state_t zap_reply_ok_state_) : - mechanism_t (options_), + mechanism_base_t (session_, options_), zap_client_t (session_, peer_address_, options_), state (waiting_for_hello), zap_reply_ok_state (zap_reply_ok_state_) diff --git a/src/zap_client.hpp b/src/zap_client.hpp index cb2837ed..c2652bb8 100644 --- a/src/zap_client.hpp +++ b/src/zap_client.hpp @@ -34,9 +34,7 @@ namespace zmq { -class session_base_t; - -class zap_client_t : public virtual mechanism_t +class zap_client_t : public virtual mechanism_base_t { public: zap_client_t (session_base_t *const session_, @@ -56,10 +54,7 @@ class zap_client_t : public virtual mechanism_t virtual int receive_and_process_zap_reply (); virtual void handle_zap_status_code (); - - int check_basic_command_structure (msg_t *msg_); protected: - session_base_t *const session; const std::string peer_address; // Status code as received from ZAP handler