GH #1485: add TCPConnectionFilter

This commit is contained in:
Guenter Obiltschnig 2016-11-11 14:13:04 +01:00
parent 1164c4b03b
commit fe59f959ef
4 changed files with 121 additions and 11 deletions

View File

@ -24,6 +24,8 @@
#include "Poco/Net/ServerSocket.h"
#include "Poco/Net/TCPServerConnectionFactory.h"
#include "Poco/Net/TCPServerParams.h"
#include "Poco/RefCountedObject.h"
#include "Poco/AutoPtr.h"
#include "Poco/Runnable.h"
#include "Poco/Thread.h"
#include "Poco/ThreadPool.h"
@ -36,6 +38,26 @@ namespace Net {
class TCPServerDispatcher;
class Net_API TCPConnectionFilter: public Poco::RefCountedObject
/// A TCPConnectionFilter can be used to reject incoming connections
/// before a thread is started to handle them.
///
/// An example use case is white-list or black-list IP address filtering.
///
/// Subclasses must override the accept() method.
{
public:
typedef Poco::AutoPtr<TCPConnectionFilter> Ptr;
virtual bool accept(const Poco::Net::SocketAddress& clientAddress) = 0;
/// Returns true if a connection from the given clientAddress should
/// be accepted, otherwise false.
protected:
virtual ~TCPConnectionFilter();
};
class Net_API TCPServer: public Poco::Runnable
/// This class implements a multithreaded TCP server.
///
@ -166,6 +188,19 @@ public:
Poco::UInt16 port() const;
/// Returns the port the server socket listens on.
void setConnectionFilter(const TCPConnectionFilter::Ptr& pFilter);
/// Sets a TCPConnectionFilter. Can also be used to remove
/// a filter by passing a null pointer.
///
/// To avoid a potential race condition, the filter must
/// be set before the TCPServer is started. Trying to set
/// the filter after start() has been called will trigger
/// an assertion.
TCPConnectionFilter::Ptr getConnectionFilter() const;
/// Returns the TCPConnectionFilter set with setConnectionFilter(),
/// or null pointer if no filter has been set.
protected:
void run();
/// Runs the server. The server will run until
@ -183,11 +218,15 @@ private:
ServerSocket _socket;
TCPServerDispatcher* _pDispatcher;
TCPConnectionFilter::Ptr _pConnectionFilter;
Poco::Thread _thread;
bool _stopped;
};
//
// inlines
//
inline const ServerSocket& TCPServer::socket() const
{
return _socket;
@ -200,6 +239,12 @@ inline Poco::UInt16 TCPServer::port() const
}
inline TCPConnectionFilter::Ptr TCPServer::getConnectionFilter() const
{
return _pConnectionFilter;
}
} } // namespace Poco::Net

View File

@ -30,6 +30,21 @@ namespace Poco {
namespace Net {
//
// TCPConnectionFilter
//
TCPConnectionFilter::~TCPConnectionFilter()
{
}
//
// TCPServer
//
TCPServer::TCPServer(TCPServerConnectionFactory::Ptr pFactory, Poco::UInt16 portNumber, TCPServerParams::Ptr pParams):
_socket(ServerSocket(portNumber)),
_thread(threadName(_socket)),
@ -122,6 +137,9 @@ void TCPServer::run()
try
{
StreamSocket ss = _socket.acceptConnection();
if (!_pConnectionFilter || _pConnectionFilter->accept(ss.peerAddress()))
{
// enable nodelay per default: OSX really needs that
#if defined(POCO_OS_FAMILY_UNIX)
if (ss.address().family() != AddressFamily::UNIX_LOCAL)
@ -131,6 +149,7 @@ void TCPServer::run()
}
_pDispatcher->enqueue(ss);
}
}
catch (Poco::Exception& exc)
{
ErrorHandler::handle(exc);
@ -161,6 +180,7 @@ int TCPServer::currentThreads() const
return _pDispatcher->currentThreads();
}
int TCPServer::maxThreads() const
{
return _pDispatcher->maxThreads();
@ -197,6 +217,14 @@ int TCPServer::refusedConnections() const
}
void TCPServer::setConnectionFilter(const TCPConnectionFilter::Ptr& pConnectionFilter)
{
poco_assert (_stopped);
_pConnectionFilter = pConnectionFilter;
}
std::string TCPServer::threadName(const ServerSocket& socket)
{
#if _WIN32_WCE == 0x0800

View File

@ -24,6 +24,7 @@
using Poco::Net::TCPServer;
using Poco::Net::TCPConnectionFilter;
using Poco::Net::TCPServerConnection;
using Poco::Net::TCPServerConnectionFactory;
using Poco::Net::TCPServerConnectionFactoryImpl;
@ -62,6 +63,15 @@ namespace
}
}
};
class RejectFilter: public TCPConnectionFilter
{
public:
bool accept(const SocketAddress&)
{
return false;
}
};
}
@ -233,7 +243,9 @@ void TCPServerTest::testMultiConnections()
assert (srv.currentConnections() == 0);
}
void TCPServerTest::testThreadCapacity(){
void TCPServerTest::testThreadCapacity()
{
ServerSocket svs(0);
TCPServerParams* pParams = new TCPServerParams;
pParams->setMaxThreads(64);
@ -243,6 +255,29 @@ void TCPServerTest::testThreadCapacity(){
}
void TCPServerTest::testFilter()
{
TCPServer srv(new TCPServerConnectionFactoryImpl<EchoConnection>());
srv.setConnectionFilter(new RejectFilter);
srv.start();
assert (srv.currentConnections() == 0);
assert (srv.currentThreads() == 0);
assert (srv.queuedConnections() == 0);
assert (srv.totalConnections() == 0);
SocketAddress sa("127.0.0.1", srv.socket().address().port());
StreamSocket ss(sa);
char buffer[256];
int n = ss.receiveBytes(buffer, sizeof(buffer));
assert (n == 0);
assert (srv.currentConnections() == 0);
assert (srv.currentThreads() == 0);
assert (srv.queuedConnections() == 0);
assert (srv.totalConnections() == 0);
}
void TCPServerTest::setUp()
{
@ -262,6 +297,7 @@ CppUnit::Test* TCPServerTest::suite()
CppUnit_addTest(pSuite, TCPServerTest, testTwoConnections);
CppUnit_addTest(pSuite, TCPServerTest, testMultiConnections);
CppUnit_addTest(pSuite, TCPServerTest, testThreadCapacity);
CppUnit_addTest(pSuite, TCPServerTest, testFilter);
return pSuite;
}

View File

@ -30,6 +30,7 @@ public:
void testTwoConnections();
void testMultiConnections();
void testThreadCapacity();
void testFilter();
void setUp();
void tearDown();