CURVE: Implement client-side ERROR handling

This commit is contained in:
Martin Hurton 2014-05-15 06:38:17 +02:00
parent 0975be6ed7
commit 0750303bfe
2 changed files with 64 additions and 40 deletions

View File

@ -82,30 +82,31 @@ int zmq::curve_client_t::next_handshake_command (msg_t *msg_)
int zmq::curve_client_t::process_handshake_command (msg_t *msg_)
{
int rc = 0;
const unsigned char *msg_data =
static_cast <unsigned char *> (msg_->data ());
const size_t msg_size = msg_->size ();
switch (state) {
case expect_welcome:
rc = process_welcome (msg_);
if (rc == 0)
state = send_initiate;
break;
case expect_ready:
rc = process_ready (msg_);
if (rc == 0)
state = connected;
break;
default:
int rc = 0;
if (msg_size >= 8 && !memcmp (msg_data, "\7WELCOME", 8))
rc = process_welcome (msg_data, msg_size);
else
if (msg_size >= 6 && !memcmp (msg_data, "\5READY", 6))
rc = process_ready (msg_data, msg_size);
else
if (msg_size >= 6 && !memcmp (msg_data, "\5ERROR", 6))
rc = process_error (msg_data, msg_size);
else {
errno = EPROTO;
rc = -1;
break;
}
if (rc == 0) {
rc = msg_->close ();
errno_assert (rc == 0);
rc = msg_->init ();
errno_assert (rc == 0);
}
return rc;
}
@ -218,7 +219,13 @@ int zmq::curve_client_t::decode (msg_t *msg_)
zmq::mechanism_t::status_t zmq::curve_client_t::status () const
{
return state == connected? mechanism_t::ready: mechanism_t::handshaking;
if (state == connected)
return mechanism_t::ready;
else
if (state == error_received)
return mechanism_t::error;
else
return mechanism_t::handshaking;
}
int zmq::curve_client_t::produce_hello (msg_t *msg_)
@ -260,15 +267,10 @@ int zmq::curve_client_t::produce_hello (msg_t *msg_)
return 0;
}
int zmq::curve_client_t::process_welcome (msg_t *msg_)
int zmq::curve_client_t::process_welcome (
const uint8_t *msg_data, size_t msg_size)
{
if (msg_->size () != 168) {
errno = EPROTO;
return -1;
}
const uint8_t * welcome = static_cast <uint8_t *> (msg_->data ());
if (memcmp (welcome, "\x07WELCOME", 8)) {
if (msg_size != 168) {
errno = EPROTO;
return -1;
}
@ -279,10 +281,10 @@ int zmq::curve_client_t::process_welcome (msg_t *msg_)
// Open Box [S' + cookie](C'->S)
memset (welcome_box, 0, crypto_box_BOXZEROBYTES);
memcpy (welcome_box + crypto_box_BOXZEROBYTES, welcome + 24, 144);
memcpy (welcome_box + crypto_box_BOXZEROBYTES, msg_data + 24, 144);
memcpy (welcome_nonce, "WELCOME-", 8);
memcpy (welcome_nonce + 8, welcome + 8, 16);
memcpy (welcome_nonce + 8, msg_data + 8, 16);
int rc = crypto_box_open (welcome_plaintext, welcome_box,
sizeof welcome_box,
@ -299,6 +301,8 @@ int zmq::curve_client_t::process_welcome (msg_t *msg_)
rc = crypto_box_beforenm (cn_precom, cn_server, cn_secret);
zmq_assert (rc == 0);
state = send_initiate;
return 0;
}
@ -375,20 +379,15 @@ int zmq::curve_client_t::produce_initiate (msg_t *msg_)
return 0;
}
int zmq::curve_client_t::process_ready (msg_t *msg_)
int zmq::curve_client_t::process_ready (
const uint8_t *msg_data, size_t msg_size)
{
if (msg_->size () < 30) {
if (msg_size < 30) {
errno = EPROTO;
return -1;
}
const uint8_t *ready = static_cast <uint8_t *> (msg_->data ());
if (memcmp (ready, "\x05READY", 6)) {
errno = EPROTO;
return -1;
}
const size_t clen = (msg_->size () - 14) + crypto_box_BOXZEROBYTES;
const size_t clen = (msg_size - 14) + crypto_box_BOXZEROBYTES;
uint8_t ready_nonce [crypto_box_NONCEBYTES];
uint8_t ready_plaintext [crypto_box_ZEROBYTES + 256];
@ -396,10 +395,10 @@ int zmq::curve_client_t::process_ready (msg_t *msg_)
memset (ready_box, 0, crypto_box_BOXZEROBYTES);
memcpy (ready_box + crypto_box_BOXZEROBYTES,
ready + 14, clen - crypto_box_BOXZEROBYTES);
msg_data + 14, clen - crypto_box_BOXZEROBYTES);
memcpy (ready_nonce, "CurveZMQREADY---", 16);
memcpy (ready_nonce + 16, ready + 6, 8);
memcpy (ready_nonce + 16, msg_data + 6, 8);
int rc = crypto_box_open_afternm (ready_plaintext, ready_box,
clen, ready_nonce, cn_precom);
@ -411,7 +410,30 @@ int zmq::curve_client_t::process_ready (msg_t *msg_)
rc = parse_metadata (ready_plaintext + crypto_box_ZEROBYTES,
clen - crypto_box_ZEROBYTES);
if (rc == 0)
state = connected;
return rc;
}
int zmq::curve_client_t::process_error (
const uint8_t *msg_data, size_t msg_size)
{
if (state != expect_welcome && state != expect_ready) {
errno = EPROTO;
return -1;
}
if (msg_size < 7) {
errno = EPROTO;
return -1;
}
const size_t error_reason_len = static_cast <size_t> (msg_data [6]);
if (error_reason_len > msg_size - 7) {
errno = EPROTO;
return -1;
}
state = error_received;
return 0;
}
#endif

View File

@ -69,6 +69,7 @@ namespace zmq
expect_welcome,
send_initiate,
expect_ready,
error_received,
connected
};
@ -103,9 +104,10 @@ namespace zmq
uint64_t cn_nonce;
int produce_hello (msg_t *msg_);
int process_welcome (msg_t *msg_);
int process_welcome (const uint8_t *cmd_data, size_t data_size);
int produce_initiate (msg_t *msg_);
int process_ready (msg_t *msg_);
int process_ready (const uint8_t *cmd_data, size_t data_size);
int process_error (const uint8_t *cmd_data, size_t data_size);
mutex_t sync;
};