diff --git a/src/stream_engine_base.hpp b/src/stream_engine_base.hpp index 359bd04a..a063ece1 100644 --- a/src/stream_engine_base.hpp +++ b/src/stream_engine_base.hpp @@ -87,7 +87,8 @@ class stream_engine_base_t : public io_object_t, public i_engine int push_msg_to_session (msg_t *msg_); int pull_and_encode (msg_t *msg_); - int decode_and_push (msg_t *msg_); + virtual int decode_and_push (msg_t *msg_); + int push_one_then_decode_and_push (msg_t *msg_); void set_handshake_timer (); @@ -165,7 +166,6 @@ class stream_engine_base_t : public io_object_t, public i_engine void unplug (); int write_credential (msg_t *msg_); - int push_one_then_decode_and_push (msg_t *msg_); void mechanism_ready (); diff --git a/src/ws_decoder.cpp b/src/ws_decoder.cpp index 91837263..89bb7d66 100644 --- a/src/ws_decoder.cpp +++ b/src/ws_decoder.cpp @@ -80,10 +80,10 @@ int zmq::ws_decoder_t::opcode_ready (unsigned char const *) _msg_flags = msg_t::command; // TODO: set the command name to CLOSE break; case zmq::ws_protocol_t::opcode_ping: - _msg_flags = msg_t::ping; + _msg_flags = msg_t::ping | msg_t::command; break; case zmq::ws_protocol_t::opcode_pong: - _msg_flags = msg_t::pong; + _msg_flags = msg_t::pong | msg_t::command; break; default: return -1; diff --git a/src/ws_encoder.cpp b/src/ws_encoder.cpp index 5a6f8852..c8564d06 100644 --- a/src/ws_encoder.cpp +++ b/src/ws_encoder.cpp @@ -55,8 +55,13 @@ void zmq::ws_encoder_t::message_ready () { int offset = 0; - // TODO: it might be close/ping/pong, which should be different op code - _tmp_buf[offset++] = 0x82; // Final | binary + if (in_progress ()->is_ping ()) + _tmp_buf[offset++] = 0x80 | zmq::ws_protocol_t::opcode_ping; + else if (in_progress ()->is_pong ()) + _tmp_buf[offset++] = 0x80 | zmq::ws_protocol_t::opcode_pong; + else + _tmp_buf[offset++] = 0x82; // Final | binary + _tmp_buf[offset] = _must_mask ? 0x80 : 0x00; size_t size = in_progress ()->size (); diff --git a/src/ws_engine.cpp b/src/ws_engine.cpp index a8e6e074..884d7217 100644 --- a/src/ws_engine.cpp +++ b/src/ws_engine.cpp @@ -122,7 +122,8 @@ zmq::ws_engine_t::ws_engine_t (fd_t fd_, _header_name_position (0), _header_value_position (0), _header_upgrade_websocket (false), - _header_connection_upgrade (false) + _header_connection_upgrade (false), + _heartbeat_timeout (0) { memset (_websocket_key, 0, MAX_HEADER_VALUE_LENGTH + 1); memset (_websocket_accept, 0, MAX_HEADER_VALUE_LENGTH + 1); @@ -130,6 +131,12 @@ zmq::ws_engine_t::ws_engine_t (fd_t fd_, _next_msg = &ws_engine_t::next_handshake_command; _process_msg = &ws_engine_t::process_handshake_command; + + if (_options.heartbeat_interval > 0) { + _heartbeat_timeout = _options.heartbeat_timeout; + if (_heartbeat_timeout == -1) + _heartbeat_timeout = _options.heartbeat_interval; + } } zmq::ws_engine_t::~ws_engine_t () @@ -227,6 +234,13 @@ bool zmq::ws_engine_t::select_protocol (char *protocol_) &ws_engine_t::routing_id_msg); _process_msg = static_cast ( &ws_engine_t::process_routing_id_msg); + + // No mechanism in place, enabling heartbeat + if (_options.heartbeat_interval > 0 && !_has_heartbeat_timer) { + add_timer (_options.heartbeat_interval, heartbeat_ivl_timer_id); + _has_heartbeat_timer = true; + } + return true; } if (_options.mechanism == ZMQ_NULL @@ -902,6 +916,75 @@ bool zmq::ws_engine_t::client_handshake () return false; } +int zmq::ws_engine_t::decode_and_push (msg_t *msg_) +{ + zmq_assert (_mechanism != NULL); + + // with WS engine, ping and pong commands are control messages and should not go through any mechanism + if (msg_->is_ping () || msg_->is_pong ()) { + if (process_command_message (msg_) == -1) + return -1; + } else if (_mechanism->decode (msg_) == -1) + return -1; + + if (_has_timeout_timer) { + _has_timeout_timer = false; + cancel_timer (heartbeat_timeout_timer_id); + } + + if (msg_->flags () & msg_t::command && !msg_->is_ping () + && !msg_->is_pong ()) + process_command_message (msg_); + + if (_metadata) + msg_->set_metadata (_metadata); + if (session ()->push_msg (msg_) == -1) { + if (errno == EAGAIN) + _process_msg = &ws_engine_t::push_one_then_decode_and_push; + return -1; + } + return 0; +} + + +int zmq::ws_engine_t::produce_ping_message (msg_t *msg_) +{ + int rc = msg_->init (); + errno_assert (rc == 0); + msg_->set_flags (msg_t::command | msg_t::ping); + + _next_msg = &ws_engine_t::pull_and_encode; + if (!_has_timeout_timer && _heartbeat_timeout > 0) { + add_timer (_heartbeat_timeout, heartbeat_timeout_timer_id); + _has_timeout_timer = true; + } + + return rc; +} + + +int zmq::ws_engine_t::produce_pong_message (msg_t *msg_) +{ + int rc = msg_->init (); + errno_assert (rc == 0); + msg_->set_flags (msg_t::command | msg_t::pong); + + _next_msg = &ws_engine_t::pull_and_encode; + return rc; +} + + +int zmq::ws_engine_t::process_command_message (msg_t *msg_) +{ + if (msg_->is_ping ()) { + _next_msg = static_cast ( + &ws_engine_t::produce_pong_message); + out_event (); + } + + return 0; +} + static int encode_base64 (const unsigned char *in_, int in_len_, char *out_, int out_len_) { diff --git a/src/ws_engine.hpp b/src/ws_engine.hpp index 12d5e612..538faf9e 100644 --- a/src/ws_engine.hpp +++ b/src/ws_engine.hpp @@ -135,6 +135,10 @@ class ws_engine_t : public stream_engine_base_t ~ws_engine_t (); protected: + int decode_and_push (msg_t *msg_); + int process_command_message (msg_t *msg_); + int produce_pong_message (msg_t *msg_); + int produce_ping_message (msg_t *msg_); bool handshake (); void plug_internal (); void start_ws_handshake (); @@ -166,6 +170,8 @@ class ws_engine_t : public stream_engine_base_t char _websocket_protocol[256]; char _websocket_key[MAX_HEADER_VALUE_LENGTH + 1]; char _websocket_accept[MAX_HEADER_VALUE_LENGTH + 1]; + + int _heartbeat_timeout; }; } diff --git a/tests/test_ws_transport.cpp b/tests/test_ws_transport.cpp index 30c285a7..f1bc0452 100644 --- a/tests/test_ws_transport.cpp +++ b/tests/test_ws_transport.cpp @@ -52,6 +52,40 @@ void test_roundtrip () test_context_socket_close (sb); } +void test_heartbeat () +{ + char connect_address[MAX_SOCKET_STRING + strlen ("/heartbeat")]; + size_t addr_length = sizeof (connect_address); + void *sb = test_context_socket (ZMQ_REP); + TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (sb, "ws://*:*/heartbeat")); + TEST_ASSERT_SUCCESS_ERRNO ( + zmq_getsockopt (sb, ZMQ_LAST_ENDPOINT, connect_address, &addr_length)); + strcat (connect_address, "/heartbeat"); + + void *sc = test_context_socket (ZMQ_REQ); + + // Setting heartbeat settings + int ivl = 10; + TEST_ASSERT_SUCCESS_ERRNO ( + zmq_setsockopt (sc, ZMQ_HEARTBEAT_IVL, &ivl, sizeof (ivl))); + + // Disable reconnect, to make sure the ping-pong actually work + ivl = -1; + TEST_ASSERT_SUCCESS_ERRNO ( + zmq_setsockopt (sc, ZMQ_RECONNECT_IVL, &ivl, sizeof (ivl))); + + // Connect to server + TEST_ASSERT_SUCCESS_ERRNO (zmq_connect (sc, connect_address)); + + // Make sure some ping and pong going through + msleep (100); + + bounce (sb, sc); + + test_context_socket_close (sc); + test_context_socket_close (sb); +} + void test_short_message () { char connect_address[MAX_SOCKET_STRING + strlen ("/short")]; @@ -170,6 +204,7 @@ int main () RUN_TEST (test_roundtrip); RUN_TEST (test_short_message); RUN_TEST (test_large_message); + RUN_TEST (test_heartbeat); if (zmq_has ("curve")) RUN_TEST (test_curve);