fix(SecureSocket): Refactor detection of timeout when reading, writing and handshaking (#4510)

* fix(SecureSocket): Refactor detection of timeout when reading, writing or handshaking. (#3725)

* enh(SecureSocket): some trivial C++17 modernisation changes.

* chore: indentation and compiler warning

---------

Co-authored-by: Alex Fabijanic <alex@pocoproject.org>
This commit is contained in:
Matej Kenda
2024-03-27 00:29:58 +01:00
committed by GitHub
parent 5fec3deeb9
commit 482c066307
7 changed files with 159 additions and 100 deletions

View File

@@ -22,10 +22,7 @@
#include "Poco/Net/SecureStreamSocketImpl.h"
#include "Poco/Net/StreamSocketImpl.h"
#include "Poco/Net/StreamSocket.h"
#include "Poco/Net/NetException.h"
#include "Poco/Net/DNS.h"
#include "Poco/NumberFormatter.h"
#include "Poco/NumberParser.h"
#include "Poco/Format.h"
#include <openssl/x509v3.h>
#include <openssl/err.h>
@@ -38,16 +35,12 @@ using Poco::NumberFormatter;
using Poco::Timespan;
// workaround for C++-incompatible macro
#define POCO_BIO_set_nbio_accept(b,n) BIO_ctrl(b,BIO_C_SET_ACCEPT,1,(void*)((n)?"a":NULL))
namespace Poco {
namespace Net {
SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr<SocketImpl> pSocketImpl, Context::Ptr pContext):
_pSSL(0),
_pSSL(nullptr),
_pSocket(pSocketImpl),
_pContext(pContext),
_needHandshake(false)
@@ -86,14 +79,14 @@ void SecureSocketImpl::acceptSSL()
{
poco_assert (!_pSSL);
BIO* pBIO = BIO_new(BIO_s_socket());
BIO* pBIO = ::BIO_new(BIO_s_socket());
if (!pBIO) throw SSLException("Cannot create BIO object");
BIO_set_fd(pBIO, static_cast<int>(_pSocket->sockfd()), BIO_NOCLOSE);
_pSSL = SSL_new(_pContext->sslContext());
_pSSL = ::SSL_new(_pContext->sslContext());
if (!_pSSL)
{
BIO_free(pBIO);
::BIO_free(pBIO);
throw SSLException("Cannot create SSL object");
}
@@ -105,15 +98,15 @@ void SecureSocketImpl::acceptSSL()
* tickets. */
if (1 != SSL_set_num_tickets(_pSSL, 0))
{
BIO_free(pBIO);
::BIO_free(pBIO);
throw SSLException("Cannot create SSL object");
}
//Otherwise we can perform two-way shutdown. Client must call SSL_read() before the final SSL_shutdown().
#endif
SSL_set_bio(_pSSL, pBIO, pBIO);
SSL_set_accept_state(_pSSL);
SSL_set_ex_data(_pSSL, SSLManager::instance().socketIndex(), this);
::SSL_set_bio(_pSSL, pBIO, pBIO);
::SSL_set_accept_state(_pSSL);
::SSL_set_ex_data(_pSSL, SSLManager::instance().socketIndex(), this);
_needHandshake = true;
}
@@ -162,18 +155,18 @@ void SecureSocketImpl::connectSSL(bool performHandshake)
poco_assert (!_pSSL);
poco_assert (_pSocket->initialized());
BIO* pBIO = BIO_new(BIO_s_socket());
::BIO* pBIO = ::BIO_new(BIO_s_socket());
if (!pBIO) throw SSLException("Cannot create SSL BIO object");
BIO_set_fd(pBIO, static_cast<int>(_pSocket->sockfd()), BIO_NOCLOSE);
_pSSL = SSL_new(_pContext->sslContext());
_pSSL = ::SSL_new(_pContext->sslContext());
if (!_pSSL)
{
BIO_free(pBIO);
::BIO_free(pBIO);
throw SSLException("Cannot create SSL object");
}
SSL_set_bio(_pSSL, pBIO, pBIO);
SSL_set_ex_data(_pSSL, SSLManager::instance().socketIndex(), this);
::SSL_set_bio(_pSSL, pBIO, pBIO);
::SSL_set_ex_data(_pSSL, SSLManager::instance().socketIndex(), this);
if (!_peerHostName.empty())
{
@@ -189,27 +182,27 @@ void SecureSocketImpl::connectSSL(bool performHandshake)
if (_pSession && _pSession->isResumable())
{
SSL_set_session(_pSSL, _pSession->sslSession());
::SSL_set_session(_pSSL, _pSession->sslSession());
}
try
{
if (performHandshake && _pSocket->getBlocking())
{
int ret = SSL_connect(_pSSL);
int ret = ::SSL_connect(_pSSL);
handleError(ret);
verifyPeerCertificate();
}
else
{
SSL_set_connect_state(_pSSL);
::SSL_set_connect_state(_pSSL);
_needHandshake = true;
}
}
catch (...)
{
SSL_free(_pSSL);
_pSSL = 0;
::SSL_free(_pSSL);
_pSSL = nullptr;
throw;
}
}
@@ -259,11 +252,11 @@ void SecureSocketImpl::shutdown()
{
if (_pSSL)
{
// Don't shut down the socket more than once.
int shutdownState = SSL_get_shutdown(_pSSL);
bool shutdownSent = (shutdownState & SSL_SENT_SHUTDOWN) == SSL_SENT_SHUTDOWN;
if (!shutdownSent)
{
// Don't shut down the socket more than once.
int shutdownState = ::SSL_get_shutdown(_pSSL);
bool shutdownSent = (shutdownState & SSL_SENT_SHUTDOWN) == SSL_SENT_SHUTDOWN;
if (!shutdownSent)
{
// A proper clean shutdown would require us to
// retry the shutdown if we get a zero return
// value, until SSL_shutdown() returns 1.
@@ -274,7 +267,7 @@ void SecureSocketImpl::shutdown()
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
int rc = 0;
if (!_bidirectShutdown)
rc = SSL_shutdown(_pSSL);
rc = ::SSL_shutdown(_pSSL);
else
{
Poco::Timespan recvTimeout = _pSocket->getReceiveTimeout();
@@ -282,11 +275,11 @@ void SecureSocketImpl::shutdown()
Poco::Timestamp tsNow;
do
{
rc = SSL_shutdown(_pSSL);
rc = ::SSL_shutdown(_pSSL);
if (rc == 1) break;
if (rc < 0)
{
int err = SSL_get_error(_pSSL, rc);
int err = ::SSL_get_error(_pSSL, rc);
if (err == SSL_ERROR_WANT_READ)
_pSocket->poll(pollTimeout, Poco::Net::Socket::SELECT_READ);
else if (err == SSL_ERROR_WANT_WRITE)
@@ -294,7 +287,7 @@ void SecureSocketImpl::shutdown()
else
{
int socketError = SocketImpl::lastError();
long lastError = ERR_get_error();
long lastError = ::ERR_get_error();
if ((err == SSL_ERROR_SSL) && (socketError == 0) && (lastError == 0x0A000123))
rc = 0;
break;
@@ -304,7 +297,7 @@ void SecureSocketImpl::shutdown()
} while (!tsNow.isElapsed(recvTimeout.totalMicroseconds()));
}
#else
int rc = SSL_shutdown(_pSSL);
int rc = ::SSL_shutdown(_pSSL);
#endif
if (rc < 0) handleError(rc);
if (_pSocket->getBlocking())
@@ -361,11 +354,17 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags)
else
return rc;
}
do
const auto sendTimeout = _pSocket->getSendTimeout();
Poco::Timestamp tsStart;
while (true)
{
rc = SSL_write(_pSSL, buffer, length);
}
while (mustRetry(rc));
rc = ::SSL_write(_pSSL, buffer, length);
if (!mustRetry(rc))
break;
if (tsStart.isElapsed(sendTimeout.totalMicroseconds()))
throw Poco::TimeoutException();
};
if (rc <= 0)
{
rc = handleError(rc);
@@ -389,11 +388,18 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags)
else
return rc;
}
do
const auto recvTimeout = _pSocket->getReceiveTimeout();
Poco::Timestamp tsStart;
while (true)
{
rc = SSL_read(_pSSL, buffer, length);
}
while (mustRetry(rc));
rc = ::SSL_read(_pSSL, buffer, length);
if (!mustRetry(rc))
break;
if (tsStart.isElapsed(recvTimeout.totalMicroseconds()))
throw Poco::TimeoutException();
};
_bidirectShutdown = false;
if (rc <= 0)
{
@@ -407,7 +413,7 @@ int SecureSocketImpl::available() const
{
poco_check_ptr (_pSSL);
return SSL_pending(_pSSL);
return ::SSL_pending(_pSSL);
}
@@ -417,11 +423,17 @@ int SecureSocketImpl::completeHandshake()
poco_check_ptr (_pSSL);
int rc;
do
const auto recvTimeout = _pSocket->getReceiveTimeout();
Poco::Timestamp tsStart;
while (true)
{
rc = SSL_do_handshake(_pSSL);
}
while (mustRetry(rc));
rc = ::SSL_do_handshake(_pSSL);
if (!mustRetry(rc))
break;
if (tsStart.isElapsed(recvTimeout.totalMicroseconds()))
throw Poco::TimeoutException();
};
if (rc <= 0)
{
return handleError(rc);
@@ -460,7 +472,7 @@ long SecureSocketImpl::verifyPeerCertificateImpl(const std::string& hostName)
return X509_V_OK;
}
X509* pCert = SSL_get_peer_certificate(_pSSL);
::X509* pCert = ::SSL_get_peer_certificate(_pSSL);
if (pCert)
{
X509Certificate cert(pCert);
@@ -477,7 +489,7 @@ bool SecureSocketImpl::isLocalHost(const std::string& hostName)
SocketAddress addr(hostName, 0);
return addr.host().isLoopback();
}
catch (Poco::Exception&)
catch (const Poco::Exception&)
{
return false;
}
@@ -487,9 +499,9 @@ bool SecureSocketImpl::isLocalHost(const std::string& hostName)
X509* SecureSocketImpl::peerCertificate() const
{
if (_pSSL)
return SSL_get_peer_certificate(_pSSL);
return ::SSL_get_peer_certificate(_pSSL);
else
return 0;
return nullptr;
}
@@ -497,26 +509,24 @@ bool SecureSocketImpl::mustRetry(int rc)
{
if (rc <= 0)
{
int sslError = SSL_get_error(_pSSL, rc);
static const Poco::Timespan pollTimeout(0, 100000);
int sslError = ::SSL_get_error(_pSSL, rc);
int socketError = _pSocket->lastError();
switch (sslError)
{
case SSL_ERROR_WANT_READ:
if (_pSocket->getBlocking())
{
if (_pSocket->poll(_pSocket->getReceiveTimeout(), Poco::Net::Socket::SELECT_READ))
return true;
else
throw Poco::TimeoutException();
_pSocket->poll(pollTimeout, Poco::Net::Socket::SELECT_READ);
return true;
}
break;
case SSL_ERROR_WANT_WRITE:
if (_pSocket->getBlocking())
{
if (_pSocket->poll(_pSocket->getSendTimeout(), Poco::Net::Socket::SELECT_WRITE))
return true;
else
throw Poco::TimeoutException();
_pSocket->poll(pollTimeout, Poco::Net::Socket::SELECT_WRITE);
return true;
}
break;
case SSL_ERROR_SYSCALL:
@@ -533,7 +543,7 @@ int SecureSocketImpl::handleError(int rc)
{
if (rc > 0) return rc;
int sslError = SSL_get_error(_pSSL, rc);
int sslError = ::SSL_get_error(_pSSL, rc);
int socketError = SocketImpl::lastError();
switch (sslError)
@@ -569,12 +579,12 @@ int SecureSocketImpl::handleError(int rc)
// fallthrough
default:
{
long lastError = ERR_get_error();
long lastError = ::ERR_get_error();
std::string msg;
if (lastError)
{
char buffer[256];
ERR_error_string_n(lastError, buffer, sizeof(buffer));
::ERR_error_string_n(lastError, buffer, sizeof(buffer));
msg = buffer;
}
// SSL_GET_ERROR(3ossl):
@@ -627,9 +637,9 @@ void SecureSocketImpl::reset()
close();
if (_pSSL)
{
SSL_set_ex_data(_pSSL, SSLManager::instance().socketIndex(), nullptr);
SSL_free(_pSSL);
_pSSL = 0;
::SSL_set_ex_data(_pSSL, SSLManager::instance().socketIndex(), nullptr);
::SSL_free(_pSSL);
_pSSL = nullptr;
}
}
@@ -655,7 +665,7 @@ void SecureSocketImpl::useSession(Session::Ptr pSession)
bool SecureSocketImpl::sessionWasReused()
{
if (_pSSL)
return SSL_session_reused(_pSSL) != 0;
return ::SSL_session_reused(_pSSL) != 0;
else
return false;
}
@@ -663,7 +673,7 @@ bool SecureSocketImpl::sessionWasReused()
int SecureSocketImpl::onSessionCreated(SSL* pSSL, SSL_SESSION* pSession)
{
void* pEx = SSL_get_ex_data(pSSL, SSLManager::instance().socketIndex());
void* pEx = ::SSL_get_ex_data(pSSL, SSLManager::instance().socketIndex());
if (pEx)
{
SecureSocketImpl* pThis = reinterpret_cast<SecureSocketImpl*>(pEx);