diff --git a/include/zmq.h b/include/zmq.h index 89e5c587..78d25d00 100644 --- a/include/zmq.h +++ b/include/zmq.h @@ -440,19 +440,23 @@ typedef struct zmq_poller_event_t int fd; #endif void *user_data; + short events; } zmq_poller_event_t; ZMQ_EXPORT void *zmq_poller_new (); ZMQ_EXPORT int zmq_poller_close (void *poller); -ZMQ_EXPORT int zmq_poller_add_socket (void *poller, void *socket, void *user_data); +ZMQ_EXPORT int zmq_poller_add_socket (void *poller, void *socket, void *user_data, short events); +ZMQ_EXPORT int zmq_poller_modify_socket (void *poller, void *socket, short events); ZMQ_EXPORT int zmq_poller_remove_socket (void *poller, void *socket); ZMQ_EXPORT int zmq_poller_wait (void *poller, zmq_poller_event_t *event, long timeout); #if defined _WIN32 -ZMQ_EXPORT int zmq_poller_add_fd (void *poller, SOCKET fd, void *user_data); +ZMQ_EXPORT int zmq_poller_add_fd (void *poller, SOCKET fd, void *user_data, short events); +ZMQ_EXPORT int zmq_poller_modify_fd (void *poller, SOCKET fd, short events); ZMQ_EXPORT int zmq_poller_remove_fd (void *poller, SOCKET fd); #else -ZMQ_EXPORT int zmq_poller_add_fd (void *poller, int fd, void *user_data); +ZMQ_EXPORT int zmq_poller_add_fd (void *poller, int fd, void *user_data, short events); +ZMQ_EXPORT int zmq_poller_modify (void *poller, int fd, short events); ZMQ_EXPORT int zmq_poller_remove_fd (void *poller, int fd); #endif diff --git a/src/socket_poller.cpp b/src/socket_poller.cpp index 18a48632..b198b34b 100644 --- a/src/socket_poller.cpp +++ b/src/socket_poller.cpp @@ -61,7 +61,7 @@ bool zmq::socket_poller_t::check_tag () return tag == 0xCAFEBABE; } -int zmq::socket_poller_t::add_socket (void *socket_, void* user_data_) +int zmq::socket_poller_t::add_socket (void *socket_, void* user_data_, short events_) { for (events_t::iterator it = events.begin (); it != events.end (); ++it) { if (it->socket == socket_) { @@ -81,7 +81,7 @@ int zmq::socket_poller_t::add_socket (void *socket_, void* user_data_) return -1; } - event_t event = {socket_, 0, user_data_}; + event_t event = {socket_, 0, user_data_, events_}; events.push_back (event); need_rebuild = true; @@ -89,9 +89,9 @@ int zmq::socket_poller_t::add_socket (void *socket_, void* user_data_) } #if defined _WIN32 -int zmq::socket_poller_t::add_fd (SOCKET fd_, void *user_data_) +int zmq::socket_poller_t::add_fd (SOCKET fd_, void *user_data_, short events_) #else -int zmq::socket_poller_t::add_fd (int fd_, void *user_data_) +int zmq::socket_poller_t::add_fd (int fd_, void *user_data_, short events_) #endif { for (events_t::iterator it = events.begin (); it != events.end (); ++it) { @@ -101,13 +101,59 @@ int zmq::socket_poller_t::add_fd (int fd_, void *user_data_) } } - event_t event = {NULL, fd_, user_data_}; + event_t event = {NULL, fd_, user_data_, events_}; events.push_back (event); need_rebuild = true; return 0; } +int zmq::socket_poller_t::modify_socket (void *socket_, short events_) +{ + events_t::iterator it; + + for (it = events.begin (); it != events.end (); ++it) { + if (it->socket == socket_) + break; + } + + if (it == events.end()) { + errno = EINVAL; + return -1; + } + + it->events = events_; + need_rebuild = true; + + return 0; +} + + +#if defined _WIN32 +int zmq::socket_poller_t::modify_fd (SOCKET fd_, short events_) +#else +int zmq::socket_poller_t::modify_fd (int fd_, short events_) +#endif +{ + events_t::iterator it; + + for (it = events.begin (); it != events.end (); ++it) { + if (!it->socket && it->fd == fd_) + break; + } + + if (it == events.end()) { + errno = EINVAL; + return -1; + } + + it->events = events_; + need_rebuild = true; + + return 0; +} + + int zmq::socket_poller_t::remove_socket (void* socket_) { events_t::iterator it; @@ -180,8 +226,9 @@ int zmq::socket_poller_t::wait (zmq::socket_poller_t::event_t *event_, long time } for (int i = 0; i < poll_size; i++) { - if (poll_set [i].revents & ZMQ_POLLIN) { - *event_ = poll_events[i]; + if ((poll_set [i].revents & poll_events [i].events) != 0) { + *event_ = poll_events[i]; + event_->events = poll_set [i].revents & poll_events [i].events; break; } @@ -216,7 +263,7 @@ void zmq::socket_poller_t::rebuild () if (!it->socket) poll_set [event_nbr].fd = it->fd; - poll_set [event_nbr].events = ZMQ_POLLIN; + poll_set [event_nbr].events = it->events; poll_events [event_nbr] = *it; } diff --git a/src/socket_poller.hpp b/src/socket_poller.hpp index 5bdc4353..09637bbb 100644 --- a/src/socket_poller.hpp +++ b/src/socket_poller.hpp @@ -53,15 +53,19 @@ namespace zmq int fd; #endif void *user_data; + short events; } event_t; - int add_socket (void *socket, void *user_data); + int add_socket (void *socket, void *user_data, short events); + int modify_socket (void *socket, short events); int remove_socket (void *socket); #if defined _WIN32 - int add_fd (SOCKET fd, void *user_data); + int add_fd (SOCKET fd, void *user_data, short events); + int mofify_fd (SOCKET fd, short events); int remove_fd (SOCKET fd); #else - int add_fd (int fd, void *user_data); + int add_fd (int fd, void *user_data, short events); + int modify_fd (int fd, short events); int remove_fd (int fd); #endif diff --git a/src/zmq.cpp b/src/zmq.cpp index 6664c77a..ad777099 100644 --- a/src/zmq.cpp +++ b/src/zmq.cpp @@ -1579,20 +1579,20 @@ int zmq_poller_close (void *poller_) return 0; } -int zmq_poller_add_socket (void *poller_, void *socket_, void *user_data_) +int zmq_poller_add_socket (void *poller_, void *socket_, void *user_data_, short events_) { if (!poller_ || !((zmq::socket_poller_t*)poller_)->check_tag ()) { errno = EFAULT; return -1; } - return ((zmq::socket_poller_t*)poller_)->add_socket (socket_, user_data_); + return ((zmq::socket_poller_t*)poller_)->add_socket (socket_, user_data_, events_); } #if defined _WIN32 -int zmq_poller_add_fd (void *poller_, SOCKET fd_, void *user_data_) +int zmq_poller_add_fd (void *poller_, SOCKET fd_, void *user_data_, short events_) #else -int zmq_poller_add_fd (void *poller_, int fd_, void *user_data_) +int zmq_poller_add_fd (void *poller_, int fd_, void *user_data_, short events_) #endif { if (!poller_ || !((zmq::socket_poller_t*)poller_)->check_tag ()) { @@ -1600,9 +1600,36 @@ int zmq_poller_add_fd (void *poller_, int fd_, void *user_data_) return -1; } - return ((zmq::socket_poller_t*)poller_)->add_fd (fd_, user_data_); + return ((zmq::socket_poller_t*)poller_)->add_fd (fd_, user_data_, events_); } + +int zmq_poller_modify_socket (void *poller_, void *socket_, short events_) +{ + if (!poller_ || !((zmq::socket_poller_t*)poller_)->check_tag ()) { + errno = EFAULT; + return -1; + } + + return ((zmq::socket_poller_t*)poller_)->modify_socket (socket_, events_); +} + + +#if defined _WIN32 +int zmq_poller_modify_fd (void *poller_, SOCKET fd_, short events_) +#else +int zmq_poller_modify_fd (void *poller_, int fd_, short events_) +#endif +{ + if (!poller_ || !((zmq::socket_poller_t*)poller_)->check_tag ()) { + errno = EFAULT; + return -1; + } + + return ((zmq::socket_poller_t*)poller_)->modify_fd (fd_, events_); +} + + int zmq_poller_remove_socket (void *poller_, void *socket) { if (!poller_ || !((zmq::socket_poller_t*)poller_)->check_tag ()) { @@ -1642,6 +1669,7 @@ int zmq_poller_wait (void *poller_, zmq_poller_event_t *event, long timeout_) event->socket = e.socket; event->fd = e.fd; event->user_data = e.user_data; + event->events = e.events; return rc; } diff --git a/tests/test_poller.cpp b/tests/test_poller.cpp index 9a9f9415..fcfeefac 100644 --- a/tests/test_poller.cpp +++ b/tests/test_poller.cpp @@ -61,7 +61,7 @@ int main (void) // Set up poller void* poller = zmq_poller_new (); - rc = zmq_poller_add_socket (poller, sink, sink); + rc = zmq_poller_add_socket (poller, sink, sink, ZMQ_POLLIN); assert (rc == 0); // Send a message @@ -96,7 +96,7 @@ int main (void) rc = zmq_getsockopt (bowl, ZMQ_FD, &fd, &fd_size); assert (rc == 0); - rc = zmq_poller_add_fd (poller, fd, bowl); + rc = zmq_poller_add_fd (poller, fd, bowl, ZMQ_POLLIN); assert (rc == 0); rc = zmq_poller_wait (poller, &event, 500); assert (rc == 0); @@ -106,7 +106,8 @@ int main (void) zmq_poller_remove_fd (poller, fd); // Polling on thread safe sockets - zmq_poller_add_socket (poller, server, NULL); + rc = zmq_poller_add_socket (poller, server, NULL, ZMQ_POLLIN); + assert (rc == 0); rc = zmq_connect (client, "tcp://127.0.0.1:55557"); assert (rc == 0); rc = zmq_send_const (client, data, 1, 0); @@ -116,8 +117,17 @@ int main (void) assert (event.socket == server); assert (event.user_data == NULL); rc = zmq_recv (server, data, 1, 0); - assert (rc == 1); + assert (rc == 1); + // Polling on pollout + rc = zmq_poller_modify_socket (poller, server, ZMQ_POLLOUT | ZMQ_POLLIN); + assert (rc == 0); + rc = zmq_poller_wait (poller, &event, 0); + assert (rc == 0); + assert (event.socket == server); + assert (event.user_data == NULL); + assert (event.events == ZMQ_POLLOUT); + // Destory poller, sockets and ctx rc = zmq_poller_close (poller); assert (rc == 0);