Problem: poller_t adds an abstraction layer on zmq_poller_*

Solution: extract base_poller_t from poller_t, which provides a direct mapping of zmq_poller_* to C++ only
This commit is contained in:
Simon Giesecke 2018-05-11 09:41:36 +02:00
parent cdef8bc069
commit bf47be0a0c
2 changed files with 99 additions and 56 deletions

View File

@ -204,7 +204,7 @@ TEST(poller, poll_basic)
message_received = true;
};
ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN, handler));
ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{-1}));
ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{-1}));
ASSERT_TRUE(message_received);
}
@ -237,13 +237,13 @@ TEST(poller, client_server)
// client sends message
ASSERT_NO_THROW(s.client.send(send_msg));
ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{-1}));
ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{-1}));
ASSERT_EQ(events, ZMQ_POLLIN);
// Re-add server socket with pollout flag
ASSERT_NO_THROW(poller.remove(s.server));
ASSERT_NO_THROW(poller.add(s.server, ZMQ_POLLIN | ZMQ_POLLOUT, handler));
ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{-1}));
ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{-1}));
ASSERT_EQ(events, ZMQ_POLLOUT);
}
@ -335,7 +335,7 @@ TEST(poller, poll_client_server)
// Modify server socket with pollout flag
ASSERT_NO_THROW(poller.modify(s.server, ZMQ_POLLIN | ZMQ_POLLOUT));
ASSERT_NO_THROW(poller.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(s.events, ZMQ_POLLIN | ZMQ_POLLOUT);
}
@ -356,8 +356,8 @@ TEST(poller, wait_one_return)
ASSERT_NO_THROW(s.client.send("Hi"));
// wait for message and verify events
int result = poller.wait(std::chrono::milliseconds{500});
ASSERT_EQ(count, result);
ASSERT_EQ(1, poller.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(1u, count);
}
TEST(poller, wait_on_move_constructed_poller)
@ -401,14 +401,14 @@ TEST(poller, received_on_move_construced_poller)
// client sends message
ASSERT_NO_THROW(s.client.send("Hi"));
// wait for message and verify it is received
a.wait(std::chrono::milliseconds{500});
ASSERT_EQ(1, a.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(1u, count);
// Move construct poller b
zmq::poller_t b{std::move(a)};
// client sends message again
ASSERT_NO_THROW(s.client.send("Hi"));
// wait for message and verify it is received
b.wait(std::chrono::milliseconds{500});
ASSERT_EQ(1, b.wait(std::chrono::milliseconds{500}));
ASSERT_EQ(2u, count);
}
@ -424,12 +424,14 @@ TEST(poller, remove_from_handler)
// Setup poller
zmq::poller_t poller;
int count = 0;
for (auto i = 0; i < ITER_NO; ++i) {
ASSERT_NO_THROW(poller.add(setup_list[i].server, ZMQ_POLLIN, [&,i](short events) {
ASSERT_EQ(events, ZMQ_POLLIN);
poller.remove(setup_list[ITER_NO-i-1].server);
ASSERT_EQ(ITER_NO-i-1, poller.size());
}));
++count;
}
ASSERT_EQ(ITER_NO, poller.size());
// Clients send messages
@ -444,8 +446,8 @@ TEST(poller, remove_from_handler)
}
// Fire all handlers in one wait
int count = poller.wait (std::chrono::milliseconds{-1});
ASSERT_EQ(count, ITER_NO);
ASSERT_EQ(ITER_NO, poller.wait (std::chrono::milliseconds{-1}));
ASSERT_EQ(ITER_NO, count);
}
#endif

133
zmq.hpp
View File

@ -577,7 +577,6 @@ namespace zmq
class socket_t
{
friend class monitor_t;
friend class poller_t;
public:
inline socket_t(context_t& context_, int type_)
{
@ -1019,6 +1018,67 @@ namespace zmq
};
#if defined(ZMQ_BUILD_DRAFT_API) && defined(ZMQ_CPP11) && defined(ZMQ_HAVE_POLLER)
template <typename T = void>
class base_poller_t
{
public:
void add (zmq::socket_t &socket, short events, T *user_data)
{
if (0 != zmq_poller_add (poller_ptr.get (), static_cast<void*>(socket), user_data, events))
{
throw error_t ();
}
}
void remove (zmq::socket_t &socket)
{
if (0 != zmq_poller_remove (poller_ptr.get (), static_cast<void*>(socket)))
{
throw error_t ();
}
}
void modify (zmq::socket_t &socket, short events)
{
if (0 != zmq_poller_modify (poller_ptr.get (), static_cast<void*>(socket), events))
{
throw error_t ();
}
}
int wait_all (std::vector<zmq_poller_event_t> &poller_events, const std::chrono::microseconds timeout)
{
int rc = zmq_poller_wait_all (poller_ptr.get (), poller_events.data (),
static_cast<int> (poller_events.size ()),
static_cast<long>(timeout.count ()));
if (rc > 0)
return rc;
#if ZMQ_VERSION >= ZMQ_MAKE_VERSION(4, 2, 3)
if (zmq_errno () == EAGAIN)
#else
if (zmq_errno () == ETIMEDOUT)
#endif
return 0;
throw error_t ();
}
private:
std::unique_ptr<void, std::function<void(void*)>> poller_ptr
{
[]() {
auto poller_new = zmq_poller_new ();
if (poller_new)
return poller_new;
throw error_t ();
}(),
[](void *ptr) {
int rc = zmq_poller_destroy (&ptr);
ZMQ_ASSERT (rc == 0);
}
};
};
class poller_t
{
public:
@ -1035,33 +1095,35 @@ namespace zmq
void add (zmq::socket_t &socket, short events, handler_t handler)
{
auto it = std::end (handlers);
auto inserted = false;
std::tie(it, inserted) = handlers.emplace (socket.ptr, std::make_shared<handler_t> (std::move (handler)));
if (0 == zmq_poller_add (poller_ptr.get (), socket.ptr, inserted && *(it->second) ? it->second.get() : nullptr, events)) {
need_rebuild = true;
return;
auto it = decltype (handlers)::iterator {};
auto inserted = bool {};
std::tie(it, inserted) = handlers.emplace (static_cast<void*>(socket), std::make_shared<handler_t> (std::move (handler)));
try
{
base_poller.add (socket, events, inserted && *(it->second) ? it->second.get() : nullptr);
need_rebuild |= inserted;
}
catch (const zmq::error_t&)
{
// rollback
if (inserted)
{
handlers.erase (static_cast<void*>(socket));
}
throw;
}
// rollback
if (inserted)
handlers.erase (socket.ptr);
throw error_t ();
}
void remove (zmq::socket_t &socket)
{
if (0 == zmq_poller_remove (poller_ptr.get (), socket.ptr)) {
handlers.erase (socket.ptr);
need_rebuild = true;
return;
}
throw error_t ();
base_poller.remove (socket);
handlers.erase (static_cast<void*>(socket));
need_rebuild = true;
}
void modify (zmq::socket_t &socket, short events)
{
if (0 != zmq_poller_modify (poller_ptr.get (), socket.ptr, events))
throw error_t ();
base_poller.modify (socket, events);
}
int wait (std::chrono::milliseconds timeout)
@ -1077,25 +1139,15 @@ namespace zmq
}
need_rebuild = false;
}
int rc = zmq_poller_wait_all (poller_ptr.get (), poller_events.data (),
static_cast<int> (poller_events.size ()),
static_cast<long>(timeout.count ()));
if (rc > 0) {
std::for_each (poller_events.begin (), poller_events.begin () + rc,
const int count = base_poller.wait_all (poller_events, timeout);
if (count != 0) {
std::for_each (poller_events.begin (), poller_events.begin () + count,
[](zmq_poller_event_t& event) {
if (event.user_data != NULL)
(*reinterpret_cast<handler_t*> (event.user_data)) (event.events);
});
return rc;
}
#if ZMQ_VERSION >= ZMQ_MAKE_VERSION(4, 2, 3)
if (zmq_errno () == EAGAIN)
#else
if (zmq_errno () == ETIMEDOUT)
#endif
return 0;
throw error_t ();
return count;
}
bool empty () const
@ -1109,20 +1161,9 @@ namespace zmq
}
private:
std::unique_ptr<void, std::function<void(void*)>> poller_ptr
{
[]() {
auto poller_new = zmq_poller_new ();
if (poller_new)
return poller_new;
throw error_t ();
}(),
[](void *ptr) {
int rc = zmq_poller_destroy (&ptr);
ZMQ_ASSERT (rc == 0);
}
};
bool need_rebuild {false};
base_poller_t<handler_t> base_poller {};
std::unordered_map<void*, std::shared_ptr<handler_t>> handlers {};
std::vector<zmq_poller_event_t> poller_events {};
std::vector<std::shared_ptr<handler_t>> poller_handlers {};