From faf6671d38273f7fcf8b8a4ac6e5a997980f862a Mon Sep 17 00:00:00 2001 From: Pawel Kurdybacha Date: Thu, 3 May 2018 21:10:05 +0100 Subject: [PATCH] Problem: poller can segfault when modified from registered handler. (#219) * Problem: poller can segfault when modified from registred handler. It is possible that a user would like to add/remove sockets from handlers. As handlers and poll items might be removed while not being processed yet - we have a segfault situation. Provided unit test `remove_from_handler` demonstrates the problem. Solution: Modify internal poll item data structure only after processing of events is finished. Please not that events processing path performance remains the same when there are no modification (add/remove) to the poller (no rebuild) - main real use case. As an effect of changes `size()` method has been removed as it does not represent any meaningful information anymore. There are active and pending (waiting for rebuild) poll items so two different sizes. User can easily track on their side number of registered sockets if original size information is needed. `wait` method returns number of processed sockets now. It might be useful information to a user for no extra cost. --- tests/poller.cpp | 66 +++++++++++++++++++++++++++++++++++++++--------- zmq.hpp | 50 +++++++++++++++++++++--------------- 2 files changed, 84 insertions(+), 32 deletions(-) diff --git a/tests/poller.cpp b/tests/poller.cpp index a24f53e..dfb0776 100644 --- a/tests/poller.cpp +++ b/tests/poller.cpp @@ -9,7 +9,6 @@ TEST(poller, create_destroy) { zmq::poller_t poller; - ASSERT_EQ(0u, poller.size ()); } static_assert(!std::is_copy_constructible::value, "poller_t should not be copy-constructible"); @@ -20,9 +19,7 @@ TEST(poller, move_construct_empty) std::unique_ptr a{new zmq::poller_t}; zmq::poller_t b = std::move(*a); - ASSERT_EQ(0u, a->size ()); a.reset (); - ASSERT_EQ(0u, b.size ()); } TEST(poller, move_assign_empty) @@ -32,9 +29,7 @@ TEST(poller, move_assign_empty) b = std::move(*a); - ASSERT_EQ(0u, a->size ()); a.reset (); - ASSERT_EQ(0u, b.size ()); } TEST(poller, move_construct_non_empty) @@ -46,9 +41,7 @@ TEST(poller, move_construct_non_empty) a->add(socket, ZMQ_POLLIN, [](short) {}); zmq::poller_t b = std::move(*a); - ASSERT_EQ(0u, a->size ()); a.reset (); - ASSERT_EQ(1u, b.size ()); } TEST(poller, move_assign_non_empty) @@ -62,9 +55,7 @@ TEST(poller, move_assign_non_empty) b = std::move(*a); - ASSERT_EQ(0u, a->size ()); a.reset (); - ASSERT_EQ(1u, b.size ()); } TEST(poller, add_handler) @@ -246,7 +237,6 @@ TEST(poller, poller_add_invalid_socket_throws) zmq::socket_t b {std::move (a)}; ASSERT_THROW (poller.add (a, ZMQ_POLLIN, zmq::poller_t::handler_t {}), zmq::error_t); - ASSERT_EQ (0u, poller.size ()); } TEST(poller, poller_remove_invalid_socket_throws) @@ -255,11 +245,9 @@ TEST(poller, poller_remove_invalid_socket_throws) zmq::socket_t socket {context, zmq::socket_type::router}; zmq::poller_t poller; ASSERT_NO_THROW (poller.add (socket, ZMQ_POLLIN, zmq::poller_t::handler_t {})); - ASSERT_EQ (1u, poller.size ()); std::vector sockets; sockets.emplace_back (std::move (socket)); ASSERT_THROW (poller.remove (socket), zmq::error_t); - ASSERT_EQ (1u, poller.size ()); } TEST(poller, wait_on_added_empty_handler) @@ -331,4 +319,58 @@ TEST(poller, poll_client_server) ASSERT_EQ(s.events, ZMQ_POLLIN | ZMQ_POLLOUT); } +TEST(poller, wait_one_return) +{ + // Setup server and client + server_client_setup s; + + int count = 0; + + // Setup poller + zmq::poller_t poller; + ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN, [&count](short) { + ++count; + })); + + // client sends message + ASSERT_NO_THROW(s.client.send("Hi")); + + // wait for message and verify events + int result = poller.wait(std::chrono::milliseconds{500}); + ASSERT_EQ(count, result); +} + +TEST(poller, remove_from_handler) +{ + constexpr auto ITER_NO = 10; + + // Setup servers and clients + std::vector setup_list; + for (auto i = 0; i < ITER_NO; ++i) + setup_list.emplace_back (server_client_setup{}); + + // Setup poller + zmq::poller_t poller; + for (auto i = 0; i < ITER_NO; ++i) { + ASSERT_NO_THROW(poller.add(setup_list[i].server, ZMQ_POLLIN, [&,i](short events) { + ASSERT_EQ(events, ZMQ_POLLIN); + poller.remove(setup_list[ITER_NO - i -1].server); + })); + } + // Clients send messages + for (auto & s : setup_list) { + ASSERT_NO_THROW(s.client.send("Hi")); + } + + // Wait for all servers to receive a message + for (auto & s : setup_list) { + zmq::pollitem_t items [] = { { s.server, 0, ZMQ_POLLIN, 0 } }; + zmq::poll (&items [0], 1); + } + + // Fire all handlers in one wait + int count = poller.wait (std::chrono::milliseconds{-1}); + ASSERT_EQ(count, ITER_NO); +} + #endif diff --git a/zmq.hpp b/zmq.hpp index 8361222..f2efacd 100644 --- a/zmq.hpp +++ b/zmq.hpp @@ -74,6 +74,7 @@ #include #include #include + #include #endif // Detect whether the compiler supports C++11 rvalue references. @@ -1054,10 +1055,9 @@ namespace zmq { auto it = std::end (handlers); auto inserted = false; - if (handler) - std::tie(it, inserted) = handlers.emplace (socket.ptr, std::move (handler)); - if (0 == zmq_poller_add (poller_ptr, socket.ptr, inserted ? &(it->second) : nullptr, events)) { - poller_events.emplace_back (zmq_poller_event_t ()); + std::tie(it, inserted) = handlers.emplace (socket.ptr, std::make_shared (std::move (handler))); + if (0 == zmq_poller_add (poller_ptr, socket.ptr, inserted && *(it->second) ? it->second.get() : nullptr, events)) { + need_rebuild = true; return; } // rollback @@ -1070,7 +1070,7 @@ namespace zmq { if (0 == zmq_poller_remove (poller_ptr, socket.ptr)) { handlers.erase (socket.ptr); - poller_events.pop_back (); + need_rebuild = true; return; } throw error_t (); @@ -1082,36 +1082,46 @@ namespace zmq throw error_t (); } - bool wait (std::chrono::milliseconds timeout) + int wait (std::chrono::milliseconds timeout) { - int rc = zmq_poller_wait_all (poller_ptr, poller_events.data (), static_cast (poller_events.size ()), static_cast(timeout.count ())); - if (rc >= 0) { - std::for_each (poller_events.begin (), poller_events.begin () + rc, [](zmq_poller_event_t& event) { + if (need_rebuild) { + poller_events.clear (); + poller_handlers.clear (); + poller_events.reserve (handlers.size ()); + poller_handlers.reserve (handlers.size ()); + for (const auto &handler : handlers) { + poller_events.emplace_back (zmq_poller_event_t {}); + poller_handlers.push_back (handler.second); + } + need_rebuild = false; + } + int rc = zmq_poller_wait_all (poller_ptr, poller_events.data (), + static_cast (poller_events.size ()), + static_cast(timeout.count ())); + if (rc > 0) { + std::for_each (poller_events.begin (), poller_events.begin () + rc, + [](zmq_poller_event_t& event) { if (event.user_data != NULL) (*reinterpret_cast (event.user_data)) (event.events); }); - return true; + return rc; } - #if ZMQ_VERSION >= ZMQ_MAKE_VERSION(4, 2, 3) if (zmq_errno () == EAGAIN) #else if (zmq_errno () == ETIMEDOUT) #endif - return false; + return 0; throw error_t (); } - size_t size () - { - return poller_events.size(); - } - private: - void *poller_ptr; - std::vector poller_events; - std::unordered_map handlers; + void *poller_ptr {nullptr}; + bool need_rebuild {false}; + std::unordered_map> handlers {}; + std::vector poller_events {}; + std::vector> poller_handlers {}; }; // class poller_t #endif // defined(ZMQ_BUILD_DRAFT_API) && defined(ZMQ_CPP11) && defined(ZMQ_HAVE_POLLER)