diff --git a/src/dish.cpp b/src/dish.cpp index fe8ac42d..acfe2988 100644 --- a/src/dish.cpp +++ b/src/dish.cpp @@ -29,6 +29,7 @@ #include +#include "../include/zmq.h" #include "macros.hpp" #include "dish.hpp" #include "err.hpp" @@ -244,3 +245,66 @@ void zmq::dish_t::send_subscriptions (pipe_t *pipe_) pipe_->flush (); } + +zmq::dish_session_t::dish_session_t (io_thread_t *io_thread_, bool connect_, + socket_base_t *socket_, const options_t &options_, + address_t *addr_) : + session_base_t (io_thread_, connect_, socket_, options_, addr_), + state (group) +{ +} + +zmq::dish_session_t::~dish_session_t () +{ +} + +int zmq::dish_session_t::push_msg (msg_t *msg_) +{ + if (state == group) { + if ((msg_->flags() & msg_t::more) != msg_t::more) { + errno = EFAULT; + return -1; + } + + if (msg_->size() > ZMQ_GROUP_MAX_LENGTH) { + errno = EFAULT; + return -1; + } + + group_msg = *msg_; + state = body; + + int rc = msg_->init (); + errno_assert (rc == 0); + return 0; + } + else { + // Set the message group + int rc = msg_->set_group ((char*)group_msg.data (), group_msg. size()); + errno_assert (rc == 0); + + // We set the group, so we don't need the group_msg anymore + rc = group_msg.close (); + errno_assert (rc == 0); + + // Thread safe socket doesn't support multipart messages + if ((msg_->flags() & msg_t::more) == msg_t::more) { + errno = EFAULT; + return -1; + } + + // Push message to dish socket + rc = session_base_t::push_msg (msg_); + + if (rc == 0) + state = group; + + return rc; + } +} + +void zmq::dish_session_t::reset () +{ + session_base_t::reset (); + state = group; +} diff --git a/src/dish.hpp b/src/dish.hpp index d0f42d70..414083ee 100644 --- a/src/dish.hpp +++ b/src/dish.hpp @@ -71,7 +71,7 @@ namespace zmq int xleave (const char *group_); private: - // Send subscriptions to a pipe + // Send subscriptions to a pipe void send_subscriptions (pipe_t *pipe_); // Fair queueing object for inbound pipes. @@ -93,6 +93,32 @@ namespace zmq const dish_t &operator = (const dish_t&); }; + class dish_session_t : public session_base_t + { + public: + + dish_session_t (zmq::io_thread_t *io_thread_, bool connect_, + zmq::socket_base_t *socket_, const options_t &options_, + address_t *addr_); + ~dish_session_t (); + + // Overrides of the functions from session_base_t. + int push_msg (msg_t *msg_); + void reset (); + + private: + + enum { + group, + body + } state; + + msg_t group_msg; + + dish_session_t (const dish_session_t&); + const dish_session_t &operator = (const dish_session_t&); + }; + } #endif diff --git a/src/msg.cpp b/src/msg.cpp index 013d30b3..263c8e93 100644 --- a/src/msg.cpp +++ b/src/msg.cpp @@ -533,13 +533,19 @@ const char * zmq::msg_t::group () int zmq::msg_t::set_group (const char * group_) { - if (strlen (group_) > ZMQ_GROUP_MAX_LENGTH) + return set_group (group_, strlen (group_)); +} + +int zmq::msg_t::set_group (const char * group_, size_t length_) +{ + if (length_> ZMQ_GROUP_MAX_LENGTH) { errno = EINVAL; return -1; } - strcpy (u.base.group, group_); + strncpy (u.base.group, group_, length_); + u.base.group[length_] = '\0'; return 0; } diff --git a/src/msg.hpp b/src/msg.hpp index 3cad5ac0..330c9984 100644 --- a/src/msg.hpp +++ b/src/msg.hpp @@ -119,6 +119,7 @@ namespace zmq int reset_routing_id (); const char * group (); int set_group (const char* group_); + int set_group (const char*, size_t length); // After calling this function you can copy the message in POD-style // refs_ times. No need to call copy. diff --git a/src/radio.cpp b/src/radio.cpp index ee9e0697..03c6ba41 100644 --- a/src/radio.cpp +++ b/src/radio.cpp @@ -120,9 +120,8 @@ int zmq::radio_t::xsend (msg_t *msg_) std::pair range = subscriptions.equal_range (std::string(msg_->group ())); - for (subscriptions_t::iterator it = range.first; it != range.second; ++it) { + for (subscriptions_t::iterator it = range.first; it != range.second; ++it) dist.match (it-> second); - } int rc = dist.send_to_matching (msg_); @@ -145,3 +144,47 @@ bool zmq::radio_t::xhas_in () { return false; } + +zmq::radio_session_t::radio_session_t (io_thread_t *io_thread_, bool connect_, + socket_base_t *socket_, const options_t &options_, + address_t *addr_) : + session_base_t (io_thread_, connect_, socket_, options_, addr_), + state (group) +{ +} + +zmq::radio_session_t::~radio_session_t () +{ +} + +int zmq::radio_session_t::pull_msg (msg_t *msg_) +{ + if (state == group) { + int rc = session_base_t::pull_msg (&pending_msg); + if (rc != 0) + return rc; + + const char *group = pending_msg.group (); + int length = strlen (group); + + // First frame is the group + msg_->init_size (length); + msg_->set_flags (msg_t::more); + memcpy (msg_->data (), group, length); + + // Next status is the body + state = body; + return 0; + } + else { + *msg_ = pending_msg; + state = group; + return 0; + } +} + +void zmq::radio_session_t::reset () +{ + session_base_t::reset (); + state = group; +} diff --git a/src/radio.hpp b/src/radio.hpp index 52157660..596a8cd4 100644 --- a/src/radio.hpp +++ b/src/radio.hpp @@ -77,6 +77,30 @@ namespace zmq const radio_t &operator = (const radio_t&); }; + class radio_session_t : public session_base_t + { + public: + + radio_session_t (zmq::io_thread_t *io_thread_, bool connect_, + zmq::socket_base_t *socket_, const options_t &options_, + address_t *addr_); + ~radio_session_t (); + + // Overrides of the functions from session_base_t. + int pull_msg (msg_t *msg_); + void reset (); + private: + + enum { + group, + body + } state; + + msg_t pending_msg; + + radio_session_t (const radio_session_t&); + const radio_session_t &operator = (const radio_session_t&); + }; } #endif diff --git a/src/session_base.cpp b/src/session_base.cpp index 6d6ea5bd..515f2947 100644 --- a/src/session_base.cpp +++ b/src/session_base.cpp @@ -45,6 +45,8 @@ #include "ctx.hpp" #include "req.hpp" +#include "radio.hpp" +#include "dish.hpp" zmq::session_base_t *zmq::session_base_t::create (class io_thread_t *io_thread_, bool active_, class socket_base_t *socket_, const options_t &options_, @@ -56,6 +58,14 @@ zmq::session_base_t *zmq::session_base_t::create (class io_thread_t *io_thread_, s = new (std::nothrow) req_session_t (io_thread_, active_, socket_, options_, addr_); break; + case ZMQ_RADIO: + s = new (std::nothrow) radio_session_t (io_thread_, active_, + socket_, options_, addr_); + break; + case ZMQ_DISH: + s = new (std::nothrow) dish_session_t (io_thread_, active_, + socket_, options_, addr_); + break; case ZMQ_DEALER: case ZMQ_REP: case ZMQ_ROUTER: @@ -69,8 +79,6 @@ zmq::session_base_t *zmq::session_base_t::create (class io_thread_t *io_thread_, case ZMQ_STREAM: case ZMQ_SERVER: case ZMQ_CLIENT: - case ZMQ_RADIO: - case ZMQ_DISH: s = new (std::nothrow) session_base_t (io_thread_, active_, socket_, options_, addr_); break; diff --git a/tests/test_radio_dish.cpp b/tests/test_radio_dish.cpp index 43618093..5cc93c33 100644 --- a/tests/test_radio_dish.cpp +++ b/tests/test_radio_dish.cpp @@ -89,7 +89,7 @@ int main (void) void *radio = zmq_socket (ctx, ZMQ_RADIO); void *dish = zmq_socket (ctx, ZMQ_DISH); - int rc = zmq_bind (radio, "inproc://test-radio-dish"); + int rc = zmq_bind (radio, "tcp://127.0.0.1:5556"); assert (rc == 0); // Leaving a group which we didn't join @@ -113,7 +113,7 @@ int main (void) assert (rc == -1); // Connecting - rc = zmq_connect (dish, "inproc://test-radio-dish"); + rc = zmq_connect (dish, "tcp://127.0.0.1:5556"); assert (rc == 0); zmq_sleep (1);