Fix/poll set race (#3630)

* fix(PollSet): PollSet data race #3628

* fix(SocketConnector): SocketConnector test #2875

* fix(PollSet): optimize the amount of locked code; fix and simplify wakeUp logic

* fix(SocketConnectorTest): fix test memleak (data not flowing, handlers not deleted) #2875

* fix(PollSet): clear() and tests
This commit is contained in:
Aleksandar Fabijanic 2022-07-05 10:43:19 +02:00 committed by GitHub
parent 93e58b4468
commit 840044f69b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 277 additions and 28 deletions

View File

@ -83,6 +83,7 @@ public:
void wakeUp(); void wakeUp();
/// Wakes up a waiting PollSet. /// Wakes up a waiting PollSet.
/// Any errors that occur during this call are ignored.
/// On platforms/implementations where this functionality /// On platforms/implementations where this functionality
/// is not available, it does nothing. /// is not available, it does nothing.
private: private:

View File

@ -42,6 +42,9 @@ namespace Net {
class PollSetImpl class PollSetImpl
{ {
public: public:
using Mutex = Poco::SpinlockMutex;
using ScopedLock = Mutex::ScopedLock;
PollSetImpl(): _epollfd(epoll_create(1)), PollSetImpl(): _epollfd(epoll_create(1)),
_events(1024), _events(1024),
_eventfd(eventfd(0, 0)) _eventfd(eventfd(0, 0))
@ -55,14 +58,13 @@ public:
~PollSetImpl() ~PollSetImpl()
{ {
::close(_eventfd.exchange(0));
ScopedLock l(_mutex);
if (_epollfd >= 0) ::close(_epollfd); if (_epollfd >= 0) ::close(_epollfd);
if (_eventfd >= 0) ::close(_eventfd);
} }
void add(const Socket& socket, int mode) void add(const Socket& socket, int mode)
{ {
Poco::FastMutex::ScopedLock lock(_mutex);
SocketImpl* sockImpl = socket.impl(); SocketImpl* sockImpl = socket.impl();
int err = addImpl(sockImpl->sockfd(), mode, sockImpl); int err = addImpl(sockImpl->sockfd(), mode, sockImpl);
@ -73,35 +75,35 @@ public:
else SocketImpl::error(); else SocketImpl::error();
} }
ScopedLock lock(_mutex);
if (_socketMap.find(sockImpl) == _socketMap.end()) if (_socketMap.find(sockImpl) == _socketMap.end())
_socketMap[sockImpl] = socket; _socketMap[sockImpl] = socket;
} }
void remove(const Socket& socket) void remove(const Socket& socket)
{ {
Poco::FastMutex::ScopedLock lock(_mutex);
poco_socket_t fd = socket.impl()->sockfd(); poco_socket_t fd = socket.impl()->sockfd();
struct epoll_event ev; struct epoll_event ev;
ev.events = 0; ev.events = 0;
ev.data.ptr = 0; ev.data.ptr = 0;
ScopedLock lock(_mutex);
int err = epoll_ctl(_epollfd, EPOLL_CTL_DEL, fd, &ev); int err = epoll_ctl(_epollfd, EPOLL_CTL_DEL, fd, &ev);
if (err) SocketImpl::error(); if (err) SocketImpl::error();
_socketMap.erase(socket.impl()); _socketMap.erase(socket.impl());
} }
bool has(const Socket& socket) const bool has(const Socket& socket) const
{ {
Poco::FastMutex::ScopedLock lock(_mutex);
SocketImpl* sockImpl = socket.impl(); SocketImpl* sockImpl = socket.impl();
ScopedLock lock(_mutex);
return sockImpl && return sockImpl &&
(_socketMap.find(sockImpl) != _socketMap.end()); (_socketMap.find(sockImpl) != _socketMap.end());
} }
bool empty() const bool empty() const
{ {
Poco::FastMutex::ScopedLock lock(_mutex); ScopedLock lock(_mutex);
return _socketMap.empty(); return _socketMap.empty();
} }
@ -117,24 +119,28 @@ public:
if (mode & PollSet::POLL_ERROR) if (mode & PollSet::POLL_ERROR)
ev.events |= EPOLLERR; ev.events |= EPOLLERR;
ev.data.ptr = socket.impl(); ev.data.ptr = socket.impl();
int err = epoll_ctl(_epollfd, EPOLL_CTL_MOD, fd, &ev);
if (err) int err = 0;
{ {
SocketImpl::error(); ScopedLock lock(_mutex);
err = epoll_ctl(_epollfd, EPOLL_CTL_MOD, fd, &ev);
} }
if (err) SocketImpl::error();
} }
void clear() void clear()
{ {
Poco::FastMutex::ScopedLock lock(_mutex);
::close(_epollfd);
_socketMap.clear();
_epollfd = epoll_create(1);
if (_epollfd < 0)
{ {
SocketImpl::error(); ScopedLock lock(_mutex);
::close(_epollfd);
_socketMap.clear();
_epollfd = epoll_create(1);
if (_epollfd < 0) SocketImpl::error();
} }
::close(_eventfd.exchange(0));
_eventfd = eventfd(0, 0);
addImpl(_eventfd, PollSet::POLL_READ, 0);
} }
PollSet::SocketModeMap poll(const Poco::Timespan& timeout) PollSet::SocketModeMap poll(const Poco::Timespan& timeout)
@ -142,6 +148,7 @@ public:
PollSet::SocketModeMap result; PollSet::SocketModeMap result;
Poco::Timespan remainingTime(timeout); Poco::Timespan remainingTime(timeout);
int rc; int rc;
ScopedLock lock(_mutex);
do do
{ {
Poco::Timestamp start; Poco::Timestamp start;
@ -160,8 +167,6 @@ public:
while (rc < 0 && SocketImpl::lastError() == POCO_EINTR); while (rc < 0 && SocketImpl::lastError() == POCO_EINTR);
if (rc < 0) SocketImpl::error(); if (rc < 0) SocketImpl::error();
Poco::FastMutex::ScopedLock lock(_mutex);
for (int i = 0; i < rc; i++) for (int i = 0; i < rc; i++)
{ {
if (_events[i].data.ptr) // skip eventfd if (_events[i].data.ptr) // skip eventfd
@ -185,13 +190,15 @@ public:
void wakeUp() void wakeUp()
{ {
uint64_t val = 1; uint64_t val = 1;
int n = ::write(_eventfd, &val, sizeof(val)); // This is guaranteed to write into a valid fd,
if (n < 0) Socket::error(); // or 0 (meaning PollSet is being destroyed).
// Errors are ignored.
::write(_eventfd, &val, sizeof(val));
} }
int count() const int count() const
{ {
Poco::FastMutex::ScopedLock lock(_mutex); ScopedLock lock(_mutex);
return static_cast<int>(_socketMap.size()); return static_cast<int>(_socketMap.size());
} }
@ -207,14 +214,15 @@ private:
if (mode & PollSet::POLL_ERROR) if (mode & PollSet::POLL_ERROR)
ev.events |= EPOLLERR; ev.events |= EPOLLERR;
ev.data.ptr = ptr; ev.data.ptr = ptr;
ScopedLock lock(_mutex);
return epoll_ctl(_epollfd, EPOLL_CTL_ADD, fd, &ev); return epoll_ctl(_epollfd, EPOLL_CTL_ADD, fd, &ev);
} }
mutable Poco::FastMutex _mutex; mutable Mutex _mutex;
int _epollfd; int _epollfd;
std::map<void*, Socket> _socketMap; std::map<void*, Socket> _socketMap;
std::vector<struct epoll_event> _events; std::vector<struct epoll_event> _events;
int _eventfd; std::atomic<int> _eventfd;
}; };

View File

@ -19,7 +19,8 @@ objects = \
MediaTypeTest QuotedPrintableTest DialogSocketTest \ MediaTypeTest QuotedPrintableTest DialogSocketTest \
HTTPClientTestSuite FTPClientTestSuite FTPClientSessionTest \ HTTPClientTestSuite FTPClientTestSuite FTPClientSessionTest \
FTPStreamFactoryTest DialogServer \ FTPStreamFactoryTest DialogServer \
SocketReactorTest ReactorTestSuite SocketProactorTest \ SocketReactorTest SocketConnectorTest ReactorTestSuite \
SocketProactorTest \
MailTestSuite MailMessageTest MailStreamTest \ MailTestSuite MailMessageTest MailStreamTest \
SMTPClientSessionTest POP3ClientSessionTest \ SMTPClientSessionTest POP3ClientSessionTest \
RawSocketTest ICMPClientTest ICMPSocketTest ICMPClientTestSuite \ RawSocketTest ICMPClientTest ICMPSocketTest ICMPClientTestSuite \

View File

@ -308,6 +308,35 @@ void PollSetTest::testPollSetWakeUp()
} }
void PollSetTest::testClear()
{
EchoServer echoServer;
StreamSocket ss;
ss.connect(SocketAddress("127.0.0.1", echoServer.port()));
PollSet ps;
ps.add(ss, PollSet::POLL_READ);
PollSet::SocketModeMap sm = ps.poll(0);
assertTrue(sm.empty());
ss.sendBytes("hello", 5);
assertTrue(ps.poll(100000).size() == 1);
char buffer[5];
ss.receiveBytes(buffer, sizeof(buffer));
ps.clear();
ps.add(ss, PollSet::POLL_READ);
sm = ps.poll(0);
assertTrue(sm.empty());
ss.sendBytes(buffer, 5);
assertTrue(ps.poll(100000).size() == 1);
ss.receiveBytes(buffer, sizeof(buffer));
}
void PollSetTest::setUp() void PollSetTest::setUp()
{ {
} }
@ -328,6 +357,7 @@ CppUnit::Test* PollSetTest::suite()
CppUnit_addTest(pSuite, PollSetTest, testPollNoServer); CppUnit_addTest(pSuite, PollSetTest, testPollNoServer);
CppUnit_addTest(pSuite, PollSetTest, testPollClosedServer); CppUnit_addTest(pSuite, PollSetTest, testPollClosedServer);
CppUnit_addTest(pSuite, PollSetTest, testPollSetWakeUp); CppUnit_addTest(pSuite, PollSetTest, testPollSetWakeUp);
CppUnit_addTest(pSuite, PollSetTest, testClear);
return pSuite; return pSuite;
} }

View File

@ -30,6 +30,7 @@ public:
void testPollNoServer(); void testPollNoServer();
void testPollClosedServer(); void testPollClosedServer();
void testPollSetWakeUp(); void testPollSetWakeUp();
void testClear();
void setUp(); void setUp();
void tearDown(); void tearDown();

View File

@ -10,6 +10,7 @@
#include "ReactorTestSuite.h" #include "ReactorTestSuite.h"
#include "SocketReactorTest.h" #include "SocketReactorTest.h"
#include "SocketConnectorTest.h"
CppUnit::Test* ReactorTestSuite::suite() CppUnit::Test* ReactorTestSuite::suite()
@ -17,6 +18,7 @@ CppUnit::Test* ReactorTestSuite::suite()
CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("ReactorTestSuite"); CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("ReactorTestSuite");
pSuite->addTest(SocketReactorTest::suite()); pSuite->addTest(SocketReactorTest::suite());
pSuite->addTest(SocketConnectorTest::suite());
return pSuite; return pSuite;
} }

View File

@ -0,0 +1,170 @@
//
// SocketConnectorTest.cpp
//
// Copyright (c) 2005-2019, Applied Informatics Software Engineering GmbH.
// and Contributors.
//
// SPDX-License-Identifier: BSL-1.0
//
#include "SocketConnectorTest.h"
#include "CppUnit/TestCaller.h"
#include "CppUnit/TestSuite.h"
#include "Poco/Net/SocketReactor.h"
#include "Poco/Net/SocketNotification.h"
#include "Poco/Net/SocketConnector.h"
#include "Poco/Net/SocketAcceptor.h"
#include "Poco/Net/StreamSocket.h"
#include "Poco/Net/ServerSocket.h"
#include "Poco/Net/SocketAddress.h"
#include "Poco/Observer.h"
#include <iostream>
using Poco::Net::SocketReactor;
using Poco::Net::SocketConnector;
using Poco::Net::SocketAcceptor;
using Poco::Net::StreamSocket;
using Poco::Net::ServerSocket;
using Poco::Net::SocketAddress;
using Poco::Net::SocketNotification;
using Poco::Net::ReadableNotification;
using Poco::Net::WritableNotification;
using Poco::Net::TimeoutNotification;
using Poco::Net::ShutdownNotification;
using Poco::Observer;
namespace
{
class EchoServiceHandler
{
public:
EchoServiceHandler(StreamSocket& socket, SocketReactor& reactor):
_socket(socket),
_reactor(reactor)
{
_reactor.addEventHandler(_socket, Observer<EchoServiceHandler, ReadableNotification>(*this, &EchoServiceHandler::onReadable));
}
~EchoServiceHandler()
{
_reactor.removeEventHandler(_socket, Observer<EchoServiceHandler, ReadableNotification>(*this, &EchoServiceHandler::onReadable));
}
void onReadable(ReadableNotification* pNf)
{
pNf->release();
char buffer[8];
int n = _socket.receiveBytes(buffer, sizeof(buffer));
if (n > 0) _socket.sendBytes(buffer, n);
else delete this;
}
private:
StreamSocket _socket;
SocketReactor& _reactor;
};
class ClientServiceHandler
{
public:
ClientServiceHandler(StreamSocket& socket, SocketReactor& reactor):
_socket(socket),
_reactor(reactor),
_or(*this, &ClientServiceHandler::onReadable),
_ow(*this, &ClientServiceHandler::onWritable)
{
_reactor.addEventHandler(_socket, _or);
_reactor.addEventHandler(_socket, _ow);
doSomething();
}
~ClientServiceHandler()
{
}
void doSomething()
{
Thread::sleep(100);
}
void onReadable(ReadableNotification* pNf)
{
pNf->release();
char buffer[32];
int n = _socket.receiveBytes(buffer, sizeof(buffer));
if (n <= 0)
{
_reactor.removeEventHandler(_socket, _or);
delete this;
}
}
void onWritable(WritableNotification* pNf)
{
pNf->release();
_reactor.removeEventHandler(_socket, _ow);
std::string data(5, 'x');
_socket.sendBytes(data.data(), (int) data.length());
_socket.shutdownSend();
}
StreamSocket _socket;
SocketReactor& _reactor;
Observer<ClientServiceHandler, ReadableNotification> _or;
Observer<ClientServiceHandler, WritableNotification> _ow;
};
}
SocketConnectorTest::SocketConnectorTest(const std::string& name): CppUnit::TestCase(name)
{
}
SocketConnectorTest::~SocketConnectorTest()
{
}
void SocketConnectorTest::testUnregisterConnector()
{
SocketAddress ssa;
ServerSocket ss(ssa);
SocketReactor reactor1;
SocketReactor reactor2;
SocketAcceptor<EchoServiceHandler> acceptor(ss, reactor1);
Poco::Thread th;
th.start(reactor1);
SocketAddress sa("127.0.0.1", ss.address().port());
SocketConnector<ClientServiceHandler> connector(sa, reactor2);
Poco::Thread th2;
th2.start(reactor2);
Thread::sleep(200);
reactor1.stop();
reactor2.stop();
th.join();
th2.join();
}
void SocketConnectorTest::setUp()
{
}
void SocketConnectorTest::tearDown()
{
}
CppUnit::Test* SocketConnectorTest::suite()
{
CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("SocketConnectorTest");
CppUnit_addTest(pSuite, SocketConnectorTest, testUnregisterConnector);
return pSuite;
}

View File

@ -0,0 +1,36 @@
//
// SocketConnectorTest.h
//
// Definition of the SocketConnectorTest class.
//
// Copyright (c) 2005-2019, Applied Informatics Software Engineering GmbH.
// and Contributors.
//
// SPDX-License-Identifier: BSL-1.0
//
#ifndef SocketConnectorTest_INCLUDED
#define SocketConnectorTest_INCLUDED
#include "Poco/Net/Net.h"
#include "CppUnit/TestCase.h"
class SocketConnectorTest: public CppUnit::TestCase
{
public:
SocketConnectorTest(const std::string& name);
~SocketConnectorTest();
void testUnregisterConnector();
void setUp();
void tearDown();
static CppUnit::Test* suite();
};
#endif // SocketConnectorTest_INCLUDED