fix(NetSSL_Win): Non-blocking sockets support #4773

This commit is contained in:
Günter Obiltschnig
2024-11-16 16:50:52 +01:00
parent bf09be3f33
commit 7780f379ea
11 changed files with 479 additions and 28 deletions

View File

@@ -207,6 +207,12 @@ public:
int available() const;
/// Returns the number of bytes available in the buffer.
SocketImpl* socket();
/// Returns the underlying SocketImpl.
const SocketImpl* socket() const;
/// Returns the underlying SocketImpl.
protected:
enum
{
@@ -294,6 +300,8 @@ private:
Poco::Buffer<BYTE> _overflowBuffer;
Poco::Buffer<BYTE> _sendBuffer;
DWORD _sendBufferOffset;
DWORD _sendBufferPending;
Poco::Buffer<BYTE> _recvBuffer;
DWORD _recvBufferOffset;
DWORD _ioBufferSize;
@@ -316,6 +324,18 @@ private:
//
// inlines
//
inline SocketImpl* SecureSocketImpl::socket()
{
return _pSocket.get();
}
inline const SocketImpl* SecureSocketImpl::socket() const
{
return _pSocket.get();
}
inline poco_socket_t SecureSocketImpl::sockfd()
{
return _pSocket->sockfd();

View File

@@ -197,6 +197,12 @@ public:
/// Returns true iff a reused session was negotiated during
/// the handshake.
// SocketImpl
virtual void setBlocking(bool flag) override;
virtual bool getBlocking() const override;
virtual void setRawOption(int level, int option, const void* value, poco_socklen_t length) override;
virtual void getRawOption(int level, int option, void* value, poco_socklen_t& length) override;
protected:
void acceptSSL();
/// Performs a SSL server-side handshake.

View File

@@ -71,6 +71,8 @@ SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr<SocketImpl> pSocketImpl, Contex
_contextFlags(0),
_overflowBuffer(0),
_sendBuffer(0),
_sendBufferOffset(0),
_sendBufferPending(0),
_recvBuffer(IO_BUFFER_SIZE),
_recvBufferOffset(0),
_ioBufferSize(0),
@@ -362,6 +364,17 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags)
if (_state == ST_ERROR) return 0;
if (_sendBufferPending > 0)
{
int sent = sendRawBytes(_sendBuffer.begin() + _sendBufferOffset, _sendBufferPending, flags);
if (sent > 0)
{
_sendBufferOffset += sent;
_sendBufferPending -= sent;
}
return _sendBufferPending == 0 ? length : -1;
}
if (_state != ST_DONE)
{
bool establish = _pSocket->getBlocking();
@@ -379,14 +392,13 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags)
}
}
int rc = 0;
int dataToSend = length;
int dataSent = 0;
const char* pBuffer = reinterpret_cast<const char*>(buffer);
if (_sendBuffer.capacity() != _ioBufferSize)
_sendBuffer.setCapacity(_ioBufferSize);
_sendBuffer.resize(length + 1024);
_sendBufferOffset = 0;
_sendBufferPending = 0;
while (dataToSend > 0)
{
AutoSecBufferDesc<4> msg(&_securityFunctions, false);
@@ -396,11 +408,16 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags)
SecBuffer* pDataBuffer = 0;
SecBuffer* pExtraBuffer = 0;
std::memcpy(_sendBuffer.begin() + _streamSizes.cbHeader, pBuffer + dataSent, dataSize);
if (_sendBuffer.size() < _sendBufferPending + dataSize + _streamSizes.cbHeader + _streamSizes.cbTrailer)
{
_sendBuffer.resize(_sendBufferPending + dataSize + _streamSizes.cbHeader + _streamSizes.cbTrailer, true);
}
msg.setSecBufferStreamHeader(0, _sendBuffer.begin(), _streamSizes.cbHeader);
msg.setSecBufferData(1, _sendBuffer.begin() + _streamSizes.cbHeader, dataSize);
msg.setSecBufferStreamTrailer(2, _sendBuffer.begin() + _streamSizes.cbHeader + dataSize, _streamSizes.cbTrailer);
std::memcpy(_sendBuffer.begin() + _sendBufferPending + _streamSizes.cbHeader, pBuffer + dataSent, dataSize);
msg.setSecBufferStreamHeader(0, _sendBuffer.begin() + _sendBufferPending, _streamSizes.cbHeader);
msg.setSecBufferData(1, _sendBuffer.begin() + _sendBufferPending + _streamSizes.cbHeader, dataSize);
msg.setSecBufferStreamTrailer(2, _sendBuffer.begin() + _sendBufferPending + _streamSizes.cbHeader + dataSize, _streamSizes.cbTrailer);
msg.setSecBufferEmpty(3);
SECURITY_STATUS securityStatus = _securityFunctions.EncryptMessage(&_hContext, 0, &msg, 0);
@@ -408,24 +425,31 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags)
if (FAILED(securityStatus) && securityStatus != SEC_E_CONTEXT_EXPIRED)
throw SSLException("Failed to encrypt message", Utility::formatError(securityStatus));
int outBufferLen = msg[0].cbBuffer + msg[1].cbBuffer + msg[2].cbBuffer;
int sent = sendRawBytes(_sendBuffer.begin(), outBufferLen, flags);
if (_pSocket->getBlocking() && sent == -1)
{
if (dataSent == 0)
return -1;
else
return dataSent;
}
if (sent != outBufferLen)
throw SSLException("Failed to send encrypted message");
poco_assert_dbg (_streamSizes.cbHeader == msg[0].cbBuffer);
poco_assert_dbg (_streamSizes.cbTrailer == msg[2].cbBuffer);
poco_assert_dbg (dataSize == msg[1].cbBuffer);
_sendBufferPending += msg[0].cbBuffer + msg[1].cbBuffer + msg[2].cbBuffer;
dataToSend -= dataSize;
dataSent += dataSize;
rc += sent;
poco_assert (_sendBufferPending <= _sendBuffer.size());
}
int sent = sendRawBytes(_sendBuffer.begin(), _sendBufferPending, flags);
if (sent >= 0)
{
_sendBufferOffset += sent;
_sendBufferPending -= sent;
if (_sendBufferPending > 0)
return -1;
else
return dataSent;
}
else
{
return -1;
}
return dataSent;
}
@@ -463,14 +487,20 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags)
{
rc = length;
std::memcpy(buffer, _overflowBuffer.begin(), rc);
std::memmove(_overflowBuffer.begin(), _overflowBuffer.begin() + rc, overflowSize - rc);
_overflowBuffer.resize(overflowSize - rc);
if ((flags & MSG_PEEK) == 0)
{
std::memmove(_overflowBuffer.begin(), _overflowBuffer.begin() + rc, overflowSize - rc);
_overflowBuffer.resize(overflowSize - rc);
}
}
else
{
rc = static_cast<int>(overflowSize);
std::memcpy(buffer, _overflowBuffer.begin(), rc);
_overflowBuffer.resize(0);
if ((flags & MSG_PEEK) == 0)
{
_overflowBuffer.resize(0);
}
}
}
else
@@ -521,13 +551,21 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags)
rc = bytesDecoded;
if (rc > length)
rc = length;
if (flags & MSG_PEEK)
{
// prepend returned data to overflow buffer
std::size_t oldSize = _overflowBuffer.size();
_overflowBuffer.resize(oldSize + rc);
std::memmove(_overflowBuffer.begin() + rc, _overflowBuffer.begin(), oldSize);
std::memcpy(_overflowBuffer.begin(), buffer, rc);
}
return rc;
}
if (securityStatus == SEC_E_INCOMPLETE_MESSAGE)
{
if (!_pSocket->getBlocking())
return -1;
continue;
}

View File

@@ -213,4 +213,28 @@ int SecureStreamSocketImpl::completeHandshake()
}
void SecureStreamSocketImpl::setBlocking(bool flag)
{
_impl.socket()->setBlocking(flag);
}
bool SecureStreamSocketImpl::getBlocking() const
{
return _impl.socket()->getBlocking();
}
void SecureStreamSocketImpl::setRawOption(int level, int option, const void* value, poco_socklen_t length)
{
_impl.socket()->setRawOption(level, option, value, length);
}
void SecureStreamSocketImpl::getRawOption(int level, int option, void* value, poco_socklen_t& length)
{
_impl.socket()->getRawOption(level, option, value, length);
}
} } // namespace Poco::Net

View File

@@ -9,19 +9,22 @@
#include "NetSSLTestSuite.h"
#include "SecureStreamSocketTestSuite.h"
#include "HTTPSClientTestSuite.h"
#include "TCPServerTestSuite.h"
#include "HTTPSServerTestSuite.h"
#include "WebSocketTestSuite.h"
CppUnit::Test* NetSSLTestSuite::suite()
{
CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("NetSSLTestSuite");
pSuite->addTest(SecureStreamSocketTestSuite::suite());
pSuite->addTest(HTTPSClientTestSuite::suite());
pSuite->addTest(TCPServerTestSuite::suite());
pSuite->addTest(HTTPSServerTestSuite::suite());
pSuite->addTest(WebSocketTestSuite::suite());
return pSuite;
}

View File

@@ -0,0 +1,222 @@
//
// SecureStreamSocketTest.cpp
//
// Copyright (c) 2006, Applied Informatics Software Engineering GmbH.
// and Contributors.
//
// SPDX-License-Identifier: BSL-1.0
//
#include "SecureStreamSocketTest.h"
#include "CppUnit/TestCaller.h"
#include "CppUnit/TestSuite.h"
#include "Poco/Net/TCPServer.h"
#include "Poco/Net/TCPServerConnection.h"
#include "Poco/Net/TCPServerConnectionFactory.h"
#include "Poco/Net/TCPServerParams.h"
#include "Poco/Net/SecureStreamSocket.h"
#include "Poco/Net/SecureServerSocket.h"
#include "Poco/Net/Context.h"
#include "Poco/Net/RejectCertificateHandler.h"
#include "Poco/Net/AcceptCertificateHandler.h"
#include "Poco/Net/Session.h"
#include "Poco/Net/SSLManager.h"
#include "Poco/Util/Application.h"
#include "Poco/Util/AbstractConfiguration.h"
#include "Poco/Thread.h"
#include <iostream>
using Poco::Net::TCPServer;
using Poco::Net::TCPServerConnection;
using Poco::Net::TCPServerConnectionFactory;
using Poco::Net::TCPServerConnectionFactoryImpl;
using Poco::Net::TCPServerParams;
using Poco::Net::StreamSocket;
using Poco::Net::SecureStreamSocket;
using Poco::Net::SecureServerSocket;
using Poco::Net::SocketAddress;
using Poco::Net::Context;
using Poco::Net::Session;
using Poco::Net::SSLManager;
using Poco::Thread;
using Poco::Util::Application;
namespace
{
class EchoConnection: public TCPServerConnection
{
public:
EchoConnection(const StreamSocket& s): TCPServerConnection(s)
{
}
void run()
{
StreamSocket& ss = socket();
try
{
char buffer[256];
int n = ss.receiveBytes(buffer, sizeof(buffer));
while (n > 0)
{
ss.sendBytes(buffer, n);
n = ss.receiveBytes(buffer, sizeof(buffer));
}
}
catch (Poco::Exception& exc)
{
std::cerr << "EchoConnection: " << exc.displayText() << std::endl;
}
}
};
}
SecureStreamSocketTest::SecureStreamSocketTest(const std::string& name): CppUnit::TestCase(name)
{
}
SecureStreamSocketTest::~SecureStreamSocketTest()
{
}
void SecureStreamSocketTest::testSendReceive()
{
SecureServerSocket svs(0);
TCPServer srv(new TCPServerConnectionFactoryImpl<EchoConnection>(), svs);
srv.start();
SocketAddress sa("127.0.0.1", svs.address().port());
SecureStreamSocket ss1(sa);
std::string data("hello, world");
ss1.sendBytes(data.data(), static_cast<int>(data.size()));
char buffer[8192];
int n = ss1.receiveBytes(buffer, sizeof(buffer));
assertTrue (n > 0);
assertTrue (std::string(buffer, n) == data);
const std::vector<std::size_t> sizes = {67, 467, 7883, 19937};
for (const auto n: sizes)
{
data.assign(n, 'X');
ss1.sendBytes(data.data(), static_cast<int>(data.size()));
std::string received;
while (received.size() < n)
{
int rc = ss1.receiveBytes(buffer, sizeof(buffer));
if (rc > 0)
{
received.append(buffer, rc);
}
else if (n == 0)
{
break;
}
}
assertTrue (received == data);
}
ss1.close();
}
void SecureStreamSocketTest::testPeek()
{
SecureServerSocket svs(0);
TCPServer srv(new TCPServerConnectionFactoryImpl<EchoConnection>(), svs);
srv.start();
SocketAddress sa("127.0.0.1", svs.address().port());
SecureStreamSocket ss(sa);
int n = ss.sendBytes("hello, world!", 13);
assertTrue (n == 13);
char buffer[256];
n = ss.receiveBytes(buffer, 5, MSG_PEEK);
assertTrue (n == 5);
assertTrue (std::string(buffer, n) == "hello");
n = ss.receiveBytes(buffer, sizeof(buffer), MSG_PEEK);
assertTrue (n == 13);
assertTrue (std::string(buffer, n) == "hello, world!");
n = ss.receiveBytes(buffer, 7);
assertTrue (n == 7);
assertTrue (std::string(buffer, n) == "hello, ");
n = ss.receiveBytes(buffer, 6);
assertTrue (n == 6);
assertTrue (std::string(buffer, n) == "world!");
ss.close();
}
void SecureStreamSocketTest::testNB()
{
SecureServerSocket svs(0);
TCPServer srv(new TCPServerConnectionFactoryImpl<EchoConnection>(), svs);
srv.start();
SocketAddress sa("127.0.0.1", svs.address().port());
SecureStreamSocket ss1(sa);
ss1.setBlocking(false);
ss1.setSendBufferSize(32000);
char buffer[8192];
const std::vector<std::size_t> sizes = {67, 467, 7883, 19937};
for (const auto n: sizes)
{
std::string data(n, 'X');
int rc = ss1.sendBytes(data.data(), static_cast<int>(data.size()));
assertTrue (rc == n);
rc = -1;
while (rc < 0)
{
rc = ss1.receiveBytes(buffer, sizeof(buffer), MSG_PEEK);
}
assertTrue (rc > 0 && rc <= n);
assertTrue (data.compare(0, rc, buffer, rc) == 0);
std::string received;
while (received.size() < n)
{
int rc = ss1.receiveBytes(buffer, sizeof(buffer));
if (rc > 0)
{
received.append(buffer, rc);
}
else if (n == 0)
{
break;
}
}
assertTrue (received == data);
}
ss1.close();
}
void SecureStreamSocketTest::setUp()
{
}
void SecureStreamSocketTest::tearDown()
{
}
CppUnit::Test* SecureStreamSocketTest::suite()
{
CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("SecureStreamSocketTest");
CppUnit_addTest(pSuite, SecureStreamSocketTest, testSendReceive);
CppUnit_addTest(pSuite, SecureStreamSocketTest, testPeek);
CppUnit_addTest(pSuite, SecureStreamSocketTest, testNB);
return pSuite;
}

View File

@@ -0,0 +1,40 @@
//
// SecureStreamSocketTest.h
//
// Definition of the SecureStreamSocketTest class.
//
// Copyright (c) 2006, Applied Informatics Software Engineering GmbH.
// and Contributors.
//
// SPDX-License-Identifier: BSL-1.0
//
#ifndef SecureStreamSocketTest_INCLUDED
#define SecureStreamSocketTest_INCLUDED
#include "Poco/Net/Net.h"
#include "CppUnit/TestCase.h"
class SecureStreamSocketTest: public CppUnit::TestCase
{
public:
SecureStreamSocketTest(const std::string& name);
~SecureStreamSocketTest();
void testSendReceive();
void testPeek();
void testNB();
void setUp();
void tearDown();
static CppUnit::Test* suite();
private:
};
#endif // SecureStreamSocketTest

View File

@@ -0,0 +1,22 @@
//
// SecureStreamSocketTestSuite.cpp
//
// Copyright (c) 2006, Applied Informatics Software Engineering GmbH.
// and Contributors.
//
// SPDX-License-Identifier: BSL-1.0
//
#include "SecureStreamSocketTestSuite.h"
#include "SecureStreamSocketTest.h"
CppUnit::Test* SecureStreamSocketTestSuite::suite()
{
CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("SecureStreamSocketTestSuite");
pSuite->addTest(SecureStreamSocketTest::suite());
return pSuite;
}

View File

@@ -0,0 +1,27 @@
//
// SecureStreamSocketTestSuite.h
//
// Definition of the SecureStreamSocketTestSuite class.
//
// Copyright (c) 2006, Applied Informatics Software Engineering GmbH.
// and Contributors.
//
// SPDX-License-Identifier: BSL-1.0
//
#ifndef SecureStreamSocketTestSuite_INCLUDED
#define SecureStreamSocketTestSuite_INCLUDED
#include "CppUnit/TestSuite.h"
class SecureStreamSocketTestSuite
{
public:
static CppUnit::Test* suite();
};
#endif // SecureStreamSocketTestSuite

View File

@@ -0,0 +1,22 @@
//
// WebSocketTestSuite.cpp
//
// Copyright (c) 2012, Applied Informatics Software Engineering GmbH.
// and Contributors.
//
// SPDX-License-Identifier: BSL-1.0
//
#include "WebSocketTestSuite.h"
#include "WebSocketTest.h"
CppUnit::Test* WebSocketTestSuite::suite()
{
CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("WebSocketTestSuite");
pSuite->addTest(WebSocketTest::suite());
return pSuite;
}

View File

@@ -0,0 +1,27 @@
//
// WebSocketTestSuite.h
//
// Definition of the WebSocketTestSuite class.
//
// Copyright (c) 2012, Applied Informatics Software Engineering GmbH.
// and Contributors.
//
// SPDX-License-Identifier: BSL-1.0
//
#ifndef WebSocketTestSuite_INCLUDED
#define WebSocketTestSuite_INCLUDED
#include "CppUnit/TestSuite.h"
class WebSocketTestSuite
{
public:
static CppUnit::Test* suite();
};
#endif // WebSocketTestSuite_INCLUDED