diff --git a/src/gssapi_client.cpp b/src/gssapi_client.cpp index 942c25fa..68e33d9b 100644 --- a/src/gssapi_client.cpp +++ b/src/gssapi_client.cpp @@ -125,6 +125,9 @@ int zmq::gssapi_client_t::process_handshake_command (msg_t *msg_) } if (state != recv_next_token) { + session->get_socket ()->event_handshake_failed_protocol ( + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND); errno = EPROTO; return -1; } diff --git a/src/gssapi_mechanism_base.cpp b/src/gssapi_mechanism_base.cpp index 65d9b124..43d4f79b 100644 --- a/src/gssapi_mechanism_base.cpp +++ b/src/gssapi_mechanism_base.cpp @@ -128,8 +128,15 @@ int zmq::gssapi_mechanism_base_t::decode_message (msg_t *msg_) const uint8_t *ptr = static_cast (msg_->data ()); size_t bytes_left = msg_->size (); + int rc = check_basic_command_structure (msg_); + if (rc == -1) + return rc; + // Get command string if (bytes_left < 8 || memcmp (ptr, "\x07MESSAGE", 8)) { + session->get_socket ()->event_handshake_failed_protocol ( + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND); errno = EPROTO; return -1; } @@ -138,6 +145,9 @@ int zmq::gssapi_mechanism_base_t::decode_message (msg_t *msg_) // Get token length if (bytes_left < 4) { + session->get_socket ()->event_handshake_failed_protocol ( + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE); errno = EPROTO; return -1; } @@ -148,6 +158,9 @@ int zmq::gssapi_mechanism_base_t::decode_message (msg_t *msg_) // Get token value if (bytes_left < wrapped.length) { + session->get_socket ()->event_handshake_failed_protocol ( + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE); errno = EPROTO; return -1; } @@ -168,11 +181,16 @@ int zmq::gssapi_mechanism_base_t::decode_message (msg_t *msg_) maj_stat = gss_unwrap(&min_stat, context, &wrapped, &plaintext, &state, (gss_qop_t *) NULL); + // TODO I don't think it is a good idea to use zmq_assert here. If + // decryption fails, gss_unwrap returns GSS_S_BAD_SIG. This opens up + // to DoS attacks by clients! Instead, a + // ZMQ_PROTOCOL_ERROR_ZMTP_CRYPTOGRAPHIC event should be emitted. + zmq_assert(maj_stat == GSS_S_COMPLETE); zmq_assert(state); // Re-initialize msg_ for plaintext - int rc = msg_->close (); + rc = msg_->close (); zmq_assert (rc == 0); rc = msg_->init_size (plaintext.length-1); @@ -190,6 +208,9 @@ int zmq::gssapi_mechanism_base_t::decode_message (msg_t *msg_) free(wrapped.value); if (bytes_left > 0) { + session->get_socket ()->event_handshake_failed_protocol ( + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_MESSAGE); errno = EPROTO; return -1; } @@ -231,8 +252,15 @@ int zmq::gssapi_mechanism_base_t::process_initiate (msg_t *msg_, void **token_va const uint8_t *ptr = static_cast (msg_->data ()); size_t bytes_left = msg_->size (); + int rc = check_basic_command_structure (msg_); + if (rc == -1) + return rc; + // Get command string if (bytes_left < 9 || memcmp (ptr, "\x08INITIATE", 9)) { + session->get_socket ()->event_handshake_failed_protocol ( + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND); errno = EPROTO; return -1; } @@ -241,6 +269,9 @@ int zmq::gssapi_mechanism_base_t::process_initiate (msg_t *msg_, void **token_va // Get token length if (bytes_left < 4) { + session->get_socket ()->event_handshake_failed_protocol ( + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE); errno = EPROTO; return -1; } @@ -250,6 +281,9 @@ int zmq::gssapi_mechanism_base_t::process_initiate (msg_t *msg_, void **token_va // Get token value if (bytes_left < token_length_) { + session->get_socket ()->event_handshake_failed_protocol ( + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE); errno = EPROTO; return -1; } @@ -264,6 +298,9 @@ int zmq::gssapi_mechanism_base_t::process_initiate (msg_t *msg_, void **token_va } if (bytes_left > 0) { + session->get_socket ()->event_handshake_failed_protocol ( + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_MALFORMED_COMMAND_INITIATE); errno = EPROTO; return -1; } @@ -317,14 +354,28 @@ int zmq::gssapi_mechanism_base_t::process_ready (msg_t *msg_) const unsigned char *ptr = static_cast (msg_->data ()); size_t bytes_left = msg_->size (); + int rc = check_basic_command_structure (msg_); + if (rc == -1) + return rc; + if (bytes_left < 6 || memcmp (ptr, "\x05READY", 6)) { + session->get_socket ()->event_handshake_failed_protocol ( + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND); errno = EPROTO; return -1; } ptr += 6; bytes_left -= 6; - return parse_metadata (ptr, bytes_left); + rc = parse_metadata (ptr, bytes_left); + if (rc == -1) + session->get_socket ()->event_handshake_failed_protocol ( + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_INVALID_METADATA); + + return rc; } + const gss_OID zmq::gssapi_mechanism_base_t::convert_nametype (int zmq_nametype) { switch (zmq_nametype) { diff --git a/src/gssapi_server.cpp b/src/gssapi_server.cpp index 936c4499..dbb065d7 100644 --- a/src/gssapi_server.cpp +++ b/src/gssapi_server.cpp @@ -114,6 +114,9 @@ int zmq::gssapi_server_t::process_handshake_command (msg_t *msg_) } if (state != recv_next_token) { + session->get_socket ()->event_handshake_failed_protocol ( + session->get_endpoint (), + ZMQ_PROTOCOL_ERROR_ZMTP_UNEXPECTED_COMMAND); errno = EPROTO; return -1; }