diff --git a/Net/include/Poco/Net/SocketImpl.h b/Net/include/Poco/Net/SocketImpl.h index f83d857c4..cd8ab31c4 100644 --- a/Net/include/Poco/Net/SocketImpl.h +++ b/Net/include/Poco/Net/SocketImpl.h @@ -295,11 +295,11 @@ public: /// If count is != 0, sends the given number of bytes, otherwise /// sends all bytes, starting from the given offset. /// - /// On Linux, macOS and FreeBSD systems, the implementation + /// On Linux, macOS and FreeBSD systems, the implementation /// uses sendfile() or sendfile64(). /// On Windows, the implementation uses TransmitFile(). /// - /// If neither sendfile() nor TransmitFile() is available, + /// If neither sendfile() nor TransmitFile() is available, /// or the socket is a secure one (secure() returne true), /// falls back to reading the file block by block and calling sendBytes(). /// @@ -307,7 +307,7 @@ public: /// as count, unless count is 0. /// /// Throws NetException (or a subclass) in case of any errors. - /// Also throws a NetException if the socket has been set to + /// Also throws a NetException if the socket has been set to /// non-blocking. virtual int available(); @@ -636,7 +636,7 @@ inline bool SocketImpl::getBlocking() const #if defined(POCO_OS_FAMILY_WINDOWS) - #pragma comment(lib, "mswsock.lib") -#endif + #pragma comment(lib, "mswsock.lib") +#endif #endif // Net_SocketImpl_INCLUDED diff --git a/Net/include/Poco/Net/TCPServerConnectionFactory.h b/Net/include/Poco/Net/TCPServerConnectionFactory.h index 380cb2dbd..cb941d0f9 100644 --- a/Net/include/Poco/Net/TCPServerConnectionFactory.h +++ b/Net/include/Poco/Net/TCPServerConnectionFactory.h @@ -35,7 +35,13 @@ class Net_API TCPServerConnectionFactory /// it accepts. /// /// Subclasses must override the createConnection() - /// method. + /// method, which can refuse connections by returning nullptr. + /// Some examples when an implementation may refuse new connections: + /// + /// - number of connections exceeded a limit + /// - a connection from unwanted client attempted + /// - too many connection attempts in a short timespan + /// - etc. /// /// The TCPServerConnectionFactoryImpl template class /// can be used to automatically instantiate a @@ -45,23 +51,55 @@ class Net_API TCPServerConnectionFactory public: using Ptr = Poco::SharedPtr; + TCPServerConnectionFactory(const TCPServerConnectionFactory&) = delete; + TCPServerConnectionFactory& operator = (const TCPServerConnectionFactory&) = delete; + TCPServerConnectionFactory(TCPServerConnectionFactory&&) = delete; + TCPServerConnectionFactory& operator = (TCPServerConnectionFactory&&) = delete; + virtual ~TCPServerConnectionFactory(); /// Destroys the TCPServerConnectionFactory. virtual TCPServerConnection* createConnection(const StreamSocket& socket) = 0; /// Creates an instance of a subclass of TCPServerConnection, /// using the given StreamSocket. + /// This function is allowed to return nullptr, in which case an accepted + /// socket will be destroyed by the TCPServerDispatcher. + + void stop(); + /// Stops the factory. + /// Normally, this function is called by TCPServerDispatcher + /// to indicate that the server is shutting down; the expected + /// implementation behavior after this call is to return nullptr + /// on all subsequent connection creation attempts. + + bool isStopped() const; + /// Returns true if the factory was stopped, false otherwise. protected: TCPServerConnectionFactory(); /// Creates the TCPServerConnectionFactory. private: - TCPServerConnectionFactory(const TCPServerConnectionFactory&); - TCPServerConnectionFactory& operator = (const TCPServerConnectionFactory&); + std::atomic _stopped; }; +// +// inlines +// + +inline void TCPServerConnectionFactory::stop() +{ + _stopped = true; +} + + +inline bool TCPServerConnectionFactory::isStopped() const +{ + return _stopped; +} + + template class TCPServerConnectionFactoryImpl: public TCPServerConnectionFactory /// This template provides a basic implementation of @@ -78,7 +116,10 @@ public: TCPServerConnection* createConnection(const StreamSocket& socket) { - return new S(socket); + if (!isStopped()) + return new S(socket); + + return nullptr; } }; diff --git a/Net/src/TCPServer.cpp b/Net/src/TCPServer.cpp index ebf48a7e6..9aea0f92b 100644 --- a/Net/src/TCPServer.cpp +++ b/Net/src/TCPServer.cpp @@ -132,6 +132,8 @@ void TCPServer::run() { if (_socket.poll(timeout, Socket::SELECT_READ)) { + if (_stopped) return; + try { StreamSocket ss = _socket.acceptConnection(); @@ -150,24 +152,30 @@ void TCPServer::run() } catch (Poco::Exception& exc) { - ErrorHandler::handle(exc); + if (!_stopped) + ErrorHandler::handle(exc); } catch (std::exception& exc) { - ErrorHandler::handle(exc); + if (!_stopped) + ErrorHandler::handle(exc); } catch (...) { - ErrorHandler::handle(); + if (!_stopped) + ErrorHandler::handle(); } } } catch (Poco::Exception& exc) { - ErrorHandler::handle(exc); - // possibly a resource issue since poll() failed; - // give some time to recover before trying again - Poco::Thread::sleep(50); + if (!_stopped) + { + ErrorHandler::handle(exc); + // possibly a resource issue since poll() failed; + // give some time to recover before trying again + Poco::Thread::sleep(50); + } } } } diff --git a/Net/src/TCPServerConnectionFactory.cpp b/Net/src/TCPServerConnectionFactory.cpp index a861e433d..3227e9bb0 100644 --- a/Net/src/TCPServerConnectionFactory.cpp +++ b/Net/src/TCPServerConnectionFactory.cpp @@ -19,7 +19,8 @@ namespace Poco { namespace Net { -TCPServerConnectionFactory::TCPServerConnectionFactory() +TCPServerConnectionFactory::TCPServerConnectionFactory(): + _stopped(false) { } diff --git a/Net/src/TCPServerDispatcher.cpp b/Net/src/TCPServerDispatcher.cpp index d798ba8f5..d77261377 100644 --- a/Net/src/TCPServerDispatcher.cpp +++ b/Net/src/TCPServerDispatcher.cpp @@ -112,10 +112,12 @@ void TCPServerDispatcher::run() if (pCNf) { std::unique_ptr pConnection(_pConnectionFactory->createConnection(pCNf->socket())); - poco_check_ptr(pConnection.get()); - beginConnection(); - pConnection->start(); - endConnection(); + if (pConnection) + { + beginConnection(); + pConnection->start(); + endConnection(); + } } } } @@ -173,6 +175,7 @@ void TCPServerDispatcher::enqueue(const StreamSocket& socket) void TCPServerDispatcher::stop() { FastMutex::ScopedLock lock(_mutex); + _pConnectionFactory->stop(); _stopped = true; _queue.clear(); for (int i = 0; i < _threadPool.allocated(); i++)