Problem: ZMQ_PUB broken on ZMQ_WS

Solution: encode subscribe/cancel messages until there are appropriate
opcodes.
Regression introduced by 253e9dd27b

Fixes https://github.com/zeromq/libzmq/issues/4101
This commit is contained in:
Luca Boccassi 2020-12-23 13:02:14 +00:00
parent a49aa0d294
commit 41c4ce1817
2 changed files with 46 additions and 2 deletions

View File

@ -73,6 +73,9 @@ void zmq::ws_encoder_t::message_ready ()
size_t size = in_progress ()->size ();
if (_is_binary)
size++;
// TODO: create an opcode for subscribe/cancel
if (in_progress ()->is_subscribe () || in_progress ()->is_cancel ())
size++;
if (size <= 125)
_tmp_buf[offset++] |= static_cast<unsigned char> (size & 127);
@ -93,6 +96,7 @@ void zmq::ws_encoder_t::message_ready ()
offset += 4;
}
int mask_index = 0;
if (_is_binary) {
// Encode flags.
unsigned char protocol_flags = 0;
@ -102,9 +106,16 @@ void zmq::ws_encoder_t::message_ready ()
protocol_flags |= ws_protocol_t::command_flag;
_tmp_buf[offset++] =
_must_mask ? protocol_flags ^ _mask[0] : protocol_flags;
_must_mask ? protocol_flags ^ _mask[mask_index++] : protocol_flags;
}
// Encode the subscribe/cancel byte.
// TODO: remove once there is an opcode for subscribe/cancel
if (in_progress ()->is_subscribe ())
_tmp_buf[offset++] = _must_mask ? 1 ^ _mask[mask_index++] : 1;
else if (in_progress ()->is_cancel ())
_tmp_buf[offset++] = _must_mask ? 0 ^ _mask[mask_index++] : 0;
next_step (_tmp_buf, offset, &ws_encoder_t::size_ready, false);
}
@ -126,7 +137,12 @@ void zmq::ws_encoder_t::size_ready ()
dest = static_cast<unsigned char *> (_masked_msg.data ());
}
int mask_index = _is_binary ? 1 : 0;
int mask_index = 0;
if (_is_binary)
++mask_index;
// TODO: remove once there is an opcode for subscribe/cancel
if (in_progress ()->is_subscribe () || in_progress ()->is_cancel ())
++mask_index;
for (size_t i = 0; i < size; ++i, mask_index++)
dest[i] = src[i] ^ _mask[mask_index % 4];

View File

@ -271,6 +271,33 @@ void test_mask_shared_msg ()
test_context_socket_close (sb);
}
void test_pub_sub ()
{
char connect_address[MAX_SOCKET_STRING];
size_t addr_length = sizeof (connect_address);
void *sb = test_context_socket (ZMQ_XPUB);
TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (sb, "ws://127.0.0.1:*"));
TEST_ASSERT_SUCCESS_ERRNO (
zmq_getsockopt (sb, ZMQ_LAST_ENDPOINT, connect_address, &addr_length));
void *sc = test_context_socket (ZMQ_SUB);
TEST_ASSERT_SUCCESS_ERRNO (zmq_setsockopt (sc, ZMQ_SUBSCRIBE, "A", 1));
TEST_ASSERT_SUCCESS_ERRNO (zmq_setsockopt (sc, ZMQ_SUBSCRIBE, "B", 1));
TEST_ASSERT_SUCCESS_ERRNO (zmq_connect (sc, connect_address));
recv_string_expect_success (sb, "\1A", 0);
recv_string_expect_success (sb, "\1B", 0);
send_string_expect_success (sb, "A", 0);
send_string_expect_success (sb, "B", 0);
recv_string_expect_success (sc, "A", 0);
recv_string_expect_success (sc, "B", 0);
test_context_socket_close (sc);
test_context_socket_close (sb);
}
int main ()
{
@ -283,6 +310,7 @@ int main ()
RUN_TEST (test_large_message);
RUN_TEST (test_heartbeat);
RUN_TEST (test_mask_shared_msg);
RUN_TEST (test_pub_sub);
if (zmq_has ("curve"))
RUN_TEST (test_curve);