From fe59f959ef920fb09c29701b4fa65044a16dd14c Mon Sep 17 00:00:00 2001 From: Guenter Obiltschnig Date: Fri, 11 Nov 2016 14:13:04 +0100 Subject: [PATCH] GH #1485: add TCPConnectionFilter --- Net/include/Poco/Net/TCPServer.h | 53 ++++++++++++++++++++++++++--- Net/src/TCPServer.cpp | 40 ++++++++++++++++++---- Net/testsuite/src/TCPServerTest.cpp | 38 ++++++++++++++++++++- Net/testsuite/src/TCPServerTest.h | 1 + 4 files changed, 121 insertions(+), 11 deletions(-) diff --git a/Net/include/Poco/Net/TCPServer.h b/Net/include/Poco/Net/TCPServer.h index 45ce8edd2..9c2ab525b 100644 --- a/Net/include/Poco/Net/TCPServer.h +++ b/Net/include/Poco/Net/TCPServer.h @@ -24,6 +24,8 @@ #include "Poco/Net/ServerSocket.h" #include "Poco/Net/TCPServerConnectionFactory.h" #include "Poco/Net/TCPServerParams.h" +#include "Poco/RefCountedObject.h" +#include "Poco/AutoPtr.h" #include "Poco/Runnable.h" #include "Poco/Thread.h" #include "Poco/ThreadPool.h" @@ -36,6 +38,26 @@ namespace Net { class TCPServerDispatcher; +class Net_API TCPConnectionFilter: public Poco::RefCountedObject + /// A TCPConnectionFilter can be used to reject incoming connections + /// before a thread is started to handle them. + /// + /// An example use case is white-list or black-list IP address filtering. + /// + /// Subclasses must override the accept() method. +{ +public: + typedef Poco::AutoPtr Ptr; + + virtual bool accept(const Poco::Net::SocketAddress& clientAddress) = 0; + /// Returns true if a connection from the given clientAddress should + /// be accepted, otherwise false. + +protected: + virtual ~TCPConnectionFilter(); +}; + + class Net_API TCPServer: public Poco::Runnable /// This class implements a multithreaded TCP server. /// @@ -165,6 +187,19 @@ public: Poco::UInt16 port() const; /// Returns the port the server socket listens on. + + void setConnectionFilter(const TCPConnectionFilter::Ptr& pFilter); + /// Sets a TCPConnectionFilter. Can also be used to remove + /// a filter by passing a null pointer. + /// + /// To avoid a potential race condition, the filter must + /// be set before the TCPServer is started. Trying to set + /// the filter after start() has been called will trigger + /// an assertion. + + TCPConnectionFilter::Ptr getConnectionFilter() const; + /// Returns the TCPConnectionFilter set with setConnectionFilter(), + /// or null pointer if no filter has been set. protected: void run(); @@ -181,13 +216,17 @@ private: TCPServer(const TCPServer&); TCPServer& operator = (const TCPServer&); - ServerSocket _socket; - TCPServerDispatcher* _pDispatcher; - Poco::Thread _thread; - bool _stopped; + ServerSocket _socket; + TCPServerDispatcher* _pDispatcher; + TCPConnectionFilter::Ptr _pConnectionFilter; + Poco::Thread _thread; + bool _stopped; }; +// +// inlines +// inline const ServerSocket& TCPServer::socket() const { return _socket; @@ -200,6 +239,12 @@ inline Poco::UInt16 TCPServer::port() const } +inline TCPConnectionFilter::Ptr TCPServer::getConnectionFilter() const +{ + return _pConnectionFilter; +} + + } } // namespace Poco::Net diff --git a/Net/src/TCPServer.cpp b/Net/src/TCPServer.cpp index 7b06d694c..34236c71d 100644 --- a/Net/src/TCPServer.cpp +++ b/Net/src/TCPServer.cpp @@ -30,6 +30,21 @@ namespace Poco { namespace Net { +// +// TCPConnectionFilter +// + + +TCPConnectionFilter::~TCPConnectionFilter() +{ +} + + +// +// TCPServer +// + + TCPServer::TCPServer(TCPServerConnectionFactory::Ptr pFactory, Poco::UInt16 portNumber, TCPServerParams::Ptr pParams): _socket(ServerSocket(portNumber)), _thread(threadName(_socket)), @@ -122,14 +137,18 @@ void TCPServer::run() try { StreamSocket ss = _socket.acceptConnection(); - // enable nodelay per default: OSX really needs that -#if defined(POCO_OS_FAMILY_UNIX) - if (ss.address().family() != AddressFamily::UNIX_LOCAL) -#endif + + if (!_pConnectionFilter || _pConnectionFilter->accept(ss.peerAddress())) { - ss.setNoDelay(true); + // enable nodelay per default: OSX really needs that +#if defined(POCO_OS_FAMILY_UNIX) + if (ss.address().family() != AddressFamily::UNIX_LOCAL) +#endif + { + ss.setNoDelay(true); + } + _pDispatcher->enqueue(ss); } - _pDispatcher->enqueue(ss); } catch (Poco::Exception& exc) { @@ -161,6 +180,7 @@ int TCPServer::currentThreads() const return _pDispatcher->currentThreads(); } + int TCPServer::maxThreads() const { return _pDispatcher->maxThreads(); @@ -197,6 +217,14 @@ int TCPServer::refusedConnections() const } +void TCPServer::setConnectionFilter(const TCPConnectionFilter::Ptr& pConnectionFilter) +{ + poco_assert (_stopped); + + _pConnectionFilter = pConnectionFilter; +} + + std::string TCPServer::threadName(const ServerSocket& socket) { #if _WIN32_WCE == 0x0800 diff --git a/Net/testsuite/src/TCPServerTest.cpp b/Net/testsuite/src/TCPServerTest.cpp index 86375779e..11240244c 100644 --- a/Net/testsuite/src/TCPServerTest.cpp +++ b/Net/testsuite/src/TCPServerTest.cpp @@ -24,6 +24,7 @@ using Poco::Net::TCPServer; +using Poco::Net::TCPConnectionFilter; using Poco::Net::TCPServerConnection; using Poco::Net::TCPServerConnectionFactory; using Poco::Net::TCPServerConnectionFactoryImpl; @@ -62,6 +63,15 @@ namespace } } }; + + class RejectFilter: public TCPConnectionFilter + { + public: + bool accept(const SocketAddress&) + { + return false; + } + }; } @@ -233,7 +243,9 @@ void TCPServerTest::testMultiConnections() assert (srv.currentConnections() == 0); } -void TCPServerTest::testThreadCapacity(){ + +void TCPServerTest::testThreadCapacity() +{ ServerSocket svs(0); TCPServerParams* pParams = new TCPServerParams; pParams->setMaxThreads(64); @@ -243,6 +255,29 @@ void TCPServerTest::testThreadCapacity(){ } +void TCPServerTest::testFilter() +{ + TCPServer srv(new TCPServerConnectionFactoryImpl()); + srv.setConnectionFilter(new RejectFilter); + srv.start(); + assert (srv.currentConnections() == 0); + assert (srv.currentThreads() == 0); + assert (srv.queuedConnections() == 0); + assert (srv.totalConnections() == 0); + + SocketAddress sa("127.0.0.1", srv.socket().address().port()); + StreamSocket ss(sa); + + char buffer[256]; + int n = ss.receiveBytes(buffer, sizeof(buffer)); + + assert (n == 0); + assert (srv.currentConnections() == 0); + assert (srv.currentThreads() == 0); + assert (srv.queuedConnections() == 0); + assert (srv.totalConnections() == 0); +} + void TCPServerTest::setUp() { @@ -262,6 +297,7 @@ CppUnit::Test* TCPServerTest::suite() CppUnit_addTest(pSuite, TCPServerTest, testTwoConnections); CppUnit_addTest(pSuite, TCPServerTest, testMultiConnections); CppUnit_addTest(pSuite, TCPServerTest, testThreadCapacity); + CppUnit_addTest(pSuite, TCPServerTest, testFilter); return pSuite; } diff --git a/Net/testsuite/src/TCPServerTest.h b/Net/testsuite/src/TCPServerTest.h index 138a970ff..c39c87cf0 100644 --- a/Net/testsuite/src/TCPServerTest.h +++ b/Net/testsuite/src/TCPServerTest.h @@ -30,6 +30,7 @@ public: void testTwoConnections(); void testMultiConnections(); void testThreadCapacity(); + void testFilter(); void setUp(); void tearDown();