diff --git a/src/ws_encoder.cpp b/src/ws_encoder.cpp index dfccb8bb..bc0b38c7 100644 --- a/src/ws_encoder.cpp +++ b/src/ws_encoder.cpp @@ -73,6 +73,9 @@ void zmq::ws_encoder_t::message_ready () size_t size = in_progress ()->size (); if (_is_binary) size++; + // TODO: create an opcode for subscribe/cancel + if (in_progress ()->is_subscribe () || in_progress ()->is_cancel ()) + size++; if (size <= 125) _tmp_buf[offset++] |= static_cast (size & 127); @@ -93,6 +96,7 @@ void zmq::ws_encoder_t::message_ready () offset += 4; } + int mask_index = 0; if (_is_binary) { // Encode flags. unsigned char protocol_flags = 0; @@ -102,9 +106,16 @@ void zmq::ws_encoder_t::message_ready () protocol_flags |= ws_protocol_t::command_flag; _tmp_buf[offset++] = - _must_mask ? protocol_flags ^ _mask[0] : protocol_flags; + _must_mask ? protocol_flags ^ _mask[mask_index++] : protocol_flags; } + // Encode the subscribe/cancel byte. + // TODO: remove once there is an opcode for subscribe/cancel + if (in_progress ()->is_subscribe ()) + _tmp_buf[offset++] = _must_mask ? 1 ^ _mask[mask_index++] : 1; + else if (in_progress ()->is_cancel ()) + _tmp_buf[offset++] = _must_mask ? 0 ^ _mask[mask_index++] : 0; + next_step (_tmp_buf, offset, &ws_encoder_t::size_ready, false); } @@ -126,7 +137,12 @@ void zmq::ws_encoder_t::size_ready () dest = static_cast (_masked_msg.data ()); } - int mask_index = _is_binary ? 1 : 0; + int mask_index = 0; + if (_is_binary) + ++mask_index; + // TODO: remove once there is an opcode for subscribe/cancel + if (in_progress ()->is_subscribe () || in_progress ()->is_cancel ()) + ++mask_index; for (size_t i = 0; i < size; ++i, mask_index++) dest[i] = src[i] ^ _mask[mask_index % 4]; diff --git a/tests/test_ws_transport.cpp b/tests/test_ws_transport.cpp index dbb6aed9..49b810fe 100644 --- a/tests/test_ws_transport.cpp +++ b/tests/test_ws_transport.cpp @@ -271,6 +271,33 @@ void test_mask_shared_msg () test_context_socket_close (sb); } +void test_pub_sub () +{ + char connect_address[MAX_SOCKET_STRING]; + size_t addr_length = sizeof (connect_address); + void *sb = test_context_socket (ZMQ_XPUB); + TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (sb, "ws://127.0.0.1:*")); + TEST_ASSERT_SUCCESS_ERRNO ( + zmq_getsockopt (sb, ZMQ_LAST_ENDPOINT, connect_address, &addr_length)); + + void *sc = test_context_socket (ZMQ_SUB); + TEST_ASSERT_SUCCESS_ERRNO (zmq_setsockopt (sc, ZMQ_SUBSCRIBE, "A", 1)); + TEST_ASSERT_SUCCESS_ERRNO (zmq_setsockopt (sc, ZMQ_SUBSCRIBE, "B", 1)); + TEST_ASSERT_SUCCESS_ERRNO (zmq_connect (sc, connect_address)); + + recv_string_expect_success (sb, "\1A", 0); + recv_string_expect_success (sb, "\1B", 0); + + send_string_expect_success (sb, "A", 0); + send_string_expect_success (sb, "B", 0); + + recv_string_expect_success (sc, "A", 0); + recv_string_expect_success (sc, "B", 0); + + test_context_socket_close (sc); + test_context_socket_close (sb); +} + int main () { @@ -283,6 +310,7 @@ int main () RUN_TEST (test_large_message); RUN_TEST (test_heartbeat); RUN_TEST (test_mask_shared_msg); + RUN_TEST (test_pub_sub); if (zmq_has ("curve")) RUN_TEST (test_curve);