diff --git a/Net/include/Poco/Net/PollSet.h b/Net/include/Poco/Net/PollSet.h index 02071611c..b412bd800 100644 --- a/Net/include/Poco/Net/PollSet.h +++ b/Net/include/Poco/Net/PollSet.h @@ -39,9 +39,9 @@ class Net_API PollSet public: enum Mode { - POLL_READ = 0x01, - POLL_WRITE = 0x02, - POLL_ERROR = 0x04 + POLL_READ = Socket::SELECT_READ, + POLL_WRITE = Socket::SELECT_WRITE, + POLL_ERROR = Socket::SELECT_ERROR }; using SocketModeMap = std::map; @@ -56,12 +56,36 @@ public: /// Adds the given socket to the set, for polling with /// the given mode, which can be an OR'd combination of /// POLL_READ, POLL_WRITE and POLL_ERROR. + /// Subsequent socket additions to the PollSet are mode-cumulative, + /// so the following code: + /// + /// StreamSocket ss; + /// PollSet ps; + /// ps.add(ss, PollSet::POLL_READ); + /// ps.add(ss, PollSet::POLL_WRITE); + /// + /// shall result in the socket being monitored for read and write, + /// equivalent to this: + /// + /// ps.update(ss, PollSet::POLL_READ | PollSet::POLL_WRITE); void remove(const Poco::Net::Socket& socket); /// Removes the given socket from the set. void update(const Poco::Net::Socket& socket, int mode); - /// Updates the mode of the given socket. + /// Updates the mode of the given socket. If socket does + /// not exist in the PollSet, it is silently added. For + /// an existing socket, any prior mode is overwritten. + /// Updating socket is non-mode-cumulative. + /// + /// The following code: + /// + /// StreamSocket ss; + /// PollSet ps; + /// ps.update(ss, PollSet::POLL_READ); + /// ps.update(ss, PollSet::POLL_WRITE); + /// + /// shall result in the socket being monitored for write only. bool has(const Socket& socket) const; /// Returns true if socket is registered for polling. @@ -79,7 +103,7 @@ public: /// their state changed. int count() const; - /// Returns the numberof sockets monitored. + /// Returns the number of sockets monitored. void wakeUp(); /// Wakes up a waiting PollSet. diff --git a/Net/src/PollSet.cpp b/Net/src/PollSet.cpp index 181fa309a..242e37503 100644 --- a/Net/src/PollSet.cpp +++ b/Net/src/PollSet.cpp @@ -44,12 +44,14 @@ class PollSetImpl public: using Mutex = Poco::SpinlockMutex; using ScopedLock = Mutex::ScopedLock; + using SocketMode = std::pair; + using SocketMap = std::map; PollSetImpl(): _epollfd(epoll_create(1)), _events(1024), _eventfd(eventfd(0, 0)) { - int err = addImpl(_eventfd, PollSet::POLL_READ, 0); + int err = addFD(_eventfd, PollSet::POLL_READ, EPOLL_CTL_ADD); if ((err) || (_epollfd < 0)) { SocketImpl::error(); @@ -59,25 +61,24 @@ public: ~PollSetImpl() { ::close(_eventfd.exchange(0)); - ScopedLock l(_mutex); if (_epollfd >= 0) ::close(_epollfd); } void add(const Socket& socket, int mode) { - SocketImpl* sockImpl = socket.impl(); - - int err = addImpl(sockImpl->sockfd(), mode, sockImpl); - + int newMode = getNewMode(socket.impl(), mode); + int err = addImpl(socket, newMode); if (err) { - if (errno == EEXIST) update(socket, mode); + if (errno == EEXIST) update(socket, newMode); else SocketImpl::error(); } + } - ScopedLock lock(_mutex); - if (_socketMap.find(sockImpl) == _socketMap.end()) - _socketMap[sockImpl] = socket; + void update(const Socket& socket, int mode) + { + int err = updateImpl(socket, mode); + if (err) SocketImpl::error(); } void remove(const Socket& socket) @@ -87,9 +88,9 @@ public: ev.events = 0; ev.data.ptr = 0; - ScopedLock lock(_mutex); int err = epoll_ctl(_epollfd, EPOLL_CTL_DEL, fd, &ev); if (err) SocketImpl::error(); + ScopedLock lock(_mutex); _socketMap.erase(socket.impl()); } @@ -107,40 +108,19 @@ public: return _socketMap.empty(); } - void update(const Socket& socket, int mode) - { - poco_socket_t fd = socket.impl()->sockfd(); - struct epoll_event ev; - ev.events = 0; - if (mode & PollSet::POLL_READ) - ev.events |= EPOLLIN; - if (mode & PollSet::POLL_WRITE) - ev.events |= EPOLLOUT; - if (mode & PollSet::POLL_ERROR) - ev.events |= EPOLLERR; - ev.data.ptr = socket.impl(); - - int err = 0; - { - ScopedLock lock(_mutex); - err = epoll_ctl(_epollfd, EPOLL_CTL_MOD, fd, &ev); - } - if (err) SocketImpl::error(); - } - void clear() { { ScopedLock lock(_mutex); - ::close(_epollfd); + close(_epollfd); _socketMap.clear(); _epollfd = epoll_create(1); if (_epollfd < 0) SocketImpl::error(); } - ::close(_eventfd.exchange(0)); + close(_eventfd.exchange(0)); _eventfd = eventfd(0, 0); - addImpl(_eventfd, PollSet::POLL_READ, 0); + addFD(_eventfd, PollSet::POLL_READ, EPOLL_CTL_ADD); } PollSet::SocketModeMap poll(const Poco::Timespan& timeout) @@ -154,6 +134,11 @@ public: Poco::Timestamp start; rc = epoll_wait(_epollfd, &_events[0], _events.size(), remainingTime.totalMilliseconds()); if (rc == 0) return result; + + // if we are hitting the events limit, resize it; even without resizing, the subseqent + // calls would round-robin through the remaining ready sockets, but it's better to give + // the call enough room once we start hitting the boundary + if (rc >= _events.size()) _events.resize(_events.size()*2); if (rc < 0 && SocketImpl::lastError() == POCO_EINTR) { Poco::Timestamp end; @@ -171,15 +156,15 @@ public: { if (_events[i].data.ptr) // skip eventfd { - std::map::iterator it = _socketMap.find(_events[i].data.ptr); + SocketMap::iterator it = _socketMap.find(_events[i].data.ptr); if (it != _socketMap.end()) { if (_events[i].events & EPOLLIN) - result[it->second] |= PollSet::POLL_READ; + result[it->second.first] |= PollSet::POLL_READ; if (_events[i].events & EPOLLOUT) - result[it->second] |= PollSet::POLL_WRITE; + result[it->second.first] |= PollSet::POLL_WRITE; if (_events[i].events & EPOLLERR) - result[it->second] |= PollSet::POLL_ERROR; + result[it->second.first] |= PollSet::POLL_ERROR; } } } @@ -193,7 +178,7 @@ public: // This is guaranteed to write into a valid fd, // or 0 (meaning PollSet is being destroyed). // Errors are ignored. - ::write(_eventfd, &val, sizeof(val)); + write(_eventfd, &val, sizeof(val)); } int count() const @@ -203,9 +188,42 @@ public: } private: - int addImpl(int fd, int mode, void* ptr) + int getNewMode(SocketImpl* sockImpl, int mode) { - struct epoll_event ev; + ScopedLock lock(_mutex); + auto it = _socketMap.find(sockImpl); + if (it != _socketMap.end()) + mode |= it->second.second; + return mode; + } + + void socketMapUpdate(const Socket& socket, int mode) + { + SocketImpl* sockImpl = socket.impl(); + ScopedLock lock(_mutex); + _socketMap[sockImpl] = {socket, mode}; + } + + int updateImpl(const Socket& socket, int mode) + { + SocketImpl* sockImpl = socket.impl(); + int ret = addFD(sockImpl->sockfd(), mode, EPOLL_CTL_MOD, sockImpl); + if (ret == 0) socketMapUpdate(socket, mode); + return ret; + } + + int addImpl(const Socket& socket, int mode) + { + SocketImpl* sockImpl = socket.impl(); + int newMode = getNewMode(sockImpl, mode); + int ret = addFD(sockImpl->sockfd(), newMode, EPOLL_CTL_ADD, sockImpl); + if (ret == 0) socketMapUpdate(socket, newMode); + return ret; + } + + int addFD(int fd, int mode, int op, void* ptr = 0) + { + struct epoll_event ev{}; ev.events = 0; if (mode & PollSet::POLL_READ) ev.events |= EPOLLIN; @@ -214,13 +232,12 @@ private: if (mode & PollSet::POLL_ERROR) ev.events |= EPOLLERR; ev.data.ptr = ptr; - ScopedLock lock(_mutex); - return epoll_ctl(_epollfd, EPOLL_CTL_ADD, fd, &ev); + return epoll_ctl(_epollfd, op, fd, &ev); } mutable Mutex _mutex; - int _epollfd; - std::map _socketMap; + std::atomic _epollfd; + SocketMap _socketMap; std::vector _events; std::atomic _eventfd; }; diff --git a/Net/src/Socket.cpp b/Net/src/Socket.cpp index 9fc8371e5..dcff65d3b 100644 --- a/Net/src/Socket.cpp +++ b/Net/src/Socket.cpp @@ -14,6 +14,7 @@ #include "Poco/Net/Socket.h" #include "Poco/Net/StreamSocketImpl.h" +#include "Poco/Net/PollSet.h" #include "Poco/Timestamp.h" #include "Poco/Error.h" #include @@ -102,135 +103,22 @@ int Socket::select(SocketList& readList, SocketList& writeList, SocketList& exce int epollSize = readList.size() + writeList.size() + exceptList.size(); if (epollSize == 0) return 0; - int epollfd = -1; + PollSet ps; + for (const auto& s : readList) ps.add(s, PollSet::POLL_READ); + for (const auto& s : writeList) ps.add(s, PollSet::POLL_WRITE); + + readList.clear(); + writeList.clear(); + exceptList.clear(); + + PollSet::SocketModeMap sm = ps.poll(timeout); + for (const auto& s : sm) { - struct epoll_event eventsIn[epollSize]; - memset(eventsIn, 0, sizeof(eventsIn)); - struct epoll_event* eventLast = eventsIn; - for (SocketList::iterator it = readList.begin(); it != readList.end(); ++it) - { - poco_socket_t sockfd = it->sockfd(); - if (sockfd != POCO_INVALID_SOCKET) - { - struct epoll_event* e = eventsIn; - for (; e != eventLast; ++e) - { - if (reinterpret_cast(e->data.ptr)->sockfd() == sockfd) - break; - } - if (e == eventLast) - { - e->data.ptr = &(*it); - ++eventLast; - } - e->events |= EPOLLIN; - } - } - - for (SocketList::iterator it = writeList.begin(); it != writeList.end(); ++it) - { - poco_socket_t sockfd = it->sockfd(); - if (sockfd != POCO_INVALID_SOCKET) - { - struct epoll_event* e = eventsIn; - for (; e != eventLast; ++e) - { - if (reinterpret_cast(e->data.ptr)->sockfd() == sockfd) - break; - } - if (e == eventLast) - { - e->data.ptr = &(*it); - ++eventLast; - } - e->events |= EPOLLOUT; - } - } - - for (SocketList::iterator it = exceptList.begin(); it != exceptList.end(); ++it) - { - poco_socket_t sockfd = it->sockfd(); - if (sockfd != POCO_INVALID_SOCKET) - { - struct epoll_event* e = eventsIn; - for (; e != eventLast; ++e) - { - if (reinterpret_cast(e->data.ptr)->sockfd() == sockfd) - break; - } - if (e == eventLast) - { - e->data.ptr = &(*it); - ++eventLast; - } - e->events |= EPOLLERR; - } - } - - epollSize = eventLast - eventsIn; - if (epollSize == 0) return 0; - - epollfd = epoll_create(1); - if (epollfd < 0) - { - SocketImpl::error("Can't create epoll queue"); - } - - for (struct epoll_event* e = eventsIn; e != eventLast; ++e) - { - poco_socket_t sockfd = reinterpret_cast(e->data.ptr)->sockfd(); - if (sockfd != POCO_INVALID_SOCKET) - { - if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sockfd, e) < 0) - { - ::close(epollfd); - SocketImpl::error("Can't insert socket to epoll queue"); - } - } - } + if (s.second & PollSet::POLL_READ) readList.push_back(s.first); + if (s.second & PollSet::POLL_WRITE) writeList.push_back(s.first); + if (s.second & PollSet::POLL_ERROR) exceptList.push_back(s.first); } - struct epoll_event eventsOut[epollSize]; - memset(eventsOut, 0, sizeof(eventsOut)); - - Poco::Timespan remainingTime(timeout); - int rc; - do - { - Poco::Timestamp start; - rc = epoll_wait(epollfd, eventsOut, epollSize, remainingTime.totalMilliseconds()); - if (rc < 0 && SocketImpl::lastError() == POCO_EINTR) - { - Poco::Timestamp end; - Poco::Timespan waited = end - start; - if (waited < remainingTime) - remainingTime -= waited; - else - remainingTime = 0; - } - } - while (rc < 0 && SocketImpl::lastError() == POCO_EINTR); - - ::close(epollfd); - if (rc < 0) SocketImpl::error(); - - SocketList readyReadList; - SocketList readyWriteList; - SocketList readyExceptList; - for (int n = 0; n < rc; ++n) - { - Socket sock = *reinterpret_cast(eventsOut[n].data.ptr); - if ((eventsOut[n].events & EPOLLERR) && - (std::end(exceptList) != std::find(exceptList.begin(), exceptList.end(), sock))) - readyExceptList.push_back(sock); - if (eventsOut[n].events & EPOLLIN) - readyReadList.push_back(sock); - if (eventsOut[n].events & EPOLLOUT) - readyWriteList.push_back(sock); - } - std::swap(readList, readyReadList); - std::swap(writeList, readyWriteList); - std::swap(exceptList, readyExceptList); return readList.size() + writeList.size() + exceptList.size(); #elif defined(POCO_HAVE_FD_POLL) diff --git a/Net/testsuite/src/PollSetTest.cpp b/Net/testsuite/src/PollSetTest.cpp index a2d5144d9..76f724081 100644 --- a/Net/testsuite/src/PollSetTest.cpp +++ b/Net/testsuite/src/PollSetTest.cpp @@ -75,6 +75,95 @@ PollSetTest::~PollSetTest() } +void PollSetTest::testAddUpdate() +{ + EchoServer echoServer1; + EchoServer echoServer2; + StreamSocket ss1; + StreamSocket ss2; + + ss1.connect(SocketAddress("127.0.0.1", echoServer1.port())); + ss2.connect(SocketAddress("127.0.0.1", echoServer2.port())); + + PollSet ps; + assertTrue(ps.empty()); + ps.add(ss1, PollSet::POLL_READ); + assertTrue(!ps.empty()); + assertTrue(ps.has(ss1)); + assertTrue(!ps.has(ss2)); + + // nothing readable + Stopwatch sw; + sw.start(); + Timespan timeout(1000000); + assertTrue (ps.poll(timeout).empty()); + assertTrue (sw.elapsed() >= 900000); + sw.restart(); + + ps.add(ss2, PollSet::POLL_READ); + assertTrue(!ps.empty()); + assertTrue(ps.has(ss1)); + assertTrue(ps.has(ss2)); + + // ss1 must be writable, if polled for + ps.add(ss1, PollSet::POLL_WRITE); + PollSet::SocketModeMap sm = ps.poll(timeout); + assertTrue (sm.find(ss1) != sm.end()); + assertTrue (sm.find(ss2) == sm.end()); + assertTrue (sm.find(ss1)->second & PollSet::POLL_WRITE); + assertTrue (sw.elapsed() < 1100000); + + ps.add(ss1, PollSet::POLL_READ); + + ss1.setBlocking(true); + ss1.sendBytes("hello", 5); + while (!ss1.poll(Timespan(0, 10000), Socket::SELECT_READ)) + Poco::Thread::sleep(10); + sw.restart(); + sm = ps.poll(timeout); + assertTrue (sm.find(ss1) != sm.end()); + assertTrue (sm.find(ss2) == sm.end()); + assertTrue (sm.find(ss1)->second & PollSet::POLL_READ); + assertTrue (sw.elapsed() < 1100000); + + char buffer[256]; + int n = ss1.receiveBytes(buffer, sizeof(buffer)); + assertTrue (n == 5); + assertTrue (std::string(buffer, n) == "hello"); + + ss2.setBlocking(true); + ss2.sendBytes("HELLO", 5); + sw.restart(); + ps.remove(ss1); + sm = ps.poll(timeout); + assertTrue (sm.find(ss1) == sm.end()); + assertTrue (sm.find(ss2) != sm.end()); + assertTrue (sm.find(ss2)->second & PollSet::POLL_READ); + assertTrue (sw.elapsed() < 1100000); + + n = ss2.receiveBytes(buffer, sizeof(buffer)); + assertTrue (n == 5); + assertTrue (std::string(buffer, n) == "HELLO"); + + ps.remove(ss2); + assertTrue(ps.empty()); + assertTrue(!ps.has(ss1)); + assertTrue(!ps.has(ss2)); + + ss2.sendBytes("HELLO", 5); + sw.restart(); + sm = ps.poll(timeout); + assertTrue (sm.empty()); + + n = ss2.receiveBytes(buffer, sizeof(buffer)); + assertTrue (n == 5); + assertTrue (std::string(buffer, n) == "HELLO"); + + ss1.close(); + ss2.close(); +} + + void PollSetTest::testTimeout() { EchoServer echoServer; @@ -122,15 +211,16 @@ void PollSetTest::testPollNB() PollSet::SocketModeMap sm; while (sm.empty()) sm = ps.poll(timeout); assertTrue(sm.find(ss1) != sm.end()); - assertTrue(sm.find(ss1)->second | PollSet::POLL_WRITE); + assertTrue(sm.find(ss1)->second & PollSet::POLL_WRITE); ss1.setBlocking(true); ss1.sendBytes("hello", 5); + while (!ss1.poll(Timespan(0, 10000), Socket::SELECT_READ)) + Poco::Thread::sleep(10); char buffer[256]; - sm = ps.poll(timeout); assertTrue(sm.find(ss1) != sm.end()); - assertTrue(sm.find(ss1)->second | PollSet::POLL_READ); + assertTrue(sm.find(ss1)->second & PollSet::POLL_READ); int n = ss1.receiveBytes(buffer, sizeof(buffer)); assertTrue(n == 5); @@ -180,12 +270,14 @@ void PollSetTest::testPoll() ss1.setBlocking(true); ss1.sendBytes("hello", 5); + while (!ss1.poll(Timespan(0, 10000), Socket::SELECT_READ)) + Poco::Thread::sleep(10); char buffer[256]; sw.restart(); sm = ps.poll(timeout); assertTrue (sm.find(ss1) != sm.end()); assertTrue (sm.find(ss2) == sm.end()); - assertTrue (sm.find(ss1)->second | PollSet::POLL_READ); + assertTrue (sm.find(ss1)->second & PollSet::POLL_READ); assertTrue (sw.elapsed() < 1100000); int n = ss1.receiveBytes(buffer, sizeof(buffer)); @@ -194,11 +286,13 @@ void PollSetTest::testPoll() ss2.setBlocking(true); ss2.sendBytes("HELLO", 5); + while (!ss2.poll(Timespan(0, 10000), Socket::SELECT_READ)) + Poco::Thread::sleep(10); sw.restart(); sm = ps.poll(timeout); assertTrue (sm.find(ss1) == sm.end()); assertTrue (sm.find(ss2) != sm.end()); - assertTrue (sm.find(ss2)->second | PollSet::POLL_READ); + assertTrue (sm.find(ss2)->second & PollSet::POLL_READ); assertTrue (sw.elapsed() < 1100000); n = ss2.receiveBytes(buffer, sizeof(buffer)); @@ -206,6 +300,7 @@ void PollSetTest::testPoll() assertTrue (std::string(buffer, n) == "HELLO"); ps.remove(ss2); + ps.update(ss1, Socket::SELECT_READ); assertTrue(!ps.empty()); assertTrue(ps.has(ss1)); assertTrue(!ps.has(ss2)); @@ -246,7 +341,7 @@ void PollSetTest::testPollNoServer() } while (sm.size() < 2); assertTrue(sm.size() == 2); for (auto s : sm) - assertTrue(0 != (s.second | PollSet::POLL_ERROR)); + assertTrue(0 != (s.second & PollSet::POLL_ERROR)); } @@ -351,6 +446,7 @@ CppUnit::Test* PollSetTest::suite() { CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("PollSetTest"); + CppUnit_addTest(pSuite, PollSetTest, testAddUpdate); CppUnit_addTest(pSuite, PollSetTest, testTimeout); CppUnit_addTest(pSuite, PollSetTest, testPollNB); CppUnit_addTest(pSuite, PollSetTest, testPoll); diff --git a/Net/testsuite/src/PollSetTest.h b/Net/testsuite/src/PollSetTest.h index b93399ea1..fa074618a 100644 --- a/Net/testsuite/src/PollSetTest.h +++ b/Net/testsuite/src/PollSetTest.h @@ -24,6 +24,7 @@ public: PollSetTest(const std::string& name); ~PollSetTest(); + void testAddUpdate(); void testTimeout(); void testPollNB(); void testPoll(); diff --git a/Net/testsuite/src/SocketTest.cpp b/Net/testsuite/src/SocketTest.cpp index 2d668095d..0da876cc5 100644 --- a/Net/testsuite/src/SocketTest.cpp +++ b/Net/testsuite/src/SocketTest.cpp @@ -514,8 +514,8 @@ void SocketTest::testSelect2() assertTrue (Socket::select(readList, writeList, exceptList, timeout) == 2); assertTrue (readList.empty()); assertTrue (writeList.size() == 2); - assertTrue (writeList[0] == ss1); - assertTrue (writeList[1] == ss2); + assertTrue (writeList[0] == ss1 || writeList[1] == ss1); + assertTrue (writeList[0] == ss2 || writeList[1] == ss2); assertTrue (exceptList.empty()); ss1.close();