diff --git a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h index b1502cf01..6a52239dc 100644 --- a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h +++ b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h @@ -224,16 +224,23 @@ protected: enum State { ST_INITIAL = 0, + // Client ST_CONNECTING, - ST_CLIENTHANDSHAKESTART, - ST_CLIENTHANDSHAKECONDREAD, - ST_CLIENTHANDSHAKEINCOMPLETE, - ST_CLIENTHANDSHAKEOK, - ST_CLIENTHANDSHAKEEXTERROR, - ST_CLIENTHANDSHAKECONTINUE, - ST_VERIFY, + ST_CLIENT_HSK_START, + ST_CLIENT_HSK_SEND_TOKEN, + ST_CLIENT_HSK_LOOP_INIT, + ST_CLIENT_HSK_LOOP_RECV, + ST_CLIENT_HSK_LOOP_PROCESS, + ST_CLIENT_HSK_LOOP_CONTINUE, + ST_CLIENT_HSK_LOOP_SEND, + ST_CLIENT_HSK_LOOP_INCOMPLETE, + ST_CLIENT_HSK_LOOP_DONE, + ST_CLIENT_HSK_SEND_FINAL, + ST_CLIENT_HSK_SEND_ERROR, + ST_CLIENT_VERIFY, ST_DONE, - ST_ERROR + ST_ERROR, + ST_MAX }; enum TLSShutdown @@ -245,7 +252,6 @@ protected: int sendRawBytes(const void* buffer, int length, int flags = 0); int receiveRawBytes(void* buffer, int length, int flags = 0); void clientConnectVerify(); - void sendInitialTokenOutBuffer(); void performServerHandshake(); bool serverHandshakeLoop(PCtxtHandle phContext, PCredHandle phCred, bool requireClientAuth, bool doInitialRead, bool newContext); void clientVerifyCertificate(const std::string& hostName); @@ -253,25 +259,32 @@ protected: void serverVerifyCertificate(); int serverShutdown(PCredHandle phCreds, CtxtHandle* phContext); int clientShutdown(PCredHandle phCreds, CtxtHandle* phContext); - bool loadSecurityLibrary(); void initClientContext(); void initServerContext(); PCCERT_CONTEXT loadCertificate(bool mustFindCertificate); void initCommon(); void cleanup(); - void performClientHandshake(); + + void performClientHandshakeStart(); void performInitialClientHandshake(); - SECURITY_STATUS performClientHandshakeLoop(); + void performClientHandshakeSendToken(); void performClientHandshakeLoopIncompleteMessage(); - void performClientHandshakeLoopCondReceive(); - void performClientHandshakeLoopReceive(); - void performClientHandshakeLoopOK(); void performClientHandshakeLoopInit(); + void performClientHandshakeLoopRecv(); + void performClientHandshakeLoopProcess(); + void performClientHandshakeLoopSend(); + void performClientHandshakeLoopDone(); + void performClientHandshakeSendFinal(); void performClientHandshakeExtraBuffer(); - void performClientHandshakeSendOutBuffer(); - void performClientHandshakeLoopContinueNeeded(); - void performClientHandshakeLoopError(); - void performClientHandshakeLoopExtError(); + void performClientHandshakeLoopContinue(); + void performClientHandshakeError(); + void performClientHandshakeSendError(); + void sendOutSecBufferAndAdvanceState(State state); + + void performClientHandshake(); + SECURITY_STATUS performClientHandshakeLoop(); + void performClientHandshakeLoopCondReceive(); + SECURITY_STATUS decodeMessage(BYTE* pBuffer, DWORD bufSize, AutoSecBufferDesc<4>& msg, SecBuffer*& pData, SecBuffer*& pExtra); SECURITY_STATUS decodeBufferFull(BYTE* pBuffer, DWORD bufSize, char* pOutBuffer, int outLength, int& bytesDecoded); void stateIllegal(); diff --git a/NetSSL_Win/src/SecureSocketImpl.cpp b/NetSSL_Win/src/SecureSocketImpl.cpp index 4e9a3b98a..4a10a7154 100644 --- a/NetSSL_Win/src/SecureSocketImpl.cpp +++ b/NetSSL_Win/src/SecureSocketImpl.cpp @@ -47,7 +47,7 @@ public: bool none(SOCKET sockfd); void select(fd_set* fdRead, fd_set* fdWrite, SOCKET sockfd); - void execute(SecureSocketImpl* pSock); + bool execute(SecureSocketImpl* pSock); private: StateMachine(const StateMachine&); @@ -192,19 +192,19 @@ SocketImpl* SecureSocketImpl::acceptConnection(SocketAddress& clientAddr) void SecureSocketImpl::connect(const SocketAddress& address, bool performHandshake) { - _state = ST_ERROR; + setState(ST_ERROR); _pSocket->connect(address); connectSSL(performHandshake); - _state = ST_DONE; + setState(ST_DONE); } void SecureSocketImpl::connect(const SocketAddress& address, const Poco::Timespan& timeout, bool performHandshake) { - _state = ST_ERROR; + setState(ST_ERROR); _pSocket->connect(address, timeout); connectSSL(performHandshake); - _state = ST_DONE; + setState(ST_DONE); } @@ -212,12 +212,12 @@ void SecureSocketImpl::connectNB(const SocketAddress& address) { try { - _state = ST_CONNECTING; + setState(ST_CONNECTING); _pSocket->connectNB(address); } catch (...) { - _state = ST_ERROR; + setState(ST_ERROR); } } @@ -318,7 +318,7 @@ int SecureSocketImpl::available() const void SecureSocketImpl::acceptSSL() { - _state = ST_DONE; + setState(ST_DONE); initServerContext(); _needHandshake = true; } @@ -379,17 +379,6 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) if (_state == ST_ERROR) return 0; - if (_sendBufferPending > 0) - { - int sent = sendRawBytes(_sendBuffer.begin() + _sendBufferOffset, _sendBufferPending, flags); - if (sent > 0) - { - _sendBufferOffset += sent; - _sendBufferPending -= sent; - } - return _sendBufferPending == 0 ? length : -1; - } - if (_state != ST_DONE) { bool establish = _pSocket->getBlocking(); @@ -403,10 +392,21 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) else { stateMachine(); - return -1; + if (_state != ST_DONE) return -1; } } + if (_sendBufferPending > 0) + { + int sent = sendRawBytes(_sendBuffer.begin() + _sendBufferOffset, _sendBufferPending, flags); + if (sent > 0) + { + _sendBufferOffset += sent; + _sendBufferPending -= sent; + } + return _sendBufferPending == 0 ? length : -1; + } + int dataToSend = length; int dataSent = 0; const char* pBuffer = reinterpret_cast(buffer); @@ -490,7 +490,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) else { stateMachine(); - return -1; + if (_state != ST_DONE) return -1; } } @@ -598,7 +598,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) if (securityStatus == SEC_I_RENEGOTIATE) { _needData = false; - _state = ST_CLIENTHANDSHAKECONDREAD; + setState(ST_CLIENT_HSK_LOOP_INIT); if (!_pSocket->getBlocking()) return -1; @@ -831,7 +831,7 @@ void SecureSocketImpl::clientConnectVerify() throw SSLException("Failed to query stream sizes", Utility::formatError(securityStatus)); _ioBufferSize = _streamSizes.cbHeader + _streamSizes.cbMaximumMessage + _streamSizes.cbTrailer; - _state = ST_DONE; + setState(ST_DONE); } catch (...) { @@ -855,6 +855,7 @@ void SecureSocketImpl::initClientContext() void SecureSocketImpl::performClientHandshake() { performInitialClientHandshake(); + performClientHandshakeSendToken(); performClientHandshakeLoop(); clientConnectVerify(); } @@ -897,38 +898,47 @@ void SecureSocketImpl::performInitialClientHandshake() } } - // incomplete credentials: more calls to InitializeSecurityContext needed - // send the token - sendInitialTokenOutBuffer(); + _extraSecBuffer.pvBuffer = 0; + _extraSecBuffer.cbBuffer = 0; + _needData = true; + + if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + { + setState(ST_CLIENT_HSK_SEND_TOKEN); + } + else if (_securityStatus == SEC_E_OK) + { + setState(ST_DONE); + } +} + + +void SecureSocketImpl::performClientHandshakeSendToken() +{ + poco_assert_dbg (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer); + + int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); + if (sent < 0) + { + return; + } + else if (sent < _outSecBuffer[0].cbBuffer) + { + _outSecBuffer[0].cbBuffer -= sent; + std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); + return; + } if (_securityStatus == SEC_E_OK) { // The security context was successfully initialized. // There is no need for another InitializeSecurityContext (Schannel) call. - _state = ST_DONE; - return; + setState(ST_DONE); } - - //SEC_I_CONTINUE_NEEDED was returned: - // Wait for a return token. The returned token is then passed in - // another call to InitializeSecurityContext (Schannel). The output token can be empty. - - _extraSecBuffer.pvBuffer = 0; - _extraSecBuffer.cbBuffer = 0; - _needData = true; - _state = ST_CLIENTHANDSHAKECONDREAD; - _securityStatus = SEC_E_INCOMPLETE_MESSAGE; -} - - -void SecureSocketImpl::sendInitialTokenOutBuffer() -{ - // send the token - if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + else { - int numBytes = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); - if (numBytes != _outSecBuffer[0].cbBuffer) - throw SSLException("Failed to send token to the server"); + setState(ST_CLIENT_HSK_LOOP_INIT); + _securityStatus = SEC_E_INCOMPLETE_MESSAGE; } } @@ -944,11 +954,13 @@ SECURITY_STATUS SecureSocketImpl::performClientHandshakeLoop() if (_securityStatus == SEC_E_OK) { - performClientHandshakeLoopOK(); + performClientHandshakeLoopDone(); + performClientHandshakeSendFinal(); } else if (_securityStatus == SEC_I_CONTINUE_NEEDED) { - performClientHandshakeLoopContinueNeeded(); + performClientHandshakeLoopContinue(); + performClientHandshakeLoopSend(); } else if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE) { @@ -958,11 +970,11 @@ SECURITY_STATUS SecureSocketImpl::performClientHandshakeLoop() { if (_outFlags & ISC_RET_EXTENDED_ERROR) { - performClientHandshakeLoopExtError(); + performClientHandshakeSendError(); } else { - performClientHandshakeLoopError(); + performClientHandshakeError(); } } else @@ -973,44 +985,30 @@ SECURITY_STATUS SecureSocketImpl::performClientHandshakeLoop() if (FAILED(_securityStatus)) { - performClientHandshakeLoopError(); + performClientHandshakeError(); } return _securityStatus; } -void SecureSocketImpl::performClientHandshakeLoopExtError() +void SecureSocketImpl::performClientHandshakeSendError() { poco_assert_dbg (FAILED(_securityStatus)); - performClientHandshakeSendOutBuffer(); - performClientHandshakeLoopError(); + sendOutSecBufferAndAdvanceState(ST_ERROR); } -void SecureSocketImpl::performClientHandshakeLoopError() +void SecureSocketImpl::performClientHandshakeError() { poco_assert_dbg (FAILED(_securityStatus)); cleanup(); - _state = ST_ERROR; + setState(ST_ERROR); throw SSLException("Error during handshake", Utility::formatError(_securityStatus)); } -void SecureSocketImpl::performClientHandshakeSendOutBuffer() -{ - if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) - { - int numBytes = sendRawBytes(static_cast(_outSecBuffer[0].pvBuffer), _outSecBuffer[0].cbBuffer); - if (numBytes != _outSecBuffer[0].cbBuffer) - throw SSLException("Socket error during handshake"); - - _outSecBuffer.release(0); - } -} - - void SecureSocketImpl::performClientHandshakeExtraBuffer() { if (_inSecBuffer[1].BufferType == SECBUFFER_EXTRA) @@ -1022,48 +1020,67 @@ void SecureSocketImpl::performClientHandshakeExtraBuffer() } -void SecureSocketImpl::performClientHandshakeLoopOK() +void SecureSocketImpl::performClientHandshakeLoopDone() { poco_assert_dbg(_securityStatus == SEC_E_OK); - performClientHandshakeSendOutBuffer(); performClientHandshakeExtraBuffer(); - _state = ST_VERIFY; + if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + { + setState(ST_CLIENT_HSK_SEND_FINAL); + } + else + { + setState(ST_CLIENT_VERIFY); + } +} + + +void SecureSocketImpl::performClientHandshakeSendFinal() +{ + sendOutSecBufferAndAdvanceState(ST_CLIENT_VERIFY); } void SecureSocketImpl::performClientHandshakeLoopInit() { + poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED); + _inSecBuffer.reset(false); _outSecBuffer.reset(true); + + if (_needData) + { + if (_recvBuffer.capacity() != IO_BUFFER_SIZE) + { + _recvBuffer.setCapacity(IO_BUFFER_SIZE); + } + setState(ST_CLIENT_HSK_LOOP_RECV); + } + else + { + setState(ST_CLIENT_HSK_LOOP_PROCESS); + } } -void SecureSocketImpl::performClientHandshakeLoopReceive() +void SecureSocketImpl::performClientHandshakeLoopRecv() { poco_assert_dbg (_needData); poco_assert (IO_BUFFER_SIZE > _recvBufferOffset); int n = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, IO_BUFFER_SIZE - _recvBufferOffset); - if (n <= 0) throw SSLException("Error during handshake: failed to read data"); - - _recvBufferOffset += n; + if (n == 0) throw SSLException("Connection unexpectedly closed during handshake"); + if (n > 0) + { + _recvBufferOffset += n; + setState(ST_CLIENT_HSK_LOOP_PROCESS); + } } -void SecureSocketImpl::performClientHandshakeLoopCondReceive() +void SecureSocketImpl::performClientHandshakeLoopProcess() { - poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED); - - performClientHandshakeLoopInit(); - if (_needData) - { - if (_recvBuffer.capacity() != IO_BUFFER_SIZE) - _recvBuffer.setCapacity(IO_BUFFER_SIZE); - performClientHandshakeLoopReceive(); - } - else _needData = true; - _inSecBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset); // inbuffer 1 should be empty _inSecBuffer.setSecBufferEmpty(1); @@ -1089,38 +1106,82 @@ void SecureSocketImpl::performClientHandshakeLoopCondReceive() if (_securityStatus == SEC_E_OK) { - _state = ST_CLIENTHANDSHAKEOK; + setState(ST_CLIENT_HSK_LOOP_DONE); } else if (_securityStatus == SEC_I_CONTINUE_NEEDED) { - _state = ST_CLIENTHANDSHAKECONTINUE; + setState(ST_CLIENT_HSK_LOOP_CONTINUE); } else if (FAILED(_securityStatus)) { if (_outFlags & ISC_RET_EXTENDED_ERROR) - _state = ST_CLIENTHANDSHAKEEXTERROR; + setState(ST_CLIENT_HSK_SEND_ERROR); else - _state = ST_ERROR; + setState(ST_ERROR); } else { - _state = ST_CLIENTHANDSHAKEINCOMPLETE; + setState(ST_CLIENT_HSK_LOOP_INCOMPLETE); } } -void SecureSocketImpl::performClientHandshakeLoopContinueNeeded() +void SecureSocketImpl::performClientHandshakeLoopCondReceive() +{ + poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED); + + performClientHandshakeLoopInit(); + + if (_state == ST_CLIENT_HSK_LOOP_RECV) + { + performClientHandshakeLoopRecv(); + } + + performClientHandshakeLoopProcess(); +} + + +void SecureSocketImpl::performClientHandshakeLoopContinue() { - performClientHandshakeSendOutBuffer(); performClientHandshakeExtraBuffer(); - _state = ST_CLIENTHANDSHAKECONDREAD; + setState(ST_CLIENT_HSK_LOOP_SEND); +} + + +void SecureSocketImpl::performClientHandshakeLoopSend() +{ + sendOutSecBufferAndAdvanceState(ST_CLIENT_HSK_LOOP_INIT); +} + + +void SecureSocketImpl::sendOutSecBufferAndAdvanceState(State state) +{ + if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + { + int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); + if (sent < 0) return; + if (sent < _outSecBuffer[0].cbBuffer) + { + _outSecBuffer[0].cbBuffer -= sent; + std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); + } + else + { + _outSecBuffer.release(0); + setState(state); + } + } + else + { + setState(state); + } } void SecureSocketImpl::performClientHandshakeLoopIncompleteMessage() { _needData = true; - _state = ST_CLIENTHANDSHAKECONDREAD; + setState(ST_CLIENT_HSK_LOOP_INIT); } @@ -1619,6 +1680,12 @@ void SecureSocketImpl::stateIllegal() void SecureSocketImpl::stateConnected() { _peerHostName = _pSocket->peerAddress().host().toString(); + setState(ST_CLIENT_HSK_START); +} + + +void SecureSocketImpl::performClientHandshakeStart() +{ initClientContext(); performInitialClientHandshake(); } @@ -1700,34 +1767,28 @@ void StateMachine::select(fd_set* fdRead, fd_set* fdWrite, SOCKET sockfd) StateMachine::StateMachine(): - _states() -{ - //ST_INITIAL: 0, -> this one is illegal, you must call connectNB before - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal)); - //ST_CONNECTING: connectNB was called, check if the socket is already available for writing - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::stateConnected)); - //ST_ESTABLISHTUNNELRECEIVED: we got the response, now start the handshake - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performInitialClientHandshake)); - //ST_CLIENTHANDSHAKECONDREAD: condread - _states.push_back(std::make_pair(&StateMachine::readable, &SecureSocketImpl::performClientHandshakeLoopCondReceive)); - //ST_CLIENTHANDSHAKEINCOMPLETE, - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopIncompleteMessage)); - //ST_CLIENTHANDSHAKEOK, - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopOK)); - //ST_CLIENTHANDSHAKEEXTERROR, - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopExtError)); - //ST_CLIENTHANDSHAKECONTINUE, - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopContinueNeeded)); - //ST_VERIFY, - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::clientConnectVerify)); - //ST_DONE, - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal)); - //ST_ERROR - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopError)); + _states(SecureSocketImpl::ST_MAX) +{ + _states[SecureSocketImpl::ST_INITIAL] = std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal); + _states[SecureSocketImpl::ST_CONNECTING] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::stateConnected); + _states[SecureSocketImpl::ST_CLIENT_HSK_START] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeStart); + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_TOKEN] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeSendToken); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INIT] = std::make_pair(&StateMachine::readable, &SecureSocketImpl::performClientHandshakeLoopInit); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_RECV] = std::make_pair(&StateMachine::readable, &SecureSocketImpl::performClientHandshakeLoopRecv); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_PROCESS] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopProcess); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_SEND] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopSend); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INCOMPLETE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopIncompleteMessage); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_CONTINUE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopContinue); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_DONE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopDone); + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_FINAL] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeSendFinal); + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_ERROR] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeSendError); + _states[SecureSocketImpl::ST_CLIENT_VERIFY] = std::make_pair(&StateMachine::none, &SecureSocketImpl::clientConnectVerify); + _states[SecureSocketImpl::ST_DONE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal); + _states[SecureSocketImpl::ST_ERROR] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeError); } -void StateMachine::execute(SecureSocketImpl* pSock) +bool StateMachine::execute(SecureSocketImpl* pSock) { try { @@ -1737,8 +1798,9 @@ void StateMachine::execute(SecureSocketImpl* pSock) if ((this->*state.first)(pSock->sockfd())) { (pSock->*(state.second))(); - (pSock->getState() == SecureSocketImpl::ST_DONE); + return true; } + else return false; } catch (...) {