From 663232b8b8a271c17f6d92c9d0bae112d959bd78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnter=20Obiltschnig?= Date: Tue, 26 Nov 2024 19:20:03 +0100 Subject: [PATCH] chore(NetSSL_Win): rewrite handshake logic to support non-blocking sockets --- .../include/Poco/Net/SecureSocketImpl.h | 87 +- NetSSL_Win/include/Poco/Net/Utility.h | 2 +- NetSSL_Win/src/SecureSocketImpl.cpp | 881 ++++++++++-------- NetSSL_Win/src/SecureStreamSocketImpl.cpp | 3 +- NetSSL_Win/src/Utility.cpp | 11 +- 5 files changed, 565 insertions(+), 419 deletions(-) diff --git a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h index 7eeedcef0..268ab2fe6 100644 --- a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h +++ b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h @@ -18,6 +18,10 @@ #define NetSSL_SecureSocketImpl_INCLUDED +// Temporary debugging aid, to be removed +// #define ENABLE_PRINT_STATE + + #include "Poco/Net/SocketImpl.h" #include "Poco/Net/NetSSL.h" #include "Poco/Net/Context.h" @@ -35,6 +39,14 @@ #include + +#ifdef ENABLE_PRINT_STATE +#define PRINT_STATE(m) printState(m) +#else +#define PRINT_STATE(m) +#endif + + namespace Poco { namespace Net { @@ -224,20 +236,25 @@ protected: enum State { ST_INITIAL = 0, - // Client ST_CONNECTING, ST_CLIENT_HSK_START, ST_CLIENT_HSK_SEND_TOKEN, ST_CLIENT_HSK_LOOP_INIT, ST_CLIENT_HSK_LOOP_RECV, ST_CLIENT_HSK_LOOP_PROCESS, - ST_CLIENT_HSK_LOOP_CONTINUE, ST_CLIENT_HSK_LOOP_SEND, - ST_CLIENT_HSK_LOOP_INCOMPLETE, ST_CLIENT_HSK_LOOP_DONE, ST_CLIENT_HSK_SEND_FINAL, ST_CLIENT_HSK_SEND_ERROR, ST_CLIENT_VERIFY, + ST_ACCEPTING, + ST_SERVER_HSK_START, + ST_SERVER_HSK_LOOP_INIT, + ST_SERVER_HSK_LOOP_RECV, + ST_SERVER_HSK_LOOP_PROCESS, + ST_SERVER_HSK_LOOP_SEND, + ST_SERVER_HSK_LOOP_DONE, + ST_SERVER_VERIFY, ST_DONE, ST_ERROR, ST_MAX @@ -251,51 +268,64 @@ protected: int sendRawBytes(const void* buffer, int length, int flags = 0); int receiveRawBytes(void* buffer, int length, int flags = 0); - void clientConnectVerify(); - void performServerHandshake(); - bool serverHandshakeLoop(PCtxtHandle phContext, PCredHandle phCred, bool requireClientAuth, bool doInitialRead, bool newContext); void clientVerifyCertificate(const std::string& hostName); void verifyCertificateChainClient(PCCERT_CONTEXT pServerCert); void serverVerifyCertificate(); int serverShutdown(PCredHandle phCreds, CtxtHandle* phContext); int clientShutdown(PCredHandle phCreds, CtxtHandle* phContext); - void initClientContext(); - void initServerContext(); PCCERT_CONTEXT loadCertificate(bool mustFindCertificate); void initCommon(); void cleanup(); - void performClientHandshakeStart(); - void performInitialClientHandshake(); - void performClientHandshakeSendToken(); - void performClientHandshakeLoopIncompleteMessage(); - void performClientHandshakeLoopInit(); - void performClientHandshakeLoopRecv(); - void performClientHandshakeLoopProcess(); - void performClientHandshakeLoopSend(); - void performClientHandshakeLoopDone(); - void performClientHandshakeSendFinal(); - void performClientHandshakeExtraBuffer(); - void performClientHandshakeLoopContinue(); - void performClientHandshakeError(); - void performClientHandshakeSendError(); - void sendOutSecBufferAndAdvanceState(State state); + void stateIllegal(); + void stateError(); - SECURITY_STATUS performClientHandshake(); + void stateClientConnected(); + void stateClientHandshakeStart(); + void stateClientHandshakeSendToken(); + void stateClientHandshakeLoopInit(); + void stateClientHandshakeLoopRecv(); + void stateClientHandshakeLoopProcess(); + void stateClientHandshakeLoopSend(); + void stateClientHandshakeLoopDone(); + void stateClientHandshakeSendFinal(); + void stateClientHandshakeSendError(); + void stateClientVerify(); + + void stateServerAccepted(); + void stateServerHandshakeStart(); + void stateServerHandshakeLoopInit(); + void stateServerHandshakeLoopRecv(); + void stateServerHandshakeLoopProcess(); + void stateServerHandshakeLoopSend(); + void stateServerHandshakeLoopDone(); + void stateServerHandshakeVerify(); + + void sendOutSecBufferAndAdvanceState(State state); + void drainExtraBuffer(); + static int getRecordLength(const BYTE* pBuffer, int length); + static bool bufferHasCompleteRecords(const BYTE* pBuffer, int length); + + void initClientCredentials(); + void initServerCredentials(); + SECURITY_STATUS doHandshake(); + int completeHandshake(); SECURITY_STATUS decodeMessage(BYTE* pBuffer, DWORD bufSize, AutoSecBufferDesc<4>& msg, SecBuffer*& pData, SecBuffer*& pExtra); SECURITY_STATUS decodeBufferFull(BYTE* pBuffer, DWORD bufSize, char* pOutBuffer, int outLength, int& bytesDecoded); - void stateIllegal(); - void stateConnected(); + void acceptSSL(); void connectSSL(bool completeHandshake); - void completeHandshake(); static int lastError(); bool stateMachine(); State getState() const; void setState(State st); static bool isLocalHost(const std::string& hostName); +#ifdef ENABLE_PRINT_STATE + void printState(const std::string& msg); +#endif + private: SecureSocketImpl(const SecureSocketImpl&); SecureSocketImpl& operator = (const SecureSocketImpl&); @@ -331,9 +361,9 @@ private: SecBuffer _extraSecBuffer; SECURITY_STATUS _securityStatus; State _state; - DWORD _outFlags; bool _needData; bool _needHandshake; + bool _initServerContext = false; friend class SecureStreamSocketImpl; friend class StateMachine; @@ -376,6 +406,7 @@ inline SecureSocketImpl::State SecureSocketImpl::getState() const inline void SecureSocketImpl::setState(SecureSocketImpl::State st) { _state = st; + PRINT_STATE("setState: "); } diff --git a/NetSSL_Win/include/Poco/Net/Utility.h b/NetSSL_Win/include/Poco/Net/Utility.h index 0e5e837a2..7a3501fd7 100644 --- a/NetSSL_Win/include/Poco/Net/Utility.h +++ b/NetSSL_Win/include/Poco/Net/Utility.h @@ -35,7 +35,7 @@ public: /// Non-case sensitive conversion of a string to a VerificationMode enum. /// If verMode is illegal an OptionException is thrown. - static const std::string& formatError(long errCode); + static std::string formatError(long errCode); /// Converts an winerror.h code into human readable form. private: diff --git a/NetSSL_Win/src/SecureSocketImpl.cpp b/NetSSL_Win/src/SecureSocketImpl.cpp index f4510acc8..c6d38d861 100644 --- a/NetSSL_Win/src/SecureSocketImpl.cpp +++ b/NetSSL_Win/src/SecureSocketImpl.cpp @@ -25,6 +25,12 @@ #include +#ifdef ENABLE_PRINT_STATE +#include "Poco/Mutex.h" +#include +#endif + + namespace Poco { namespace Net { @@ -58,8 +64,6 @@ SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr pSocketImpl, Contex _securityFunctions(SSLManager::instance().securityFunctions()), _pOwnCertificate(0), _pPeerCertificate(0), - _hCreds(), - _hContext(), _contextFlags(0), _overflowBuffer(0), _sendBuffer(0), @@ -77,11 +81,8 @@ SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr pSocketImpl, Contex _needData(true), _needHandshake(false) { - _hCreds.dwLower = 0; - _hCreds.dwUpper = 0; - - _hContext.dwLower = 0; - _hContext.dwUpper = 0; + SecInvalidateHandle(&_hCreds); + SecInvalidateHandle(&_hContext); _streamSizes.cbBlockSize = 0; _streamSizes.cbHeader = 0; @@ -142,14 +143,12 @@ void SecureSocketImpl::cleanup() { _peerHostName.clear(); - _hCreds.dwLower = 0; - _hCreds.dwUpper = 0; + SecInvalidateHandle(&_hCreds); - if (_hContext.dwLower != 0 && _hContext.dwUpper != 0) + if (SecIsValidHandle(&_hContext)) { _securityFunctions.DeleteSecurityContext(&_hContext); - _hContext.dwLower = 0; - _hContext.dwUpper = 0; + SecInvalidateHandle(&_hContext); } if (_pOwnCertificate) @@ -249,9 +248,9 @@ void SecureSocketImpl::listen(int backlog) { _mode = MODE_SERVER; - if (_hCreds.dwLower == 0 && _hCreds.dwUpper == 0) + if (!SecIsValidHandle(&_hCreds)) { - initServerContext(); + initServerCredentials(); } _pSocket->listen(backlog); } @@ -309,8 +308,7 @@ int SecureSocketImpl::available() const void SecureSocketImpl::acceptSSL() { - setState(ST_DONE); - initServerContext(); + setState(ST_ACCEPTING); _needHandshake = true; } @@ -333,11 +331,14 @@ void SecureSocketImpl::verifyPeerCertificate(const std::string& hostName) return; } - if (_mode == MODE_SERVER) + { serverVerifyCertificate(); + } else + { clientVerifyCertificate(hostName); + } } @@ -364,7 +365,7 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) { if (_needHandshake) { - completeHandshake(); + doHandshake(); _needHandshake = false; } @@ -374,7 +375,7 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) { if (_pSocket->getBlocking()) { - performClientHandshake(); + doHandshake(); } else { @@ -462,7 +463,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) { if (_needHandshake) { - completeHandshake(); + doHandshake(); _needHandshake = false; } @@ -471,7 +472,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) { if (_pSocket->getBlocking()) { - performClientHandshake(); + doHandshake(); } else { @@ -591,7 +592,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) if (!_pSocket->getBlocking()) return bytesDecoded > 0 ? bytesDecoded : -1; - securityStatus = performClientHandshake(); + securityStatus = doHandshake(); if (securityStatus != SEC_E_OK) break; @@ -624,6 +625,7 @@ SECURITY_STATUS SecureSocketImpl::decodeMessage(BYTE* pBuffer, DWORD bufSize, Au SECURITY_STATUS securityStatus = _securityFunctions.DecryptMessage(&_hContext, &msg, 0, 0); // TODO: when decrypting the close_notify alert, returns SEC_E_DECRYPT_FAILURE + // instead of SEC_I_CONTEXT_EXPIRED if (securityStatus == SEC_E_OK || securityStatus == SEC_I_RENEGOTIATE) { @@ -781,7 +783,7 @@ void SecureSocketImpl::connectSSL(bool completeHandshake) setState(ST_CONNECTING); if (completeHandshake) { - performClientHandshake(); + doHandshake(); _needHandshake = false; } else @@ -791,16 +793,315 @@ void SecureSocketImpl::connectSSL(bool completeHandshake) } -void SecureSocketImpl::completeHandshake() +void SecureSocketImpl::stateIllegal() { - if (_mode == MODE_SERVER) - performServerHandshake(); - else - performClientHandshake(); + throw Poco::IllegalStateException("SSL state machine"); } -void SecureSocketImpl::clientConnectVerify() +void SecureSocketImpl::stateError() +{ + poco_assert_dbg (FAILED(_securityStatus)); + + cleanup(); + setState(ST_ERROR); + throw SSLException("Error during handshake", Utility::formatError(_securityStatus)); +} + + +void SecureSocketImpl::stateClientConnected() +{ + _peerHostName = _pSocket->peerAddress().host().toString(); + setState(ST_CLIENT_HSK_START); +} + + +void SecureSocketImpl::stateClientHandshakeStart() +{ + initClientCredentials(); + + // get initial security token + _outSecBuffer.reset(true); + _outSecBuffer.setSecBufferToken(0, 0, 0); + _recvBuffer.setCapacity(IO_BUFFER_SIZE); + _recvBufferOffset = 0; + + TimeStamp ts; + DWORD contextAttributes(0); + std::wstring whostName; + Poco::UnicodeConverter::convert(_peerHostName, whostName); + _securityStatus = _securityFunctions.InitializeSecurityContextW( + &_hCreds, + 0, + const_cast(whostName.c_str()), + _contextFlags, + 0, + 0, + 0, + 0, + &_hContext, + &_outSecBuffer, + &contextAttributes, + &ts); + + if (_securityStatus != SEC_E_OK && _securityStatus != SEC_I_CONTINUE_NEEDED) + { + if (_securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) + { + // the server is asking for client credentials, we didn't send one because we were not configured to do so, abort + throw SSLException("Handshake failed: No client credentials configured"); + } + else + { + throw SSLException("Handshake failed", Utility::formatError(_securityStatus)); + } + } + + _extraSecBuffer.pvBuffer = 0; + _extraSecBuffer.cbBuffer = 0; + _needData = true; + + setState(ST_CLIENT_HSK_SEND_TOKEN); +} + + +void SecureSocketImpl::stateClientHandshakeSendToken() +{ + poco_assert_dbg (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer); + + int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); + if (sent < 0) + { + return; + } + else if (sent < _outSecBuffer[0].cbBuffer) + { + _outSecBuffer[0].cbBuffer -= sent; + std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); + return; + } + + if (_securityStatus == SEC_E_OK) + { + // The security context was successfully initialized. + // There is no need for another InitializeSecurityContext (Schannel) call. + setState(ST_DONE); + } + else + { + setState(ST_CLIENT_HSK_LOOP_INIT); + } +} + + +void SecureSocketImpl::stateClientHandshakeSendError() +{ + poco_assert_dbg (FAILED(_securityStatus)); + + sendOutSecBufferAndAdvanceState(ST_ERROR); +} + + +void SecureSocketImpl::stateClientHandshakeLoopInit() +{ + _inSecBuffer.reset(false); + _outSecBuffer.reset(true); + + if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || _securityStatus == SEC_I_CONTINUE_NEEDED) + { + setState(ST_CLIENT_HSK_LOOP_RECV); + } + else + { + setState(ST_CLIENT_HSK_LOOP_PROCESS); + } +} + + +int SecureSocketImpl::getRecordLength(const BYTE* pRec, int len) +{ + // TLS Record Header: + // struct { + // ContentType type; /* 1 byte */ + // ProtocolVersion version; /* 2 bytes */ + // uint16 length; /* must be <= 2^14 */ + // opaque fragment[TLSPlaintext.length]; + // } TLSPlaintext; + + if (len >= 5) + { + Poco::UInt16 recLen = (pRec[3] << 8) + pRec[4]; + if (recLen <= 16384) + { + return recLen + 5; + } + else + { + return -1; + } + } + return 0; +} + + +bool SecureSocketImpl::bufferHasCompleteRecords(const BYTE* pBuffer, int length) +{ + const BYTE* pEnd = pBuffer + length; + int recLen = getRecordLength(pBuffer, length); + while (recLen > 0 && length >= recLen) + { + pBuffer += recLen; + length -= recLen; + recLen = getRecordLength(pBuffer, length); + } + return pBuffer == pEnd; +} + + +void SecureSocketImpl::stateClientHandshakeLoopRecv() +{ + poco_assert (IO_BUFFER_SIZE > _recvBufferOffset); + + int n = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, IO_BUFFER_SIZE - _recvBufferOffset); + if (n == 0) throw SSLException("Connection unexpectedly closed during handshake"); + if (n > 0) + { + _recvBufferOffset += n; + // NOTE: InitializeSecurityContext() will fail with SEC_E_INVALID_TOKEN + // if provided with a buffer containing incomplete records, rather than returning + // SEC_E_INCOMPLETE_MESSAGE. Therefore we have to make sure that we + // receive complete records before passing the buffer to InitializeSecurityContext(). + if (bufferHasCompleteRecords(_recvBuffer.begin(), _recvBufferOffset)) + { + setState(ST_CLIENT_HSK_LOOP_PROCESS); + } + + } +} + + +void SecureSocketImpl::stateClientHandshakeLoopProcess() +{ + _inSecBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset); + _inSecBuffer.setSecBufferEmpty(1); + _outSecBuffer.setSecBufferToken(0, 0, 0); + + PRINT_STATE("stateClientHandshakeLoopProcess before InitializeSecurityContextW: "); + + DWORD outFlags; + TimeStamp ts; + _securityStatus = _securityFunctions.InitializeSecurityContextW( + &_hCreds, + &_hContext, + 0, + _contextFlags, + 0, + 0, + &_inSecBuffer, + 0, + 0, + &_outSecBuffer, + &outFlags, + &ts); + + PRINT_STATE("stateClientHandshakeLoopProcess after InitializeSecurityContextW: "); + + if (_securityStatus == SEC_E_OK) + { + drainExtraBuffer(); + setState(ST_CLIENT_HSK_LOOP_DONE); + } + else if (_securityStatus == SEC_I_CONTINUE_NEEDED) + { + drainExtraBuffer(); + setState(ST_CLIENT_HSK_LOOP_SEND); + } + else if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE) + { + _needData = true; + setState(ST_CLIENT_HSK_LOOP_INIT); + } + else if (FAILED(_securityStatus)) + { + if (outFlags & ISC_RET_EXTENDED_ERROR) + { + setState(ST_CLIENT_HSK_SEND_ERROR); + } + else + { + setState(ST_ERROR); + } + } + else + { + setState(ST_CLIENT_HSK_LOOP_INIT); + } +} + + +void SecureSocketImpl::stateClientHandshakeLoopSend() +{ + sendOutSecBufferAndAdvanceState(ST_CLIENT_HSK_LOOP_INIT); +} + + +void SecureSocketImpl::stateClientHandshakeLoopDone() +{ + poco_assert_dbg(_securityStatus == SEC_E_OK); + + if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + { + setState(ST_CLIENT_HSK_SEND_FINAL); + } + else + { + setState(ST_CLIENT_VERIFY); + } +} + + +void SecureSocketImpl::stateClientHandshakeSendFinal() +{ + sendOutSecBufferAndAdvanceState(ST_CLIENT_VERIFY); +} + + +void SecureSocketImpl::sendOutSecBufferAndAdvanceState(State state) +{ + if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + { + int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); + if (sent < 0) return; + if (sent < _outSecBuffer[0].cbBuffer) + { + _outSecBuffer[0].cbBuffer -= sent; + std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); + } + else + { + _outSecBuffer.release(0); + setState(state); + } + } + else + { + setState(state); + } +} + + +void SecureSocketImpl::drainExtraBuffer() +{ + if (_inSecBuffer[1].BufferType == SECBUFFER_EXTRA && _inSecBuffer[1].pvBuffer && _inSecBuffer[1].cbBuffer > 0) + { + std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - _inSecBuffer[1].cbBuffer), _inSecBuffer[1].cbBuffer); + _recvBufferOffset = _inSecBuffer[1].cbBuffer; + } + else _recvBufferOffset = 0; +} + + +void SecureSocketImpl::stateClientVerify() { poco_assert_dbg(!_pPeerCertificate); poco_assert_dbg(!_peerHostName.empty()); @@ -832,294 +1133,127 @@ void SecureSocketImpl::clientConnectVerify() } -void SecureSocketImpl::initClientContext() +void SecureSocketImpl::stateServerAccepted() { - _pOwnCertificate = loadCertificate(false); - _hCreds = _pContext->credentials(); + initServerCredentials(); + setState(ST_SERVER_HSK_START); } -SECURITY_STATUS SecureSocketImpl::performClientHandshake() +void SecureSocketImpl::stateServerHandshakeStart() { - while (_state != ST_DONE) - { - stateMachine(); - } - return _securityStatus; + _securityStatus = SEC_E_INCOMPLETE_MESSAGE; + _initServerContext = true; + _recvBuffer.setCapacity(IO_BUFFER_SIZE); + _recvBufferOffset = 0; + setState(ST_SERVER_HSK_LOOP_INIT); } -void SecureSocketImpl::performInitialClientHandshake() +void SecureSocketImpl::stateServerHandshakeLoopInit() { - // get initial security token - _outSecBuffer.reset(true); - _outSecBuffer.setSecBufferToken(0, 0, 0); - - TimeStamp ts; - DWORD contextAttributes(0); - std::wstring whostName; - Poco::UnicodeConverter::convert(_peerHostName, whostName); - _securityStatus = _securityFunctions.InitializeSecurityContextW( - &_hCreds, - 0, - const_cast(whostName.c_str()), - _contextFlags, - 0, - 0, - 0, - 0, - &_hContext, - &_outSecBuffer, - &contextAttributes, - &ts); - - if (_securityStatus != SEC_E_OK) + if (_securityStatus == SEC_I_CONTINUE_NEEDED || _securityStatus == SEC_E_INCOMPLETE_MESSAGE || _securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) { - if (_securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) + if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || _securityStatus == SEC_I_CONTINUE_NEEDED) { - // the server is asking for client credentials, we didn't send one because we were not configured to do so, abort - throw SSLException("Handshake failed: No client credentials configured"); + setState(ST_SERVER_HSK_LOOP_RECV); } - else if (_securityStatus != SEC_I_CONTINUE_NEEDED) + else { - throw SSLException("Handshake failed", Utility::formatError(_securityStatus)); + setState(ST_SERVER_HSK_LOOP_PROCESS); } } - - _extraSecBuffer.pvBuffer = 0; - _extraSecBuffer.cbBuffer = 0; - _needData = true; - - if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + else if (_securityStatus == SEC_E_CERT_UNKNOWN || _securityStatus == SEC_E_CERT_EXPIRED) { - setState(ST_CLIENT_HSK_SEND_TOKEN); + setState(ST_SERVER_HSK_LOOP_PROCESS); } else if (_securityStatus == SEC_E_OK) { - setState(ST_DONE); + setState(ST_SERVER_HSK_LOOP_DONE); } } -void SecureSocketImpl::performClientHandshakeSendToken() +void SecureSocketImpl::stateServerHandshakeLoopRecv() { - poco_assert_dbg (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer); - - int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); - if (sent < 0) - { - return; - } - else if (sent < _outSecBuffer[0].cbBuffer) - { - _outSecBuffer[0].cbBuffer -= sent; - std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); - return; - } - - if (_securityStatus == SEC_E_OK) - { - // The security context was successfully initialized. - // There is no need for another InitializeSecurityContext (Schannel) call. - setState(ST_DONE); - } - else - { - setState(ST_CLIENT_HSK_LOOP_INIT); - _securityStatus = SEC_E_INCOMPLETE_MESSAGE; - } -} - - -void SecureSocketImpl::performClientHandshakeSendError() -{ - poco_assert_dbg (FAILED(_securityStatus)); - - sendOutSecBufferAndAdvanceState(ST_ERROR); -} - - -void SecureSocketImpl::performClientHandshakeError() -{ - poco_assert_dbg (FAILED(_securityStatus)); - cleanup(); - setState(ST_ERROR); - throw SSLException("Error during handshake", Utility::formatError(_securityStatus)); -} - - -void SecureSocketImpl::performClientHandshakeExtraBuffer() -{ - if (_inSecBuffer[1].BufferType == SECBUFFER_EXTRA) - { - std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - _inSecBuffer[1].cbBuffer), _inSecBuffer[1].cbBuffer); - _recvBufferOffset = _inSecBuffer[1].cbBuffer; - } - else _recvBufferOffset = 0; -} - - -void SecureSocketImpl::performClientHandshakeLoopDone() -{ - poco_assert_dbg(_securityStatus == SEC_E_OK); - - performClientHandshakeExtraBuffer(); - if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) - { - setState(ST_CLIENT_HSK_SEND_FINAL); - } - else - { - setState(ST_CLIENT_VERIFY); - } -} - - -void SecureSocketImpl::performClientHandshakeSendFinal() -{ - sendOutSecBufferAndAdvanceState(ST_CLIENT_VERIFY); -} - - -void SecureSocketImpl::performClientHandshakeLoopInit() -{ - poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED); - - _inSecBuffer.reset(false); - _outSecBuffer.reset(true); - - if (_needData) - { - if (_recvBuffer.capacity() != IO_BUFFER_SIZE) - { - _recvBuffer.setCapacity(IO_BUFFER_SIZE); - } - setState(ST_CLIENT_HSK_LOOP_RECV); - } - else - { - setState(ST_CLIENT_HSK_LOOP_PROCESS); - } -} - - -void SecureSocketImpl::performClientHandshakeLoopRecv() -{ - poco_assert_dbg (_needData); - poco_assert (IO_BUFFER_SIZE > _recvBufferOffset); + poco_assert_dbg (_recvBuffer.capacity() >= IO_BUFFER_SIZE && IO_BUFFER_SIZE > _recvBufferOffset); int n = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, IO_BUFFER_SIZE - _recvBufferOffset); if (n == 0) throw SSLException("Connection unexpectedly closed during handshake"); if (n > 0) { _recvBufferOffset += n; - setState(ST_CLIENT_HSK_LOOP_PROCESS); + setState(ST_SERVER_HSK_LOOP_PROCESS); } } -void SecureSocketImpl::performClientHandshakeLoopProcess() +void SecureSocketImpl::stateServerHandshakeLoopProcess() { + PRINT_STATE("stateServerHandshakeLoopProcess before AcceptSecurityContext: "); _inSecBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset); - // inbuffer 1 should be empty _inSecBuffer.setSecBufferEmpty(1); - - // outBuffer[0] should be empty _outSecBuffer.setSecBufferToken(0, 0, 0); - _outFlags = 0; - TimeStamp ts; - _securityStatus = _securityFunctions.InitializeSecurityContextW( - &_hCreds, - &_hContext, - 0, - _contextFlags, - 0, - 0, - &_inSecBuffer, - 0, - 0, - &_outSecBuffer, - &_outFlags, - &ts); + TimeStamp tsExpiry; + DWORD outFlags; + _securityStatus = _securityFunctions.AcceptSecurityContext( + &_hCreds, + _initServerContext ? NULL : &_hContext, + &_inSecBuffer, + _contextFlags, + 0, + _initServerContext ? &_hContext : NULL, + &_outSecBuffer, + &outFlags, + &tsExpiry); - if (_securityStatus == SEC_E_OK) + _initServerContext = !SecIsValidHandle(&_hContext); + + PRINT_STATE("stateServerHandshakeLoopProcess after AcceptSecurityContext: "); + + setState(ST_SERVER_HSK_LOOP_INIT); + if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE) { - setState(ST_CLIENT_HSK_LOOP_DONE); + return; } - else if (_securityStatus == SEC_I_CONTINUE_NEEDED) + else if (_securityStatus == SEC_E_OK || _securityStatus == SEC_I_CONTINUE_NEEDED || (FAILED(_securityStatus) && (0 != (outFlags & ISC_RET_EXTENDED_ERROR)))) { - setState(ST_CLIENT_HSK_LOOP_CONTINUE); + drainExtraBuffer(); + if (_outSecBuffer[0].cbBuffer != 0 && _outSecBuffer[0].pvBuffer != 0) + { + setState(ST_SERVER_HSK_LOOP_SEND); + return; + } + } + else if (_securityStatus == SEC_E_CERT_UNKNOWN || _securityStatus == SEC_E_CERT_EXPIRED || _securityStatus == CRYPT_E_REVOKED || _securityStatus == SEC_E_UNTRUSTED_ROOT) + { + throw SSLException("Handshake failure, client has rejected certificate", Utility::formatError(_securityStatus)); } else if (FAILED(_securityStatus)) { - if (_outFlags & ISC_RET_EXTENDED_ERROR) - setState(ST_CLIENT_HSK_SEND_ERROR); - else - setState(ST_ERROR); - } - else - { - setState(ST_CLIENT_HSK_LOOP_INCOMPLETE); + throw SSLException("Handshake failure:", Utility::formatError(_securityStatus)); } } -void SecureSocketImpl::performClientHandshakeLoopContinue() +void SecureSocketImpl::stateServerHandshakeLoopSend() { - performClientHandshakeExtraBuffer(); - setState(ST_CLIENT_HSK_LOOP_SEND); + sendOutSecBufferAndAdvanceState(ST_SERVER_HSK_LOOP_INIT); } -void SecureSocketImpl::performClientHandshakeLoopSend() +void SecureSocketImpl::stateServerHandshakeLoopDone() { - sendOutSecBufferAndAdvanceState(ST_CLIENT_HSK_LOOP_INIT); + _securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext, SECPKG_ATTR_STREAM_SIZES, &_streamSizes); + if (_securityStatus != SEC_E_OK) throw SSLException("Cannot query stream sizes", Utility::formatError(_securityStatus)); + _ioBufferSize = _streamSizes.cbHeader + _streamSizes.cbMaximumMessage + _streamSizes.cbTrailer; + setState(ST_SERVER_VERIFY); } -void SecureSocketImpl::sendOutSecBufferAndAdvanceState(State state) +void SecureSocketImpl::stateServerHandshakeVerify() { - if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) - { - int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); - if (sent < 0) return; - if (sent < _outSecBuffer[0].cbBuffer) - { - _outSecBuffer[0].cbBuffer -= sent; - std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); - } - else - { - _outSecBuffer.release(0); - setState(state); - } - } - else - { - setState(state); - } -} - - -void SecureSocketImpl::performClientHandshakeLoopIncompleteMessage() -{ - _needData = true; - setState(ST_CLIENT_HSK_LOOP_INIT); -} - - -void SecureSocketImpl::initServerContext() -{ - _pOwnCertificate = loadCertificate(true); - _hCreds = _pContext->credentials(); -} - - -void SecureSocketImpl::performServerHandshake() -{ - serverHandshakeLoop(&_hContext, &_hCreds, _clientAuthRequired, true, true); - SECURITY_STATUS securityStatus; if (_clientAuthRequired) { @@ -1140,100 +1274,49 @@ void SecureSocketImpl::performServerHandshake() serverVerifyCertificate(); } } - - securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext,SECPKG_ATTR_STREAM_SIZES, &_streamSizes); - if (securityStatus != SEC_E_OK) throw SSLException("Cannot query stream sizes", Utility::formatError(securityStatus)); - - _ioBufferSize = _streamSizes.cbHeader + _streamSizes.cbMaximumMessage + _streamSizes.cbTrailer; + setState(ST_DONE); } -bool SecureSocketImpl::serverHandshakeLoop(PCtxtHandle phContext, PCredHandle phCred, bool requireClientAuth, bool doInitialRead, bool newContext) +void SecureSocketImpl::initClientCredentials() { - TimeStamp tsExpiry; - int n = 0; - bool doRead = doInitialRead; - bool initContext = newContext; - DWORD outFlags; - SECURITY_STATUS securityStatus = SEC_E_INCOMPLETE_MESSAGE; + _pOwnCertificate = loadCertificate(false); + _hCreds = _pContext->credentials(); +} - while (securityStatus == SEC_I_CONTINUE_NEEDED || securityStatus == SEC_E_INCOMPLETE_MESSAGE || securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) + +void SecureSocketImpl::initServerCredentials() +{ + _pOwnCertificate = loadCertificate(true); + _hCreds = _pContext->credentials(); +} + + +SECURITY_STATUS SecureSocketImpl::doHandshake() +{ + while (_state != ST_DONE && _state != ST_ERROR) { - if (securityStatus == SEC_E_INCOMPLETE_MESSAGE) - { - if (doRead) - { - n = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, IO_BUFFER_SIZE - _recvBufferOffset); - - if (n <= 0) - throw SSLException("Failed to receive data in handshake"); - else - _recvBufferOffset += n; - } - else doRead = true; - } - - AutoSecBufferDesc<2> inBuffer(&_securityFunctions, false); - AutoSecBufferDesc<1> outBuffer(&_securityFunctions, true); - inBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset); - inBuffer.setSecBufferEmpty(1); - outBuffer.setSecBufferToken(0, 0, 0); - - securityStatus = _securityFunctions.AcceptSecurityContext( - phCred, - initContext ? NULL : phContext, - &inBuffer, - _contextFlags, - 0, - initContext ? phContext : NULL, - &outBuffer, - &outFlags, - &tsExpiry); - - initContext = false; - - if (securityStatus == SEC_E_OK || securityStatus == SEC_I_CONTINUE_NEEDED || (FAILED(securityStatus) && (0 != (outFlags & ISC_RET_EXTENDED_ERROR)))) - { - if (outBuffer[0].cbBuffer != 0 && outBuffer[0].pvBuffer != 0) - { - n = sendRawBytes(outBuffer[0].pvBuffer, outBuffer[0].cbBuffer); - outBuffer.release(0); - } - } - - if (securityStatus == SEC_E_OK ) - { - if (inBuffer[1].BufferType == SECBUFFER_EXTRA) - { - std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - inBuffer[1].cbBuffer), inBuffer[1].cbBuffer); - _recvBufferOffset = inBuffer[1].cbBuffer; - } - else - { - _recvBufferOffset = 0; - } - return true; - } - else if (FAILED(securityStatus) && securityStatus != SEC_E_INCOMPLETE_MESSAGE) - { - throw SSLException("Handshake failure:", Utility::formatError(securityStatus)); - } - - if (securityStatus != SEC_E_INCOMPLETE_MESSAGE && securityStatus != SEC_I_INCOMPLETE_CREDENTIALS) - { - if (inBuffer[1].BufferType == SECBUFFER_EXTRA) - { - std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - inBuffer[1].cbBuffer), inBuffer[1].cbBuffer); - _recvBufferOffset = inBuffer[1].cbBuffer; - } - else - { - _recvBufferOffset = 0; - } - } + stateMachine(); } + return _securityStatus; +} - return false; + +int SecureSocketImpl::completeHandshake() +{ + if (_pSocket->getBlocking()) + { + doHandshake(); + return 0; + } + else + { + while (_state != ST_DONE && stateMachine()) + { + // no-op + } + return (_state == ST_DONE) ? 0 : -1; + } } @@ -1489,7 +1572,7 @@ void SecureSocketImpl::serverVerifyCertificate() int SecureSocketImpl::clientShutdown(PCredHandle phCreds, CtxtHandle* phContext) { - if (phContext->dwLower == 0 && phContext->dwUpper == 0) + if (!SecIsValidHandle(phContext)) { return 0; } @@ -1542,7 +1625,7 @@ int SecureSocketImpl::clientShutdown(PCredHandle phCreds, CtxtHandle* phContext) int SecureSocketImpl::serverShutdown(PCredHandle phCreds, CtxtHandle* phContext) { - if (phContext->dwLower == 0 && phContext->dwUpper == 0) + if (!SecIsValidHandle(phContext)) { // handshake has never been done poco_assert_dbg (_needHandshake); @@ -1595,32 +1678,61 @@ int SecureSocketImpl::serverShutdown(PCredHandle phCreds, CtxtHandle* phContext) } -void SecureSocketImpl::stateIllegal() -{ - throw Poco::IllegalStateException("SSL state machine"); -} - - -void SecureSocketImpl::stateConnected() -{ - _peerHostName = _pSocket->peerAddress().host().toString(); - setState(ST_CLIENT_HSK_START); -} - - -void SecureSocketImpl::performClientHandshakeStart() -{ - initClientContext(); - performInitialClientHandshake(); -} - - bool SecureSocketImpl::stateMachine() { return StateMachine::instance().execute(this); } +#ifdef ENABLE_PRINT_STATE + + +#define CASE_PRINT_ENUM(enum) case enum: std::cout << #enum; break; + + +void SecureSocketImpl::printState(const std::string& message) +{ + Poco::Thread* pThread = Poco::Thread::current(); + int threadId = 0; + if (pThread) threadId = pThread->id(); + + static Poco::FastMutex mutex; + Poco::FastMutex::ScopedLock lock(mutex); + + std::cout << message << " thread=" << threadId << " state="; + switch (_state) + { + CASE_PRINT_ENUM(ST_INITIAL) + CASE_PRINT_ENUM(ST_CONNECTING) + CASE_PRINT_ENUM(ST_CLIENT_HSK_START) + CASE_PRINT_ENUM(ST_CLIENT_HSK_SEND_TOKEN) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_INIT) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_RECV) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_PROCESS) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_SEND) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_DONE) + CASE_PRINT_ENUM(ST_CLIENT_HSK_SEND_FINAL) + CASE_PRINT_ENUM(ST_CLIENT_HSK_SEND_ERROR) + CASE_PRINT_ENUM(ST_CLIENT_VERIFY) + CASE_PRINT_ENUM(ST_ACCEPTING) + CASE_PRINT_ENUM(ST_SERVER_HSK_START) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_INIT) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_RECV) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_PROCESS) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_SEND) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_DONE) + CASE_PRINT_ENUM(ST_SERVER_VERIFY) + CASE_PRINT_ENUM(ST_DONE) + CASE_PRINT_ENUM(ST_ERROR) + } + + std::cout << " securityStatus=" << Utility::formatError(_securityStatus) << " hCreds="<< SecIsValidHandle(&_hCreds) << " hContext=" << SecIsValidHandle(&_hContext) << " recvBufOff=" << _recvBufferOffset << std::endl; +} + + +#endif + + StateMachine& StateMachine::instance() { static StateMachine sm; @@ -1632,21 +1744,30 @@ StateMachine::StateMachine(): _states(SecureSocketImpl::ST_MAX) { _states[SecureSocketImpl::ST_INITIAL] = &SecureSocketImpl::stateIllegal; - _states[SecureSocketImpl::ST_CONNECTING] = &SecureSocketImpl::stateConnected; - _states[SecureSocketImpl::ST_CLIENT_HSK_START] = &SecureSocketImpl::performClientHandshakeStart; - _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_TOKEN] = &SecureSocketImpl::performClientHandshakeSendToken; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INIT] = &SecureSocketImpl::performClientHandshakeLoopInit; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_RECV] = &SecureSocketImpl::performClientHandshakeLoopRecv; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_PROCESS] = &SecureSocketImpl::performClientHandshakeLoopProcess; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_SEND] = &SecureSocketImpl::performClientHandshakeLoopSend; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INCOMPLETE] = &SecureSocketImpl::performClientHandshakeLoopIncompleteMessage; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_CONTINUE] = &SecureSocketImpl::performClientHandshakeLoopContinue; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_DONE] = &SecureSocketImpl::performClientHandshakeLoopDone; - _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_FINAL] = &SecureSocketImpl::performClientHandshakeSendFinal; - _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_ERROR] = &SecureSocketImpl::performClientHandshakeSendError; - _states[SecureSocketImpl::ST_CLIENT_VERIFY] = &SecureSocketImpl::clientConnectVerify; + + _states[SecureSocketImpl::ST_CONNECTING] = &SecureSocketImpl::stateClientConnected; + _states[SecureSocketImpl::ST_CLIENT_HSK_START] = &SecureSocketImpl::stateClientHandshakeStart; + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_TOKEN] = &SecureSocketImpl::stateClientHandshakeSendToken; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INIT] = &SecureSocketImpl::stateClientHandshakeLoopInit; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_RECV] = &SecureSocketImpl::stateClientHandshakeLoopRecv; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_PROCESS] = &SecureSocketImpl::stateClientHandshakeLoopProcess; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_SEND] = &SecureSocketImpl::stateClientHandshakeLoopSend; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_DONE] = &SecureSocketImpl::stateClientHandshakeLoopDone; + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_FINAL] = &SecureSocketImpl::stateClientHandshakeSendFinal; + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_ERROR] = &SecureSocketImpl::stateClientHandshakeSendError; + _states[SecureSocketImpl::ST_CLIENT_VERIFY] = &SecureSocketImpl::stateClientVerify; + + _states[SecureSocketImpl::ST_ACCEPTING] = &SecureSocketImpl::stateServerAccepted; + _states[SecureSocketImpl::ST_SERVER_HSK_START] = &SecureSocketImpl::stateServerHandshakeStart; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_INIT] = &SecureSocketImpl::stateServerHandshakeLoopInit; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_RECV] = &SecureSocketImpl::stateServerHandshakeLoopRecv; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_PROCESS] = &SecureSocketImpl::stateServerHandshakeLoopProcess; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_SEND] = &SecureSocketImpl::stateServerHandshakeLoopSend; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_DONE] = &SecureSocketImpl::stateServerHandshakeLoopDone; + _states[SecureSocketImpl::ST_SERVER_VERIFY] = &SecureSocketImpl::stateServerHandshakeVerify; + _states[SecureSocketImpl::ST_DONE] = &SecureSocketImpl::stateIllegal; - _states[SecureSocketImpl::ST_ERROR] = &SecureSocketImpl::performClientHandshakeError; + _states[SecureSocketImpl::ST_ERROR] = &SecureSocketImpl::stateError; } diff --git a/NetSSL_Win/src/SecureStreamSocketImpl.cpp b/NetSSL_Win/src/SecureStreamSocketImpl.cpp index add82ca23..876162a61 100644 --- a/NetSSL_Win/src/SecureStreamSocketImpl.cpp +++ b/NetSSL_Win/src/SecureStreamSocketImpl.cpp @@ -209,8 +209,7 @@ void SecureStreamSocketImpl::verifyPeerCertificate(const std::string& hostName) int SecureStreamSocketImpl::completeHandshake() { - _impl.completeHandshake(); - return 0; + return _impl.completeHandshake(); } diff --git a/NetSSL_Win/src/Utility.cpp b/NetSSL_Win/src/Utility.cpp index 5970bf8a2..d8ac4eb9c 100644 --- a/NetSSL_Win/src/Utility.cpp +++ b/NetSSL_Win/src/Utility.cpp @@ -24,9 +24,6 @@ namespace Poco { namespace Net { -Poco::FastMutex Utility::_mutex; - - Context::VerificationMode Utility::convertVerificationMode(const std::string& vMode) { std::string mode = Poco::toLower(vMode); @@ -54,6 +51,7 @@ inline void add(std::map& messageMap, long key, const s std::map Utility::initSSPIErr() { std::map messageMap; + add(messageMap, SEC_E_OK, "OK"); add(messageMap, NTE_BAD_UID, "Bad UID"); add(messageMap, NTE_BAD_HASH, "Bad Hash"); add(messageMap, NTE_BAD_KEY, "Bad Key"); @@ -185,18 +183,15 @@ std::map Utility::initSSPIErr() } -const std::string& Utility::formatError(long errCode) +std::string Utility::formatError(long errCode) { - Poco::FastMutex::ScopedLock lock(_mutex); - - static const std::string def("Internal SSPI error"); static const std::map errs(initSSPIErr()); const std::map::const_iterator it = errs.find(errCode); if (it != errs.end()) return it->second; else - return def; + return "0x" + Poco::NumberFormatter::formatHex(errCode, 8); }