diff --git a/tests/poller.cpp b/tests/poller.cpp index 6012620..983fb62 100644 --- a/tests/poller.cpp +++ b/tests/poller.cpp @@ -113,7 +113,6 @@ TEST(poller, remove_unregistered_throws) ASSERT_THROW(poller.remove(socket), zmq::error_t); } -/// \todo this should lead to an exception instead TEST(poller, remove_registered_empty) { zmq::context_t context; @@ -156,19 +155,37 @@ private: std::string endpoint_; }; +struct server_client_setup +{ + server_client_setup () + { + init (); + } + + void init() + { + endpoint = loopback_ip4_binder {server}.endpoint (); + ASSERT_NO_THROW (client.connect (endpoint)); + } + + zmq::poller_t::handler_t handler = [&](short e) { + events = e; + }; + + zmq::context_t context; + zmq::socket_t server {context, zmq::socket_type::router}; + zmq::socket_t client {context, zmq::socket_type::dealer}; + std::string endpoint; + short events = 0; +}; + } //namespace TEST(poller, poll_basic) { - zmq::context_t context; + server_client_setup s; - zmq::socket_t vent{context, zmq::socket_type::push}; - auto endpoint = loopback_ip4_binder(vent).endpoint(); - - zmq::socket_t sink{context, zmq::socket_type::pull}; - ASSERT_NO_THROW(sink.connect(endpoint)); - - ASSERT_NO_THROW(vent.send("Hi")); + ASSERT_NO_THROW(s.client.send("Hi")); zmq::poller_t poller; bool message_received = false; @@ -176,7 +193,7 @@ TEST(poller, poll_basic) ASSERT_TRUE(0 != (events & ZMQ_POLLIN)); message_received = true; }; - ASSERT_NO_THROW(poller.add(sink, ZMQ_POLLIN, handler)); + ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN, handler)); ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{-1})); ASSERT_TRUE(message_received); } @@ -184,48 +201,41 @@ TEST(poller, poll_basic) /// \todo this contains multiple test cases that should be split up TEST(poller, client_server) { - zmq::context_t context; const std::string send_msg = "Hi"; - // Setup server - zmq::socket_t server{context, zmq::socket_type::router}; - auto endpoint = loopback_ip4_binder(server).endpoint(); + // Setup server and client + server_client_setup s; // Setup poller zmq::poller_t poller; - bool got_pollin = false; - bool got_pollout = false; - zmq::poller_t::handler_t handler = [&](short events) { - if (0 != (events & ZMQ_POLLIN)) { + short events; + zmq::poller_t::handler_t handler = [&](short e) { + if (0 != (e & ZMQ_POLLIN)) { zmq::message_t zmq_msg; - ASSERT_NO_THROW(server.recv(&zmq_msg)); // skip msg id - ASSERT_NO_THROW(server.recv(&zmq_msg)); // get message + ASSERT_NO_THROW(s.server.recv(&zmq_msg)); // skip msg id + ASSERT_NO_THROW(s.server.recv(&zmq_msg)); // get message std::string recv_msg(zmq_msg.data(), zmq_msg.size()); ASSERT_EQ(send_msg, recv_msg); - got_pollin = true; - } else if (0 != (events & ZMQ_POLLOUT)) { - got_pollout = true; - } else { + } else if (0 != (e & ~ZMQ_POLLOUT)) { ASSERT_TRUE(false) << "Unexpected event type " << events; } + events = e; }; - ASSERT_NO_THROW(poller.add(server, ZMQ_POLLIN, handler)); - // Setup client and send message - zmq::socket_t client{context, zmq::socket_type::dealer}; - ASSERT_NO_THROW(client.connect(endpoint)); - ASSERT_NO_THROW(client.send(send_msg)); + ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN, handler)); + + // client sends message + ASSERT_NO_THROW(s.client.send(send_msg)); ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{-1})); - ASSERT_TRUE(got_pollin); - ASSERT_FALSE(got_pollout); + ASSERT_EQ(events, ZMQ_POLLIN); // Re-add server socket with pollout flag - ASSERT_NO_THROW(poller.remove(server)); - ASSERT_NO_THROW(poller.add(server, ZMQ_POLLIN | ZMQ_POLLOUT, handler)); + ASSERT_NO_THROW(poller.remove(s.server)); + ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN | ZMQ_POLLOUT, handler)); ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{-1})); - ASSERT_TRUE(got_pollout); + ASSERT_EQ(events, ZMQ_POLLOUT); } TEST(poller, poller_add_invalid_socket_throws) @@ -254,17 +264,62 @@ TEST(poller, poller_remove_invalid_socket_throws) TEST(poller, wait_on_added_empty_handler) { - zmq::context_t context; - zmq::socket_t vent{context, zmq::socket_type::push}; - auto endpoint = loopback_ip4_binder(vent).endpoint(); - - zmq::socket_t sink{context, zmq::socket_type::pull}; - ASSERT_NO_THROW(sink.connect(endpoint)); - ASSERT_NO_THROW(vent.send("Hi")); - + server_client_setup s; + ASSERT_NO_THROW(s.client.send("Hi")); zmq::poller_t poller; std::function handler; - ASSERT_NO_THROW(poller.add(sink, ZMQ_POLLIN, handler)); + ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN, handler)); ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{-1})); } + +TEST(poller, modify_empty_throws) +{ + zmq::context_t context; + zmq::socket_t socket {context, zmq::socket_type::push}; + zmq::poller_t poller; + ASSERT_THROW (poller.modify (socket, ZMQ_POLLIN), zmq::error_t); +} + +TEST(poller, modify_invalid_socket_throws) +{ + zmq::context_t context; + zmq::socket_t a {context, zmq::socket_type::push}; + zmq::socket_t b {std::move (a)}; + zmq::poller_t poller; + ASSERT_THROW (poller.modify (a, ZMQ_POLLIN), zmq::error_t); +} + +TEST(poller, modified_not_added_throws) +{ + zmq::context_t context; + zmq::socket_t a {context, zmq::socket_type::push}; + zmq::socket_t b {context, zmq::socket_type::push}; + zmq::poller_t poller; + ASSERT_NO_THROW (poller.add (a, ZMQ_POLLIN, zmq::poller_t::handler_t {})); + ASSERT_THROW (poller.modify (b, ZMQ_POLLIN), zmq::error_t); +} + +TEST(poller, poll_client_server) +{ + // Setup server and client + server_client_setup s; + + // Setup poller + zmq::poller_t poller; + ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN, s.handler)); + + // client sends message + ASSERT_NO_THROW(s.client.send("Hi")); + + // wait for message and verify events + ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{500})); + ASSERT_TRUE(s.events == ZMQ_POLLIN); + ASSERT_EQ(s.events, ZMQ_POLLIN); + + // Modify server socket with pollout flag + ASSERT_NO_THROW(poller.modify(s.server, ZMQ_POLLIN | ZMQ_POLLOUT)); + ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{500})); + ASSERT_EQ(s.events, ZMQ_POLLIN | ZMQ_POLLOUT); +} + #endif diff --git a/zmq.hpp b/zmq.hpp index ecbca26..4a49699 100644 --- a/zmq.hpp +++ b/zmq.hpp @@ -1083,6 +1083,12 @@ namespace zmq throw error_t (); } + void modify (zmq::socket_t &socket, short events) + { + if (0 != zmq_poller_modify (poller_ptr, socket.ptr, events)) + throw error_t (); + } + bool wait (std::chrono::milliseconds timeout) { int rc = zmq_poller_wait_all (poller_ptr, poller_events.data (), static_cast (poller_events.size ()), static_cast(timeout.count ()));