Problem: poller_t adds an abstraction layer on zmq_poller_*

Solution: extract base_poller_t from poller_t, which provides a direct mapping of zmq_poller_* to C++ only
This commit is contained in:
Simon Giesecke 2018-05-11 09:41:36 +02:00
parent cdef8bc069
commit bf47be0a0c
2 changed files with 99 additions and 56 deletions

View File

@ -204,7 +204,7 @@ TEST(poller, poll_basic)
message_received = true; message_received = true;
}; };
ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN, handler)); ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN, handler));
ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{-1})); ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{-1}));
ASSERT_TRUE(message_received); ASSERT_TRUE(message_received);
} }
@ -237,13 +237,13 @@ TEST(poller, client_server)
// client sends message // client sends message
ASSERT_NO_THROW(s.client.send(send_msg)); ASSERT_NO_THROW(s.client.send(send_msg));
ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{-1})); ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{-1}));
ASSERT_EQ(events, ZMQ_POLLIN); ASSERT_EQ(events, ZMQ_POLLIN);
// Re-add server socket with pollout flag // Re-add server socket with pollout flag
ASSERT_NO_THROW(poller.remove(s.server)); ASSERT_NO_THROW(poller.remove(s.server));
ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN | ZMQ_POLLOUT, handler)); ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN | ZMQ_POLLOUT, handler));
ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{-1})); ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{-1}));
ASSERT_EQ(events, ZMQ_POLLOUT); ASSERT_EQ(events, ZMQ_POLLOUT);
} }
@ -335,7 +335,7 @@ TEST(poller, poll_client_server)
// Modify server socket with pollout flag // Modify server socket with pollout flag
ASSERT_NO_THROW(poller.modify(s.server, ZMQ_POLLIN | ZMQ_POLLOUT)); ASSERT_NO_THROW(poller.modify(s.server, ZMQ_POLLIN | ZMQ_POLLOUT));
ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{500})); ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(s.events, ZMQ_POLLIN | ZMQ_POLLOUT); ASSERT_EQ(s.events, ZMQ_POLLIN | ZMQ_POLLOUT);
} }
@ -356,8 +356,8 @@ TEST(poller, wait_one_return)
ASSERT_NO_THROW(s.client.send("Hi")); ASSERT_NO_THROW(s.client.send("Hi"));
// wait for message and verify events // wait for message and verify events
int result = poller.wait(std::chrono::milliseconds{500}); ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(count, result); ASSERT_EQ(1u, count);
} }
TEST(poller, wait_on_move_constructed_poller) TEST(poller, wait_on_move_constructed_poller)
@ -401,14 +401,14 @@ TEST(poller, received_on_move_construced_poller)
// client sends message // client sends message
ASSERT_NO_THROW(s.client.send("Hi")); ASSERT_NO_THROW(s.client.send("Hi"));
// wait for message and verify it is received // wait for message and verify it is received
a.wait(std::chrono::milliseconds{500}); ASSERT_EQ(1, a.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(1u, count); ASSERT_EQ(1u, count);
// Move construct poller b // Move construct poller b
zmq::poller_t b{std::move(a)}; zmq::poller_t b{std::move(a)};
// client sends message again // client sends message again
ASSERT_NO_THROW(s.client.send("Hi")); ASSERT_NO_THROW(s.client.send("Hi"));
// wait for message and verify it is received // wait for message and verify it is received
b.wait(std::chrono::milliseconds{500}); ASSERT_EQ(1, b.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(2u, count); ASSERT_EQ(2u, count);
} }
@ -424,12 +424,14 @@ TEST(poller, remove_from_handler)
// Setup poller // Setup poller
zmq::poller_t poller; zmq::poller_t poller;
int count = 0;
for (auto i = 0; i < ITER_NO; ++i) { for (auto i = 0; i < ITER_NO; ++i) {
ASSERT_NO_THROW(poller.add(setup_list[i].server, ZMQ_POLLIN, [&,i](short events) { ASSERT_NO_THROW(poller.add(setup_list[i].server, ZMQ_POLLIN, [&,i](short events) {
ASSERT_EQ(events, ZMQ_POLLIN); ASSERT_EQ(events, ZMQ_POLLIN);
poller.remove(setup_list[ITER_NO-i-1].server); poller.remove(setup_list[ITER_NO-i-1].server);
ASSERT_EQ(ITER_NO-i-1, poller.size()); ASSERT_EQ(ITER_NO-i-1, poller.size());
})); }));
++count;
} }
ASSERT_EQ(ITER_NO, poller.size()); ASSERT_EQ(ITER_NO, poller.size());
// Clients send messages // Clients send messages
@ -444,8 +446,8 @@ TEST(poller, remove_from_handler)
} }
// Fire all handlers in one wait // Fire all handlers in one wait
int count = poller.wait (std::chrono::milliseconds{-1}); ASSERT_EQ(ITER_NO, poller.wait (std::chrono::milliseconds{-1}));
ASSERT_EQ(count, ITER_NO); ASSERT_EQ(ITER_NO, count);
} }
#endif #endif

127
zmq.hpp
View File

@ -577,7 +577,6 @@ namespace zmq
class socket_t class socket_t
{ {
friend class monitor_t; friend class monitor_t;
friend class poller_t;
public: public:
inline socket_t(context_t& context_, int type_) inline socket_t(context_t& context_, int type_)
{ {
@ -1019,6 +1018,67 @@ namespace zmq
}; };
#if defined(ZMQ_BUILD_DRAFT_API) && defined(ZMQ_CPP11) && defined(ZMQ_HAVE_POLLER) #if defined(ZMQ_BUILD_DRAFT_API) && defined(ZMQ_CPP11) && defined(ZMQ_HAVE_POLLER)
template <typename T = void>
class base_poller_t
{
public:
void add (zmq::socket_t &socket, short events, T *user_data)
{
if (0 != zmq_poller_add (poller_ptr.get (), static_cast<void*>(socket), user_data, events))
{
throw error_t ();
}
}
void remove (zmq::socket_t &socket)
{
if (0 != zmq_poller_remove (poller_ptr.get (), static_cast<void*>(socket)))
{
throw error_t ();
}
}
void modify (zmq::socket_t &socket, short events)
{
if (0 != zmq_poller_modify (poller_ptr.get (), static_cast<void*>(socket), events))
{
throw error_t ();
}
}
int wait_all (std::vector<zmq_poller_event_t> &poller_events, const std::chrono::microseconds timeout)
{
int rc = zmq_poller_wait_all (poller_ptr.get (), poller_events.data (),
static_cast<int> (poller_events.size ()),
static_cast<long>(timeout.count ()));
if (rc > 0)
return rc;
#if ZMQ_VERSION >= ZMQ_MAKE_VERSION(4, 2, 3)
if (zmq_errno () == EAGAIN)
#else
if (zmq_errno () == ETIMEDOUT)
#endif
return 0;
throw error_t ();
}
private:
std::unique_ptr<void, std::function<void(void*)>> poller_ptr
{
[]() {
auto poller_new = zmq_poller_new ();
if (poller_new)
return poller_new;
throw error_t ();
}(),
[](void *ptr) {
int rc = zmq_poller_destroy (&ptr);
ZMQ_ASSERT (rc == 0);
}
};
};
class poller_t class poller_t
{ {
public: public:
@ -1035,33 +1095,35 @@ namespace zmq
void add (zmq::socket_t &socket, short events, handler_t handler) void add (zmq::socket_t &socket, short events, handler_t handler)
{ {
auto it = std::end (handlers); auto it = decltype (handlers)::iterator {};
auto inserted = false; auto inserted = bool {};
std::tie(it, inserted) = handlers.emplace (socket.ptr, std::make_shared<handler_t> (std::move (handler))); std::tie(it, inserted) = handlers.emplace (static_cast<void*>(socket), std::make_shared<handler_t> (std::move (handler)));
if (0 == zmq_poller_add (poller_ptr.get (), socket.ptr, inserted && *(it->second) ? it->second.get() : nullptr, events)) { try
need_rebuild = true; {
return; base_poller.add (socket, events, inserted && *(it->second) ? it->second.get() : nullptr);
need_rebuild |= inserted;
} }
catch (const zmq::error_t&)
{
// rollback // rollback
if (inserted) if (inserted)
handlers.erase (socket.ptr); {
throw error_t (); handlers.erase (static_cast<void*>(socket));
}
throw;
}
} }
void remove (zmq::socket_t &socket) void remove (zmq::socket_t &socket)
{ {
if (0 == zmq_poller_remove (poller_ptr.get (), socket.ptr)) { base_poller.remove (socket);
handlers.erase (socket.ptr); handlers.erase (static_cast<void*>(socket));
need_rebuild = true; need_rebuild = true;
return;
}
throw error_t ();
} }
void modify (zmq::socket_t &socket, short events) void modify (zmq::socket_t &socket, short events)
{ {
if (0 != zmq_poller_modify (poller_ptr.get (), socket.ptr, events)) base_poller.modify (socket, events);
throw error_t ();
} }
int wait (std::chrono::milliseconds timeout) int wait (std::chrono::milliseconds timeout)
@ -1077,25 +1139,15 @@ namespace zmq
} }
need_rebuild = false; need_rebuild = false;
} }
int rc = zmq_poller_wait_all (poller_ptr.get (), poller_events.data (), const int count = base_poller.wait_all (poller_events, timeout);
static_cast<int> (poller_events.size ()), if (count != 0) {
static_cast<long>(timeout.count ())); std::for_each (poller_events.begin (), poller_events.begin () + count,
if (rc > 0) {
std::for_each (poller_events.begin (), poller_events.begin () + rc,
[](zmq_poller_event_t& event) { [](zmq_poller_event_t& event) {
if (event.user_data != NULL) if (event.user_data != NULL)
(*reinterpret_cast<handler_t*> (event.user_data)) (event.events); (*reinterpret_cast<handler_t*> (event.user_data)) (event.events);
}); });
return rc;
} }
#if ZMQ_VERSION >= ZMQ_MAKE_VERSION(4, 2, 3) return count;
if (zmq_errno () == EAGAIN)
#else
if (zmq_errno () == ETIMEDOUT)
#endif
return 0;
throw error_t ();
} }
bool empty () const bool empty () const
@ -1109,20 +1161,9 @@ namespace zmq
} }
private: private:
std::unique_ptr<void, std::function<void(void*)>> poller_ptr
{
[]() {
auto poller_new = zmq_poller_new ();
if (poller_new)
return poller_new;
throw error_t ();
}(),
[](void *ptr) {
int rc = zmq_poller_destroy (&ptr);
ZMQ_ASSERT (rc == 0);
}
};
bool need_rebuild {false}; bool need_rebuild {false};
base_poller_t<handler_t> base_poller {};
std::unordered_map<void*, std::shared_ptr<handler_t>> handlers {}; std::unordered_map<void*, std::shared_ptr<handler_t>> handlers {};
std::vector<zmq_poller_event_t> poller_events {}; std::vector<zmq_poller_event_t> poller_events {};
std::vector<std::shared_ptr<handler_t>> poller_handlers {}; std::vector<std::shared_ptr<handler_t>> poller_handlers {};