fix(NetSSL): Non-blocking sockets support #4773

This commit is contained in:
Günter Obiltschnig 2024-11-16 16:50:38 +01:00
parent cedd086362
commit bf09be3f33
12 changed files with 447 additions and 34 deletions

View File

@ -161,15 +161,6 @@ public:
/// underlying TCP connection. No orderly SSL shutdown
/// is performed.
void setBlocking(bool flag);
/// Sets the socket in blocking mode if flag is true,
/// disables blocking mode if flag is false.
bool getBlocking() const;
/// Returns the blocking mode of the socket.
/// This method will only work if the blocking modes of
/// the socket are changed via the setBlocking method!
int sendBytes(const void* buffer, int length, int flags = 0);
/// Sends the contents of the given buffer through
/// the socket. Any specified flags are ignored.
@ -239,6 +230,12 @@ public:
/// Returns true iff a reused session was negotiated during
/// the handshake.
SocketImpl* socket();
/// Returns the underlying SocketImpl.
const SocketImpl* socket() const;
/// Returns the underlying SocketImpl.
protected:
void acceptSSL();
/// Performs a server-side SSL handshake and certificate verification.
@ -308,6 +305,18 @@ private:
//
// inlines
//
inline SocketImpl* SecureSocketImpl::socket()
{
return _pSocket.get();
}
inline const SocketImpl* SecureSocketImpl::socket() const
{
return _pSocket.get();
}
inline poco_socket_t SecureSocketImpl::sockfd()
{
return _pSocket->sockfd();

View File

@ -196,6 +196,12 @@ public:
/// Returns true iff a reused session was negotiated during
/// the handshake.
// SocketImpl
virtual void setBlocking(bool flag) override;
virtual bool getBlocking() const override;
virtual void setRawOption(int level, int option, const void* value, poco_socklen_t length) override;
virtual void getRawOption(int level, int option, void* value, poco_socklen_t& length) override;
protected:
void acceptSSL();
/// Performs a SSL server-side handshake.

View File

@ -332,22 +332,6 @@ void SecureSocketImpl::close()
}
void SecureSocketImpl::setBlocking(bool flag)
{
poco_check_ptr (_pSocket);
_pSocket->setBlocking(flag);
}
bool SecureSocketImpl::getBlocking() const
{
poco_check_ptr (_pSocket);
return _pSocket->getBlocking();
}
int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags)
{
poco_assert (_pSocket->initialized());
@ -409,7 +393,14 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags)
Poco::Timestamp tsStart;
while (true)
{
rc = ::SSL_read(_pSSL, buffer, length);
if (flags & MSG_PEEK)
{
rc = SSL_peek(_pSSL, buffer, length);
}
else
{
rc = SSL_read(_pSSL, buffer, length);
}
if (!mustRetry(rc))
break;

View File

@ -225,4 +225,28 @@ int SecureStreamSocketImpl::completeHandshake()
}
void SecureStreamSocketImpl::setBlocking(bool flag)
{
_impl.socket()->setBlocking(flag);
}
bool SecureStreamSocketImpl::getBlocking() const
{
return _impl.socket()->getBlocking();
}
void SecureStreamSocketImpl::setRawOption(int level, int option, const void* value, poco_socklen_t length)
{
_impl.socket()->setRawOption(level, option, value, length);
}
void SecureStreamSocketImpl::getRawOption(int level, int option, void* value, poco_socklen_t& length)
{
_impl.socket()->getRawOption(level, option, value, length);
}
} } // namespace Poco::Net

View File

@ -25,7 +25,7 @@ objects = NetSSLTestSuite Driver \
HTTPSClientSessionTest HTTPSClientTestSuite HTTPSServerTest HTTPSServerTestSuite \
HTTPSStreamFactoryTest HTTPSTestServer TCPServerTest TCPServerTestSuite \
WebSocketTest WebSocketTestSuite FTPSClientSessionTest FTPSClientTestSuite \
DialogServer
DialogServer SecureStreamSocketTest SecureStreamSocketTestSuite
target = testrunner
target_version = 1

View File

@ -9,7 +9,7 @@
#include "NetSSLTestSuite.h"
#include "SecureStreamSocketTestSuite.h"
#include "HTTPSClientTestSuite.h"
#include "TCPServerTestSuite.h"
#include "HTTPSServerTestSuite.h"
@ -21,6 +21,7 @@ CppUnit::Test* NetSSLTestSuite::suite()
{
CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("OpenSSLTestSuite");
pSuite->addTest(SecureStreamSocketTestSuite::suite());
pSuite->addTest(HTTPSClientTestSuite::suite());
pSuite->addTest(TCPServerTestSuite::suite());
pSuite->addTest(HTTPSServerTestSuite::suite());

View File

@ -0,0 +1,222 @@
//
// SecureStreamSocketTest.cpp
//
// Copyright (c) 2006, Applied Informatics Software Engineering GmbH.
// and Contributors.
//
// SPDX-License-Identifier: BSL-1.0
//
#include "SecureStreamSocketTest.h"
#include "CppUnit/TestCaller.h"
#include "CppUnit/TestSuite.h"
#include "Poco/Net/TCPServer.h"
#include "Poco/Net/TCPServerConnection.h"
#include "Poco/Net/TCPServerConnectionFactory.h"
#include "Poco/Net/TCPServerParams.h"
#include "Poco/Net/SecureStreamSocket.h"
#include "Poco/Net/SecureServerSocket.h"
#include "Poco/Net/Context.h"
#include "Poco/Net/RejectCertificateHandler.h"
#include "Poco/Net/AcceptCertificateHandler.h"
#include "Poco/Net/Session.h"
#include "Poco/Net/SSLManager.h"
#include "Poco/Util/Application.h"
#include "Poco/Util/AbstractConfiguration.h"
#include "Poco/Thread.h"
#include <iostream>
using Poco::Net::TCPServer;
using Poco::Net::TCPServerConnection;
using Poco::Net::TCPServerConnectionFactory;
using Poco::Net::TCPServerConnectionFactoryImpl;
using Poco::Net::TCPServerParams;
using Poco::Net::StreamSocket;
using Poco::Net::SecureStreamSocket;
using Poco::Net::SecureServerSocket;
using Poco::Net::SocketAddress;
using Poco::Net::Context;
using Poco::Net::Session;
using Poco::Net::SSLManager;
using Poco::Thread;
using Poco::Util::Application;
namespace
{
class EchoConnection: public TCPServerConnection
{
public:
EchoConnection(const StreamSocket& s): TCPServerConnection(s)
{
}
void run()
{
StreamSocket& ss = socket();
try
{
char buffer[256];
int n = ss.receiveBytes(buffer, sizeof(buffer));
while (n > 0)
{
ss.sendBytes(buffer, n);
n = ss.receiveBytes(buffer, sizeof(buffer));
}
}
catch (Poco::Exception& exc)
{
std::cerr << "EchoConnection: " << exc.displayText() << std::endl;
}
}
};
}
SecureStreamSocketTest::SecureStreamSocketTest(const std::string& name): CppUnit::TestCase(name)
{
}
SecureStreamSocketTest::~SecureStreamSocketTest()
{
}
void SecureStreamSocketTest::testSendReceive()
{
SecureServerSocket svs(0);
TCPServer srv(new TCPServerConnectionFactoryImpl<EchoConnection>(), svs);
srv.start();
SocketAddress sa("127.0.0.1", svs.address().port());
SecureStreamSocket ss1(sa);
std::string data("hello, world");
ss1.sendBytes(data.data(), static_cast<int>(data.size()));
char buffer[8192];
int n = ss1.receiveBytes(buffer, sizeof(buffer));
assertTrue (n > 0);
assertTrue (std::string(buffer, n) == data);
const std::vector<std::size_t> sizes = {67, 467, 7883, 19937};
for (const auto n: sizes)
{
data.assign(n, 'X');
ss1.sendBytes(data.data(), static_cast<int>(data.size()));
std::string received;
while (received.size() < n)
{
int rc = ss1.receiveBytes(buffer, sizeof(buffer));
if (rc > 0)
{
received.append(buffer, rc);
}
else if (n == 0)
{
break;
}
}
assertTrue (received == data);
}
ss1.close();
}
void SecureStreamSocketTest::testPeek()
{
SecureServerSocket svs(0);
TCPServer srv(new TCPServerConnectionFactoryImpl<EchoConnection>(), svs);
srv.start();
SocketAddress sa("127.0.0.1", svs.address().port());
SecureStreamSocket ss(sa);
int n = ss.sendBytes("hello, world!", 13);
assertTrue (n == 13);
char buffer[256];
n = ss.receiveBytes(buffer, 5, MSG_PEEK);
assertTrue (n == 5);
assertTrue (std::string(buffer, n) == "hello");
n = ss.receiveBytes(buffer, sizeof(buffer), MSG_PEEK);
assertTrue (n == 13);
assertTrue (std::string(buffer, n) == "hello, world!");
n = ss.receiveBytes(buffer, 7);
assertTrue (n == 7);
assertTrue (std::string(buffer, n) == "hello, ");
n = ss.receiveBytes(buffer, 6);
assertTrue (n == 6);
assertTrue (std::string(buffer, n) == "world!");
ss.close();
}
void SecureStreamSocketTest::testNB()
{
SecureServerSocket svs(0);
TCPServer srv(new TCPServerConnectionFactoryImpl<EchoConnection>(), svs);
srv.start();
SocketAddress sa("127.0.0.1", svs.address().port());
SecureStreamSocket ss1(sa);
ss1.setBlocking(false);
ss1.setSendBufferSize(32000);
char buffer[8192];
const std::vector<std::size_t> sizes = {67, 467, 7883, 19937};
for (const auto n: sizes)
{
std::string data(n, 'X');
int rc = ss1.sendBytes(data.data(), static_cast<int>(data.size()));
assertTrue (rc == n);
rc = -1;
while (rc < 0)
{
rc = ss1.receiveBytes(buffer, sizeof(buffer), MSG_PEEK);
}
assertTrue (rc > 0 && rc <= n);
assertTrue (data.compare(0, rc, buffer, rc) == 0);
std::string received;
while (received.size() < n)
{
int rc = ss1.receiveBytes(buffer, sizeof(buffer));
if (rc > 0)
{
received.append(buffer, rc);
}
else if (n == 0)
{
break;
}
}
assertTrue (received == data);
}
ss1.close();
}
void SecureStreamSocketTest::setUp()
{
}
void SecureStreamSocketTest::tearDown()
{
}
CppUnit::Test* SecureStreamSocketTest::suite()
{
CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("SecureStreamSocketTest");
CppUnit_addTest(pSuite, SecureStreamSocketTest, testSendReceive);
CppUnit_addTest(pSuite, SecureStreamSocketTest, testPeek);
CppUnit_addTest(pSuite, SecureStreamSocketTest, testNB);
return pSuite;
}

View File

@ -0,0 +1,40 @@
//
// SecureStreamSocketTest.h
//
// Definition of the SecureStreamSocketTest class.
//
// Copyright (c) 2006, Applied Informatics Software Engineering GmbH.
// and Contributors.
//
// SPDX-License-Identifier: BSL-1.0
//
#ifndef SecureStreamSocketTest_INCLUDED
#define SecureStreamSocketTest_INCLUDED
#include "Poco/Net/Net.h"
#include "CppUnit/TestCase.h"
class SecureStreamSocketTest: public CppUnit::TestCase
{
public:
SecureStreamSocketTest(const std::string& name);
~SecureStreamSocketTest();
void testSendReceive();
void testPeek();
void testNB();
void setUp();
void tearDown();
static CppUnit::Test* suite();
private:
};
#endif // SecureStreamSocketTest

View File

@ -0,0 +1,22 @@
//
// SecureStreamSocketTestSuite.cpp
//
// Copyright (c) 2006, Applied Informatics Software Engineering GmbH.
// and Contributors.
//
// SPDX-License-Identifier: BSL-1.0
//
#include "SecureStreamSocketTestSuite.h"
#include "SecureStreamSocketTest.h"
CppUnit::Test* SecureStreamSocketTestSuite::suite()
{
CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("SecureStreamSocketTestSuite");
pSuite->addTest(SecureStreamSocketTest::suite());
return pSuite;
}

View File

@ -0,0 +1,27 @@
//
// SecureStreamSocketTestSuite.h
//
// Definition of the SecureStreamSocketTestSuite class.
//
// Copyright (c) 2006, Applied Informatics Software Engineering GmbH.
// and Contributors.
//
// SPDX-License-Identifier: BSL-1.0
//
#ifndef SecureStreamSocketTestSuite_INCLUDED
#define SecureStreamSocketTestSuite_INCLUDED
#include "CppUnit/TestSuite.h"
class SecureStreamSocketTestSuite
{
public:
static CppUnit::Test* suite();
};
#endif // SecureStreamSocketTestSuite

View File

@ -49,16 +49,16 @@ namespace
try
{
WebSocket ws(request, response);
std::unique_ptr<char[]> pBuffer(new char[_bufSize]);
Poco::Buffer<char> buffer(_bufSize);
int flags;
int n;
do
{
n = ws.receiveFrame(pBuffer.get(), static_cast<int>(_bufSize), flags);
Poco::Thread::current()->sleep(handleDelay.totalMilliseconds());
n = ws.receiveFrame(buffer.begin(), static_cast<int>(_bufSize), flags);
if (n == 0)
break;
ws.sendFrame(pBuffer.get(), n, flags);
Poco::Thread::current()->sleep(handleDelay.totalMilliseconds());
ws.sendFrame(buffer.begin(), n, flags);
}
while ((flags & WebSocket::FRAME_OP_BITMASK) != WebSocket::FRAME_OP_CLOSE);
}
@ -156,6 +156,8 @@ void WebSocketTest::testWebSocketTimeout()
ws.shutdown();
ws.receiveFrame(buffer, sizeof(buffer), flags);
server.stop();
WebSocketRequestHandler::handleDelay = 0;
}
@ -220,7 +222,7 @@ void WebSocketTest::testWebSocket()
assertTrue (n == 2);
assertTrue ((flags & WebSocket::FRAME_OP_BITMASK) == WebSocket::FRAME_OP_CLOSE);
server.stop();
server.stopAll(true);
}
@ -256,7 +258,74 @@ void WebSocketTest::testWebSocketLarge()
assertTrue (n == payload.size());
assertTrue (payload.compare(0, payload.size(), buffer, 0, n) == 0);
server.stop();
server.stopAll(true);
}
void WebSocketTest::testWebSocketNB()
{
Poco::Net::SecureServerSocket ss(0);
Poco::Net::HTTPServer server(new WebSocketRequestHandlerFactory(256*1024), ss, new Poco::Net::HTTPServerParams);
server.start();
Poco::Thread::sleep(200);
HTTPSClientSession cs("127.0.0.1", ss.address().port());
HTTPRequest request(HTTPRequest::HTTP_GET, "/ws", HTTPRequest::HTTP_1_1);
HTTPResponse response;
WebSocket ws(cs, request, response);
ws.setBlocking(false);
int flags;
char buffer[256*1024] = {};
int n = ws.receiveFrame(buffer, sizeof(buffer), flags);
assertTrue (n < 0);
std::string payload("x");
n = ws.sendFrame(payload.data(), (int) payload.size());
assertTrue (n > 0);
if (ws.poll(1000000, Poco::Net::Socket::SELECT_READ))
{
n = ws.receiveFrame(buffer, sizeof(buffer), flags);
while (n < 0)
{
n = ws.receiveFrame(buffer, sizeof(buffer), flags);
}
}
assertTrue (n == payload.size());
assertTrue (payload.compare(0, payload.size(), buffer, n) == 0);
assertTrue (flags == WebSocket::FRAME_TEXT);
ws.setSendBufferSize(256*1024);
ws.setReceiveBufferSize(256*1024);
payload.assign(256000, 'z');
n = ws.sendFrame(payload.data(), (int) payload.size());
assertTrue (n > 0);
if (ws.poll(1000000, Poco::Net::Socket::SELECT_READ))
{
n = ws.receiveFrame(buffer, sizeof(buffer), flags);
while (n < 0)
{
n = ws.receiveFrame(buffer, sizeof(buffer), flags);
}
}
assertTrue (n == payload.size());
assertTrue (payload.compare(0, payload.size(), buffer, n) == 0);
assertTrue (flags == WebSocket::FRAME_TEXT);
n = ws.shutdown();
assertTrue (n > 0);
n = ws.receiveFrame(buffer, sizeof(buffer), flags);
while (n < 0)
{
n = ws.receiveFrame(buffer, sizeof(buffer), flags);
}
assertTrue (n == 2);
assertTrue ((flags & WebSocket::FRAME_OP_BITMASK) == WebSocket::FRAME_OP_CLOSE);
server.stopAll(true);
}
@ -277,6 +346,7 @@ CppUnit::Test* WebSocketTest::suite()
CppUnit_addTest(pSuite, WebSocketTest, testWebSocket);
CppUnit_addTest(pSuite, WebSocketTest, testWebSocketTimeout);
CppUnit_addTest(pSuite, WebSocketTest, testWebSocketLarge);
CppUnit_addTest(pSuite, WebSocketTest, testWebSocketNB);
return pSuite;
}

View File

@ -27,6 +27,7 @@ public:
void testWebSocketTimeout();
void testWebSocket();
void testWebSocketLarge();
void testWebSocketNB();
void setUp() override;
void tearDown() override;