fix(SecureSocket): Refactor detection of timeout when reading, writing and handshaking (#4510)

* fix(SecureSocket): Refactor detection of timeout when reading, writing or handshaking. (#3725)

* enh(SecureSocket): some trivial C++17 modernisation changes.

* chore: indentation and compiler warning

---------

Co-authored-by: Alex Fabijanic <alex@pocoproject.org>
This commit is contained in:
Matej Kenda 2024-03-27 00:29:58 +01:00 committed by GitHub
parent 5fec3deeb9
commit 482c066307
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 159 additions and 100 deletions

View File

@ -147,7 +147,7 @@ SocketImpl* SocketImpl::acceptConnection(SocketAddress& clientAddr)
return new StreamSocketImpl(sd);
}
error(); // will throw
return 0;
return nullptr;
}
@ -527,7 +527,7 @@ int SocketImpl::sendTo(const SocketBufVec& buffers, const SocketAddress& address
msgHdr.msg_namelen = address.length();
msgHdr.msg_iov = const_cast<iovec*>(&buffers[0]);
msgHdr.msg_iovlen = buffers.size();
msgHdr.msg_control = 0;
msgHdr.msg_control = nullptr;
msgHdr.msg_controllen = 0;
msgHdr.msg_flags = flags;
rc = sendmsg(_sockfd, &msgHdr, flags);
@ -613,7 +613,7 @@ int SocketImpl::receiveFrom(SocketBufVec& buffers, struct sockaddr** pSA, poco_s
msgHdr.msg_namelen = **ppSALen;
msgHdr.msg_iov = &buffers[0];
msgHdr.msg_iovlen = buffers.size();
msgHdr.msg_control = 0;
msgHdr.msg_control = nullptr;
msgHdr.msg_controllen = 0;
msgHdr.msg_flags = flags;
rc = recvmsg(_sockfd, &msgHdr, flags);
@ -652,7 +652,7 @@ int SocketImpl::available()
if (result && (type() == SOCKET_TYPE_DATAGRAM))
{
std::vector<char> buf(result);
result = recvfrom(sockfd(), &buf[0], result, MSG_PEEK, NULL, NULL);
result = recvfrom(sockfd(), &buf[0], result, MSG_PEEK, nullptr, nullptr);
}
#endif
return result;
@ -1114,7 +1114,7 @@ void SocketImpl::setReusePort(bool flag)
int value = flag ? 1 : 0;
setOption(SOL_SOCKET, SO_REUSEPORT, value);
}
catch (IOException&)
catch (const IOException&)
{
// ignore error, since not all implementations
// support SO_REUSEPORT, even if the macro

View File

@ -132,7 +132,7 @@ public:
/// a SecureStreamSocketImpl, otherwise an InvalidArgumentException
/// will be thrown.
virtual ~SecureStreamSocket();
~SecureStreamSocket() override;
/// Destroys the StreamSocket.
SecureStreamSocket& operator = (const Socket& socket);

View File

@ -39,12 +39,12 @@ public:
SecureStreamSocketImpl(StreamSocketImpl* pStreamSocket, Context::Ptr pContext);
/// Creates the SecureStreamSocketImpl.
SocketImpl* acceptConnection(SocketAddress& clientAddr);
SocketImpl* acceptConnection(SocketAddress& clientAddr) override;
/// Not supported by a SecureStreamSocket.
///
/// Throws a Poco::InvalidAccessException.
void connect(const SocketAddress& address);
void connect(const SocketAddress& address) override;
/// Initializes the socket and establishes a connection to
/// the TCP server at the given address.
///
@ -52,57 +52,57 @@ public:
/// connection is established. Instead, incoming and outgoing
/// packets are restricted to the specified address.
void connect(const SocketAddress& address, const Poco::Timespan& timeout);
void connect(const SocketAddress& address, const Poco::Timespan& timeout) override;
/// Initializes the socket, sets the socket timeout and
/// establishes a connection to the TCP server at the given address.
void connectNB(const SocketAddress& address);
void connectNB(const SocketAddress& address) override;
/// Initializes the socket and establishes a connection to
/// the TCP server at the given address. Prior to opening the
/// connection the socket is set to nonblocking mode.
void bind(const SocketAddress& address, bool reuseAddress = false);
void bind(const SocketAddress& address, bool reuseAddress = false) override;
/// Not supported by a SecureStreamSocket.
///
/// Throws a Poco::InvalidAccessException.
void listen(int backlog = 64);
void listen(int backlog = 64) override;
/// Not supported by a SecureStreamSocket.
///
/// Throws a Poco::InvalidAccessException.
void close();
void close() override;
/// Close the socket.
int sendBytes(const void* buffer, int length, int flags = 0);
int sendBytes(const void* buffer, int length, int flags = 0) override;
/// Sends the contents of the given buffer through
/// the socket. Any specified flags are ignored.
///
/// Returns the number of bytes sent, which may be
/// less than the number of bytes specified.
int receiveBytes(void* buffer, int length, int flags = 0);
int receiveBytes(void* buffer, int length, int flags = 0) override;
/// Receives data from the socket and stores it
/// in buffer. Up to length bytes are received.
///
/// Returns the number of bytes received.
int sendTo(const void* buffer, int length, const SocketAddress& address, int flags = 0);
int sendTo(const void* buffer, int length, const SocketAddress& address, int flags = 0) override;
/// Not supported by a SecureStreamSocket.
///
/// Throws a Poco::InvalidAccessException.
int receiveFrom(void* buffer, int length, SocketAddress& address, int flags = 0);
int receiveFrom(void* buffer, int length, SocketAddress& address, int flags = 0) override;
/// Not supported by a SecureStreamSocket.
///
/// Throws a Poco::InvalidAccessException.
void sendUrgent(unsigned char data);
void sendUrgent(unsigned char data) override;
/// Not supported by a SecureStreamSocket.
///
/// Throws a Poco::InvalidAccessException.
int available();
int available() override;
/// Returns the number of bytes available that can be read
/// without causing the socket to block.
///
@ -110,26 +110,26 @@ public:
/// can be read from the currently buffered SSL record,
/// before a new record is read from the underlying socket.
void shutdownReceive();
void shutdownReceive() override;
/// Shuts down the receiving part of the socket connection.
///
/// Since SSL does not support a half shutdown, this does
/// nothing.
void shutdownSend();
void shutdownSend() override;
/// Shuts down the receiving part of the socket connection.
///
/// Since SSL does not support a half shutdown, this does
/// nothing.
void shutdown();
void shutdown() override;
/// Shuts down the SSL connection.
void abort();
/// Aborts the connection by closing the underlying
/// TCP connection. No orderly SSL shutdown is performed.
bool secure() const;
bool secure() const override;
/// Returns true iff the socket's connection is secure
/// (using SSL or TLS).
@ -203,7 +203,7 @@ protected:
void connectSSL();
/// Performs a SSL client-side handshake on an already connected TCP socket.
~SecureStreamSocketImpl();
~SecureStreamSocketImpl() override;
/// Destroys the SecureStreamSocketImpl.
static int lastError();

View File

@ -22,10 +22,7 @@
#include "Poco/Net/SecureStreamSocketImpl.h"
#include "Poco/Net/StreamSocketImpl.h"
#include "Poco/Net/StreamSocket.h"
#include "Poco/Net/NetException.h"
#include "Poco/Net/DNS.h"
#include "Poco/NumberFormatter.h"
#include "Poco/NumberParser.h"
#include "Poco/Format.h"
#include <openssl/x509v3.h>
#include <openssl/err.h>
@ -38,16 +35,12 @@ using Poco::NumberFormatter;
using Poco::Timespan;
// workaround for C++-incompatible macro
#define POCO_BIO_set_nbio_accept(b,n) BIO_ctrl(b,BIO_C_SET_ACCEPT,1,(void*)((n)?"a":NULL))
namespace Poco {
namespace Net {
SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr<SocketImpl> pSocketImpl, Context::Ptr pContext):
_pSSL(0),
_pSSL(nullptr),
_pSocket(pSocketImpl),
_pContext(pContext),
_needHandshake(false)
@ -86,14 +79,14 @@ void SecureSocketImpl::acceptSSL()
{
poco_assert (!_pSSL);
BIO* pBIO = BIO_new(BIO_s_socket());
BIO* pBIO = ::BIO_new(BIO_s_socket());
if (!pBIO) throw SSLException("Cannot create BIO object");
BIO_set_fd(pBIO, static_cast<int>(_pSocket->sockfd()), BIO_NOCLOSE);
_pSSL = SSL_new(_pContext->sslContext());
_pSSL = ::SSL_new(_pContext->sslContext());
if (!_pSSL)
{
BIO_free(pBIO);
::BIO_free(pBIO);
throw SSLException("Cannot create SSL object");
}
@ -105,15 +98,15 @@ void SecureSocketImpl::acceptSSL()
* tickets. */
if (1 != SSL_set_num_tickets(_pSSL, 0))
{
BIO_free(pBIO);
::BIO_free(pBIO);
throw SSLException("Cannot create SSL object");
}
//Otherwise we can perform two-way shutdown. Client must call SSL_read() before the final SSL_shutdown().
#endif
SSL_set_bio(_pSSL, pBIO, pBIO);
SSL_set_accept_state(_pSSL);
SSL_set_ex_data(_pSSL, SSLManager::instance().socketIndex(), this);
::SSL_set_bio(_pSSL, pBIO, pBIO);
::SSL_set_accept_state(_pSSL);
::SSL_set_ex_data(_pSSL, SSLManager::instance().socketIndex(), this);
_needHandshake = true;
}
@ -162,18 +155,18 @@ void SecureSocketImpl::connectSSL(bool performHandshake)
poco_assert (!_pSSL);
poco_assert (_pSocket->initialized());
BIO* pBIO = BIO_new(BIO_s_socket());
::BIO* pBIO = ::BIO_new(BIO_s_socket());
if (!pBIO) throw SSLException("Cannot create SSL BIO object");
BIO_set_fd(pBIO, static_cast<int>(_pSocket->sockfd()), BIO_NOCLOSE);
_pSSL = SSL_new(_pContext->sslContext());
_pSSL = ::SSL_new(_pContext->sslContext());
if (!_pSSL)
{
BIO_free(pBIO);
::BIO_free(pBIO);
throw SSLException("Cannot create SSL object");
}
SSL_set_bio(_pSSL, pBIO, pBIO);
SSL_set_ex_data(_pSSL, SSLManager::instance().socketIndex(), this);
::SSL_set_bio(_pSSL, pBIO, pBIO);
::SSL_set_ex_data(_pSSL, SSLManager::instance().socketIndex(), this);
if (!_peerHostName.empty())
{
@ -189,27 +182,27 @@ void SecureSocketImpl::connectSSL(bool performHandshake)
if (_pSession && _pSession->isResumable())
{
SSL_set_session(_pSSL, _pSession->sslSession());
::SSL_set_session(_pSSL, _pSession->sslSession());
}
try
{
if (performHandshake && _pSocket->getBlocking())
{
int ret = SSL_connect(_pSSL);
int ret = ::SSL_connect(_pSSL);
handleError(ret);
verifyPeerCertificate();
}
else
{
SSL_set_connect_state(_pSSL);
::SSL_set_connect_state(_pSSL);
_needHandshake = true;
}
}
catch (...)
{
SSL_free(_pSSL);
_pSSL = 0;
::SSL_free(_pSSL);
_pSSL = nullptr;
throw;
}
}
@ -259,11 +252,11 @@ void SecureSocketImpl::shutdown()
{
if (_pSSL)
{
// Don't shut down the socket more than once.
int shutdownState = SSL_get_shutdown(_pSSL);
bool shutdownSent = (shutdownState & SSL_SENT_SHUTDOWN) == SSL_SENT_SHUTDOWN;
if (!shutdownSent)
{
// Don't shut down the socket more than once.
int shutdownState = ::SSL_get_shutdown(_pSSL);
bool shutdownSent = (shutdownState & SSL_SENT_SHUTDOWN) == SSL_SENT_SHUTDOWN;
if (!shutdownSent)
{
// A proper clean shutdown would require us to
// retry the shutdown if we get a zero return
// value, until SSL_shutdown() returns 1.
@ -274,7 +267,7 @@ void SecureSocketImpl::shutdown()
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
int rc = 0;
if (!_bidirectShutdown)
rc = SSL_shutdown(_pSSL);
rc = ::SSL_shutdown(_pSSL);
else
{
Poco::Timespan recvTimeout = _pSocket->getReceiveTimeout();
@ -282,11 +275,11 @@ void SecureSocketImpl::shutdown()
Poco::Timestamp tsNow;
do
{
rc = SSL_shutdown(_pSSL);
rc = ::SSL_shutdown(_pSSL);
if (rc == 1) break;
if (rc < 0)
{
int err = SSL_get_error(_pSSL, rc);
int err = ::SSL_get_error(_pSSL, rc);
if (err == SSL_ERROR_WANT_READ)
_pSocket->poll(pollTimeout, Poco::Net::Socket::SELECT_READ);
else if (err == SSL_ERROR_WANT_WRITE)
@ -294,7 +287,7 @@ void SecureSocketImpl::shutdown()
else
{
int socketError = SocketImpl::lastError();
long lastError = ERR_get_error();
long lastError = ::ERR_get_error();
if ((err == SSL_ERROR_SSL) && (socketError == 0) && (lastError == 0x0A000123))
rc = 0;
break;
@ -304,7 +297,7 @@ void SecureSocketImpl::shutdown()
} while (!tsNow.isElapsed(recvTimeout.totalMicroseconds()));
}
#else
int rc = SSL_shutdown(_pSSL);
int rc = ::SSL_shutdown(_pSSL);
#endif
if (rc < 0) handleError(rc);
if (_pSocket->getBlocking())
@ -361,11 +354,17 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags)
else
return rc;
}
do
const auto sendTimeout = _pSocket->getSendTimeout();
Poco::Timestamp tsStart;
while (true)
{
rc = SSL_write(_pSSL, buffer, length);
}
while (mustRetry(rc));
rc = ::SSL_write(_pSSL, buffer, length);
if (!mustRetry(rc))
break;
if (tsStart.isElapsed(sendTimeout.totalMicroseconds()))
throw Poco::TimeoutException();
};
if (rc <= 0)
{
rc = handleError(rc);
@ -389,11 +388,18 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags)
else
return rc;
}
do
const auto recvTimeout = _pSocket->getReceiveTimeout();
Poco::Timestamp tsStart;
while (true)
{
rc = SSL_read(_pSSL, buffer, length);
}
while (mustRetry(rc));
rc = ::SSL_read(_pSSL, buffer, length);
if (!mustRetry(rc))
break;
if (tsStart.isElapsed(recvTimeout.totalMicroseconds()))
throw Poco::TimeoutException();
};
_bidirectShutdown = false;
if (rc <= 0)
{
@ -407,7 +413,7 @@ int SecureSocketImpl::available() const
{
poco_check_ptr (_pSSL);
return SSL_pending(_pSSL);
return ::SSL_pending(_pSSL);
}
@ -417,11 +423,17 @@ int SecureSocketImpl::completeHandshake()
poco_check_ptr (_pSSL);
int rc;
do
const auto recvTimeout = _pSocket->getReceiveTimeout();
Poco::Timestamp tsStart;
while (true)
{
rc = SSL_do_handshake(_pSSL);
}
while (mustRetry(rc));
rc = ::SSL_do_handshake(_pSSL);
if (!mustRetry(rc))
break;
if (tsStart.isElapsed(recvTimeout.totalMicroseconds()))
throw Poco::TimeoutException();
};
if (rc <= 0)
{
return handleError(rc);
@ -460,7 +472,7 @@ long SecureSocketImpl::verifyPeerCertificateImpl(const std::string& hostName)
return X509_V_OK;
}
X509* pCert = SSL_get_peer_certificate(_pSSL);
::X509* pCert = ::SSL_get_peer_certificate(_pSSL);
if (pCert)
{
X509Certificate cert(pCert);
@ -477,7 +489,7 @@ bool SecureSocketImpl::isLocalHost(const std::string& hostName)
SocketAddress addr(hostName, 0);
return addr.host().isLoopback();
}
catch (Poco::Exception&)
catch (const Poco::Exception&)
{
return false;
}
@ -487,9 +499,9 @@ bool SecureSocketImpl::isLocalHost(const std::string& hostName)
X509* SecureSocketImpl::peerCertificate() const
{
if (_pSSL)
return SSL_get_peer_certificate(_pSSL);
return ::SSL_get_peer_certificate(_pSSL);
else
return 0;
return nullptr;
}
@ -497,26 +509,24 @@ bool SecureSocketImpl::mustRetry(int rc)
{
if (rc <= 0)
{
int sslError = SSL_get_error(_pSSL, rc);
static const Poco::Timespan pollTimeout(0, 100000);
int sslError = ::SSL_get_error(_pSSL, rc);
int socketError = _pSocket->lastError();
switch (sslError)
{
case SSL_ERROR_WANT_READ:
if (_pSocket->getBlocking())
{
if (_pSocket->poll(_pSocket->getReceiveTimeout(), Poco::Net::Socket::SELECT_READ))
return true;
else
throw Poco::TimeoutException();
_pSocket->poll(pollTimeout, Poco::Net::Socket::SELECT_READ);
return true;
}
break;
case SSL_ERROR_WANT_WRITE:
if (_pSocket->getBlocking())
{
if (_pSocket->poll(_pSocket->getSendTimeout(), Poco::Net::Socket::SELECT_WRITE))
return true;
else
throw Poco::TimeoutException();
_pSocket->poll(pollTimeout, Poco::Net::Socket::SELECT_WRITE);
return true;
}
break;
case SSL_ERROR_SYSCALL:
@ -533,7 +543,7 @@ int SecureSocketImpl::handleError(int rc)
{
if (rc > 0) return rc;
int sslError = SSL_get_error(_pSSL, rc);
int sslError = ::SSL_get_error(_pSSL, rc);
int socketError = SocketImpl::lastError();
switch (sslError)
@ -569,12 +579,12 @@ int SecureSocketImpl::handleError(int rc)
// fallthrough
default:
{
long lastError = ERR_get_error();
long lastError = ::ERR_get_error();
std::string msg;
if (lastError)
{
char buffer[256];
ERR_error_string_n(lastError, buffer, sizeof(buffer));
::ERR_error_string_n(lastError, buffer, sizeof(buffer));
msg = buffer;
}
// SSL_GET_ERROR(3ossl):
@ -627,9 +637,9 @@ void SecureSocketImpl::reset()
close();
if (_pSSL)
{
SSL_set_ex_data(_pSSL, SSLManager::instance().socketIndex(), nullptr);
SSL_free(_pSSL);
_pSSL = 0;
::SSL_set_ex_data(_pSSL, SSLManager::instance().socketIndex(), nullptr);
::SSL_free(_pSSL);
_pSSL = nullptr;
}
}
@ -655,7 +665,7 @@ void SecureSocketImpl::useSession(Session::Ptr pSession)
bool SecureSocketImpl::sessionWasReused()
{
if (_pSSL)
return SSL_session_reused(_pSSL) != 0;
return ::SSL_session_reused(_pSSL) != 0;
else
return false;
}
@ -663,7 +673,7 @@ bool SecureSocketImpl::sessionWasReused()
int SecureSocketImpl::onSessionCreated(SSL* pSSL, SSL_SESSION* pSession)
{
void* pEx = SSL_get_ex_data(pSSL, SSLManager::instance().socketIndex());
void* pEx = ::SSL_get_ex_data(pSSL, SSLManager::instance().socketIndex());
if (pEx)
{
SecureSocketImpl* pThis = reinterpret_cast<SecureSocketImpl*>(pEx);

View File

@ -103,7 +103,7 @@ void HTTPSStreamFactoryTest::testError()
uri.setPort(server.port());
try
{
std::istream* pStr = factory.open(uri);
factory.open(uri);
fail("not found - must throw");
}
catch (HTTPException& exc)

View File

@ -23,7 +23,6 @@
#include "Poco/Net/SecureServerSocket.h"
#include "Poco/Net/NetException.h"
#include "Poco/Thread.h"
#include <iostream>
using Poco::Net::HTTPSClientSession;
using Poco::Net::HTTPRequest;
@ -55,13 +54,14 @@ namespace
do
{
n = ws.receiveFrame(pBuffer.get(), static_cast<int>(_bufSize), flags);
Poco::Thread::current()->sleep(handleDelay.totalMilliseconds());
if (n == 0)
break;
ws.sendFrame(pBuffer.get(), n, flags);
}
while ((flags & WebSocket::FRAME_OP_BITMASK) != WebSocket::FRAME_OP_CLOSE);
}
catch (WebSocketException& exc)
catch (const WebSocketException& exc)
{
switch (exc.code())
{
@ -79,10 +79,17 @@ namespace
}
}
public:
static Poco::Timespan handleDelay;
private:
std::size_t _bufSize;
};
Poco::Timespan WebSocketRequestHandler::handleDelay {0};
class WebSocketRequestHandlerFactory: public Poco::Net::HTTPRequestHandlerFactory
{
public:
@ -90,7 +97,7 @@ namespace
{
}
Poco::Net::HTTPRequestHandler* createRequestHandler(const HTTPServerRequest& request)
Poco::Net::HTTPRequestHandler* createRequestHandler(const HTTPServerRequest& request) override
{
return new WebSocketRequestHandler(_bufSize);
}
@ -111,6 +118,46 @@ WebSocketTest::~WebSocketTest()
}
void WebSocketTest::testWebSocketTimeout()
{
Poco::Net::SecureServerSocket ss(0);
Poco::Net::HTTPServer server(new WebSocketRequestHandlerFactory, ss, new Poco::Net::HTTPServerParams);
server.start();
Poco::Thread::sleep(200);
HTTPSClientSession cs("127.0.0.1", ss.address().port());
HTTPRequest request(HTTPRequest::HTTP_GET, "/ws");
HTTPResponse response;
WebSocket ws(cs, request, response);
ws.setSendTimeout( Poco::Timespan(2, 0));
ws.setReceiveTimeout( Poco::Timespan(2, 0));
Poco::Timestamp sendStart;
char buffer[1024] = {};
int flags;
try
{
// Server will take long to process and cause WS timeout
WebSocketRequestHandler::handleDelay.assign(3, 0);
std::string payload("x");
ws.sendFrame(payload.data(), (int) payload.size());
ws.receiveFrame(buffer, sizeof(buffer), flags);
failmsg("Data exchange shall time out.");
}
catch (const Poco::TimeoutException& te)
{
assertTrue(sendStart.elapsed() < Poco::Timespan(4, 0).totalMicroseconds());
}
ws.shutdown();
ws.receiveFrame(buffer, sizeof(buffer), flags);
server.stop();
}
void WebSocketTest::testWebSocket()
{
Poco::Net::SecureServerSocket ss(0);
@ -227,6 +274,7 @@ CppUnit::Test* WebSocketTest::suite()
CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("WebSocketTest");
CppUnit_addTest(pSuite, WebSocketTest, testWebSocket);
CppUnit_addTest(pSuite, WebSocketTest, testWebSocketTimeout);
CppUnit_addTest(pSuite, WebSocketTest, testWebSocketLarge);
return pSuite;

View File

@ -22,13 +22,14 @@ class WebSocketTest: public CppUnit::TestCase
{
public:
WebSocketTest(const std::string& name);
~WebSocketTest();
~WebSocketTest() override;
void testWebSocketTimeout();
void testWebSocket();
void testWebSocketLarge();
void setUp();
void tearDown();
void setUp() override;
void tearDown() override;
static CppUnit::Test* suite();