This commit is contained in:
Alex Fabijanic 2022-07-05 20:49:46 +02:00
parent 840044f69b
commit 81696487a0
6 changed files with 211 additions and 185 deletions

View File

@ -39,9 +39,9 @@ class Net_API PollSet
public:
enum Mode
{
POLL_READ = 0x01,
POLL_WRITE = 0x02,
POLL_ERROR = 0x04
POLL_READ = Socket::SELECT_READ,
POLL_WRITE = Socket::SELECT_WRITE,
POLL_ERROR = Socket::SELECT_ERROR
};
using SocketModeMap = std::map<Poco::Net::Socket, int>;
@ -56,12 +56,36 @@ public:
/// Adds the given socket to the set, for polling with
/// the given mode, which can be an OR'd combination of
/// POLL_READ, POLL_WRITE and POLL_ERROR.
/// Subsequent socket additions to the PollSet are mode-cumulative,
/// so the following code:
///
/// StreamSocket ss;
/// PollSet ps;
/// ps.add(ss, PollSet::POLL_READ);
/// ps.add(ss, PollSet::POLL_WRITE);
///
/// shall result in the socket being monitored for read and write,
/// equivalent to this:
///
/// ps.update(ss, PollSet::POLL_READ | PollSet::POLL_WRITE);
void remove(const Poco::Net::Socket& socket);
/// Removes the given socket from the set.
void update(const Poco::Net::Socket& socket, int mode);
/// Updates the mode of the given socket.
/// Updates the mode of the given socket. If socket does
/// not exist in the PollSet, it is silently added. For
/// an existing socket, any prior mode is overwritten.
/// Updating socket is non-mode-cumulative.
///
/// The following code:
///
/// StreamSocket ss;
/// PollSet ps;
/// ps.update(ss, PollSet::POLL_READ);
/// ps.update(ss, PollSet::POLL_WRITE);
///
/// shall result in the socket being monitored for write only.
bool has(const Socket& socket) const;
/// Returns true if socket is registered for polling.
@ -79,7 +103,7 @@ public:
/// their state changed.
int count() const;
/// Returns the numberof sockets monitored.
/// Returns the number of sockets monitored.
void wakeUp();
/// Wakes up a waiting PollSet.

View File

@ -44,12 +44,14 @@ class PollSetImpl
public:
using Mutex = Poco::SpinlockMutex;
using ScopedLock = Mutex::ScopedLock;
using SocketMode = std::pair<Socket, int>;
using SocketMap = std::map<void*, SocketMode>;
PollSetImpl(): _epollfd(epoll_create(1)),
_events(1024),
_eventfd(eventfd(0, 0))
{
int err = addImpl(_eventfd, PollSet::POLL_READ, 0);
int err = addFD(_eventfd, PollSet::POLL_READ, EPOLL_CTL_ADD);
if ((err) || (_epollfd < 0))
{
SocketImpl::error();
@ -59,25 +61,24 @@ public:
~PollSetImpl()
{
::close(_eventfd.exchange(0));
ScopedLock l(_mutex);
if (_epollfd >= 0) ::close(_epollfd);
}
void add(const Socket& socket, int mode)
{
SocketImpl* sockImpl = socket.impl();
int err = addImpl(sockImpl->sockfd(), mode, sockImpl);
int newMode = getNewMode(socket.impl(), mode);
int err = addImpl(socket, newMode);
if (err)
{
if (errno == EEXIST) update(socket, mode);
if (errno == EEXIST) update(socket, newMode);
else SocketImpl::error();
}
}
ScopedLock lock(_mutex);
if (_socketMap.find(sockImpl) == _socketMap.end())
_socketMap[sockImpl] = socket;
void update(const Socket& socket, int mode)
{
int err = updateImpl(socket, mode);
if (err) SocketImpl::error();
}
void remove(const Socket& socket)
@ -87,9 +88,9 @@ public:
ev.events = 0;
ev.data.ptr = 0;
ScopedLock lock(_mutex);
int err = epoll_ctl(_epollfd, EPOLL_CTL_DEL, fd, &ev);
if (err) SocketImpl::error();
ScopedLock lock(_mutex);
_socketMap.erase(socket.impl());
}
@ -107,40 +108,19 @@ public:
return _socketMap.empty();
}
void update(const Socket& socket, int mode)
{
poco_socket_t fd = socket.impl()->sockfd();
struct epoll_event ev;
ev.events = 0;
if (mode & PollSet::POLL_READ)
ev.events |= EPOLLIN;
if (mode & PollSet::POLL_WRITE)
ev.events |= EPOLLOUT;
if (mode & PollSet::POLL_ERROR)
ev.events |= EPOLLERR;
ev.data.ptr = socket.impl();
int err = 0;
{
ScopedLock lock(_mutex);
err = epoll_ctl(_epollfd, EPOLL_CTL_MOD, fd, &ev);
}
if (err) SocketImpl::error();
}
void clear()
{
{
ScopedLock lock(_mutex);
::close(_epollfd);
close(_epollfd);
_socketMap.clear();
_epollfd = epoll_create(1);
if (_epollfd < 0) SocketImpl::error();
}
::close(_eventfd.exchange(0));
close(_eventfd.exchange(0));
_eventfd = eventfd(0, 0);
addImpl(_eventfd, PollSet::POLL_READ, 0);
addFD(_eventfd, PollSet::POLL_READ, EPOLL_CTL_ADD);
}
PollSet::SocketModeMap poll(const Poco::Timespan& timeout)
@ -154,6 +134,11 @@ public:
Poco::Timestamp start;
rc = epoll_wait(_epollfd, &_events[0], _events.size(), remainingTime.totalMilliseconds());
if (rc == 0) return result;
// if we are hitting the events limit, resize it; even without resizing, the subseqent
// calls would round-robin through the remaining ready sockets, but it's better to give
// the call enough room once we start hitting the boundary
if (rc >= _events.size()) _events.resize(_events.size()*2);
if (rc < 0 && SocketImpl::lastError() == POCO_EINTR)
{
Poco::Timestamp end;
@ -171,15 +156,15 @@ public:
{
if (_events[i].data.ptr) // skip eventfd
{
std::map<void *, Socket>::iterator it = _socketMap.find(_events[i].data.ptr);
SocketMap::iterator it = _socketMap.find(_events[i].data.ptr);
if (it != _socketMap.end())
{
if (_events[i].events & EPOLLIN)
result[it->second] |= PollSet::POLL_READ;
result[it->second.first] |= PollSet::POLL_READ;
if (_events[i].events & EPOLLOUT)
result[it->second] |= PollSet::POLL_WRITE;
result[it->second.first] |= PollSet::POLL_WRITE;
if (_events[i].events & EPOLLERR)
result[it->second] |= PollSet::POLL_ERROR;
result[it->second.first] |= PollSet::POLL_ERROR;
}
}
}
@ -193,7 +178,7 @@ public:
// This is guaranteed to write into a valid fd,
// or 0 (meaning PollSet is being destroyed).
// Errors are ignored.
::write(_eventfd, &val, sizeof(val));
write(_eventfd, &val, sizeof(val));
}
int count() const
@ -203,9 +188,42 @@ public:
}
private:
int addImpl(int fd, int mode, void* ptr)
int getNewMode(SocketImpl* sockImpl, int mode)
{
struct epoll_event ev;
ScopedLock lock(_mutex);
auto it = _socketMap.find(sockImpl);
if (it != _socketMap.end())
mode |= it->second.second;
return mode;
}
void socketMapUpdate(const Socket& socket, int mode)
{
SocketImpl* sockImpl = socket.impl();
ScopedLock lock(_mutex);
_socketMap[sockImpl] = {socket, mode};
}
int updateImpl(const Socket& socket, int mode)
{
SocketImpl* sockImpl = socket.impl();
int ret = addFD(sockImpl->sockfd(), mode, EPOLL_CTL_MOD, sockImpl);
if (ret == 0) socketMapUpdate(socket, mode);
return ret;
}
int addImpl(const Socket& socket, int mode)
{
SocketImpl* sockImpl = socket.impl();
int newMode = getNewMode(sockImpl, mode);
int ret = addFD(sockImpl->sockfd(), newMode, EPOLL_CTL_ADD, sockImpl);
if (ret == 0) socketMapUpdate(socket, newMode);
return ret;
}
int addFD(int fd, int mode, int op, void* ptr = 0)
{
struct epoll_event ev{};
ev.events = 0;
if (mode & PollSet::POLL_READ)
ev.events |= EPOLLIN;
@ -214,13 +232,12 @@ private:
if (mode & PollSet::POLL_ERROR)
ev.events |= EPOLLERR;
ev.data.ptr = ptr;
ScopedLock lock(_mutex);
return epoll_ctl(_epollfd, EPOLL_CTL_ADD, fd, &ev);
return epoll_ctl(_epollfd, op, fd, &ev);
}
mutable Mutex _mutex;
int _epollfd;
std::map<void*, Socket> _socketMap;
std::atomic<int> _epollfd;
SocketMap _socketMap;
std::vector<struct epoll_event> _events;
std::atomic<int> _eventfd;
};

View File

@ -14,6 +14,7 @@
#include "Poco/Net/Socket.h"
#include "Poco/Net/StreamSocketImpl.h"
#include "Poco/Net/PollSet.h"
#include "Poco/Timestamp.h"
#include "Poco/Error.h"
#include <algorithm>
@ -102,135 +103,22 @@ int Socket::select(SocketList& readList, SocketList& writeList, SocketList& exce
int epollSize = readList.size() + writeList.size() + exceptList.size();
if (epollSize == 0) return 0;
int epollfd = -1;
PollSet ps;
for (const auto& s : readList) ps.add(s, PollSet::POLL_READ);
for (const auto& s : writeList) ps.add(s, PollSet::POLL_WRITE);
readList.clear();
writeList.clear();
exceptList.clear();
PollSet::SocketModeMap sm = ps.poll(timeout);
for (const auto& s : sm)
{
struct epoll_event eventsIn[epollSize];
memset(eventsIn, 0, sizeof(eventsIn));
struct epoll_event* eventLast = eventsIn;
for (SocketList::iterator it = readList.begin(); it != readList.end(); ++it)
{
poco_socket_t sockfd = it->sockfd();
if (sockfd != POCO_INVALID_SOCKET)
{
struct epoll_event* e = eventsIn;
for (; e != eventLast; ++e)
{
if (reinterpret_cast<Socket*>(e->data.ptr)->sockfd() == sockfd)
break;
}
if (e == eventLast)
{
e->data.ptr = &(*it);
++eventLast;
}
e->events |= EPOLLIN;
}
}
for (SocketList::iterator it = writeList.begin(); it != writeList.end(); ++it)
{
poco_socket_t sockfd = it->sockfd();
if (sockfd != POCO_INVALID_SOCKET)
{
struct epoll_event* e = eventsIn;
for (; e != eventLast; ++e)
{
if (reinterpret_cast<Socket*>(e->data.ptr)->sockfd() == sockfd)
break;
}
if (e == eventLast)
{
e->data.ptr = &(*it);
++eventLast;
}
e->events |= EPOLLOUT;
}
}
for (SocketList::iterator it = exceptList.begin(); it != exceptList.end(); ++it)
{
poco_socket_t sockfd = it->sockfd();
if (sockfd != POCO_INVALID_SOCKET)
{
struct epoll_event* e = eventsIn;
for (; e != eventLast; ++e)
{
if (reinterpret_cast<Socket*>(e->data.ptr)->sockfd() == sockfd)
break;
}
if (e == eventLast)
{
e->data.ptr = &(*it);
++eventLast;
}
e->events |= EPOLLERR;
}
}
epollSize = eventLast - eventsIn;
if (epollSize == 0) return 0;
epollfd = epoll_create(1);
if (epollfd < 0)
{
SocketImpl::error("Can't create epoll queue");
}
for (struct epoll_event* e = eventsIn; e != eventLast; ++e)
{
poco_socket_t sockfd = reinterpret_cast<Socket*>(e->data.ptr)->sockfd();
if (sockfd != POCO_INVALID_SOCKET)
{
if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sockfd, e) < 0)
{
::close(epollfd);
SocketImpl::error("Can't insert socket to epoll queue");
}
}
}
if (s.second & PollSet::POLL_READ) readList.push_back(s.first);
if (s.second & PollSet::POLL_WRITE) writeList.push_back(s.first);
if (s.second & PollSet::POLL_ERROR) exceptList.push_back(s.first);
}
struct epoll_event eventsOut[epollSize];
memset(eventsOut, 0, sizeof(eventsOut));
Poco::Timespan remainingTime(timeout);
int rc;
do
{
Poco::Timestamp start;
rc = epoll_wait(epollfd, eventsOut, epollSize, remainingTime.totalMilliseconds());
if (rc < 0 && SocketImpl::lastError() == POCO_EINTR)
{
Poco::Timestamp end;
Poco::Timespan waited = end - start;
if (waited < remainingTime)
remainingTime -= waited;
else
remainingTime = 0;
}
}
while (rc < 0 && SocketImpl::lastError() == POCO_EINTR);
::close(epollfd);
if (rc < 0) SocketImpl::error();
SocketList readyReadList;
SocketList readyWriteList;
SocketList readyExceptList;
for (int n = 0; n < rc; ++n)
{
Socket sock = *reinterpret_cast<Socket*>(eventsOut[n].data.ptr);
if ((eventsOut[n].events & EPOLLERR) &&
(std::end(exceptList) != std::find(exceptList.begin(), exceptList.end(), sock)))
readyExceptList.push_back(sock);
if (eventsOut[n].events & EPOLLIN)
readyReadList.push_back(sock);
if (eventsOut[n].events & EPOLLOUT)
readyWriteList.push_back(sock);
}
std::swap(readList, readyReadList);
std::swap(writeList, readyWriteList);
std::swap(exceptList, readyExceptList);
return readList.size() + writeList.size() + exceptList.size();
#elif defined(POCO_HAVE_FD_POLL)

View File

@ -75,6 +75,95 @@ PollSetTest::~PollSetTest()
}
void PollSetTest::testAddUpdate()
{
EchoServer echoServer1;
EchoServer echoServer2;
StreamSocket ss1;
StreamSocket ss2;
ss1.connect(SocketAddress("127.0.0.1", echoServer1.port()));
ss2.connect(SocketAddress("127.0.0.1", echoServer2.port()));
PollSet ps;
assertTrue(ps.empty());
ps.add(ss1, PollSet::POLL_READ);
assertTrue(!ps.empty());
assertTrue(ps.has(ss1));
assertTrue(!ps.has(ss2));
// nothing readable
Stopwatch sw;
sw.start();
Timespan timeout(1000000);
assertTrue (ps.poll(timeout).empty());
assertTrue (sw.elapsed() >= 900000);
sw.restart();
ps.add(ss2, PollSet::POLL_READ);
assertTrue(!ps.empty());
assertTrue(ps.has(ss1));
assertTrue(ps.has(ss2));
// ss1 must be writable, if polled for
ps.add(ss1, PollSet::POLL_WRITE);
PollSet::SocketModeMap sm = ps.poll(timeout);
assertTrue (sm.find(ss1) != sm.end());
assertTrue (sm.find(ss2) == sm.end());
assertTrue (sm.find(ss1)->second & PollSet::POLL_WRITE);
assertTrue (sw.elapsed() < 1100000);
ps.add(ss1, PollSet::POLL_READ);
ss1.setBlocking(true);
ss1.sendBytes("hello", 5);
while (!ss1.poll(Timespan(0, 10000), Socket::SELECT_READ))
Poco::Thread::sleep(10);
sw.restart();
sm = ps.poll(timeout);
assertTrue (sm.find(ss1) != sm.end());
assertTrue (sm.find(ss2) == sm.end());
assertTrue (sm.find(ss1)->second & PollSet::POLL_READ);
assertTrue (sw.elapsed() < 1100000);
char buffer[256];
int n = ss1.receiveBytes(buffer, sizeof(buffer));
assertTrue (n == 5);
assertTrue (std::string(buffer, n) == "hello");
ss2.setBlocking(true);
ss2.sendBytes("HELLO", 5);
sw.restart();
ps.remove(ss1);
sm = ps.poll(timeout);
assertTrue (sm.find(ss1) == sm.end());
assertTrue (sm.find(ss2) != sm.end());
assertTrue (sm.find(ss2)->second & PollSet::POLL_READ);
assertTrue (sw.elapsed() < 1100000);
n = ss2.receiveBytes(buffer, sizeof(buffer));
assertTrue (n == 5);
assertTrue (std::string(buffer, n) == "HELLO");
ps.remove(ss2);
assertTrue(ps.empty());
assertTrue(!ps.has(ss1));
assertTrue(!ps.has(ss2));
ss2.sendBytes("HELLO", 5);
sw.restart();
sm = ps.poll(timeout);
assertTrue (sm.empty());
n = ss2.receiveBytes(buffer, sizeof(buffer));
assertTrue (n == 5);
assertTrue (std::string(buffer, n) == "HELLO");
ss1.close();
ss2.close();
}
void PollSetTest::testTimeout()
{
EchoServer echoServer;
@ -122,15 +211,16 @@ void PollSetTest::testPollNB()
PollSet::SocketModeMap sm;
while (sm.empty()) sm = ps.poll(timeout);
assertTrue(sm.find(ss1) != sm.end());
assertTrue(sm.find(ss1)->second | PollSet::POLL_WRITE);
assertTrue(sm.find(ss1)->second & PollSet::POLL_WRITE);
ss1.setBlocking(true);
ss1.sendBytes("hello", 5);
while (!ss1.poll(Timespan(0, 10000), Socket::SELECT_READ))
Poco::Thread::sleep(10);
char buffer[256];
sm = ps.poll(timeout);
assertTrue(sm.find(ss1) != sm.end());
assertTrue(sm.find(ss1)->second | PollSet::POLL_READ);
assertTrue(sm.find(ss1)->second & PollSet::POLL_READ);
int n = ss1.receiveBytes(buffer, sizeof(buffer));
assertTrue(n == 5);
@ -180,12 +270,14 @@ void PollSetTest::testPoll()
ss1.setBlocking(true);
ss1.sendBytes("hello", 5);
while (!ss1.poll(Timespan(0, 10000), Socket::SELECT_READ))
Poco::Thread::sleep(10);
char buffer[256];
sw.restart();
sm = ps.poll(timeout);
assertTrue (sm.find(ss1) != sm.end());
assertTrue (sm.find(ss2) == sm.end());
assertTrue (sm.find(ss1)->second | PollSet::POLL_READ);
assertTrue (sm.find(ss1)->second & PollSet::POLL_READ);
assertTrue (sw.elapsed() < 1100000);
int n = ss1.receiveBytes(buffer, sizeof(buffer));
@ -194,11 +286,13 @@ void PollSetTest::testPoll()
ss2.setBlocking(true);
ss2.sendBytes("HELLO", 5);
while (!ss2.poll(Timespan(0, 10000), Socket::SELECT_READ))
Poco::Thread::sleep(10);
sw.restart();
sm = ps.poll(timeout);
assertTrue (sm.find(ss1) == sm.end());
assertTrue (sm.find(ss2) != sm.end());
assertTrue (sm.find(ss2)->second | PollSet::POLL_READ);
assertTrue (sm.find(ss2)->second & PollSet::POLL_READ);
assertTrue (sw.elapsed() < 1100000);
n = ss2.receiveBytes(buffer, sizeof(buffer));
@ -206,6 +300,7 @@ void PollSetTest::testPoll()
assertTrue (std::string(buffer, n) == "HELLO");
ps.remove(ss2);
ps.update(ss1, Socket::SELECT_READ);
assertTrue(!ps.empty());
assertTrue(ps.has(ss1));
assertTrue(!ps.has(ss2));
@ -246,7 +341,7 @@ void PollSetTest::testPollNoServer()
} while (sm.size() < 2);
assertTrue(sm.size() == 2);
for (auto s : sm)
assertTrue(0 != (s.second | PollSet::POLL_ERROR));
assertTrue(0 != (s.second & PollSet::POLL_ERROR));
}
@ -351,6 +446,7 @@ CppUnit::Test* PollSetTest::suite()
{
CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("PollSetTest");
CppUnit_addTest(pSuite, PollSetTest, testAddUpdate);
CppUnit_addTest(pSuite, PollSetTest, testTimeout);
CppUnit_addTest(pSuite, PollSetTest, testPollNB);
CppUnit_addTest(pSuite, PollSetTest, testPoll);

View File

@ -24,6 +24,7 @@ public:
PollSetTest(const std::string& name);
~PollSetTest();
void testAddUpdate();
void testTimeout();
void testPollNB();
void testPoll();

View File

@ -514,8 +514,8 @@ void SocketTest::testSelect2()
assertTrue (Socket::select(readList, writeList, exceptList, timeout) == 2);
assertTrue (readList.empty());
assertTrue (writeList.size() == 2);
assertTrue (writeList[0] == ss1);
assertTrue (writeList[1] == ss2);
assertTrue (writeList[0] == ss1 || writeList[1] == ss1);
assertTrue (writeList[0] == ss2 || writeList[1] == ss2);
assertTrue (exceptList.empty());
ss1.close();