diff --git a/include/zmq.h b/include/zmq.h index 8c39fd07..c3c2d0c8 100644 --- a/include/zmq.h +++ b/include/zmq.h @@ -397,7 +397,7 @@ ZMQ_EXPORT const char *zmq_msg_gets (const zmq_msg_t *msg_, #define ZMQ_GSSAPI 3 /* RADIO-DISH protocol */ -#define ZMQ_GROUP_MAX_LENGTH 15 +#define ZMQ_GROUP_MAX_LENGTH 255 /* Deprecated options and aliases */ #define ZMQ_IDENTITY ZMQ_ROUTING_ID diff --git a/src/dish.cpp b/src/dish.cpp index f3aca49c..d8a4befc 100644 --- a/src/dish.cpp +++ b/src/dish.cpp @@ -233,7 +233,6 @@ void zmq::dish_t::send_subscriptions (pipe_t *pipe_) // Send it to the pipe. pipe_->write (&msg); - msg.close (); } pipe_->flush (); diff --git a/src/dist.cpp b/src/dist.cpp index 6ac27045..7795bb36 100644 --- a/src/dist.cpp +++ b/src/dist.cpp @@ -178,9 +178,7 @@ void zmq::dist_t::distribute (msg_t *msg_) ++i; } } - int rc = msg_->close (); - errno_assert (rc == 0); - rc = msg_->init (); + int rc = msg_->init (); errno_assert (rc == 0); return; } diff --git a/src/msg.cpp b/src/msg.cpp index 612a18b2..0871fbd2 100644 --- a/src/msg.cpp +++ b/src/msg.cpp @@ -79,7 +79,8 @@ int zmq::msg_t::init () _u.vsm.type = type_vsm; _u.vsm.flags = 0; _u.vsm.size = 0; - _u.vsm.group[0] = '\0'; + _u.vsm.group.sgroup.group[0] = '\0'; + _u.vsm.group.type = group_type_short; _u.vsm.routing_id = 0; return 0; } @@ -91,13 +92,15 @@ int zmq::msg_t::init_size (size_t size_) _u.vsm.type = type_vsm; _u.vsm.flags = 0; _u.vsm.size = static_cast (size_); - _u.vsm.group[0] = '\0'; + _u.vsm.group.sgroup.group[0] = '\0'; + _u.vsm.group.type = group_type_short; _u.vsm.routing_id = 0; } else { _u.lmsg.metadata = NULL; _u.lmsg.type = type_lmsg; _u.lmsg.flags = 0; - _u.lmsg.group[0] = '\0'; + _u.lmsg.group.sgroup.group[0] = '\0'; + _u.lmsg.group.type = group_type_short; _u.lmsg.routing_id = 0; _u.lmsg.content = NULL; if (sizeof (content_t) + size_ > size_) @@ -129,7 +132,8 @@ int zmq::msg_t::init_external_storage (content_t *content_, _u.zclmsg.metadata = NULL; _u.zclmsg.type = type_zclmsg; _u.zclmsg.flags = 0; - _u.zclmsg.group[0] = '\0'; + _u.zclmsg.group.sgroup.group[0] = '\0'; + _u.zclmsg.group.type = group_type_short; _u.zclmsg.routing_id = 0; _u.zclmsg.content = content_; @@ -158,13 +162,15 @@ int zmq::msg_t::init_data (void *data_, _u.cmsg.flags = 0; _u.cmsg.data = data_; _u.cmsg.size = size_; - _u.cmsg.group[0] = '\0'; + _u.cmsg.group.sgroup.group[0] = '\0'; + _u.cmsg.group.type = group_type_short; _u.cmsg.routing_id = 0; } else { _u.lmsg.metadata = NULL; _u.lmsg.type = type_lmsg; _u.lmsg.flags = 0; - _u.lmsg.group[0] = '\0'; + _u.lmsg.group.sgroup.group[0] = '\0'; + _u.lmsg.group.type = group_type_short; _u.lmsg.routing_id = 0; _u.lmsg.content = static_cast (malloc (sizeof (content_t))); @@ -187,7 +193,8 @@ int zmq::msg_t::init_delimiter () _u.delimiter.metadata = NULL; _u.delimiter.type = type_delimiter; _u.delimiter.flags = 0; - _u.delimiter.group[0] = '\0'; + _u.delimiter.group.sgroup.group[0] = '\0'; + _u.delimiter.group.type = group_type_short; _u.delimiter.routing_id = 0; return 0; } @@ -197,7 +204,8 @@ int zmq::msg_t::init_join () _u.base.metadata = NULL; _u.base.type = type_join; _u.base.flags = 0; - _u.base.group[0] = '\0'; + _u.base.group.sgroup.group[0] = '\0'; + _u.base.group.type = group_type_short; _u.base.routing_id = 0; return 0; } @@ -207,7 +215,8 @@ int zmq::msg_t::init_leave () _u.base.metadata = NULL; _u.base.type = type_leave; _u.base.flags = 0; - _u.base.group[0] = '\0'; + _u.base.group.sgroup.group[0] = '\0'; + _u.base.group.type = group_type_short; _u.base.routing_id = 0; return 0; } @@ -259,6 +268,16 @@ int zmq::msg_t::close () _u.base.metadata = NULL; } + if (_u.base.group.type == group_type_long) { + if (!_u.base.group.lgroup.content->refcnt.sub (1)) { + // We used "placement new" operator to initialize the reference + // counter so we call the destructor explicitly now. + _u.base.group.lgroup.content->refcnt.~atomic_counter_t (); + + free (_u.base.group.lgroup.content); + } + } + // Make the message invalid. _u.base.type = 0; @@ -316,6 +335,9 @@ int zmq::msg_t::copy (msg_t &src_) if (src_._u.base.metadata != NULL) src_._u.base.metadata->add_ref (); + if (src_._u.base.group.type == group_type_long) + src_._u.base.group.lgroup.content->refcnt.add (1); + *this = src_; return 0; @@ -600,12 +622,16 @@ int zmq::msg_t::reset_routing_id () const char *zmq::msg_t::group () const { - return _u.base.group; + if (_u.base.group.type == group_type_long) + return _u.base.group.lgroup.content->group; + return _u.base.group.sgroup.group; } int zmq::msg_t::set_group (const char *group_) { - return set_group (group_, ZMQ_GROUP_MAX_LENGTH); + size_t length = strnlen (group_, ZMQ_GROUP_MAX_LENGTH); + + return set_group (group_, length); } int zmq::msg_t::set_group (const char *group_, size_t length_) @@ -615,8 +641,19 @@ int zmq::msg_t::set_group (const char *group_, size_t length_) return -1; } - strncpy (_u.base.group, group_, length_); - _u.base.group[length_] = '\0'; + if (length_ > 14) { + _u.base.group.lgroup.type = group_type_long; + _u.base.group.lgroup.content = + (long_group_t *) malloc (sizeof (long_group_t)); + assert (_u.base.group.lgroup.content); + new (&_u.base.group.lgroup.content->refcnt) zmq::atomic_counter_t (); + _u.base.group.lgroup.content->refcnt.set (1); + strncpy (_u.base.group.lgroup.content->group, group_, length_); + _u.base.group.lgroup.content->group[length_] = '\0'; + } else { + strncpy (_u.base.group.sgroup.group, group_, length_); + _u.base.group.sgroup.group[length_] = '\0'; + } return 0; } diff --git a/src/msg.hpp b/src/msg.hpp index c9c8ec56..46447315 100644 --- a/src/msg.hpp +++ b/src/msg.hpp @@ -46,7 +46,7 @@ // Note that it has to be declared as "C" so that it is the same as // zmq_free_fn defined in zmq.h. extern "C" { -typedef void(msg_free_fn) (void *data_, void *hint_); +typedef void (msg_free_fn) (void *data_, void *hint_); } namespace zmq @@ -210,6 +210,33 @@ class msg_t type_max = 107 }; + enum group_type_t + { + group_type_short, + group_type_long + }; + + struct long_group_t + { + char group[ZMQ_GROUP_MAX_LENGTH + 1]; + atomic_counter_t refcnt; + }; + + union group_t + { + unsigned char type; + struct + { + unsigned char type; + char group[15]; + } sgroup; + struct + { + unsigned char type; + long_group_t *content; + } lgroup; + }; + // Note that fields shared between different message types are not // moved to the parent class (msg_t). This way we get tighter packing // of the data. Shared fields can be accessed via 'base' member of @@ -219,13 +246,13 @@ class msg_t struct { metadata_t *metadata; - unsigned char - unused[msg_t_size - - (sizeof (metadata_t *) + 2 + 16 + sizeof (uint32_t))]; + unsigned char unused[msg_t_size + - (sizeof (metadata_t *) + 2 + + sizeof (uint32_t) + sizeof (group_t))]; unsigned char type; unsigned char flags; - char group[16]; uint32_t routing_id; + group_t group; } base; struct { @@ -234,57 +261,59 @@ class msg_t unsigned char size; unsigned char type; unsigned char flags; - char group[16]; uint32_t routing_id; + group_t group; } vsm; struct { metadata_t *metadata; content_t *content; - unsigned char unused[msg_t_size - - (sizeof (metadata_t *) + sizeof (content_t *) - + 2 + 16 + sizeof (uint32_t))]; + unsigned char + unused[msg_t_size + - (sizeof (metadata_t *) + sizeof (content_t *) + 2 + + sizeof (uint32_t) + sizeof (group_t))]; unsigned char type; unsigned char flags; - char group[16]; uint32_t routing_id; + group_t group; } lmsg; struct { metadata_t *metadata; content_t *content; - unsigned char unused[msg_t_size - - (sizeof (metadata_t *) + sizeof (content_t *) - + 2 + 16 + sizeof (uint32_t))]; + unsigned char + unused[msg_t_size + - (sizeof (metadata_t *) + sizeof (content_t *) + 2 + + sizeof (uint32_t) + sizeof (group_t))]; unsigned char type; unsigned char flags; - char group[16]; uint32_t routing_id; + group_t group; } zclmsg; struct { metadata_t *metadata; void *data; size_t size; - unsigned char - unused[msg_t_size - - (sizeof (metadata_t *) + sizeof (void *) - + sizeof (size_t) + 2 + 16 + sizeof (uint32_t))]; + unsigned char unused[msg_t_size + - (sizeof (metadata_t *) + sizeof (void *) + + sizeof (size_t) + 2 + sizeof (uint32_t) + + sizeof (group_t))]; unsigned char type; unsigned char flags; - char group[16]; uint32_t routing_id; + group_t group; } cmsg; struct { metadata_t *metadata; - unsigned char - unused[msg_t_size - - (sizeof (metadata_t *) + 2 + 16 + sizeof (uint32_t))]; + unsigned char unused[msg_t_size + - (sizeof (metadata_t *) + 2 + + sizeof (uint32_t) + sizeof (group_t))]; unsigned char type; unsigned char flags; - char group[16]; uint32_t routing_id; + group_t group; } delimiter; } _u; }; diff --git a/tests/test_radio_dish.cpp b/tests/test_radio_dish.cpp index 54cb0494..3c3fb18e 100644 --- a/tests/test_radio_dish.cpp +++ b/tests/test_radio_dish.cpp @@ -96,6 +96,35 @@ void test_leave_unjoined_fails () test_context_socket_close (dish); } +void test_long_group () +{ + size_t len = MAX_SOCKET_STRING; + char my_endpoint[MAX_SOCKET_STRING]; + + void *radio = test_context_socket (ZMQ_RADIO); + bind_loopback (radio, false, my_endpoint, len); + + void *dish = test_context_socket (ZMQ_DISH); + + // Joining to a long group, over 14 chars + char group[19] = "0123456789ABCDEFGH"; + TEST_ASSERT_SUCCESS_ERRNO (zmq_join (dish, group)); + + // Connecting + TEST_ASSERT_SUCCESS_ERRNO (zmq_connect (dish, my_endpoint)); + + msleep (SETTLE_TIME); + + // This is going to be sent to the dish + msg_send_expect_success (radio, group, "HELLO"); + + // Check the correct message arrived + msg_recv_cmp (dish, group, "HELLO"); + + test_context_socket_close (dish); + test_context_socket_close (radio); +} + void test_join_too_long_fails () { void *dish = test_context_socket (ZMQ_DISH); @@ -505,6 +534,7 @@ int main (void) UNITY_BEGIN (); RUN_TEST (test_leave_unjoined_fails); RUN_TEST (test_join_too_long_fails); + RUN_TEST (test_long_group); RUN_TEST (test_join_twice_fails); RUN_TEST (test_radio_bind_fails_ipv4); RUN_TEST (test_radio_bind_fails_ipv6);