diff --git a/tests/poller.cpp b/tests/poller.cpp index a24f53e..dfb0776 100644 --- a/tests/poller.cpp +++ b/tests/poller.cpp @@ -9,7 +9,6 @@ TEST(poller, create_destroy) { zmq::poller_t poller; - ASSERT_EQ(0u, poller.size ()); } static_assert(!std::is_copy_constructible::value, "poller_t should not be copy-constructible"); @@ -20,9 +19,7 @@ TEST(poller, move_construct_empty) std::unique_ptr a{new zmq::poller_t}; zmq::poller_t b = std::move(*a); - ASSERT_EQ(0u, a->size ()); a.reset (); - ASSERT_EQ(0u, b.size ()); } TEST(poller, move_assign_empty) @@ -32,9 +29,7 @@ TEST(poller, move_assign_empty) b = std::move(*a); - ASSERT_EQ(0u, a->size ()); a.reset (); - ASSERT_EQ(0u, b.size ()); } TEST(poller, move_construct_non_empty) @@ -46,9 +41,7 @@ TEST(poller, move_construct_non_empty) a->add(socket, ZMQ_POLLIN, [](short) {}); zmq::poller_t b = std::move(*a); - ASSERT_EQ(0u, a->size ()); a.reset (); - ASSERT_EQ(1u, b.size ()); } TEST(poller, move_assign_non_empty) @@ -62,9 +55,7 @@ TEST(poller, move_assign_non_empty) b = std::move(*a); - ASSERT_EQ(0u, a->size ()); a.reset (); - ASSERT_EQ(1u, b.size ()); } TEST(poller, add_handler) @@ -246,7 +237,6 @@ TEST(poller, poller_add_invalid_socket_throws) zmq::socket_t b {std::move (a)}; ASSERT_THROW (poller.add (a, ZMQ_POLLIN, zmq::poller_t::handler_t {}), zmq::error_t); - ASSERT_EQ (0u, poller.size ()); } TEST(poller, poller_remove_invalid_socket_throws) @@ -255,11 +245,9 @@ TEST(poller, poller_remove_invalid_socket_throws) zmq::socket_t socket {context, zmq::socket_type::router}; zmq::poller_t poller; ASSERT_NO_THROW (poller.add (socket, ZMQ_POLLIN, zmq::poller_t::handler_t {})); - ASSERT_EQ (1u, poller.size ()); std::vector sockets; sockets.emplace_back (std::move (socket)); ASSERT_THROW (poller.remove (socket), zmq::error_t); - ASSERT_EQ (1u, poller.size ()); } TEST(poller, wait_on_added_empty_handler) @@ -331,4 +319,58 @@ TEST(poller, poll_client_server) ASSERT_EQ(s.events, ZMQ_POLLIN | ZMQ_POLLOUT); } +TEST(poller, wait_one_return) +{ + // Setup server and client + server_client_setup s; + + int count = 0; + + // Setup poller + zmq::poller_t poller; + ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN, [&count](short) { + ++count; + })); + + // client sends message + ASSERT_NO_THROW(s.client.send("Hi")); + + // wait for message and verify events + int result = poller.wait(std::chrono::milliseconds{500}); + ASSERT_EQ(count, result); +} + +TEST(poller, remove_from_handler) +{ + constexpr auto ITER_NO = 10; + + // Setup servers and clients + std::vector setup_list; + for (auto i = 0; i < ITER_NO; ++i) + setup_list.emplace_back (server_client_setup{}); + + // Setup poller + zmq::poller_t poller; + for (auto i = 0; i < ITER_NO; ++i) { + ASSERT_NO_THROW(poller.add(setup_list[i].server, ZMQ_POLLIN, [&,i](short events) { + ASSERT_EQ(events, ZMQ_POLLIN); + poller.remove(setup_list[ITER_NO - i -1].server); + })); + } + // Clients send messages + for (auto & s : setup_list) { + ASSERT_NO_THROW(s.client.send("Hi")); + } + + // Wait for all servers to receive a message + for (auto & s : setup_list) { + zmq::pollitem_t items [] = { { s.server, 0, ZMQ_POLLIN, 0 } }; + zmq::poll (&items [0], 1); + } + + // Fire all handlers in one wait + int count = poller.wait (std::chrono::milliseconds{-1}); + ASSERT_EQ(count, ITER_NO); +} + #endif diff --git a/zmq.hpp b/zmq.hpp index 8361222..f2efacd 100644 --- a/zmq.hpp +++ b/zmq.hpp @@ -74,6 +74,7 @@ #include #include #include + #include #endif // Detect whether the compiler supports C++11 rvalue references. @@ -1054,10 +1055,9 @@ namespace zmq { auto it = std::end (handlers); auto inserted = false; - if (handler) - std::tie(it, inserted) = handlers.emplace (socket.ptr, std::move (handler)); - if (0 == zmq_poller_add (poller_ptr, socket.ptr, inserted ? &(it->second) : nullptr, events)) { - poller_events.emplace_back (zmq_poller_event_t ()); + std::tie(it, inserted) = handlers.emplace (socket.ptr, std::make_shared (std::move (handler))); + if (0 == zmq_poller_add (poller_ptr, socket.ptr, inserted && *(it->second) ? it->second.get() : nullptr, events)) { + need_rebuild = true; return; } // rollback @@ -1070,7 +1070,7 @@ namespace zmq { if (0 == zmq_poller_remove (poller_ptr, socket.ptr)) { handlers.erase (socket.ptr); - poller_events.pop_back (); + need_rebuild = true; return; } throw error_t (); @@ -1082,36 +1082,46 @@ namespace zmq throw error_t (); } - bool wait (std::chrono::milliseconds timeout) + int wait (std::chrono::milliseconds timeout) { - int rc = zmq_poller_wait_all (poller_ptr, poller_events.data (), static_cast (poller_events.size ()), static_cast(timeout.count ())); - if (rc >= 0) { - std::for_each (poller_events.begin (), poller_events.begin () + rc, [](zmq_poller_event_t& event) { + if (need_rebuild) { + poller_events.clear (); + poller_handlers.clear (); + poller_events.reserve (handlers.size ()); + poller_handlers.reserve (handlers.size ()); + for (const auto &handler : handlers) { + poller_events.emplace_back (zmq_poller_event_t {}); + poller_handlers.push_back (handler.second); + } + need_rebuild = false; + } + int rc = zmq_poller_wait_all (poller_ptr, poller_events.data (), + static_cast (poller_events.size ()), + static_cast(timeout.count ())); + if (rc > 0) { + std::for_each (poller_events.begin (), poller_events.begin () + rc, + [](zmq_poller_event_t& event) { if (event.user_data != NULL) (*reinterpret_cast (event.user_data)) (event.events); }); - return true; + return rc; } - #if ZMQ_VERSION >= ZMQ_MAKE_VERSION(4, 2, 3) if (zmq_errno () == EAGAIN) #else if (zmq_errno () == ETIMEDOUT) #endif - return false; + return 0; throw error_t (); } - size_t size () - { - return poller_events.size(); - } - private: - void *poller_ptr; - std::vector poller_events; - std::unordered_map handlers; + void *poller_ptr {nullptr}; + bool need_rebuild {false}; + std::unordered_map> handlers {}; + std::vector poller_events {}; + std::vector> poller_handlers {}; }; // class poller_t #endif // defined(ZMQ_BUILD_DRAFT_API) && defined(ZMQ_CPP11) && defined(ZMQ_HAVE_POLLER)