Problem: poller can segfault when modified from registered handler. (#219)

* Problem: poller can segfault when modified from registred handler.

It is possible that a user would like to add/remove sockets from
handlers. As handlers and poll items might be removed while not
being processed yet - we have a segfault situation.
Provided unit test `remove_from_handler` demonstrates the problem.

Solution: Modify internal poll item data structure only after processing
of events is finished.

Please not that events processing path performance remains the same when there are
no modification (add/remove) to the poller (no rebuild) - main real use case.

As an effect of changes `size()` method has been removed as it does not
represent any meaningful information anymore. There are active and pending
(waiting for rebuild) poll items so two different sizes. User can
easily track on their side number of registered sockets if original size
information is needed.

`wait` method returns number of processed sockets now. It might be
useful information to a user for no extra cost.
This commit is contained in:
Pawel Kurdybacha 2018-05-03 21:10:05 +01:00 committed by Luca Boccassi
parent ac64eba5c6
commit faf6671d38
2 changed files with 84 additions and 32 deletions

View File

@ -9,7 +9,6 @@
TEST(poller, create_destroy)
{
zmq::poller_t poller;
ASSERT_EQ(0u, poller.size ());
}
static_assert(!std::is_copy_constructible<zmq::poller_t>::value, "poller_t should not be copy-constructible");
@ -20,9 +19,7 @@ TEST(poller, move_construct_empty)
std::unique_ptr<zmq::poller_t> a{new zmq::poller_t};
zmq::poller_t b = std::move(*a);
ASSERT_EQ(0u, a->size ());
a.reset ();
ASSERT_EQ(0u, b.size ());
}
TEST(poller, move_assign_empty)
@ -32,9 +29,7 @@ TEST(poller, move_assign_empty)
b = std::move(*a);
ASSERT_EQ(0u, a->size ());
a.reset ();
ASSERT_EQ(0u, b.size ());
}
TEST(poller, move_construct_non_empty)
@ -46,9 +41,7 @@ TEST(poller, move_construct_non_empty)
a->add(socket, ZMQ_POLLIN, [](short) {});
zmq::poller_t b = std::move(*a);
ASSERT_EQ(0u, a->size ());
a.reset ();
ASSERT_EQ(1u, b.size ());
}
TEST(poller, move_assign_non_empty)
@ -62,9 +55,7 @@ TEST(poller, move_assign_non_empty)
b = std::move(*a);
ASSERT_EQ(0u, a->size ());
a.reset ();
ASSERT_EQ(1u, b.size ());
}
TEST(poller, add_handler)
@ -246,7 +237,6 @@ TEST(poller, poller_add_invalid_socket_throws)
zmq::socket_t b {std::move (a)};
ASSERT_THROW (poller.add (a, ZMQ_POLLIN, zmq::poller_t::handler_t {}),
zmq::error_t);
ASSERT_EQ (0u, poller.size ());
}
TEST(poller, poller_remove_invalid_socket_throws)
@ -255,11 +245,9 @@ TEST(poller, poller_remove_invalid_socket_throws)
zmq::socket_t socket {context, zmq::socket_type::router};
zmq::poller_t poller;
ASSERT_NO_THROW (poller.add (socket, ZMQ_POLLIN, zmq::poller_t::handler_t {}));
ASSERT_EQ (1u, poller.size ());
std::vector<zmq::socket_t> sockets;
sockets.emplace_back (std::move (socket));
ASSERT_THROW (poller.remove (socket), zmq::error_t);
ASSERT_EQ (1u, poller.size ());
}
TEST(poller, wait_on_added_empty_handler)
@ -331,4 +319,58 @@ TEST(poller, poll_client_server)
ASSERT_EQ(s.events, ZMQ_POLLIN | ZMQ_POLLOUT);
}
TEST(poller, wait_one_return)
{
// Setup server and client
server_client_setup s;
int count = 0;
// Setup poller
zmq::poller_t poller;
ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN, [&count](short) {
++count;
}));
// client sends message
ASSERT_NO_THROW(s.client.send("Hi"));
// wait for message and verify events
int result = poller.wait(std::chrono::milliseconds{500});
ASSERT_EQ(count, result);
}
TEST(poller, remove_from_handler)
{
constexpr auto ITER_NO = 10;
// Setup servers and clients
std::vector<server_client_setup> setup_list;
for (auto i = 0; i < ITER_NO; ++i)
setup_list.emplace_back (server_client_setup{});
// Setup poller
zmq::poller_t poller;
for (auto i = 0; i < ITER_NO; ++i) {
ASSERT_NO_THROW(poller.add(setup_list[i].server, ZMQ_POLLIN, [&,i](short events) {
ASSERT_EQ(events, ZMQ_POLLIN);
poller.remove(setup_list[ITER_NO - i -1].server);
}));
}
// Clients send messages
for (auto & s : setup_list) {
ASSERT_NO_THROW(s.client.send("Hi"));
}
// Wait for all servers to receive a message
for (auto & s : setup_list) {
zmq::pollitem_t items [] = { { s.server, 0, ZMQ_POLLIN, 0 } };
zmq::poll (&items [0], 1);
}
// Fire all handlers in one wait
int count = poller.wait (std::chrono::milliseconds{-1});
ASSERT_EQ(count, ITER_NO);
}
#endif

50
zmq.hpp
View File

@ -74,6 +74,7 @@
#include <tuple>
#include <functional>
#include <unordered_map>
#include <memory>
#endif
// Detect whether the compiler supports C++11 rvalue references.
@ -1054,10 +1055,9 @@ namespace zmq
{
auto it = std::end (handlers);
auto inserted = false;
if (handler)
std::tie(it, inserted) = handlers.emplace (socket.ptr, std::move (handler));
if (0 == zmq_poller_add (poller_ptr, socket.ptr, inserted ? &(it->second) : nullptr, events)) {
poller_events.emplace_back (zmq_poller_event_t ());
std::tie(it, inserted) = handlers.emplace (socket.ptr, std::make_shared<handler_t> (std::move (handler)));
if (0 == zmq_poller_add (poller_ptr, socket.ptr, inserted && *(it->second) ? it->second.get() : nullptr, events)) {
need_rebuild = true;
return;
}
// rollback
@ -1070,7 +1070,7 @@ namespace zmq
{
if (0 == zmq_poller_remove (poller_ptr, socket.ptr)) {
handlers.erase (socket.ptr);
poller_events.pop_back ();
need_rebuild = true;
return;
}
throw error_t ();
@ -1082,36 +1082,46 @@ namespace zmq
throw error_t ();
}
bool wait (std::chrono::milliseconds timeout)
int wait (std::chrono::milliseconds timeout)
{
int rc = zmq_poller_wait_all (poller_ptr, poller_events.data (), static_cast<int> (poller_events.size ()), static_cast<long>(timeout.count ()));
if (rc >= 0) {
std::for_each (poller_events.begin (), poller_events.begin () + rc, [](zmq_poller_event_t& event) {
if (need_rebuild) {
poller_events.clear ();
poller_handlers.clear ();
poller_events.reserve (handlers.size ());
poller_handlers.reserve (handlers.size ());
for (const auto &handler : handlers) {
poller_events.emplace_back (zmq_poller_event_t {});
poller_handlers.push_back (handler.second);
}
need_rebuild = false;
}
int rc = zmq_poller_wait_all (poller_ptr, poller_events.data (),
static_cast<int> (poller_events.size ()),
static_cast<long>(timeout.count ()));
if (rc > 0) {
std::for_each (poller_events.begin (), poller_events.begin () + rc,
[](zmq_poller_event_t& event) {
if (event.user_data != NULL)
(*reinterpret_cast<handler_t*> (event.user_data)) (event.events);
});
return true;
return rc;
}
#if ZMQ_VERSION >= ZMQ_MAKE_VERSION(4, 2, 3)
if (zmq_errno () == EAGAIN)
#else
if (zmq_errno () == ETIMEDOUT)
#endif
return false;
return 0;
throw error_t ();
}
size_t size ()
{
return poller_events.size();
}
private:
void *poller_ptr;
std::vector<zmq_poller_event_t> poller_events;
std::unordered_map<void*, handler_t> handlers;
void *poller_ptr {nullptr};
bool need_rebuild {false};
std::unordered_map<void*, std::shared_ptr<handler_t>> handlers {};
std::vector<zmq_poller_event_t> poller_events {};
std::vector<std::shared_ptr<handler_t>> poller_handlers {};
}; // class poller_t
#endif // defined(ZMQ_BUILD_DRAFT_API) && defined(ZMQ_CPP11) && defined(ZMQ_HAVE_POLLER)