mirror of
https://github.com/pocoproject/poco.git
synced 2025-04-01 09:24:55 +02:00
GH #1485: add TCPConnectionFilter
This commit is contained in:
parent
1164c4b03b
commit
fe59f959ef
@ -24,6 +24,8 @@
|
||||
#include "Poco/Net/ServerSocket.h"
|
||||
#include "Poco/Net/TCPServerConnectionFactory.h"
|
||||
#include "Poco/Net/TCPServerParams.h"
|
||||
#include "Poco/RefCountedObject.h"
|
||||
#include "Poco/AutoPtr.h"
|
||||
#include "Poco/Runnable.h"
|
||||
#include "Poco/Thread.h"
|
||||
#include "Poco/ThreadPool.h"
|
||||
@ -36,6 +38,26 @@ namespace Net {
|
||||
class TCPServerDispatcher;
|
||||
|
||||
|
||||
class Net_API TCPConnectionFilter: public Poco::RefCountedObject
|
||||
/// A TCPConnectionFilter can be used to reject incoming connections
|
||||
/// before a thread is started to handle them.
|
||||
///
|
||||
/// An example use case is white-list or black-list IP address filtering.
|
||||
///
|
||||
/// Subclasses must override the accept() method.
|
||||
{
|
||||
public:
|
||||
typedef Poco::AutoPtr<TCPConnectionFilter> Ptr;
|
||||
|
||||
virtual bool accept(const Poco::Net::SocketAddress& clientAddress) = 0;
|
||||
/// Returns true if a connection from the given clientAddress should
|
||||
/// be accepted, otherwise false.
|
||||
|
||||
protected:
|
||||
virtual ~TCPConnectionFilter();
|
||||
};
|
||||
|
||||
|
||||
class Net_API TCPServer: public Poco::Runnable
|
||||
/// This class implements a multithreaded TCP server.
|
||||
///
|
||||
@ -165,6 +187,19 @@ public:
|
||||
|
||||
Poco::UInt16 port() const;
|
||||
/// Returns the port the server socket listens on.
|
||||
|
||||
void setConnectionFilter(const TCPConnectionFilter::Ptr& pFilter);
|
||||
/// Sets a TCPConnectionFilter. Can also be used to remove
|
||||
/// a filter by passing a null pointer.
|
||||
///
|
||||
/// To avoid a potential race condition, the filter must
|
||||
/// be set before the TCPServer is started. Trying to set
|
||||
/// the filter after start() has been called will trigger
|
||||
/// an assertion.
|
||||
|
||||
TCPConnectionFilter::Ptr getConnectionFilter() const;
|
||||
/// Returns the TCPConnectionFilter set with setConnectionFilter(),
|
||||
/// or null pointer if no filter has been set.
|
||||
|
||||
protected:
|
||||
void run();
|
||||
@ -181,13 +216,17 @@ private:
|
||||
TCPServer(const TCPServer&);
|
||||
TCPServer& operator = (const TCPServer&);
|
||||
|
||||
ServerSocket _socket;
|
||||
TCPServerDispatcher* _pDispatcher;
|
||||
Poco::Thread _thread;
|
||||
bool _stopped;
|
||||
ServerSocket _socket;
|
||||
TCPServerDispatcher* _pDispatcher;
|
||||
TCPConnectionFilter::Ptr _pConnectionFilter;
|
||||
Poco::Thread _thread;
|
||||
bool _stopped;
|
||||
};
|
||||
|
||||
|
||||
//
|
||||
// inlines
|
||||
//
|
||||
inline const ServerSocket& TCPServer::socket() const
|
||||
{
|
||||
return _socket;
|
||||
@ -200,6 +239,12 @@ inline Poco::UInt16 TCPServer::port() const
|
||||
}
|
||||
|
||||
|
||||
inline TCPConnectionFilter::Ptr TCPServer::getConnectionFilter() const
|
||||
{
|
||||
return _pConnectionFilter;
|
||||
}
|
||||
|
||||
|
||||
} } // namespace Poco::Net
|
||||
|
||||
|
||||
|
@ -30,6 +30,21 @@ namespace Poco {
|
||||
namespace Net {
|
||||
|
||||
|
||||
//
|
||||
// TCPConnectionFilter
|
||||
//
|
||||
|
||||
|
||||
TCPConnectionFilter::~TCPConnectionFilter()
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// TCPServer
|
||||
//
|
||||
|
||||
|
||||
TCPServer::TCPServer(TCPServerConnectionFactory::Ptr pFactory, Poco::UInt16 portNumber, TCPServerParams::Ptr pParams):
|
||||
_socket(ServerSocket(portNumber)),
|
||||
_thread(threadName(_socket)),
|
||||
@ -122,14 +137,18 @@ void TCPServer::run()
|
||||
try
|
||||
{
|
||||
StreamSocket ss = _socket.acceptConnection();
|
||||
// enable nodelay per default: OSX really needs that
|
||||
#if defined(POCO_OS_FAMILY_UNIX)
|
||||
if (ss.address().family() != AddressFamily::UNIX_LOCAL)
|
||||
#endif
|
||||
|
||||
if (!_pConnectionFilter || _pConnectionFilter->accept(ss.peerAddress()))
|
||||
{
|
||||
ss.setNoDelay(true);
|
||||
// enable nodelay per default: OSX really needs that
|
||||
#if defined(POCO_OS_FAMILY_UNIX)
|
||||
if (ss.address().family() != AddressFamily::UNIX_LOCAL)
|
||||
#endif
|
||||
{
|
||||
ss.setNoDelay(true);
|
||||
}
|
||||
_pDispatcher->enqueue(ss);
|
||||
}
|
||||
_pDispatcher->enqueue(ss);
|
||||
}
|
||||
catch (Poco::Exception& exc)
|
||||
{
|
||||
@ -161,6 +180,7 @@ int TCPServer::currentThreads() const
|
||||
return _pDispatcher->currentThreads();
|
||||
}
|
||||
|
||||
|
||||
int TCPServer::maxThreads() const
|
||||
{
|
||||
return _pDispatcher->maxThreads();
|
||||
@ -197,6 +217,14 @@ int TCPServer::refusedConnections() const
|
||||
}
|
||||
|
||||
|
||||
void TCPServer::setConnectionFilter(const TCPConnectionFilter::Ptr& pConnectionFilter)
|
||||
{
|
||||
poco_assert (_stopped);
|
||||
|
||||
_pConnectionFilter = pConnectionFilter;
|
||||
}
|
||||
|
||||
|
||||
std::string TCPServer::threadName(const ServerSocket& socket)
|
||||
{
|
||||
#if _WIN32_WCE == 0x0800
|
||||
|
@ -24,6 +24,7 @@
|
||||
|
||||
|
||||
using Poco::Net::TCPServer;
|
||||
using Poco::Net::TCPConnectionFilter;
|
||||
using Poco::Net::TCPServerConnection;
|
||||
using Poco::Net::TCPServerConnectionFactory;
|
||||
using Poco::Net::TCPServerConnectionFactoryImpl;
|
||||
@ -62,6 +63,15 @@ namespace
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class RejectFilter: public TCPConnectionFilter
|
||||
{
|
||||
public:
|
||||
bool accept(const SocketAddress&)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -233,7 +243,9 @@ void TCPServerTest::testMultiConnections()
|
||||
assert (srv.currentConnections() == 0);
|
||||
}
|
||||
|
||||
void TCPServerTest::testThreadCapacity(){
|
||||
|
||||
void TCPServerTest::testThreadCapacity()
|
||||
{
|
||||
ServerSocket svs(0);
|
||||
TCPServerParams* pParams = new TCPServerParams;
|
||||
pParams->setMaxThreads(64);
|
||||
@ -243,6 +255,29 @@ void TCPServerTest::testThreadCapacity(){
|
||||
}
|
||||
|
||||
|
||||
void TCPServerTest::testFilter()
|
||||
{
|
||||
TCPServer srv(new TCPServerConnectionFactoryImpl<EchoConnection>());
|
||||
srv.setConnectionFilter(new RejectFilter);
|
||||
srv.start();
|
||||
assert (srv.currentConnections() == 0);
|
||||
assert (srv.currentThreads() == 0);
|
||||
assert (srv.queuedConnections() == 0);
|
||||
assert (srv.totalConnections() == 0);
|
||||
|
||||
SocketAddress sa("127.0.0.1", srv.socket().address().port());
|
||||
StreamSocket ss(sa);
|
||||
|
||||
char buffer[256];
|
||||
int n = ss.receiveBytes(buffer, sizeof(buffer));
|
||||
|
||||
assert (n == 0);
|
||||
assert (srv.currentConnections() == 0);
|
||||
assert (srv.currentThreads() == 0);
|
||||
assert (srv.queuedConnections() == 0);
|
||||
assert (srv.totalConnections() == 0);
|
||||
}
|
||||
|
||||
|
||||
void TCPServerTest::setUp()
|
||||
{
|
||||
@ -262,6 +297,7 @@ CppUnit::Test* TCPServerTest::suite()
|
||||
CppUnit_addTest(pSuite, TCPServerTest, testTwoConnections);
|
||||
CppUnit_addTest(pSuite, TCPServerTest, testMultiConnections);
|
||||
CppUnit_addTest(pSuite, TCPServerTest, testThreadCapacity);
|
||||
CppUnit_addTest(pSuite, TCPServerTest, testFilter);
|
||||
|
||||
return pSuite;
|
||||
}
|
||||
|
@ -30,6 +30,7 @@ public:
|
||||
void testTwoConnections();
|
||||
void testMultiConnections();
|
||||
void testThreadCapacity();
|
||||
void testFilter();
|
||||
|
||||
void setUp();
|
||||
void tearDown();
|
||||
|
Loading…
x
Reference in New Issue
Block a user