From bf09be3f33cea43b2798b63e5dfb0b90c5f61d6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnter=20Obiltschnig?= Date: Sat, 16 Nov 2024 16:50:38 +0100 Subject: [PATCH] fix(NetSSL): Non-blocking sockets support #4773 --- .../include/Poco/Net/SecureSocketImpl.h | 27 ++- .../include/Poco/Net/SecureStreamSocketImpl.h | 6 + NetSSL_OpenSSL/src/SecureSocketImpl.cpp | 25 +- NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp | 24 ++ NetSSL_OpenSSL/testsuite/Makefile | 2 +- .../testsuite/src/NetSSLTestSuite.cpp | 3 +- .../testsuite/src/SecureStreamSocketTest.cpp | 222 ++++++++++++++++++ .../testsuite/src/SecureStreamSocketTest.h | 40 ++++ .../src/SecureStreamSocketTestSuite.cpp | 22 ++ .../src/SecureStreamSocketTestSuite.h | 27 +++ .../testsuite/src/WebSocketTest.cpp | 82 ++++++- NetSSL_OpenSSL/testsuite/src/WebSocketTest.h | 1 + 12 files changed, 447 insertions(+), 34 deletions(-) create mode 100644 NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.cpp create mode 100644 NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.h create mode 100644 NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTestSuite.cpp create mode 100644 NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTestSuite.h diff --git a/NetSSL_OpenSSL/include/Poco/Net/SecureSocketImpl.h b/NetSSL_OpenSSL/include/Poco/Net/SecureSocketImpl.h index 38b7ea502..10e9c1aee 100644 --- a/NetSSL_OpenSSL/include/Poco/Net/SecureSocketImpl.h +++ b/NetSSL_OpenSSL/include/Poco/Net/SecureSocketImpl.h @@ -161,15 +161,6 @@ public: /// underlying TCP connection. No orderly SSL shutdown /// is performed. - void setBlocking(bool flag); - /// Sets the socket in blocking mode if flag is true, - /// disables blocking mode if flag is false. - - bool getBlocking() const; - /// Returns the blocking mode of the socket. - /// This method will only work if the blocking modes of - /// the socket are changed via the setBlocking method! - int sendBytes(const void* buffer, int length, int flags = 0); /// Sends the contents of the given buffer through /// the socket. Any specified flags are ignored. @@ -239,6 +230,12 @@ public: /// Returns true iff a reused session was negotiated during /// the handshake. + SocketImpl* socket(); + /// Returns the underlying SocketImpl. + + const SocketImpl* socket() const; + /// Returns the underlying SocketImpl. + protected: void acceptSSL(); /// Performs a server-side SSL handshake and certificate verification. @@ -308,6 +305,18 @@ private: // // inlines // +inline SocketImpl* SecureSocketImpl::socket() +{ + return _pSocket.get(); +} + + +inline const SocketImpl* SecureSocketImpl::socket() const +{ + return _pSocket.get(); +} + + inline poco_socket_t SecureSocketImpl::sockfd() { return _pSocket->sockfd(); diff --git a/NetSSL_OpenSSL/include/Poco/Net/SecureStreamSocketImpl.h b/NetSSL_OpenSSL/include/Poco/Net/SecureStreamSocketImpl.h index 1327a67cd..942a36c45 100644 --- a/NetSSL_OpenSSL/include/Poco/Net/SecureStreamSocketImpl.h +++ b/NetSSL_OpenSSL/include/Poco/Net/SecureStreamSocketImpl.h @@ -196,6 +196,12 @@ public: /// Returns true iff a reused session was negotiated during /// the handshake. + // SocketImpl + virtual void setBlocking(bool flag) override; + virtual bool getBlocking() const override; + virtual void setRawOption(int level, int option, const void* value, poco_socklen_t length) override; + virtual void getRawOption(int level, int option, void* value, poco_socklen_t& length) override; + protected: void acceptSSL(); /// Performs a SSL server-side handshake. diff --git a/NetSSL_OpenSSL/src/SecureSocketImpl.cpp b/NetSSL_OpenSSL/src/SecureSocketImpl.cpp index a01a2c315..2d70fd57b 100644 --- a/NetSSL_OpenSSL/src/SecureSocketImpl.cpp +++ b/NetSSL_OpenSSL/src/SecureSocketImpl.cpp @@ -332,22 +332,6 @@ void SecureSocketImpl::close() } -void SecureSocketImpl::setBlocking(bool flag) -{ - poco_check_ptr (_pSocket); - - _pSocket->setBlocking(flag); -} - - -bool SecureSocketImpl::getBlocking() const -{ - poco_check_ptr (_pSocket); - - return _pSocket->getBlocking(); -} - - int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) { poco_assert (_pSocket->initialized()); @@ -409,7 +393,14 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) Poco::Timestamp tsStart; while (true) { - rc = ::SSL_read(_pSSL, buffer, length); + if (flags & MSG_PEEK) + { + rc = SSL_peek(_pSSL, buffer, length); + } + else + { + rc = SSL_read(_pSSL, buffer, length); + } if (!mustRetry(rc)) break; diff --git a/NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp b/NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp index 29d111abe..289279e37 100644 --- a/NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp +++ b/NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp @@ -225,4 +225,28 @@ int SecureStreamSocketImpl::completeHandshake() } +void SecureStreamSocketImpl::setBlocking(bool flag) +{ + _impl.socket()->setBlocking(flag); +} + + +bool SecureStreamSocketImpl::getBlocking() const +{ + return _impl.socket()->getBlocking(); +} + + +void SecureStreamSocketImpl::setRawOption(int level, int option, const void* value, poco_socklen_t length) +{ + _impl.socket()->setRawOption(level, option, value, length); +} + + +void SecureStreamSocketImpl::getRawOption(int level, int option, void* value, poco_socklen_t& length) +{ + _impl.socket()->getRawOption(level, option, value, length); +} + + } } // namespace Poco::Net diff --git a/NetSSL_OpenSSL/testsuite/Makefile b/NetSSL_OpenSSL/testsuite/Makefile index 6987638d7..826c62347 100644 --- a/NetSSL_OpenSSL/testsuite/Makefile +++ b/NetSSL_OpenSSL/testsuite/Makefile @@ -25,7 +25,7 @@ objects = NetSSLTestSuite Driver \ HTTPSClientSessionTest HTTPSClientTestSuite HTTPSServerTest HTTPSServerTestSuite \ HTTPSStreamFactoryTest HTTPSTestServer TCPServerTest TCPServerTestSuite \ WebSocketTest WebSocketTestSuite FTPSClientSessionTest FTPSClientTestSuite \ - DialogServer + DialogServer SecureStreamSocketTest SecureStreamSocketTestSuite target = testrunner target_version = 1 diff --git a/NetSSL_OpenSSL/testsuite/src/NetSSLTestSuite.cpp b/NetSSL_OpenSSL/testsuite/src/NetSSLTestSuite.cpp index 39681544b..5f216ea12 100644 --- a/NetSSL_OpenSSL/testsuite/src/NetSSLTestSuite.cpp +++ b/NetSSL_OpenSSL/testsuite/src/NetSSLTestSuite.cpp @@ -9,7 +9,7 @@ #include "NetSSLTestSuite.h" - +#include "SecureStreamSocketTestSuite.h" #include "HTTPSClientTestSuite.h" #include "TCPServerTestSuite.h" #include "HTTPSServerTestSuite.h" @@ -21,6 +21,7 @@ CppUnit::Test* NetSSLTestSuite::suite() { CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("OpenSSLTestSuite"); + pSuite->addTest(SecureStreamSocketTestSuite::suite()); pSuite->addTest(HTTPSClientTestSuite::suite()); pSuite->addTest(TCPServerTestSuite::suite()); pSuite->addTest(HTTPSServerTestSuite::suite()); diff --git a/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.cpp b/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.cpp new file mode 100644 index 000000000..bf4a6a9b1 --- /dev/null +++ b/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.cpp @@ -0,0 +1,222 @@ +// +// SecureStreamSocketTest.cpp +// +// Copyright (c) 2006, Applied Informatics Software Engineering GmbH. +// and Contributors. +// +// SPDX-License-Identifier: BSL-1.0 +// + + +#include "SecureStreamSocketTest.h" +#include "CppUnit/TestCaller.h" +#include "CppUnit/TestSuite.h" +#include "Poco/Net/TCPServer.h" +#include "Poco/Net/TCPServerConnection.h" +#include "Poco/Net/TCPServerConnectionFactory.h" +#include "Poco/Net/TCPServerParams.h" +#include "Poco/Net/SecureStreamSocket.h" +#include "Poco/Net/SecureServerSocket.h" +#include "Poco/Net/Context.h" +#include "Poco/Net/RejectCertificateHandler.h" +#include "Poco/Net/AcceptCertificateHandler.h" +#include "Poco/Net/Session.h" +#include "Poco/Net/SSLManager.h" +#include "Poco/Util/Application.h" +#include "Poco/Util/AbstractConfiguration.h" +#include "Poco/Thread.h" +#include + + +using Poco::Net::TCPServer; +using Poco::Net::TCPServerConnection; +using Poco::Net::TCPServerConnectionFactory; +using Poco::Net::TCPServerConnectionFactoryImpl; +using Poco::Net::TCPServerParams; +using Poco::Net::StreamSocket; +using Poco::Net::SecureStreamSocket; +using Poco::Net::SecureServerSocket; +using Poco::Net::SocketAddress; +using Poco::Net::Context; +using Poco::Net::Session; +using Poco::Net::SSLManager; +using Poco::Thread; +using Poco::Util::Application; + + +namespace +{ + class EchoConnection: public TCPServerConnection + { + public: + EchoConnection(const StreamSocket& s): TCPServerConnection(s) + { + } + + void run() + { + StreamSocket& ss = socket(); + try + { + char buffer[256]; + int n = ss.receiveBytes(buffer, sizeof(buffer)); + while (n > 0) + { + ss.sendBytes(buffer, n); + n = ss.receiveBytes(buffer, sizeof(buffer)); + } + } + catch (Poco::Exception& exc) + { + std::cerr << "EchoConnection: " << exc.displayText() << std::endl; + } + } + }; +} + + +SecureStreamSocketTest::SecureStreamSocketTest(const std::string& name): CppUnit::TestCase(name) +{ +} + + +SecureStreamSocketTest::~SecureStreamSocketTest() +{ +} + + +void SecureStreamSocketTest::testSendReceive() +{ + SecureServerSocket svs(0); + TCPServer srv(new TCPServerConnectionFactoryImpl(), svs); + srv.start(); + + SocketAddress sa("127.0.0.1", svs.address().port()); + SecureStreamSocket ss1(sa); + std::string data("hello, world"); + ss1.sendBytes(data.data(), static_cast(data.size())); + char buffer[8192]; + int n = ss1.receiveBytes(buffer, sizeof(buffer)); + assertTrue (n > 0); + assertTrue (std::string(buffer, n) == data); + + const std::vector sizes = {67, 467, 7883, 19937}; + for (const auto n: sizes) + { + data.assign(n, 'X'); + ss1.sendBytes(data.data(), static_cast(data.size())); + std::string received; + while (received.size() < n) + { + int rc = ss1.receiveBytes(buffer, sizeof(buffer)); + if (rc > 0) + { + received.append(buffer, rc); + } + else if (n == 0) + { + break; + } + } + assertTrue (received == data); + } + + ss1.close(); +} + + +void SecureStreamSocketTest::testPeek() +{ + SecureServerSocket svs(0); + TCPServer srv(new TCPServerConnectionFactoryImpl(), svs); + srv.start(); + + SocketAddress sa("127.0.0.1", svs.address().port()); + SecureStreamSocket ss(sa); + + int n = ss.sendBytes("hello, world!", 13); + assertTrue (n == 13); + char buffer[256]; + n = ss.receiveBytes(buffer, 5, MSG_PEEK); + assertTrue (n == 5); + assertTrue (std::string(buffer, n) == "hello"); + n = ss.receiveBytes(buffer, sizeof(buffer), MSG_PEEK); + assertTrue (n == 13); + assertTrue (std::string(buffer, n) == "hello, world!"); + n = ss.receiveBytes(buffer, 7); + assertTrue (n == 7); + assertTrue (std::string(buffer, n) == "hello, "); + n = ss.receiveBytes(buffer, 6); + assertTrue (n == 6); + assertTrue (std::string(buffer, n) == "world!"); + ss.close(); +} + + +void SecureStreamSocketTest::testNB() +{ + SecureServerSocket svs(0); + TCPServer srv(new TCPServerConnectionFactoryImpl(), svs); + srv.start(); + + SocketAddress sa("127.0.0.1", svs.address().port()); + SecureStreamSocket ss1(sa); + ss1.setBlocking(false); + ss1.setSendBufferSize(32000); + + char buffer[8192]; + const std::vector sizes = {67, 467, 7883, 19937}; + for (const auto n: sizes) + { + std::string data(n, 'X'); + int rc = ss1.sendBytes(data.data(), static_cast(data.size())); + assertTrue (rc == n); + + rc = -1; + while (rc < 0) + { + rc = ss1.receiveBytes(buffer, sizeof(buffer), MSG_PEEK); + } + assertTrue (rc > 0 && rc <= n); + assertTrue (data.compare(0, rc, buffer, rc) == 0); + + std::string received; + while (received.size() < n) + { + int rc = ss1.receiveBytes(buffer, sizeof(buffer)); + if (rc > 0) + { + received.append(buffer, rc); + } + else if (n == 0) + { + break; + } + } + assertTrue (received == data); + } + + ss1.close(); +} + + +void SecureStreamSocketTest::setUp() +{ +} + + +void SecureStreamSocketTest::tearDown() +{ +} + + +CppUnit::Test* SecureStreamSocketTest::suite() +{ + CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("SecureStreamSocketTest"); + + CppUnit_addTest(pSuite, SecureStreamSocketTest, testSendReceive); + CppUnit_addTest(pSuite, SecureStreamSocketTest, testPeek); + CppUnit_addTest(pSuite, SecureStreamSocketTest, testNB); + + return pSuite; +} diff --git a/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.h b/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.h new file mode 100644 index 000000000..62ac9e1c7 --- /dev/null +++ b/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.h @@ -0,0 +1,40 @@ +// +// SecureStreamSocketTest.h +// +// Definition of the SecureStreamSocketTest class. +// +// Copyright (c) 2006, Applied Informatics Software Engineering GmbH. +// and Contributors. +// +// SPDX-License-Identifier: BSL-1.0 +// + + +#ifndef SecureStreamSocketTest_INCLUDED +#define SecureStreamSocketTest_INCLUDED + + +#include "Poco/Net/Net.h" +#include "CppUnit/TestCase.h" + + +class SecureStreamSocketTest: public CppUnit::TestCase +{ +public: + SecureStreamSocketTest(const std::string& name); + ~SecureStreamSocketTest(); + + void testSendReceive(); + void testPeek(); + void testNB(); + + void setUp(); + void tearDown(); + + static CppUnit::Test* suite(); + +private: +}; + + +#endif // SecureStreamSocketTest diff --git a/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTestSuite.cpp b/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTestSuite.cpp new file mode 100644 index 000000000..1aa23be2a --- /dev/null +++ b/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTestSuite.cpp @@ -0,0 +1,22 @@ +// +// SecureStreamSocketTestSuite.cpp +// +// Copyright (c) 2006, Applied Informatics Software Engineering GmbH. +// and Contributors. +// +// SPDX-License-Identifier: BSL-1.0 +// + + +#include "SecureStreamSocketTestSuite.h" +#include "SecureStreamSocketTest.h" + + +CppUnit::Test* SecureStreamSocketTestSuite::suite() +{ + CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("SecureStreamSocketTestSuite"); + + pSuite->addTest(SecureStreamSocketTest::suite()); + + return pSuite; +} diff --git a/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTestSuite.h b/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTestSuite.h new file mode 100644 index 000000000..0bbf4cc3c --- /dev/null +++ b/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTestSuite.h @@ -0,0 +1,27 @@ +// +// SecureStreamSocketTestSuite.h +// +// Definition of the SecureStreamSocketTestSuite class. +// +// Copyright (c) 2006, Applied Informatics Software Engineering GmbH. +// and Contributors. +// +// SPDX-License-Identifier: BSL-1.0 +// + + +#ifndef SecureStreamSocketTestSuite_INCLUDED +#define SecureStreamSocketTestSuite_INCLUDED + + +#include "CppUnit/TestSuite.h" + + +class SecureStreamSocketTestSuite +{ +public: + static CppUnit::Test* suite(); +}; + + +#endif // SecureStreamSocketTestSuite diff --git a/NetSSL_OpenSSL/testsuite/src/WebSocketTest.cpp b/NetSSL_OpenSSL/testsuite/src/WebSocketTest.cpp index bf37e4256..1d7739672 100644 --- a/NetSSL_OpenSSL/testsuite/src/WebSocketTest.cpp +++ b/NetSSL_OpenSSL/testsuite/src/WebSocketTest.cpp @@ -49,16 +49,16 @@ namespace try { WebSocket ws(request, response); - std::unique_ptr pBuffer(new char[_bufSize]); + Poco::Buffer buffer(_bufSize); int flags; int n; do { - n = ws.receiveFrame(pBuffer.get(), static_cast(_bufSize), flags); - Poco::Thread::current()->sleep(handleDelay.totalMilliseconds()); + n = ws.receiveFrame(buffer.begin(), static_cast(_bufSize), flags); if (n == 0) break; - ws.sendFrame(pBuffer.get(), n, flags); + Poco::Thread::current()->sleep(handleDelay.totalMilliseconds()); + ws.sendFrame(buffer.begin(), n, flags); } while ((flags & WebSocket::FRAME_OP_BITMASK) != WebSocket::FRAME_OP_CLOSE); } @@ -156,6 +156,8 @@ void WebSocketTest::testWebSocketTimeout() ws.shutdown(); ws.receiveFrame(buffer, sizeof(buffer), flags); server.stop(); + + WebSocketRequestHandler::handleDelay = 0; } @@ -220,7 +222,7 @@ void WebSocketTest::testWebSocket() assertTrue (n == 2); assertTrue ((flags & WebSocket::FRAME_OP_BITMASK) == WebSocket::FRAME_OP_CLOSE); - server.stop(); + server.stopAll(true); } @@ -256,7 +258,74 @@ void WebSocketTest::testWebSocketLarge() assertTrue (n == payload.size()); assertTrue (payload.compare(0, payload.size(), buffer, 0, n) == 0); - server.stop(); + server.stopAll(true); +} + + +void WebSocketTest::testWebSocketNB() +{ + Poco::Net::SecureServerSocket ss(0); + Poco::Net::HTTPServer server(new WebSocketRequestHandlerFactory(256*1024), ss, new Poco::Net::HTTPServerParams); + server.start(); + + Poco::Thread::sleep(200); + + HTTPSClientSession cs("127.0.0.1", ss.address().port()); + HTTPRequest request(HTTPRequest::HTTP_GET, "/ws", HTTPRequest::HTTP_1_1); + HTTPResponse response; + WebSocket ws(cs, request, response); + ws.setBlocking(false); + + int flags; + char buffer[256*1024] = {}; + int n = ws.receiveFrame(buffer, sizeof(buffer), flags); + assertTrue (n < 0); + + std::string payload("x"); + n = ws.sendFrame(payload.data(), (int) payload.size()); + assertTrue (n > 0); + if (ws.poll(1000000, Poco::Net::Socket::SELECT_READ)) + { + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + while (n < 0) + { + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + } + } + assertTrue (n == payload.size()); + assertTrue (payload.compare(0, payload.size(), buffer, n) == 0); + assertTrue (flags == WebSocket::FRAME_TEXT); + + ws.setSendBufferSize(256*1024); + ws.setReceiveBufferSize(256*1024); + + payload.assign(256000, 'z'); + n = ws.sendFrame(payload.data(), (int) payload.size()); + assertTrue (n > 0); + if (ws.poll(1000000, Poco::Net::Socket::SELECT_READ)) + { + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + while (n < 0) + { + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + } + } + assertTrue (n == payload.size()); + assertTrue (payload.compare(0, payload.size(), buffer, n) == 0); + assertTrue (flags == WebSocket::FRAME_TEXT); + + n = ws.shutdown(); + assertTrue (n > 0); + + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + while (n < 0) + { + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + } + assertTrue (n == 2); + assertTrue ((flags & WebSocket::FRAME_OP_BITMASK) == WebSocket::FRAME_OP_CLOSE); + + server.stopAll(true); } @@ -277,6 +346,7 @@ CppUnit::Test* WebSocketTest::suite() CppUnit_addTest(pSuite, WebSocketTest, testWebSocket); CppUnit_addTest(pSuite, WebSocketTest, testWebSocketTimeout); CppUnit_addTest(pSuite, WebSocketTest, testWebSocketLarge); + CppUnit_addTest(pSuite, WebSocketTest, testWebSocketNB); return pSuite; } diff --git a/NetSSL_OpenSSL/testsuite/src/WebSocketTest.h b/NetSSL_OpenSSL/testsuite/src/WebSocketTest.h index f8ea9131c..7947484fe 100644 --- a/NetSSL_OpenSSL/testsuite/src/WebSocketTest.h +++ b/NetSSL_OpenSSL/testsuite/src/WebSocketTest.h @@ -27,6 +27,7 @@ public: void testWebSocketTimeout(); void testWebSocket(); void testWebSocketLarge(); + void testWebSocketNB(); void setUp() override; void tearDown() override;