diff --git a/src/curve_client.cpp b/src/curve_client.cpp index cdaefec1..1dd1dd61 100644 --- a/src/curve_client.cpp +++ b/src/curve_client.cpp @@ -82,30 +82,31 @@ int zmq::curve_client_t::next_handshake_command (msg_t *msg_) int zmq::curve_client_t::process_handshake_command (msg_t *msg_) { - int rc = 0; + const unsigned char *msg_data = + static_cast (msg_->data ()); + const size_t msg_size = msg_->size (); - switch (state) { - case expect_welcome: - rc = process_welcome (msg_); - if (rc == 0) - state = send_initiate; - break; - case expect_ready: - rc = process_ready (msg_); - if (rc == 0) - state = connected; - break; - default: - errno = EPROTO; - rc = -1; - break; + int rc = 0; + if (msg_size >= 8 && !memcmp (msg_data, "\7WELCOME", 8)) + rc = process_welcome (msg_data, msg_size); + else + if (msg_size >= 6 && !memcmp (msg_data, "\5READY", 6)) + rc = process_ready (msg_data, msg_size); + else + if (msg_size >= 6 && !memcmp (msg_data, "\5ERROR", 6)) + rc = process_error (msg_data, msg_size); + else { + errno = EPROTO; + rc = -1; } + if (rc == 0) { rc = msg_->close (); errno_assert (rc == 0); rc = msg_->init (); errno_assert (rc == 0); } + return rc; } @@ -218,7 +219,13 @@ int zmq::curve_client_t::decode (msg_t *msg_) zmq::mechanism_t::status_t zmq::curve_client_t::status () const { - return state == connected? mechanism_t::ready: mechanism_t::handshaking; + if (state == connected) + return mechanism_t::ready; + else + if (state == error_received) + return mechanism_t::error; + else + return mechanism_t::handshaking; } int zmq::curve_client_t::produce_hello (msg_t *msg_) @@ -260,15 +267,10 @@ int zmq::curve_client_t::produce_hello (msg_t *msg_) return 0; } -int zmq::curve_client_t::process_welcome (msg_t *msg_) +int zmq::curve_client_t::process_welcome ( + const uint8_t *msg_data, size_t msg_size) { - if (msg_->size () != 168) { - errno = EPROTO; - return -1; - } - - const uint8_t * welcome = static_cast (msg_->data ()); - if (memcmp (welcome, "\x07WELCOME", 8)) { + if (msg_size != 168) { errno = EPROTO; return -1; } @@ -279,10 +281,10 @@ int zmq::curve_client_t::process_welcome (msg_t *msg_) // Open Box [S' + cookie](C'->S) memset (welcome_box, 0, crypto_box_BOXZEROBYTES); - memcpy (welcome_box + crypto_box_BOXZEROBYTES, welcome + 24, 144); + memcpy (welcome_box + crypto_box_BOXZEROBYTES, msg_data + 24, 144); memcpy (welcome_nonce, "WELCOME-", 8); - memcpy (welcome_nonce + 8, welcome + 8, 16); + memcpy (welcome_nonce + 8, msg_data + 8, 16); int rc = crypto_box_open (welcome_plaintext, welcome_box, sizeof welcome_box, @@ -299,6 +301,8 @@ int zmq::curve_client_t::process_welcome (msg_t *msg_) rc = crypto_box_beforenm (cn_precom, cn_server, cn_secret); zmq_assert (rc == 0); + state = send_initiate; + return 0; } @@ -375,20 +379,15 @@ int zmq::curve_client_t::produce_initiate (msg_t *msg_) return 0; } -int zmq::curve_client_t::process_ready (msg_t *msg_) +int zmq::curve_client_t::process_ready ( + const uint8_t *msg_data, size_t msg_size) { - if (msg_->size () < 30) { + if (msg_size < 30) { errno = EPROTO; return -1; } - const uint8_t *ready = static_cast (msg_->data ()); - if (memcmp (ready, "\x05READY", 6)) { - errno = EPROTO; - return -1; - } - - const size_t clen = (msg_->size () - 14) + crypto_box_BOXZEROBYTES; + const size_t clen = (msg_size - 14) + crypto_box_BOXZEROBYTES; uint8_t ready_nonce [crypto_box_NONCEBYTES]; uint8_t ready_plaintext [crypto_box_ZEROBYTES + 256]; @@ -396,10 +395,10 @@ int zmq::curve_client_t::process_ready (msg_t *msg_) memset (ready_box, 0, crypto_box_BOXZEROBYTES); memcpy (ready_box + crypto_box_BOXZEROBYTES, - ready + 14, clen - crypto_box_BOXZEROBYTES); + msg_data + 14, clen - crypto_box_BOXZEROBYTES); memcpy (ready_nonce, "CurveZMQREADY---", 16); - memcpy (ready_nonce + 16, ready + 6, 8); + memcpy (ready_nonce + 16, msg_data + 6, 8); int rc = crypto_box_open_afternm (ready_plaintext, ready_box, clen, ready_nonce, cn_precom); @@ -411,7 +410,30 @@ int zmq::curve_client_t::process_ready (msg_t *msg_) rc = parse_metadata (ready_plaintext + crypto_box_ZEROBYTES, clen - crypto_box_ZEROBYTES); + if (rc == 0) + state = connected; + return rc; } +int zmq::curve_client_t::process_error ( + const uint8_t *msg_data, size_t msg_size) +{ + if (state != expect_welcome && state != expect_ready) { + errno = EPROTO; + return -1; + } + if (msg_size < 7) { + errno = EPROTO; + return -1; + } + const size_t error_reason_len = static_cast (msg_data [6]); + if (error_reason_len > msg_size - 7) { + errno = EPROTO; + return -1; + } + state = error_received; + return 0; +} + #endif diff --git a/src/curve_client.hpp b/src/curve_client.hpp index 45f47026..14769fda 100644 --- a/src/curve_client.hpp +++ b/src/curve_client.hpp @@ -69,6 +69,7 @@ namespace zmq expect_welcome, send_initiate, expect_ready, + error_received, connected }; @@ -103,9 +104,10 @@ namespace zmq uint64_t cn_nonce; int produce_hello (msg_t *msg_); - int process_welcome (msg_t *msg_); + int process_welcome (const uint8_t *cmd_data, size_t data_size); int produce_initiate (msg_t *msg_); - int process_ready (msg_t *msg_); + int process_ready (const uint8_t *cmd_data, size_t data_size); + int process_error (const uint8_t *cmd_data, size_t data_size); mutex_t sync; };