chore(NetSSL_Win): refactored state machine (wip)

This commit is contained in:
Günter Obiltschnig 2024-11-24 15:22:17 +01:00
parent 5402032249
commit f1bb63a1f2
2 changed files with 219 additions and 144 deletions

View File

@ -224,16 +224,23 @@ protected:
enum State
{
ST_INITIAL = 0,
// Client
ST_CONNECTING,
ST_CLIENTHANDSHAKESTART,
ST_CLIENTHANDSHAKECONDREAD,
ST_CLIENTHANDSHAKEINCOMPLETE,
ST_CLIENTHANDSHAKEOK,
ST_CLIENTHANDSHAKEEXTERROR,
ST_CLIENTHANDSHAKECONTINUE,
ST_VERIFY,
ST_CLIENT_HSK_START,
ST_CLIENT_HSK_SEND_TOKEN,
ST_CLIENT_HSK_LOOP_INIT,
ST_CLIENT_HSK_LOOP_RECV,
ST_CLIENT_HSK_LOOP_PROCESS,
ST_CLIENT_HSK_LOOP_CONTINUE,
ST_CLIENT_HSK_LOOP_SEND,
ST_CLIENT_HSK_LOOP_INCOMPLETE,
ST_CLIENT_HSK_LOOP_DONE,
ST_CLIENT_HSK_SEND_FINAL,
ST_CLIENT_HSK_SEND_ERROR,
ST_CLIENT_VERIFY,
ST_DONE,
ST_ERROR
ST_ERROR,
ST_MAX
};
enum TLSShutdown
@ -245,7 +252,6 @@ protected:
int sendRawBytes(const void* buffer, int length, int flags = 0);
int receiveRawBytes(void* buffer, int length, int flags = 0);
void clientConnectVerify();
void sendInitialTokenOutBuffer();
void performServerHandshake();
bool serverHandshakeLoop(PCtxtHandle phContext, PCredHandle phCred, bool requireClientAuth, bool doInitialRead, bool newContext);
void clientVerifyCertificate(const std::string& hostName);
@ -253,25 +259,32 @@ protected:
void serverVerifyCertificate();
int serverShutdown(PCredHandle phCreds, CtxtHandle* phContext);
int clientShutdown(PCredHandle phCreds, CtxtHandle* phContext);
bool loadSecurityLibrary();
void initClientContext();
void initServerContext();
PCCERT_CONTEXT loadCertificate(bool mustFindCertificate);
void initCommon();
void cleanup();
void performClientHandshake();
void performClientHandshakeStart();
void performInitialClientHandshake();
SECURITY_STATUS performClientHandshakeLoop();
void performClientHandshakeSendToken();
void performClientHandshakeLoopIncompleteMessage();
void performClientHandshakeLoopCondReceive();
void performClientHandshakeLoopReceive();
void performClientHandshakeLoopOK();
void performClientHandshakeLoopInit();
void performClientHandshakeLoopRecv();
void performClientHandshakeLoopProcess();
void performClientHandshakeLoopSend();
void performClientHandshakeLoopDone();
void performClientHandshakeSendFinal();
void performClientHandshakeExtraBuffer();
void performClientHandshakeSendOutBuffer();
void performClientHandshakeLoopContinueNeeded();
void performClientHandshakeLoopError();
void performClientHandshakeLoopExtError();
void performClientHandshakeLoopContinue();
void performClientHandshakeError();
void performClientHandshakeSendError();
void sendOutSecBufferAndAdvanceState(State state);
void performClientHandshake();
SECURITY_STATUS performClientHandshakeLoop();
void performClientHandshakeLoopCondReceive();
SECURITY_STATUS decodeMessage(BYTE* pBuffer, DWORD bufSize, AutoSecBufferDesc<4>& msg, SecBuffer*& pData, SecBuffer*& pExtra);
SECURITY_STATUS decodeBufferFull(BYTE* pBuffer, DWORD bufSize, char* pOutBuffer, int outLength, int& bytesDecoded);
void stateIllegal();

View File

@ -47,7 +47,7 @@ public:
bool none(SOCKET sockfd);
void select(fd_set* fdRead, fd_set* fdWrite, SOCKET sockfd);
void execute(SecureSocketImpl* pSock);
bool execute(SecureSocketImpl* pSock);
private:
StateMachine(const StateMachine&);
@ -192,19 +192,19 @@ SocketImpl* SecureSocketImpl::acceptConnection(SocketAddress& clientAddr)
void SecureSocketImpl::connect(const SocketAddress& address, bool performHandshake)
{
_state = ST_ERROR;
setState(ST_ERROR);
_pSocket->connect(address);
connectSSL(performHandshake);
_state = ST_DONE;
setState(ST_DONE);
}
void SecureSocketImpl::connect(const SocketAddress& address, const Poco::Timespan& timeout, bool performHandshake)
{
_state = ST_ERROR;
setState(ST_ERROR);
_pSocket->connect(address, timeout);
connectSSL(performHandshake);
_state = ST_DONE;
setState(ST_DONE);
}
@ -212,12 +212,12 @@ void SecureSocketImpl::connectNB(const SocketAddress& address)
{
try
{
_state = ST_CONNECTING;
setState(ST_CONNECTING);
_pSocket->connectNB(address);
}
catch (...)
{
_state = ST_ERROR;
setState(ST_ERROR);
}
}
@ -318,7 +318,7 @@ int SecureSocketImpl::available() const
void SecureSocketImpl::acceptSSL()
{
_state = ST_DONE;
setState(ST_DONE);
initServerContext();
_needHandshake = true;
}
@ -379,17 +379,6 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags)
if (_state == ST_ERROR) return 0;
if (_sendBufferPending > 0)
{
int sent = sendRawBytes(_sendBuffer.begin() + _sendBufferOffset, _sendBufferPending, flags);
if (sent > 0)
{
_sendBufferOffset += sent;
_sendBufferPending -= sent;
}
return _sendBufferPending == 0 ? length : -1;
}
if (_state != ST_DONE)
{
bool establish = _pSocket->getBlocking();
@ -403,10 +392,21 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags)
else
{
stateMachine();
return -1;
if (_state != ST_DONE) return -1;
}
}
if (_sendBufferPending > 0)
{
int sent = sendRawBytes(_sendBuffer.begin() + _sendBufferOffset, _sendBufferPending, flags);
if (sent > 0)
{
_sendBufferOffset += sent;
_sendBufferPending -= sent;
}
return _sendBufferPending == 0 ? length : -1;
}
int dataToSend = length;
int dataSent = 0;
const char* pBuffer = reinterpret_cast<const char*>(buffer);
@ -490,7 +490,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags)
else
{
stateMachine();
return -1;
if (_state != ST_DONE) return -1;
}
}
@ -598,7 +598,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags)
if (securityStatus == SEC_I_RENEGOTIATE)
{
_needData = false;
_state = ST_CLIENTHANDSHAKECONDREAD;
setState(ST_CLIENT_HSK_LOOP_INIT);
if (!_pSocket->getBlocking())
return -1;
@ -831,7 +831,7 @@ void SecureSocketImpl::clientConnectVerify()
throw SSLException("Failed to query stream sizes", Utility::formatError(securityStatus));
_ioBufferSize = _streamSizes.cbHeader + _streamSizes.cbMaximumMessage + _streamSizes.cbTrailer;
_state = ST_DONE;
setState(ST_DONE);
}
catch (...)
{
@ -855,6 +855,7 @@ void SecureSocketImpl::initClientContext()
void SecureSocketImpl::performClientHandshake()
{
performInitialClientHandshake();
performClientHandshakeSendToken();
performClientHandshakeLoop();
clientConnectVerify();
}
@ -897,38 +898,47 @@ void SecureSocketImpl::performInitialClientHandshake()
}
}
// incomplete credentials: more calls to InitializeSecurityContext needed
// send the token
sendInitialTokenOutBuffer();
_extraSecBuffer.pvBuffer = 0;
_extraSecBuffer.cbBuffer = 0;
_needData = true;
if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer)
{
setState(ST_CLIENT_HSK_SEND_TOKEN);
}
else if (_securityStatus == SEC_E_OK)
{
setState(ST_DONE);
}
}
void SecureSocketImpl::performClientHandshakeSendToken()
{
poco_assert_dbg (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer);
int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer);
if (sent < 0)
{
return;
}
else if (sent < _outSecBuffer[0].cbBuffer)
{
_outSecBuffer[0].cbBuffer -= sent;
std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast<char*>(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer);
return;
}
if (_securityStatus == SEC_E_OK)
{
// The security context was successfully initialized.
// There is no need for another InitializeSecurityContext (Schannel) call.
_state = ST_DONE;
return;
setState(ST_DONE);
}
//SEC_I_CONTINUE_NEEDED was returned:
// Wait for a return token. The returned token is then passed in
// another call to InitializeSecurityContext (Schannel). The output token can be empty.
_extraSecBuffer.pvBuffer = 0;
_extraSecBuffer.cbBuffer = 0;
_needData = true;
_state = ST_CLIENTHANDSHAKECONDREAD;
_securityStatus = SEC_E_INCOMPLETE_MESSAGE;
}
void SecureSocketImpl::sendInitialTokenOutBuffer()
{
// send the token
if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer)
else
{
int numBytes = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer);
if (numBytes != _outSecBuffer[0].cbBuffer)
throw SSLException("Failed to send token to the server");
setState(ST_CLIENT_HSK_LOOP_INIT);
_securityStatus = SEC_E_INCOMPLETE_MESSAGE;
}
}
@ -944,11 +954,13 @@ SECURITY_STATUS SecureSocketImpl::performClientHandshakeLoop()
if (_securityStatus == SEC_E_OK)
{
performClientHandshakeLoopOK();
performClientHandshakeLoopDone();
performClientHandshakeSendFinal();
}
else if (_securityStatus == SEC_I_CONTINUE_NEEDED)
{
performClientHandshakeLoopContinueNeeded();
performClientHandshakeLoopContinue();
performClientHandshakeLoopSend();
}
else if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE)
{
@ -958,11 +970,11 @@ SECURITY_STATUS SecureSocketImpl::performClientHandshakeLoop()
{
if (_outFlags & ISC_RET_EXTENDED_ERROR)
{
performClientHandshakeLoopExtError();
performClientHandshakeSendError();
}
else
{
performClientHandshakeLoopError();
performClientHandshakeError();
}
}
else
@ -973,44 +985,30 @@ SECURITY_STATUS SecureSocketImpl::performClientHandshakeLoop()
if (FAILED(_securityStatus))
{
performClientHandshakeLoopError();
performClientHandshakeError();
}
return _securityStatus;
}
void SecureSocketImpl::performClientHandshakeLoopExtError()
void SecureSocketImpl::performClientHandshakeSendError()
{
poco_assert_dbg (FAILED(_securityStatus));
performClientHandshakeSendOutBuffer();
performClientHandshakeLoopError();
sendOutSecBufferAndAdvanceState(ST_ERROR);
}
void SecureSocketImpl::performClientHandshakeLoopError()
void SecureSocketImpl::performClientHandshakeError()
{
poco_assert_dbg (FAILED(_securityStatus));
cleanup();
_state = ST_ERROR;
setState(ST_ERROR);
throw SSLException("Error during handshake", Utility::formatError(_securityStatus));
}
void SecureSocketImpl::performClientHandshakeSendOutBuffer()
{
if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer)
{
int numBytes = sendRawBytes(static_cast<const void*>(_outSecBuffer[0].pvBuffer), _outSecBuffer[0].cbBuffer);
if (numBytes != _outSecBuffer[0].cbBuffer)
throw SSLException("Socket error during handshake");
_outSecBuffer.release(0);
}
}
void SecureSocketImpl::performClientHandshakeExtraBuffer()
{
if (_inSecBuffer[1].BufferType == SECBUFFER_EXTRA)
@ -1022,48 +1020,67 @@ void SecureSocketImpl::performClientHandshakeExtraBuffer()
}
void SecureSocketImpl::performClientHandshakeLoopOK()
void SecureSocketImpl::performClientHandshakeLoopDone()
{
poco_assert_dbg(_securityStatus == SEC_E_OK);
performClientHandshakeSendOutBuffer();
performClientHandshakeExtraBuffer();
_state = ST_VERIFY;
if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer)
{
setState(ST_CLIENT_HSK_SEND_FINAL);
}
else
{
setState(ST_CLIENT_VERIFY);
}
}
void SecureSocketImpl::performClientHandshakeSendFinal()
{
sendOutSecBufferAndAdvanceState(ST_CLIENT_VERIFY);
}
void SecureSocketImpl::performClientHandshakeLoopInit()
{
poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED);
_inSecBuffer.reset(false);
_outSecBuffer.reset(true);
if (_needData)
{
if (_recvBuffer.capacity() != IO_BUFFER_SIZE)
{
_recvBuffer.setCapacity(IO_BUFFER_SIZE);
}
setState(ST_CLIENT_HSK_LOOP_RECV);
}
else
{
setState(ST_CLIENT_HSK_LOOP_PROCESS);
}
}
void SecureSocketImpl::performClientHandshakeLoopReceive()
void SecureSocketImpl::performClientHandshakeLoopRecv()
{
poco_assert_dbg (_needData);
poco_assert (IO_BUFFER_SIZE > _recvBufferOffset);
int n = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, IO_BUFFER_SIZE - _recvBufferOffset);
if (n <= 0) throw SSLException("Error during handshake: failed to read data");
if (n == 0) throw SSLException("Connection unexpectedly closed during handshake");
if (n > 0)
{
_recvBufferOffset += n;
setState(ST_CLIENT_HSK_LOOP_PROCESS);
}
}
void SecureSocketImpl::performClientHandshakeLoopCondReceive()
void SecureSocketImpl::performClientHandshakeLoopProcess()
{
poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED);
performClientHandshakeLoopInit();
if (_needData)
{
if (_recvBuffer.capacity() != IO_BUFFER_SIZE)
_recvBuffer.setCapacity(IO_BUFFER_SIZE);
performClientHandshakeLoopReceive();
}
else _needData = true;
_inSecBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset);
// inbuffer 1 should be empty
_inSecBuffer.setSecBufferEmpty(1);
@ -1089,38 +1106,82 @@ void SecureSocketImpl::performClientHandshakeLoopCondReceive()
if (_securityStatus == SEC_E_OK)
{
_state = ST_CLIENTHANDSHAKEOK;
setState(ST_CLIENT_HSK_LOOP_DONE);
}
else if (_securityStatus == SEC_I_CONTINUE_NEEDED)
{
_state = ST_CLIENTHANDSHAKECONTINUE;
setState(ST_CLIENT_HSK_LOOP_CONTINUE);
}
else if (FAILED(_securityStatus))
{
if (_outFlags & ISC_RET_EXTENDED_ERROR)
_state = ST_CLIENTHANDSHAKEEXTERROR;
setState(ST_CLIENT_HSK_SEND_ERROR);
else
_state = ST_ERROR;
setState(ST_ERROR);
}
else
{
_state = ST_CLIENTHANDSHAKEINCOMPLETE;
setState(ST_CLIENT_HSK_LOOP_INCOMPLETE);
}
}
void SecureSocketImpl::performClientHandshakeLoopContinueNeeded()
void SecureSocketImpl::performClientHandshakeLoopCondReceive()
{
poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED);
performClientHandshakeLoopInit();
if (_state == ST_CLIENT_HSK_LOOP_RECV)
{
performClientHandshakeLoopRecv();
}
performClientHandshakeLoopProcess();
}
void SecureSocketImpl::performClientHandshakeLoopContinue()
{
performClientHandshakeSendOutBuffer();
performClientHandshakeExtraBuffer();
_state = ST_CLIENTHANDSHAKECONDREAD;
setState(ST_CLIENT_HSK_LOOP_SEND);
}
void SecureSocketImpl::performClientHandshakeLoopSend()
{
sendOutSecBufferAndAdvanceState(ST_CLIENT_HSK_LOOP_INIT);
}
void SecureSocketImpl::sendOutSecBufferAndAdvanceState(State state)
{
if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer)
{
int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer);
if (sent < 0) return;
if (sent < _outSecBuffer[0].cbBuffer)
{
_outSecBuffer[0].cbBuffer -= sent;
std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast<char*>(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer);
}
else
{
_outSecBuffer.release(0);
setState(state);
}
}
else
{
setState(state);
}
}
void SecureSocketImpl::performClientHandshakeLoopIncompleteMessage()
{
_needData = true;
_state = ST_CLIENTHANDSHAKECONDREAD;
setState(ST_CLIENT_HSK_LOOP_INIT);
}
@ -1619,6 +1680,12 @@ void SecureSocketImpl::stateIllegal()
void SecureSocketImpl::stateConnected()
{
_peerHostName = _pSocket->peerAddress().host().toString();
setState(ST_CLIENT_HSK_START);
}
void SecureSocketImpl::performClientHandshakeStart()
{
initClientContext();
performInitialClientHandshake();
}
@ -1700,34 +1767,28 @@ void StateMachine::select(fd_set* fdRead, fd_set* fdWrite, SOCKET sockfd)
StateMachine::StateMachine():
_states()
_states(SecureSocketImpl::ST_MAX)
{
//ST_INITIAL: 0, -> this one is illegal, you must call connectNB before
_states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal));
//ST_CONNECTING: connectNB was called, check if the socket is already available for writing
_states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::stateConnected));
//ST_ESTABLISHTUNNELRECEIVED: we got the response, now start the handshake
_states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performInitialClientHandshake));
//ST_CLIENTHANDSHAKECONDREAD: condread
_states.push_back(std::make_pair(&StateMachine::readable, &SecureSocketImpl::performClientHandshakeLoopCondReceive));
//ST_CLIENTHANDSHAKEINCOMPLETE,
_states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopIncompleteMessage));
//ST_CLIENTHANDSHAKEOK,
_states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopOK));
//ST_CLIENTHANDSHAKEEXTERROR,
_states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopExtError));
//ST_CLIENTHANDSHAKECONTINUE,
_states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopContinueNeeded));
//ST_VERIFY,
_states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::clientConnectVerify));
//ST_DONE,
_states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal));
//ST_ERROR
_states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopError));
_states[SecureSocketImpl::ST_INITIAL] = std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal);
_states[SecureSocketImpl::ST_CONNECTING] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::stateConnected);
_states[SecureSocketImpl::ST_CLIENT_HSK_START] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeStart);
_states[SecureSocketImpl::ST_CLIENT_HSK_SEND_TOKEN] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeSendToken);
_states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INIT] = std::make_pair(&StateMachine::readable, &SecureSocketImpl::performClientHandshakeLoopInit);
_states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_RECV] = std::make_pair(&StateMachine::readable, &SecureSocketImpl::performClientHandshakeLoopRecv);
_states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_PROCESS] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopProcess);
_states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_SEND] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopSend);
_states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INCOMPLETE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopIncompleteMessage);
_states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_CONTINUE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopContinue);
_states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_DONE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopDone);
_states[SecureSocketImpl::ST_CLIENT_HSK_SEND_FINAL] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeSendFinal);
_states[SecureSocketImpl::ST_CLIENT_HSK_SEND_ERROR] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeSendError);
_states[SecureSocketImpl::ST_CLIENT_VERIFY] = std::make_pair(&StateMachine::none, &SecureSocketImpl::clientConnectVerify);
_states[SecureSocketImpl::ST_DONE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal);
_states[SecureSocketImpl::ST_ERROR] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeError);
}
void StateMachine::execute(SecureSocketImpl* pSock)
bool StateMachine::execute(SecureSocketImpl* pSock)
{
try
{
@ -1737,8 +1798,9 @@ void StateMachine::execute(SecureSocketImpl* pSock)
if ((this->*state.first)(pSock->sockfd()))
{
(pSock->*(state.second))();
(pSock->getState() == SecureSocketImpl::ST_DONE);
return true;
}
else return false;
}
catch (...)
{