From 25ce19427343723875b56314166a36d093482598 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnter=20Obiltschnig?= Date: Sat, 23 Nov 2024 11:10:53 +0100 Subject: [PATCH 1/7] fix(NetSSL): shutdown behavior --- Net/include/Poco/Net/SocketImpl.h | 14 ++++- Net/include/Poco/Net/StreamSocket.h | 14 ++++- Net/include/Poco/Net/WebSocketImpl.h | 4 +- .../HTTPTimeServer/src/HTTPTimeServer.cpp | 3 +- Net/src/SocketImpl.cpp | 6 +- Net/src/StreamSocket.cpp | 8 +-- Net/src/WebSocketImpl.cpp | 8 +-- .../include/Poco/Net/SecureSocketImpl.h | 6 +- .../include/Poco/Net/SecureStreamSocketImpl.h | 19 +++++-- NetSSL_OpenSSL/src/SecureSocketImpl.cpp | 57 ++++--------------- NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp | 7 ++- .../include/Poco/Net/SecureSocketImpl.h | 5 +- .../include/Poco/Net/SecureStreamSocketImpl.h | 19 +++++-- NetSSL_Win/src/SecureSocketImpl.cpp | 18 ++++-- NetSSL_Win/src/SecureStreamSocketImpl.cpp | 3 +- 15 files changed, 104 insertions(+), 87 deletions(-) diff --git a/Net/include/Poco/Net/SocketImpl.h b/Net/include/Poco/Net/SocketImpl.h index 2a71877e1..a2d4c8e62 100644 --- a/Net/include/Poco/Net/SocketImpl.h +++ b/Net/include/Poco/Net/SocketImpl.h @@ -170,12 +170,22 @@ public: virtual void shutdownReceive(); /// Shuts down the receiving part of the socket connection. - virtual void shutdownSend(); + virtual int shutdownSend(); /// Shuts down the sending part of the socket connection. + /// + /// Returns 0 for a non-blocking socket. May return + /// a negative value for a non-blocking socket in case + /// of a TLS connection. In that case, the operation should + /// be retried once the underlying socket becomes writable. - virtual void shutdown(); + virtual int shutdown(); /// Shuts down both the receiving and the sending part /// of the socket connection. + /// + /// Returns 0 for a non-blocking socket. May return + /// a negative value for a non-blocking socket in case + /// of a TLS connection. In that case, the operation should + /// be retried once the underlying socket becomes writable. virtual int sendBytes(const void* buffer, int length, int flags = 0); /// Sends the contents of the given buffer through diff --git a/Net/include/Poco/Net/StreamSocket.h b/Net/include/Poco/Net/StreamSocket.h index 662fb6dd3..fa5706316 100644 --- a/Net/include/Poco/Net/StreamSocket.h +++ b/Net/include/Poco/Net/StreamSocket.h @@ -157,12 +157,22 @@ public: void shutdownReceive(); /// Shuts down the receiving part of the socket connection. - void shutdownSend(); + int shutdownSend(); /// Shuts down the sending part of the socket connection. + /// + /// Returns 0 for a non-blocking socket. May return + /// a negative value for a non-blocking socket in case + /// of a TLS connection. In that case, the operation should + /// be retried once the underlying socket becomes writable. - void shutdown(); + int shutdown(); /// Shuts down both the receiving and the sending part /// of the socket connection. + /// + /// Returns 0 for a non-blocking socket. May return + /// a negative value for a non-blocking socket in case + /// of a TLS connection. In that case, the operation should + /// be retried once the underlying socket becomes writable. int sendBytes(const void* buffer, int length, int flags = 0); /// Sends the contents of the given buffer through diff --git a/Net/include/Poco/Net/WebSocketImpl.h b/Net/include/Poco/Net/WebSocketImpl.h index 4e53d3871..ee8ede76d 100644 --- a/Net/include/Poco/Net/WebSocketImpl.h +++ b/Net/include/Poco/Net/WebSocketImpl.h @@ -69,8 +69,8 @@ public: virtual void listen(int backlog = 64); virtual void close(); virtual void shutdownReceive(); - virtual void shutdownSend(); - virtual void shutdown(); + virtual int shutdownSend(); + virtual int shutdown(); virtual int sendTo(const void* buffer, int length, const SocketAddress& address, int flags = 0); virtual int receiveFrom(void* buffer, int length, SocketAddress& address, int flags = 0); virtual void sendUrgent(unsigned char data); diff --git a/Net/samples/HTTPTimeServer/src/HTTPTimeServer.cpp b/Net/samples/HTTPTimeServer/src/HTTPTimeServer.cpp index 3289c3d1e..9f4ee7667 100644 --- a/Net/samples/HTTPTimeServer/src/HTTPTimeServer.cpp +++ b/Net/samples/HTTPTimeServer/src/HTTPTimeServer.cpp @@ -67,10 +67,11 @@ public: response.setChunkedTransferEncoding(true); response.setContentType("text/html"); + response.set("Clear-Site-Data", "\"cookies\""); std::ostream& ostr = response.send(); ostr << "HTTPTimeServer powered by POCO C++ Libraries"; - ostr << ""; + ostr << ""; ostr << "

"; ostr << dt; ostr << "

"; diff --git a/Net/src/SocketImpl.cpp b/Net/src/SocketImpl.cpp index 579dfc656..b0c828fb3 100644 --- a/Net/src/SocketImpl.cpp +++ b/Net/src/SocketImpl.cpp @@ -327,21 +327,23 @@ void SocketImpl::shutdownReceive() } -void SocketImpl::shutdownSend() +int SocketImpl::shutdownSend() { if (_sockfd == POCO_INVALID_SOCKET) throw InvalidSocketException(); int rc = ::shutdown(_sockfd, 1); if (rc != 0) error(); + return 0; } -void SocketImpl::shutdown() +int SocketImpl::shutdown() { if (_sockfd == POCO_INVALID_SOCKET) throw InvalidSocketException(); int rc = ::shutdown(_sockfd, 2); if (rc != 0) error(); + return 0; } diff --git a/Net/src/StreamSocket.cpp b/Net/src/StreamSocket.cpp index 806c56516..dc8e84858 100644 --- a/Net/src/StreamSocket.cpp +++ b/Net/src/StreamSocket.cpp @@ -146,15 +146,15 @@ void StreamSocket::shutdownReceive() } -void StreamSocket::shutdownSend() +int StreamSocket::shutdownSend() { - impl()->shutdownSend(); + return impl()->shutdownSend(); } -void StreamSocket::shutdown() +int StreamSocket::shutdown() { - impl()->shutdown(); + return impl()->shutdown(); } diff --git a/Net/src/WebSocketImpl.cpp b/Net/src/WebSocketImpl.cpp index 23759d54c..fa2a27fb2 100644 --- a/Net/src/WebSocketImpl.cpp +++ b/Net/src/WebSocketImpl.cpp @@ -539,15 +539,15 @@ void WebSocketImpl::shutdownReceive() } -void WebSocketImpl::shutdownSend() +int WebSocketImpl::shutdownSend() { - _pStreamSocketImpl->shutdownSend(); + return _pStreamSocketImpl->shutdownSend(); } -void WebSocketImpl::shutdown() +int WebSocketImpl::shutdown() { - _pStreamSocketImpl->shutdown(); + return _pStreamSocketImpl->shutdown(); } diff --git a/NetSSL_OpenSSL/include/Poco/Net/SecureSocketImpl.h b/NetSSL_OpenSSL/include/Poco/Net/SecureSocketImpl.h index 10e9c1aee..477718618 100644 --- a/NetSSL_OpenSSL/include/Poco/Net/SecureSocketImpl.h +++ b/NetSSL_OpenSSL/include/Poco/Net/SecureSocketImpl.h @@ -148,10 +148,11 @@ public: /// number of connections that can be queued /// for this socket. - void shutdown(); + int shutdown(); /// Shuts down the connection by attempting /// an orderly SSL shutdown, then actually - /// shutting down the TCP connection. + /// shutting down the TCP connection in the + /// send direction. void close(); /// Close the socket. @@ -294,7 +295,6 @@ private: bool _needHandshake; std::string _peerHostName; Session::Ptr _pSession; - bool _bidirectShutdown = true; mutable MutexT _mutex; friend class SecureStreamSocketImpl; diff --git a/NetSSL_OpenSSL/include/Poco/Net/SecureStreamSocketImpl.h b/NetSSL_OpenSSL/include/Poco/Net/SecureStreamSocketImpl.h index 942a36c45..083399166 100644 --- a/NetSSL_OpenSSL/include/Poco/Net/SecureStreamSocketImpl.h +++ b/NetSSL_OpenSSL/include/Poco/Net/SecureStreamSocketImpl.h @@ -116,14 +116,25 @@ public: /// Since SSL does not support a half shutdown, this does /// nothing. - void shutdownSend() override; + int shutdownSend() override; /// Shuts down the receiving part of the socket connection. /// - /// Since SSL does not support a half shutdown, this does - /// nothing. + /// Sends a close notify shutdown alert message to the peer + /// (if not sent yet), then calls shutdownSend() on the + /// underlying socket. + /// + /// Returns 0 if the message has been sent. + /// Returns 1 if the message has been sent, but the peer + /// has not yet sent its shutdown alert message. + /// In case of a non-blocking socket, returns < 0 if the + /// message cannot be sent at the moment. In this case, + /// the call to shutdownSend() must be retried after the + /// underlying socket becomes writable again. - void shutdown() override; + int shutdown() override; /// Shuts down the SSL connection. + /// + /// Same as shutdownSend(). void abort(); /// Aborts the connection by closing the underlying diff --git a/NetSSL_OpenSSL/src/SecureSocketImpl.cpp b/NetSSL_OpenSSL/src/SecureSocketImpl.cpp index 2d70fd57b..a68cb917a 100644 --- a/NetSSL_OpenSSL/src/SecureSocketImpl.cpp +++ b/NetSSL_OpenSSL/src/SecureSocketImpl.cpp @@ -253,69 +253,33 @@ void SecureSocketImpl::listen(int backlog) } -void SecureSocketImpl::shutdown() +int SecureSocketImpl::shutdown() { if (_pSSL) { UnLockT l(_mutex); - // Don't shut down the socket more than once. int shutdownState = ::SSL_get_shutdown(_pSSL); bool shutdownSent = (shutdownState & SSL_SENT_SHUTDOWN) == SSL_SENT_SHUTDOWN; if (!shutdownSent) { - // A proper clean shutdown would require us to - // retry the shutdown if we get a zero return - // value, until SSL_shutdown() returns 1. - // However, this will lead to problems with - // most web browsers, so we just set the shutdown - // flag by calling SSL_shutdown() once and be - // done with it. -#if OPENSSL_VERSION_NUMBER >= 0x30000000L - int rc = 0; - if (!_bidirectShutdown) - rc = ::SSL_shutdown(_pSSL); - else - { - Poco::Timespan recvTimeout = _pSocket->getReceiveTimeout(); - Poco::Timespan pollTimeout(0, 100000); - Poco::Timestamp tsNow; - do - { - rc = ::SSL_shutdown(_pSSL); - if (rc == 1) break; - if (rc < 0) - { - int err = ::SSL_get_error(_pSSL, rc); - if (err == SSL_ERROR_WANT_READ) - _pSocket->poll(pollTimeout, Poco::Net::Socket::SELECT_READ); - else if (err == SSL_ERROR_WANT_WRITE) - _pSocket->poll(pollTimeout, Poco::Net::Socket::SELECT_WRITE); - else - { - int socketError = SocketImpl::lastError(); - long lastError = ::ERR_get_error(); - if ((err == SSL_ERROR_SSL) && (socketError == 0) && (lastError == 0x0A000123)) - rc = 0; - break; - } - } - else _pSocket->poll(pollTimeout, Poco::Net::Socket::SELECT_READ); - } while (!tsNow.isElapsed(recvTimeout.totalMicroseconds())); - } -#else int rc = ::SSL_shutdown(_pSSL); -#endif - if (rc < 0) handleError(rc); + if (rc < 0) rc = handleError(rc); l.unlock(); - if (_pSocket->getBlocking()) + if (rc >= 0) { - _pSocket->shutdown(); + _pSocket->shutdownSend(); } + return rc; + } + else + { + return (shutdownState & SSL_RECEIVED_SHUTDOWN) == SSL_RECEIVED_SHUTDOWN; } } + return 1; } @@ -407,7 +371,6 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) if (tsStart.isElapsed(recvTimeout.totalMicroseconds())) throw Poco::TimeoutException(); }; - _bidirectShutdown = false; if (rc <= 0) { return handleError(rc); diff --git a/NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp b/NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp index 289279e37..3523831e8 100644 --- a/NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp +++ b/NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp @@ -156,14 +156,15 @@ void SecureStreamSocketImpl::shutdownReceive() } -void SecureStreamSocketImpl::shutdownSend() +int SecureStreamSocketImpl::shutdownSend() { + return _impl.shutdown(); } -void SecureStreamSocketImpl::shutdown() +int SecureStreamSocketImpl::shutdown() { - _impl.shutdown(); + return _impl.shutdown(); } diff --git a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h index b3fa77c29..9fc60b4b3 100644 --- a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h +++ b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h @@ -154,10 +154,11 @@ public: /// number of connections that can be queued /// for this socket. - void shutdown(); + int shutdown(); /// Shuts down the connection by attempting /// an orderly SSL shutdown, then actually - /// shutting down the TCP connection. + /// shutting down the TCP connection in the + /// send direction. void close(); /// Close the socket. diff --git a/NetSSL_Win/include/Poco/Net/SecureStreamSocketImpl.h b/NetSSL_Win/include/Poco/Net/SecureStreamSocketImpl.h index 243f0a28a..6760d6b7d 100644 --- a/NetSSL_Win/include/Poco/Net/SecureStreamSocketImpl.h +++ b/NetSSL_Win/include/Poco/Net/SecureStreamSocketImpl.h @@ -117,14 +117,25 @@ public: /// Since SSL does not support a half shutdown, this does /// nothing. - void shutdownSend(); + int shutdownSend(); /// Shuts down the receiving part of the socket connection. /// - /// Since SSL does not support a half shutdown, this does - /// nothing. + /// Sends a close notify shutdown alert message to the peer + /// (if not sent yet), then calls shutdownSend() on the + /// underlying socket. + /// + /// Returns 0 if the message has been sent. + /// Returns 1 if the message has been sent, but the peer + /// has not yet sent its shutdown alert message. + /// In case of a non-blocking socket, returns < 0 if the + /// message cannot be sent at the moment. In this case, + /// the call to shutdownSend() must be retried after the + /// underlying socket becomes writable again. - void shutdown(); + int shutdown(); /// Shuts down the SSL connection. + /// + /// Same as shutdownSend(). void abort(); /// Aborts the connection by closing the underlying diff --git a/NetSSL_Win/src/SecureSocketImpl.cpp b/NetSSL_Win/src/SecureSocketImpl.cpp index fb62c3ea4..c9ea76b83 100644 --- a/NetSSL_Win/src/SecureSocketImpl.cpp +++ b/NetSSL_Win/src/SecureSocketImpl.cpp @@ -265,7 +265,7 @@ void SecureSocketImpl::listen(int backlog) } -void SecureSocketImpl::shutdown() +int SecureSocketImpl::shutdown() { if (_mode == MODE_SERVER) serverDisconnect(&_hCreds, &_hContext); @@ -273,16 +273,22 @@ void SecureSocketImpl::shutdown() clientDisconnect(&_hCreds, &_hContext); _pSocket->shutdown(); + return 0; } void SecureSocketImpl::close() { - if (_mode == MODE_SERVER) - serverDisconnect(&_hCreds, &_hContext); - else - clientDisconnect(&_hCreds, &_hContext); - + try + { + if (_mode == MODE_SERVER) + serverDisconnect(&_hCreds, &_hContext); + else + clientDisconnect(&_hCreds, &_hContext); + } + catch (Poco::Exception&) + { + } _pSocket->close(); cleanup(); } diff --git a/NetSSL_Win/src/SecureStreamSocketImpl.cpp b/NetSSL_Win/src/SecureStreamSocketImpl.cpp index e90a03059..ca7048c9b 100644 --- a/NetSSL_Win/src/SecureStreamSocketImpl.cpp +++ b/NetSSL_Win/src/SecureStreamSocketImpl.cpp @@ -151,12 +151,13 @@ void SecureStreamSocketImpl::shutdownReceive() void SecureStreamSocketImpl::shutdownSend() { + return _impl.shutdown(); } void SecureStreamSocketImpl::shutdown() { - _impl.shutdown(); + return _impl.shutdown(); } From e911884753d0a8698f3da0372a8a7649b9eec985 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnter=20Obiltschnig?= Date: Sat, 23 Nov 2024 15:38:13 +0100 Subject: [PATCH 2/7] fix(NetSSL_Win): shutdown behavior --- .../include/Poco/Net/SecureSocketImpl.h | 11 +++- NetSSL_Win/src/HTTPSStreamFactory.cpp | 2 +- NetSSL_Win/src/SecureSocketImpl.cpp | 55 ++++++++++++------- NetSSL_Win/src/SecureStreamSocketImpl.cpp | 4 +- 4 files changed, 48 insertions(+), 24 deletions(-) diff --git a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h index 9fc60b4b3..b1502cf01 100644 --- a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h +++ b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h @@ -236,6 +236,12 @@ protected: ST_ERROR }; + enum TLSShutdown + { + TLS_SHUTDOWN_SENT = 1, + TLS_SHUTDOWN_RECEIVED = 2 + }; + int sendRawBytes(const void* buffer, int length, int flags = 0); int receiveRawBytes(void* buffer, int length, int flags = 0); void clientConnectVerify(); @@ -245,8 +251,8 @@ protected: void clientVerifyCertificate(const std::string& hostName); void verifyCertificateChainClient(PCCERT_CONTEXT pServerCert); void serverVerifyCertificate(); - LONG serverDisconnect(PCredHandle phCreds, CtxtHandle* phContext); - LONG clientDisconnect(PCredHandle phCreds, CtxtHandle* phContext); + int serverShutdown(PCredHandle phCreds, CtxtHandle* phContext); + int clientShutdown(PCredHandle phCreds, CtxtHandle* phContext); bool loadSecurityLibrary(); void initClientContext(); void initServerContext(); @@ -286,6 +292,7 @@ private: Poco::AutoPtr _pSocket; Context::Ptr _pContext; Mode _mode; + int _shutdownFlags; std::string _peerHostName; bool _useMachineStore; bool _clientAuthRequired; diff --git a/NetSSL_Win/src/HTTPSStreamFactory.cpp b/NetSSL_Win/src/HTTPSStreamFactory.cpp index ec1d43f9e..b2cb86821 100644 --- a/NetSSL_Win/src/HTTPSStreamFactory.cpp +++ b/NetSSL_Win/src/HTTPSStreamFactory.cpp @@ -138,7 +138,7 @@ std::istream* HTTPSStreamFactory::open(const URI& uri) { return new HTTPResponseStream(rs, pSession); } - else if (res.getStatus() == HTTPResponse::HTTP_USEPROXY && !retry) + else if (res.getStatus() == HTTPResponse::HTTP_USE_PROXY && !retry) { // The requested resource MUST be accessed through the proxy // given by the Location field. The Location field gives the diff --git a/NetSSL_Win/src/SecureSocketImpl.cpp b/NetSSL_Win/src/SecureSocketImpl.cpp index c9ea76b83..4e9a3b98a 100644 --- a/NetSSL_Win/src/SecureSocketImpl.cpp +++ b/NetSSL_Win/src/SecureSocketImpl.cpp @@ -62,6 +62,7 @@ SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr pSocketImpl, Contex _pSocket(pSocketImpl), _pContext(pContext), _mode(pContext->isForServerUse() ? MODE_SERVER : MODE_CLIENT), + _shutdownFlags(0), _clientAuthRequired(pContext->verificationMode() >= Context::VERIFY_STRICT), _securityFunctions(SSLManager::instance().securityFunctions()), _pOwnCertificate(0), @@ -267,13 +268,24 @@ void SecureSocketImpl::listen(int backlog) int SecureSocketImpl::shutdown() { - if (_mode == MODE_SERVER) - serverDisconnect(&_hCreds, &_hContext); - else - clientDisconnect(&_hCreds, &_hContext); - - _pSocket->shutdown(); - return 0; + int rc = 0; + if ((_shutdownFlags & TLS_SHUTDOWN_SENT) == 0) + { + if (_mode == MODE_SERVER) + { + rc = serverShutdown(&_hCreds, &_hContext); + } + else + { + rc = clientShutdown(&_hCreds, &_hContext); + } + if (rc >= 0) + { + _pSocket->shutdownSend(); + _shutdownFlags |= TLS_SHUTDOWN_SENT; + } + } + return (_shutdownFlags & TLS_SHUTDOWN_RECEIVED) ? 1 : 0; } @@ -281,10 +293,7 @@ void SecureSocketImpl::close() { try { - if (_mode == MODE_SERVER) - serverDisconnect(&_hCreds, &_hContext); - else - clientDisconnect(&_hCreds, &_hContext); + shutdown(); } catch (Poco::Exception&) { @@ -577,13 +586,12 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) if (securityStatus == SEC_I_CONTEXT_EXPIRED) { - SetLastError(securityStatus); + _shutdownFlags |= TLS_SHUTDOWN_RECEIVED; break; } - if (securityStatus != SEC_E_OK && securityStatus != SEC_I_RENEGOTIATE && securityStatus != SEC_I_CONTEXT_EXPIRED) + if (securityStatus != SEC_E_OK && securityStatus != SEC_I_RENEGOTIATE) { - SetLastError(securityStatus); break; } @@ -626,6 +634,7 @@ SECURITY_STATUS SecureSocketImpl::decodeMessage(BYTE* pBuffer, DWORD bufSize, Au pExtraBuffer = 0; SECURITY_STATUS securityStatus = _securityFunctions.DecryptMessage(&_hContext, &msg, 0, 0); + // TODO: when decrypting the close_notify alert, returns SEC_E_DECRYPT_FAILURE if (securityStatus == SEC_E_OK || securityStatus == SEC_I_RENEGOTIATE) { @@ -1493,11 +1502,11 @@ void SecureSocketImpl::serverVerifyCertificate() } -LONG SecureSocketImpl::clientDisconnect(PCredHandle phCreds, CtxtHandle* phContext) +int SecureSocketImpl::clientShutdown(PCredHandle phCreds, CtxtHandle* phContext) { if (phContext->dwLower == 0 && phContext->dwUpper == 0) { - return SEC_E_OK; + return 0; } AutoSecBufferDesc<1> tokBuffer(&_securityFunctions, false); @@ -1506,7 +1515,7 @@ LONG SecureSocketImpl::clientDisconnect(PCredHandle phCreds, CtxtHandle* phConte tokBuffer.setSecBufferToken(0, &tokenType, sizeof(tokenType)); DWORD status = _securityFunctions.ApplyControlToken(phContext, &tokBuffer); - if (FAILED(status)) return status; + if (FAILED(status)) throw SSLException(Utility::formatError(status)); DWORD sspiFlags = ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT @@ -1534,11 +1543,19 @@ LONG SecureSocketImpl::clientDisconnect(PCredHandle phCreds, CtxtHandle* phConte &sspiOutFlags, &expiry); - return status; + if (FAILED(status)) throw SSLException(Utility::formatError(status)); + + if (outBuffer[0].pvBuffer && outBuffer[0].cbBuffer) + { + int sent = sendRawBytes(outBuffer[0].pvBuffer, outBuffer[0].cbBuffer); + return sent; + } + + return 0; } -LONG SecureSocketImpl::serverDisconnect(PCredHandle phCreds, CtxtHandle* phContext) +int SecureSocketImpl::serverShutdown(PCredHandle phCreds, CtxtHandle* phContext) { if (phContext->dwLower == 0 && phContext->dwUpper == 0) { diff --git a/NetSSL_Win/src/SecureStreamSocketImpl.cpp b/NetSSL_Win/src/SecureStreamSocketImpl.cpp index ca7048c9b..add82ca23 100644 --- a/NetSSL_Win/src/SecureStreamSocketImpl.cpp +++ b/NetSSL_Win/src/SecureStreamSocketImpl.cpp @@ -149,13 +149,13 @@ void SecureStreamSocketImpl::shutdownReceive() } -void SecureStreamSocketImpl::shutdownSend() +int SecureStreamSocketImpl::shutdownSend() { return _impl.shutdown(); } -void SecureStreamSocketImpl::shutdown() +int SecureStreamSocketImpl::shutdown() { return _impl.shutdown(); } From a48633b567893a1f4eadd553b1bdb90015a89f2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnter=20Obiltschnig?= Date: Sun, 24 Nov 2024 15:22:17 +0100 Subject: [PATCH 3/7] chore(NetSSL_Win): refactored state machine (wip) --- .../include/Poco/Net/SecureSocketImpl.h | 51 +-- NetSSL_Win/src/SecureSocketImpl.cpp | 312 +++++++++++------- 2 files changed, 219 insertions(+), 144 deletions(-) diff --git a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h index b1502cf01..6a52239dc 100644 --- a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h +++ b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h @@ -224,16 +224,23 @@ protected: enum State { ST_INITIAL = 0, + // Client ST_CONNECTING, - ST_CLIENTHANDSHAKESTART, - ST_CLIENTHANDSHAKECONDREAD, - ST_CLIENTHANDSHAKEINCOMPLETE, - ST_CLIENTHANDSHAKEOK, - ST_CLIENTHANDSHAKEEXTERROR, - ST_CLIENTHANDSHAKECONTINUE, - ST_VERIFY, + ST_CLIENT_HSK_START, + ST_CLIENT_HSK_SEND_TOKEN, + ST_CLIENT_HSK_LOOP_INIT, + ST_CLIENT_HSK_LOOP_RECV, + ST_CLIENT_HSK_LOOP_PROCESS, + ST_CLIENT_HSK_LOOP_CONTINUE, + ST_CLIENT_HSK_LOOP_SEND, + ST_CLIENT_HSK_LOOP_INCOMPLETE, + ST_CLIENT_HSK_LOOP_DONE, + ST_CLIENT_HSK_SEND_FINAL, + ST_CLIENT_HSK_SEND_ERROR, + ST_CLIENT_VERIFY, ST_DONE, - ST_ERROR + ST_ERROR, + ST_MAX }; enum TLSShutdown @@ -245,7 +252,6 @@ protected: int sendRawBytes(const void* buffer, int length, int flags = 0); int receiveRawBytes(void* buffer, int length, int flags = 0); void clientConnectVerify(); - void sendInitialTokenOutBuffer(); void performServerHandshake(); bool serverHandshakeLoop(PCtxtHandle phContext, PCredHandle phCred, bool requireClientAuth, bool doInitialRead, bool newContext); void clientVerifyCertificate(const std::string& hostName); @@ -253,25 +259,32 @@ protected: void serverVerifyCertificate(); int serverShutdown(PCredHandle phCreds, CtxtHandle* phContext); int clientShutdown(PCredHandle phCreds, CtxtHandle* phContext); - bool loadSecurityLibrary(); void initClientContext(); void initServerContext(); PCCERT_CONTEXT loadCertificate(bool mustFindCertificate); void initCommon(); void cleanup(); - void performClientHandshake(); + + void performClientHandshakeStart(); void performInitialClientHandshake(); - SECURITY_STATUS performClientHandshakeLoop(); + void performClientHandshakeSendToken(); void performClientHandshakeLoopIncompleteMessage(); - void performClientHandshakeLoopCondReceive(); - void performClientHandshakeLoopReceive(); - void performClientHandshakeLoopOK(); void performClientHandshakeLoopInit(); + void performClientHandshakeLoopRecv(); + void performClientHandshakeLoopProcess(); + void performClientHandshakeLoopSend(); + void performClientHandshakeLoopDone(); + void performClientHandshakeSendFinal(); void performClientHandshakeExtraBuffer(); - void performClientHandshakeSendOutBuffer(); - void performClientHandshakeLoopContinueNeeded(); - void performClientHandshakeLoopError(); - void performClientHandshakeLoopExtError(); + void performClientHandshakeLoopContinue(); + void performClientHandshakeError(); + void performClientHandshakeSendError(); + void sendOutSecBufferAndAdvanceState(State state); + + void performClientHandshake(); + SECURITY_STATUS performClientHandshakeLoop(); + void performClientHandshakeLoopCondReceive(); + SECURITY_STATUS decodeMessage(BYTE* pBuffer, DWORD bufSize, AutoSecBufferDesc<4>& msg, SecBuffer*& pData, SecBuffer*& pExtra); SECURITY_STATUS decodeBufferFull(BYTE* pBuffer, DWORD bufSize, char* pOutBuffer, int outLength, int& bytesDecoded); void stateIllegal(); diff --git a/NetSSL_Win/src/SecureSocketImpl.cpp b/NetSSL_Win/src/SecureSocketImpl.cpp index 4e9a3b98a..4a10a7154 100644 --- a/NetSSL_Win/src/SecureSocketImpl.cpp +++ b/NetSSL_Win/src/SecureSocketImpl.cpp @@ -47,7 +47,7 @@ public: bool none(SOCKET sockfd); void select(fd_set* fdRead, fd_set* fdWrite, SOCKET sockfd); - void execute(SecureSocketImpl* pSock); + bool execute(SecureSocketImpl* pSock); private: StateMachine(const StateMachine&); @@ -192,19 +192,19 @@ SocketImpl* SecureSocketImpl::acceptConnection(SocketAddress& clientAddr) void SecureSocketImpl::connect(const SocketAddress& address, bool performHandshake) { - _state = ST_ERROR; + setState(ST_ERROR); _pSocket->connect(address); connectSSL(performHandshake); - _state = ST_DONE; + setState(ST_DONE); } void SecureSocketImpl::connect(const SocketAddress& address, const Poco::Timespan& timeout, bool performHandshake) { - _state = ST_ERROR; + setState(ST_ERROR); _pSocket->connect(address, timeout); connectSSL(performHandshake); - _state = ST_DONE; + setState(ST_DONE); } @@ -212,12 +212,12 @@ void SecureSocketImpl::connectNB(const SocketAddress& address) { try { - _state = ST_CONNECTING; + setState(ST_CONNECTING); _pSocket->connectNB(address); } catch (...) { - _state = ST_ERROR; + setState(ST_ERROR); } } @@ -318,7 +318,7 @@ int SecureSocketImpl::available() const void SecureSocketImpl::acceptSSL() { - _state = ST_DONE; + setState(ST_DONE); initServerContext(); _needHandshake = true; } @@ -379,17 +379,6 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) if (_state == ST_ERROR) return 0; - if (_sendBufferPending > 0) - { - int sent = sendRawBytes(_sendBuffer.begin() + _sendBufferOffset, _sendBufferPending, flags); - if (sent > 0) - { - _sendBufferOffset += sent; - _sendBufferPending -= sent; - } - return _sendBufferPending == 0 ? length : -1; - } - if (_state != ST_DONE) { bool establish = _pSocket->getBlocking(); @@ -403,10 +392,21 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) else { stateMachine(); - return -1; + if (_state != ST_DONE) return -1; } } + if (_sendBufferPending > 0) + { + int sent = sendRawBytes(_sendBuffer.begin() + _sendBufferOffset, _sendBufferPending, flags); + if (sent > 0) + { + _sendBufferOffset += sent; + _sendBufferPending -= sent; + } + return _sendBufferPending == 0 ? length : -1; + } + int dataToSend = length; int dataSent = 0; const char* pBuffer = reinterpret_cast(buffer); @@ -490,7 +490,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) else { stateMachine(); - return -1; + if (_state != ST_DONE) return -1; } } @@ -598,7 +598,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) if (securityStatus == SEC_I_RENEGOTIATE) { _needData = false; - _state = ST_CLIENTHANDSHAKECONDREAD; + setState(ST_CLIENT_HSK_LOOP_INIT); if (!_pSocket->getBlocking()) return -1; @@ -831,7 +831,7 @@ void SecureSocketImpl::clientConnectVerify() throw SSLException("Failed to query stream sizes", Utility::formatError(securityStatus)); _ioBufferSize = _streamSizes.cbHeader + _streamSizes.cbMaximumMessage + _streamSizes.cbTrailer; - _state = ST_DONE; + setState(ST_DONE); } catch (...) { @@ -855,6 +855,7 @@ void SecureSocketImpl::initClientContext() void SecureSocketImpl::performClientHandshake() { performInitialClientHandshake(); + performClientHandshakeSendToken(); performClientHandshakeLoop(); clientConnectVerify(); } @@ -897,38 +898,47 @@ void SecureSocketImpl::performInitialClientHandshake() } } - // incomplete credentials: more calls to InitializeSecurityContext needed - // send the token - sendInitialTokenOutBuffer(); + _extraSecBuffer.pvBuffer = 0; + _extraSecBuffer.cbBuffer = 0; + _needData = true; + + if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + { + setState(ST_CLIENT_HSK_SEND_TOKEN); + } + else if (_securityStatus == SEC_E_OK) + { + setState(ST_DONE); + } +} + + +void SecureSocketImpl::performClientHandshakeSendToken() +{ + poco_assert_dbg (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer); + + int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); + if (sent < 0) + { + return; + } + else if (sent < _outSecBuffer[0].cbBuffer) + { + _outSecBuffer[0].cbBuffer -= sent; + std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); + return; + } if (_securityStatus == SEC_E_OK) { // The security context was successfully initialized. // There is no need for another InitializeSecurityContext (Schannel) call. - _state = ST_DONE; - return; + setState(ST_DONE); } - - //SEC_I_CONTINUE_NEEDED was returned: - // Wait for a return token. The returned token is then passed in - // another call to InitializeSecurityContext (Schannel). The output token can be empty. - - _extraSecBuffer.pvBuffer = 0; - _extraSecBuffer.cbBuffer = 0; - _needData = true; - _state = ST_CLIENTHANDSHAKECONDREAD; - _securityStatus = SEC_E_INCOMPLETE_MESSAGE; -} - - -void SecureSocketImpl::sendInitialTokenOutBuffer() -{ - // send the token - if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + else { - int numBytes = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); - if (numBytes != _outSecBuffer[0].cbBuffer) - throw SSLException("Failed to send token to the server"); + setState(ST_CLIENT_HSK_LOOP_INIT); + _securityStatus = SEC_E_INCOMPLETE_MESSAGE; } } @@ -944,11 +954,13 @@ SECURITY_STATUS SecureSocketImpl::performClientHandshakeLoop() if (_securityStatus == SEC_E_OK) { - performClientHandshakeLoopOK(); + performClientHandshakeLoopDone(); + performClientHandshakeSendFinal(); } else if (_securityStatus == SEC_I_CONTINUE_NEEDED) { - performClientHandshakeLoopContinueNeeded(); + performClientHandshakeLoopContinue(); + performClientHandshakeLoopSend(); } else if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE) { @@ -958,11 +970,11 @@ SECURITY_STATUS SecureSocketImpl::performClientHandshakeLoop() { if (_outFlags & ISC_RET_EXTENDED_ERROR) { - performClientHandshakeLoopExtError(); + performClientHandshakeSendError(); } else { - performClientHandshakeLoopError(); + performClientHandshakeError(); } } else @@ -973,44 +985,30 @@ SECURITY_STATUS SecureSocketImpl::performClientHandshakeLoop() if (FAILED(_securityStatus)) { - performClientHandshakeLoopError(); + performClientHandshakeError(); } return _securityStatus; } -void SecureSocketImpl::performClientHandshakeLoopExtError() +void SecureSocketImpl::performClientHandshakeSendError() { poco_assert_dbg (FAILED(_securityStatus)); - performClientHandshakeSendOutBuffer(); - performClientHandshakeLoopError(); + sendOutSecBufferAndAdvanceState(ST_ERROR); } -void SecureSocketImpl::performClientHandshakeLoopError() +void SecureSocketImpl::performClientHandshakeError() { poco_assert_dbg (FAILED(_securityStatus)); cleanup(); - _state = ST_ERROR; + setState(ST_ERROR); throw SSLException("Error during handshake", Utility::formatError(_securityStatus)); } -void SecureSocketImpl::performClientHandshakeSendOutBuffer() -{ - if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) - { - int numBytes = sendRawBytes(static_cast(_outSecBuffer[0].pvBuffer), _outSecBuffer[0].cbBuffer); - if (numBytes != _outSecBuffer[0].cbBuffer) - throw SSLException("Socket error during handshake"); - - _outSecBuffer.release(0); - } -} - - void SecureSocketImpl::performClientHandshakeExtraBuffer() { if (_inSecBuffer[1].BufferType == SECBUFFER_EXTRA) @@ -1022,48 +1020,67 @@ void SecureSocketImpl::performClientHandshakeExtraBuffer() } -void SecureSocketImpl::performClientHandshakeLoopOK() +void SecureSocketImpl::performClientHandshakeLoopDone() { poco_assert_dbg(_securityStatus == SEC_E_OK); - performClientHandshakeSendOutBuffer(); performClientHandshakeExtraBuffer(); - _state = ST_VERIFY; + if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + { + setState(ST_CLIENT_HSK_SEND_FINAL); + } + else + { + setState(ST_CLIENT_VERIFY); + } +} + + +void SecureSocketImpl::performClientHandshakeSendFinal() +{ + sendOutSecBufferAndAdvanceState(ST_CLIENT_VERIFY); } void SecureSocketImpl::performClientHandshakeLoopInit() { + poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED); + _inSecBuffer.reset(false); _outSecBuffer.reset(true); + + if (_needData) + { + if (_recvBuffer.capacity() != IO_BUFFER_SIZE) + { + _recvBuffer.setCapacity(IO_BUFFER_SIZE); + } + setState(ST_CLIENT_HSK_LOOP_RECV); + } + else + { + setState(ST_CLIENT_HSK_LOOP_PROCESS); + } } -void SecureSocketImpl::performClientHandshakeLoopReceive() +void SecureSocketImpl::performClientHandshakeLoopRecv() { poco_assert_dbg (_needData); poco_assert (IO_BUFFER_SIZE > _recvBufferOffset); int n = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, IO_BUFFER_SIZE - _recvBufferOffset); - if (n <= 0) throw SSLException("Error during handshake: failed to read data"); - - _recvBufferOffset += n; + if (n == 0) throw SSLException("Connection unexpectedly closed during handshake"); + if (n > 0) + { + _recvBufferOffset += n; + setState(ST_CLIENT_HSK_LOOP_PROCESS); + } } -void SecureSocketImpl::performClientHandshakeLoopCondReceive() +void SecureSocketImpl::performClientHandshakeLoopProcess() { - poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED); - - performClientHandshakeLoopInit(); - if (_needData) - { - if (_recvBuffer.capacity() != IO_BUFFER_SIZE) - _recvBuffer.setCapacity(IO_BUFFER_SIZE); - performClientHandshakeLoopReceive(); - } - else _needData = true; - _inSecBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset); // inbuffer 1 should be empty _inSecBuffer.setSecBufferEmpty(1); @@ -1089,38 +1106,82 @@ void SecureSocketImpl::performClientHandshakeLoopCondReceive() if (_securityStatus == SEC_E_OK) { - _state = ST_CLIENTHANDSHAKEOK; + setState(ST_CLIENT_HSK_LOOP_DONE); } else if (_securityStatus == SEC_I_CONTINUE_NEEDED) { - _state = ST_CLIENTHANDSHAKECONTINUE; + setState(ST_CLIENT_HSK_LOOP_CONTINUE); } else if (FAILED(_securityStatus)) { if (_outFlags & ISC_RET_EXTENDED_ERROR) - _state = ST_CLIENTHANDSHAKEEXTERROR; + setState(ST_CLIENT_HSK_SEND_ERROR); else - _state = ST_ERROR; + setState(ST_ERROR); } else { - _state = ST_CLIENTHANDSHAKEINCOMPLETE; + setState(ST_CLIENT_HSK_LOOP_INCOMPLETE); } } -void SecureSocketImpl::performClientHandshakeLoopContinueNeeded() +void SecureSocketImpl::performClientHandshakeLoopCondReceive() +{ + poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED); + + performClientHandshakeLoopInit(); + + if (_state == ST_CLIENT_HSK_LOOP_RECV) + { + performClientHandshakeLoopRecv(); + } + + performClientHandshakeLoopProcess(); +} + + +void SecureSocketImpl::performClientHandshakeLoopContinue() { - performClientHandshakeSendOutBuffer(); performClientHandshakeExtraBuffer(); - _state = ST_CLIENTHANDSHAKECONDREAD; + setState(ST_CLIENT_HSK_LOOP_SEND); +} + + +void SecureSocketImpl::performClientHandshakeLoopSend() +{ + sendOutSecBufferAndAdvanceState(ST_CLIENT_HSK_LOOP_INIT); +} + + +void SecureSocketImpl::sendOutSecBufferAndAdvanceState(State state) +{ + if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + { + int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); + if (sent < 0) return; + if (sent < _outSecBuffer[0].cbBuffer) + { + _outSecBuffer[0].cbBuffer -= sent; + std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); + } + else + { + _outSecBuffer.release(0); + setState(state); + } + } + else + { + setState(state); + } } void SecureSocketImpl::performClientHandshakeLoopIncompleteMessage() { _needData = true; - _state = ST_CLIENTHANDSHAKECONDREAD; + setState(ST_CLIENT_HSK_LOOP_INIT); } @@ -1619,6 +1680,12 @@ void SecureSocketImpl::stateIllegal() void SecureSocketImpl::stateConnected() { _peerHostName = _pSocket->peerAddress().host().toString(); + setState(ST_CLIENT_HSK_START); +} + + +void SecureSocketImpl::performClientHandshakeStart() +{ initClientContext(); performInitialClientHandshake(); } @@ -1700,34 +1767,28 @@ void StateMachine::select(fd_set* fdRead, fd_set* fdWrite, SOCKET sockfd) StateMachine::StateMachine(): - _states() -{ - //ST_INITIAL: 0, -> this one is illegal, you must call connectNB before - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal)); - //ST_CONNECTING: connectNB was called, check if the socket is already available for writing - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::stateConnected)); - //ST_ESTABLISHTUNNELRECEIVED: we got the response, now start the handshake - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performInitialClientHandshake)); - //ST_CLIENTHANDSHAKECONDREAD: condread - _states.push_back(std::make_pair(&StateMachine::readable, &SecureSocketImpl::performClientHandshakeLoopCondReceive)); - //ST_CLIENTHANDSHAKEINCOMPLETE, - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopIncompleteMessage)); - //ST_CLIENTHANDSHAKEOK, - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopOK)); - //ST_CLIENTHANDSHAKEEXTERROR, - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopExtError)); - //ST_CLIENTHANDSHAKECONTINUE, - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopContinueNeeded)); - //ST_VERIFY, - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::clientConnectVerify)); - //ST_DONE, - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal)); - //ST_ERROR - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopError)); + _states(SecureSocketImpl::ST_MAX) +{ + _states[SecureSocketImpl::ST_INITIAL] = std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal); + _states[SecureSocketImpl::ST_CONNECTING] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::stateConnected); + _states[SecureSocketImpl::ST_CLIENT_HSK_START] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeStart); + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_TOKEN] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeSendToken); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INIT] = std::make_pair(&StateMachine::readable, &SecureSocketImpl::performClientHandshakeLoopInit); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_RECV] = std::make_pair(&StateMachine::readable, &SecureSocketImpl::performClientHandshakeLoopRecv); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_PROCESS] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopProcess); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_SEND] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopSend); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INCOMPLETE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopIncompleteMessage); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_CONTINUE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopContinue); + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_DONE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopDone); + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_FINAL] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeSendFinal); + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_ERROR] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeSendError); + _states[SecureSocketImpl::ST_CLIENT_VERIFY] = std::make_pair(&StateMachine::none, &SecureSocketImpl::clientConnectVerify); + _states[SecureSocketImpl::ST_DONE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal); + _states[SecureSocketImpl::ST_ERROR] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeError); } -void StateMachine::execute(SecureSocketImpl* pSock) +bool StateMachine::execute(SecureSocketImpl* pSock) { try { @@ -1737,8 +1798,9 @@ void StateMachine::execute(SecureSocketImpl* pSock) if ((this->*state.first)(pSock->sockfd())) { (pSock->*(state.second))(); - (pSock->getState() == SecureSocketImpl::ST_DONE); + return true; } + else return false; } catch (...) { From 3c1c92ce50e51d570c001c8d8075d2febb875452 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnter=20Obiltschnig?= Date: Sun, 24 Nov 2024 16:22:57 +0100 Subject: [PATCH 4/7] chore(NetSSL_Win): remove select() calls from state machine --- NetSSL_Win/src/SecureSocketImpl.cpp | 121 ++++++---------------------- 1 file changed, 23 insertions(+), 98 deletions(-) diff --git a/NetSSL_Win/src/SecureSocketImpl.cpp b/NetSSL_Win/src/SecureSocketImpl.cpp index 4a10a7154..ee861e8f4 100644 --- a/NetSSL_Win/src/SecureSocketImpl.cpp +++ b/NetSSL_Win/src/SecureSocketImpl.cpp @@ -32,7 +32,6 @@ namespace Net { class StateMachine { public: - typedef bool (StateMachine::*ConditionMethod)(SOCKET sockfd); typedef void (SecureSocketImpl::*StateImpl)(void); StateMachine(); @@ -40,21 +39,13 @@ public: static StateMachine& instance(); - // Conditions - bool readable(SOCKET sockfd); - bool writable(SOCKET sockfd); - bool readOrWritable(SOCKET sockfd); - bool none(SOCKET sockfd); - void select(fd_set* fdRead, fd_set* fdWrite, SOCKET sockfd); - bool execute(SecureSocketImpl* pSock); private: - StateMachine(const StateMachine&); - StateMachine& operator = (const StateMachine&); + StateMachine(const StateMachine&) = delete; + StateMachine& operator = (const StateMachine&) = delete; - typedef std::pair ConditionState; - std::vector _states; + std::vector _states; }; @@ -1704,87 +1695,25 @@ StateMachine& StateMachine::instance() } -bool StateMachine::readable(SOCKET sockfd) -{ - fd_set fdRead; - FD_ZERO(&fdRead); - FD_SET(sockfd, &fdRead); - select(&fdRead, 0, sockfd); - return (FD_ISSET(sockfd, &fdRead) != 0); -} - - -bool StateMachine::writable(SOCKET sockfd) -{ - fd_set fdWrite; - FD_ZERO(&fdWrite); - FD_SET(sockfd, &fdWrite); - select(0, &fdWrite, sockfd); - return (FD_ISSET(sockfd, &fdWrite) != 0); -} - - -bool StateMachine::readOrWritable(SOCKET sockfd) -{ - fd_set fdRead, fdWrite; - FD_ZERO(&fdRead); - FD_SET(sockfd, &fdRead); - fdWrite = fdRead; - select(&fdRead, &fdWrite, sockfd); - return (FD_ISSET(sockfd, &fdRead) != 0 || FD_ISSET(sockfd, &fdWrite) != 0); -} - - -bool StateMachine::none(SOCKET sockfd) -{ - return true; -} - - -void StateMachine::select(fd_set* fdRead, fd_set* fdWrite, SOCKET sockfd) -{ - Poco::Timespan remainingTime(((Poco::Timestamp::TimeDiff)SecureSocketImpl::TIMEOUT_MILLISECS)*1000); - int rc(0); - do - { - struct timeval tv; - tv.tv_sec = (long) remainingTime.totalSeconds(); - tv.tv_usec = (long) remainingTime.useconds(); - Poco::Timestamp start; - rc = ::select(int(sockfd) + 1, fdRead, fdWrite, 0, &tv); - if (rc < 0 && SecureSocketImpl::lastError() == POCO_EINTR) - { - Poco::Timestamp end; - Poco::Timespan waited = end - start; - if (waited < remainingTime) - remainingTime -= waited; - else - remainingTime = 0; - } - } - while (rc < 0 && SecureSocketImpl::lastError() == POCO_EINTR); -} - - StateMachine::StateMachine(): _states(SecureSocketImpl::ST_MAX) { - _states[SecureSocketImpl::ST_INITIAL] = std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal); - _states[SecureSocketImpl::ST_CONNECTING] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::stateConnected); - _states[SecureSocketImpl::ST_CLIENT_HSK_START] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeStart); - _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_TOKEN] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeSendToken); - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INIT] = std::make_pair(&StateMachine::readable, &SecureSocketImpl::performClientHandshakeLoopInit); - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_RECV] = std::make_pair(&StateMachine::readable, &SecureSocketImpl::performClientHandshakeLoopRecv); - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_PROCESS] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopProcess); - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_SEND] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopSend); - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INCOMPLETE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopIncompleteMessage); - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_CONTINUE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopContinue); - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_DONE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopDone); - _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_FINAL] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeSendFinal); - _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_ERROR] = std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeSendError); - _states[SecureSocketImpl::ST_CLIENT_VERIFY] = std::make_pair(&StateMachine::none, &SecureSocketImpl::clientConnectVerify); - _states[SecureSocketImpl::ST_DONE] = std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal); - _states[SecureSocketImpl::ST_ERROR] = std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeError); + _states[SecureSocketImpl::ST_INITIAL] = &SecureSocketImpl::stateIllegal; + _states[SecureSocketImpl::ST_CONNECTING] = &SecureSocketImpl::stateConnected; + _states[SecureSocketImpl::ST_CLIENT_HSK_START] = &SecureSocketImpl::performClientHandshakeStart; + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_TOKEN] = &SecureSocketImpl::performClientHandshakeSendToken; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INIT] = &SecureSocketImpl::performClientHandshakeLoopInit; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_RECV] = &SecureSocketImpl::performClientHandshakeLoopRecv; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_PROCESS] = &SecureSocketImpl::performClientHandshakeLoopProcess; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_SEND] = &SecureSocketImpl::performClientHandshakeLoopSend; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INCOMPLETE] = &SecureSocketImpl::performClientHandshakeLoopIncompleteMessage; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_CONTINUE] = &SecureSocketImpl::performClientHandshakeLoopContinue; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_DONE] = &SecureSocketImpl::performClientHandshakeLoopDone; + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_FINAL] = &SecureSocketImpl::performClientHandshakeSendFinal; + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_ERROR] = &SecureSocketImpl::performClientHandshakeSendError; + _states[SecureSocketImpl::ST_CLIENT_VERIFY] = &SecureSocketImpl::clientConnectVerify; + _states[SecureSocketImpl::ST_DONE] = &SecureSocketImpl::stateIllegal; + _states[SecureSocketImpl::ST_ERROR] = &SecureSocketImpl::performClientHandshakeError; } @@ -1793,14 +1722,10 @@ bool StateMachine::execute(SecureSocketImpl* pSock) try { poco_assert_dbg (pSock); - ConditionState& state = _states[pSock->getState()]; - ConditionMethod& meth = state.first; - if ((this->*state.first)(pSock->sockfd())) - { - (pSock->*(state.second))(); - return true; - } - else return false; + auto curState = pSock->getState(); + StateImpl& stateImpl = _states[curState]; + (pSock->*(stateImpl))(); + return pSock->getState() != curState; } catch (...) { From 34852d1c629e3dcaf551eef91d3aabbb3488a0ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnter=20Obiltschnig?= Date: Sun, 24 Nov 2024 16:59:40 +0100 Subject: [PATCH 5/7] chore(NetSSL_Win): use state machine also for blocking connections; fix non-blocking renegotiation --- .../include/Poco/Net/SecureSocketImpl.h | 6 +- NetSSL_Win/src/SecureSocketImpl.cpp | 115 ++++-------------- 2 files changed, 26 insertions(+), 95 deletions(-) diff --git a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h index 6a52239dc..7eeedcef0 100644 --- a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h +++ b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h @@ -281,9 +281,7 @@ protected: void performClientHandshakeSendError(); void sendOutSecBufferAndAdvanceState(State state); - void performClientHandshake(); - SECURITY_STATUS performClientHandshakeLoop(); - void performClientHandshakeLoopCondReceive(); + SECURITY_STATUS performClientHandshake(); SECURITY_STATUS decodeMessage(BYTE* pBuffer, DWORD bufSize, AutoSecBufferDesc<4>& msg, SecBuffer*& pData, SecBuffer*& pExtra); SECURITY_STATUS decodeBufferFull(BYTE* pBuffer, DWORD bufSize, char* pOutBuffer, int outLength, int& bytesDecoded); @@ -293,7 +291,7 @@ protected: void connectSSL(bool completeHandshake); void completeHandshake(); static int lastError(); - void stateMachine(); + bool stateMachine(); State getState() const; void setState(State st); static bool isLocalHost(const std::string& hostName); diff --git a/NetSSL_Win/src/SecureSocketImpl.cpp b/NetSSL_Win/src/SecureSocketImpl.cpp index ee861e8f4..f4510acc8 100644 --- a/NetSSL_Win/src/SecureSocketImpl.cpp +++ b/NetSSL_Win/src/SecureSocketImpl.cpp @@ -372,17 +372,16 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) if (_state != ST_DONE) { - bool establish = _pSocket->getBlocking(); - if (establish) + if (_pSocket->getBlocking()) { - while (_state != ST_DONE) - { - stateMachine(); - } + performClientHandshake(); } else { - stateMachine(); + while (_state != ST_DONE && stateMachine()) + { + // no-op + } if (_state != ST_DONE) return -1; } } @@ -470,17 +469,16 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) if (_state == ST_ERROR) return 0; if (_state != ST_DONE) { - bool establish = _pSocket->getBlocking(); - if (establish) + if (_pSocket->getBlocking()) { - while (_state != ST_DONE) - { - stateMachine(); - } + performClientHandshake(); } else { - stateMachine(); + while (_state != ST_DONE && stateMachine()) + { + // no-op + } if (_state != ST_DONE) return -1; } } @@ -591,9 +589,9 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) _needData = false; setState(ST_CLIENT_HSK_LOOP_INIT); if (!_pSocket->getBlocking()) - return -1; + return bytesDecoded > 0 ? bytesDecoded : -1; - securityStatus = performClientHandshakeLoop(); + securityStatus = performClientHandshake(); if (securityStatus != SEC_E_OK) break; @@ -725,9 +723,7 @@ SECURITY_STATUS SecureSocketImpl::decodeBufferFull(BYTE* pBuffer, DWORD bufSize, if (securityStatus == SEC_I_RENEGOTIATE) { _needData = false; - securityStatus = performClientHandshakeLoop(); - if (securityStatus != SEC_E_OK) - break; + return securityStatus; } } while (securityStatus == SEC_E_OK && pBuffer); @@ -782,7 +778,7 @@ void SecureSocketImpl::connectSSL(bool completeHandshake) _peerHostName = _pSocket->peerAddress().host().toString(); } - initClientContext(); + setState(ST_CONNECTING); if (completeHandshake) { performClientHandshake(); @@ -843,12 +839,13 @@ void SecureSocketImpl::initClientContext() } -void SecureSocketImpl::performClientHandshake() +SECURITY_STATUS SecureSocketImpl::performClientHandshake() { - performInitialClientHandshake(); - performClientHandshakeSendToken(); - performClientHandshakeLoop(); - clientConnectVerify(); + while (_state != ST_DONE) + { + stateMachine(); + } + return _securityStatus; } @@ -934,55 +931,6 @@ void SecureSocketImpl::performClientHandshakeSendToken() } -SECURITY_STATUS SecureSocketImpl::performClientHandshakeLoop() -{ - _recvBufferOffset = 0; - _securityStatus = SEC_E_INCOMPLETE_MESSAGE; - - while (_securityStatus == SEC_I_CONTINUE_NEEDED || _securityStatus == SEC_E_INCOMPLETE_MESSAGE || _securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) - { - performClientHandshakeLoopCondReceive(); - - if (_securityStatus == SEC_E_OK) - { - performClientHandshakeLoopDone(); - performClientHandshakeSendFinal(); - } - else if (_securityStatus == SEC_I_CONTINUE_NEEDED) - { - performClientHandshakeLoopContinue(); - performClientHandshakeLoopSend(); - } - else if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE) - { - performClientHandshakeLoopIncompleteMessage(); - } - else if (FAILED(_securityStatus)) - { - if (_outFlags & ISC_RET_EXTENDED_ERROR) - { - performClientHandshakeSendError(); - } - else - { - performClientHandshakeError(); - } - } - else - { - performClientHandshakeLoopIncompleteMessage(); - } - } - - if (FAILED(_securityStatus)) - { - performClientHandshakeError(); - } - - return _securityStatus; -} - - void SecureSocketImpl::performClientHandshakeSendError() { poco_assert_dbg (FAILED(_securityStatus)); @@ -1117,21 +1065,6 @@ void SecureSocketImpl::performClientHandshakeLoopProcess() } -void SecureSocketImpl::performClientHandshakeLoopCondReceive() -{ - poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED); - - performClientHandshakeLoopInit(); - - if (_state == ST_CLIENT_HSK_LOOP_RECV) - { - performClientHandshakeLoopRecv(); - } - - performClientHandshakeLoopProcess(); -} - - void SecureSocketImpl::performClientHandshakeLoopContinue() { performClientHandshakeExtraBuffer(); @@ -1682,9 +1615,9 @@ void SecureSocketImpl::performClientHandshakeStart() } -void SecureSocketImpl::stateMachine() +bool SecureSocketImpl::stateMachine() { - StateMachine::instance().execute(this); + return StateMachine::instance().execute(this); } From bc6ca8603e9d3dd3485008d370b8f8f297d00336 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnter=20Obiltschnig?= Date: Sun, 24 Nov 2024 21:04:37 +0100 Subject: [PATCH 6/7] fix(NetSSL): handle EWOULDBLOCK when calling SSL_shutdown() --- NetSSL_OpenSSL/src/SecureSocketImpl.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/NetSSL_OpenSSL/src/SecureSocketImpl.cpp b/NetSSL_OpenSSL/src/SecureSocketImpl.cpp index a68cb917a..ce03396e8 100644 --- a/NetSSL_OpenSSL/src/SecureSocketImpl.cpp +++ b/NetSSL_OpenSSL/src/SecureSocketImpl.cpp @@ -264,7 +264,13 @@ int SecureSocketImpl::shutdown() if (!shutdownSent) { int rc = ::SSL_shutdown(_pSSL); - if (rc < 0) rc = handleError(rc); + if (rc < 0) + { + if (SocketImpl::lastError() == POCO_EWOULDBLOCK) + rc = SecureStreamSocket::ERR_SSL_WANT_WRITE; + else + rc = handleError(rc); + } l.unlock(); From 44ffffb6f75cd391e618864d21542ab54f475382 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnter=20Obiltschnig?= Date: Tue, 26 Nov 2024 19:20:03 +0100 Subject: [PATCH 7/7] chore(NetSSL_Win): rewrite handshake logic to support non-blocking sockets --- .../include/Poco/Net/SecureSocketImpl.h | 87 +- NetSSL_Win/include/Poco/Net/Utility.h | 2 +- NetSSL_Win/src/SecureSocketImpl.cpp | 881 ++++++++++-------- NetSSL_Win/src/SecureStreamSocketImpl.cpp | 3 +- NetSSL_Win/src/Utility.cpp | 11 +- 5 files changed, 565 insertions(+), 419 deletions(-) diff --git a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h index 7eeedcef0..268ab2fe6 100644 --- a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h +++ b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h @@ -18,6 +18,10 @@ #define NetSSL_SecureSocketImpl_INCLUDED +// Temporary debugging aid, to be removed +// #define ENABLE_PRINT_STATE + + #include "Poco/Net/SocketImpl.h" #include "Poco/Net/NetSSL.h" #include "Poco/Net/Context.h" @@ -35,6 +39,14 @@ #include + +#ifdef ENABLE_PRINT_STATE +#define PRINT_STATE(m) printState(m) +#else +#define PRINT_STATE(m) +#endif + + namespace Poco { namespace Net { @@ -224,20 +236,25 @@ protected: enum State { ST_INITIAL = 0, - // Client ST_CONNECTING, ST_CLIENT_HSK_START, ST_CLIENT_HSK_SEND_TOKEN, ST_CLIENT_HSK_LOOP_INIT, ST_CLIENT_HSK_LOOP_RECV, ST_CLIENT_HSK_LOOP_PROCESS, - ST_CLIENT_HSK_LOOP_CONTINUE, ST_CLIENT_HSK_LOOP_SEND, - ST_CLIENT_HSK_LOOP_INCOMPLETE, ST_CLIENT_HSK_LOOP_DONE, ST_CLIENT_HSK_SEND_FINAL, ST_CLIENT_HSK_SEND_ERROR, ST_CLIENT_VERIFY, + ST_ACCEPTING, + ST_SERVER_HSK_START, + ST_SERVER_HSK_LOOP_INIT, + ST_SERVER_HSK_LOOP_RECV, + ST_SERVER_HSK_LOOP_PROCESS, + ST_SERVER_HSK_LOOP_SEND, + ST_SERVER_HSK_LOOP_DONE, + ST_SERVER_VERIFY, ST_DONE, ST_ERROR, ST_MAX @@ -251,51 +268,64 @@ protected: int sendRawBytes(const void* buffer, int length, int flags = 0); int receiveRawBytes(void* buffer, int length, int flags = 0); - void clientConnectVerify(); - void performServerHandshake(); - bool serverHandshakeLoop(PCtxtHandle phContext, PCredHandle phCred, bool requireClientAuth, bool doInitialRead, bool newContext); void clientVerifyCertificate(const std::string& hostName); void verifyCertificateChainClient(PCCERT_CONTEXT pServerCert); void serverVerifyCertificate(); int serverShutdown(PCredHandle phCreds, CtxtHandle* phContext); int clientShutdown(PCredHandle phCreds, CtxtHandle* phContext); - void initClientContext(); - void initServerContext(); PCCERT_CONTEXT loadCertificate(bool mustFindCertificate); void initCommon(); void cleanup(); - void performClientHandshakeStart(); - void performInitialClientHandshake(); - void performClientHandshakeSendToken(); - void performClientHandshakeLoopIncompleteMessage(); - void performClientHandshakeLoopInit(); - void performClientHandshakeLoopRecv(); - void performClientHandshakeLoopProcess(); - void performClientHandshakeLoopSend(); - void performClientHandshakeLoopDone(); - void performClientHandshakeSendFinal(); - void performClientHandshakeExtraBuffer(); - void performClientHandshakeLoopContinue(); - void performClientHandshakeError(); - void performClientHandshakeSendError(); - void sendOutSecBufferAndAdvanceState(State state); + void stateIllegal(); + void stateError(); - SECURITY_STATUS performClientHandshake(); + void stateClientConnected(); + void stateClientHandshakeStart(); + void stateClientHandshakeSendToken(); + void stateClientHandshakeLoopInit(); + void stateClientHandshakeLoopRecv(); + void stateClientHandshakeLoopProcess(); + void stateClientHandshakeLoopSend(); + void stateClientHandshakeLoopDone(); + void stateClientHandshakeSendFinal(); + void stateClientHandshakeSendError(); + void stateClientVerify(); + + void stateServerAccepted(); + void stateServerHandshakeStart(); + void stateServerHandshakeLoopInit(); + void stateServerHandshakeLoopRecv(); + void stateServerHandshakeLoopProcess(); + void stateServerHandshakeLoopSend(); + void stateServerHandshakeLoopDone(); + void stateServerHandshakeVerify(); + + void sendOutSecBufferAndAdvanceState(State state); + void drainExtraBuffer(); + static int getRecordLength(const BYTE* pBuffer, int length); + static bool bufferHasCompleteRecords(const BYTE* pBuffer, int length); + + void initClientCredentials(); + void initServerCredentials(); + SECURITY_STATUS doHandshake(); + int completeHandshake(); SECURITY_STATUS decodeMessage(BYTE* pBuffer, DWORD bufSize, AutoSecBufferDesc<4>& msg, SecBuffer*& pData, SecBuffer*& pExtra); SECURITY_STATUS decodeBufferFull(BYTE* pBuffer, DWORD bufSize, char* pOutBuffer, int outLength, int& bytesDecoded); - void stateIllegal(); - void stateConnected(); + void acceptSSL(); void connectSSL(bool completeHandshake); - void completeHandshake(); static int lastError(); bool stateMachine(); State getState() const; void setState(State st); static bool isLocalHost(const std::string& hostName); +#ifdef ENABLE_PRINT_STATE + void printState(const std::string& msg); +#endif + private: SecureSocketImpl(const SecureSocketImpl&); SecureSocketImpl& operator = (const SecureSocketImpl&); @@ -331,9 +361,9 @@ private: SecBuffer _extraSecBuffer; SECURITY_STATUS _securityStatus; State _state; - DWORD _outFlags; bool _needData; bool _needHandshake; + bool _initServerContext = false; friend class SecureStreamSocketImpl; friend class StateMachine; @@ -376,6 +406,7 @@ inline SecureSocketImpl::State SecureSocketImpl::getState() const inline void SecureSocketImpl::setState(SecureSocketImpl::State st) { _state = st; + PRINT_STATE("setState: "); } diff --git a/NetSSL_Win/include/Poco/Net/Utility.h b/NetSSL_Win/include/Poco/Net/Utility.h index 0e5e837a2..7a3501fd7 100644 --- a/NetSSL_Win/include/Poco/Net/Utility.h +++ b/NetSSL_Win/include/Poco/Net/Utility.h @@ -35,7 +35,7 @@ public: /// Non-case sensitive conversion of a string to a VerificationMode enum. /// If verMode is illegal an OptionException is thrown. - static const std::string& formatError(long errCode); + static std::string formatError(long errCode); /// Converts an winerror.h code into human readable form. private: diff --git a/NetSSL_Win/src/SecureSocketImpl.cpp b/NetSSL_Win/src/SecureSocketImpl.cpp index f4510acc8..c6d38d861 100644 --- a/NetSSL_Win/src/SecureSocketImpl.cpp +++ b/NetSSL_Win/src/SecureSocketImpl.cpp @@ -25,6 +25,12 @@ #include +#ifdef ENABLE_PRINT_STATE +#include "Poco/Mutex.h" +#include +#endif + + namespace Poco { namespace Net { @@ -58,8 +64,6 @@ SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr pSocketImpl, Contex _securityFunctions(SSLManager::instance().securityFunctions()), _pOwnCertificate(0), _pPeerCertificate(0), - _hCreds(), - _hContext(), _contextFlags(0), _overflowBuffer(0), _sendBuffer(0), @@ -77,11 +81,8 @@ SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr pSocketImpl, Contex _needData(true), _needHandshake(false) { - _hCreds.dwLower = 0; - _hCreds.dwUpper = 0; - - _hContext.dwLower = 0; - _hContext.dwUpper = 0; + SecInvalidateHandle(&_hCreds); + SecInvalidateHandle(&_hContext); _streamSizes.cbBlockSize = 0; _streamSizes.cbHeader = 0; @@ -142,14 +143,12 @@ void SecureSocketImpl::cleanup() { _peerHostName.clear(); - _hCreds.dwLower = 0; - _hCreds.dwUpper = 0; + SecInvalidateHandle(&_hCreds); - if (_hContext.dwLower != 0 && _hContext.dwUpper != 0) + if (SecIsValidHandle(&_hContext)) { _securityFunctions.DeleteSecurityContext(&_hContext); - _hContext.dwLower = 0; - _hContext.dwUpper = 0; + SecInvalidateHandle(&_hContext); } if (_pOwnCertificate) @@ -249,9 +248,9 @@ void SecureSocketImpl::listen(int backlog) { _mode = MODE_SERVER; - if (_hCreds.dwLower == 0 && _hCreds.dwUpper == 0) + if (!SecIsValidHandle(&_hCreds)) { - initServerContext(); + initServerCredentials(); } _pSocket->listen(backlog); } @@ -309,8 +308,7 @@ int SecureSocketImpl::available() const void SecureSocketImpl::acceptSSL() { - setState(ST_DONE); - initServerContext(); + setState(ST_ACCEPTING); _needHandshake = true; } @@ -333,11 +331,14 @@ void SecureSocketImpl::verifyPeerCertificate(const std::string& hostName) return; } - if (_mode == MODE_SERVER) + { serverVerifyCertificate(); + } else + { clientVerifyCertificate(hostName); + } } @@ -364,7 +365,7 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) { if (_needHandshake) { - completeHandshake(); + doHandshake(); _needHandshake = false; } @@ -374,7 +375,7 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) { if (_pSocket->getBlocking()) { - performClientHandshake(); + doHandshake(); } else { @@ -462,7 +463,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) { if (_needHandshake) { - completeHandshake(); + doHandshake(); _needHandshake = false; } @@ -471,7 +472,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) { if (_pSocket->getBlocking()) { - performClientHandshake(); + doHandshake(); } else { @@ -591,7 +592,7 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) if (!_pSocket->getBlocking()) return bytesDecoded > 0 ? bytesDecoded : -1; - securityStatus = performClientHandshake(); + securityStatus = doHandshake(); if (securityStatus != SEC_E_OK) break; @@ -624,6 +625,7 @@ SECURITY_STATUS SecureSocketImpl::decodeMessage(BYTE* pBuffer, DWORD bufSize, Au SECURITY_STATUS securityStatus = _securityFunctions.DecryptMessage(&_hContext, &msg, 0, 0); // TODO: when decrypting the close_notify alert, returns SEC_E_DECRYPT_FAILURE + // instead of SEC_I_CONTEXT_EXPIRED if (securityStatus == SEC_E_OK || securityStatus == SEC_I_RENEGOTIATE) { @@ -781,7 +783,7 @@ void SecureSocketImpl::connectSSL(bool completeHandshake) setState(ST_CONNECTING); if (completeHandshake) { - performClientHandshake(); + doHandshake(); _needHandshake = false; } else @@ -791,16 +793,315 @@ void SecureSocketImpl::connectSSL(bool completeHandshake) } -void SecureSocketImpl::completeHandshake() +void SecureSocketImpl::stateIllegal() { - if (_mode == MODE_SERVER) - performServerHandshake(); - else - performClientHandshake(); + throw Poco::IllegalStateException("SSL state machine"); } -void SecureSocketImpl::clientConnectVerify() +void SecureSocketImpl::stateError() +{ + poco_assert_dbg (FAILED(_securityStatus)); + + cleanup(); + setState(ST_ERROR); + throw SSLException("Error during handshake", Utility::formatError(_securityStatus)); +} + + +void SecureSocketImpl::stateClientConnected() +{ + _peerHostName = _pSocket->peerAddress().host().toString(); + setState(ST_CLIENT_HSK_START); +} + + +void SecureSocketImpl::stateClientHandshakeStart() +{ + initClientCredentials(); + + // get initial security token + _outSecBuffer.reset(true); + _outSecBuffer.setSecBufferToken(0, 0, 0); + _recvBuffer.setCapacity(IO_BUFFER_SIZE); + _recvBufferOffset = 0; + + TimeStamp ts; + DWORD contextAttributes(0); + std::wstring whostName; + Poco::UnicodeConverter::convert(_peerHostName, whostName); + _securityStatus = _securityFunctions.InitializeSecurityContextW( + &_hCreds, + 0, + const_cast(whostName.c_str()), + _contextFlags, + 0, + 0, + 0, + 0, + &_hContext, + &_outSecBuffer, + &contextAttributes, + &ts); + + if (_securityStatus != SEC_E_OK && _securityStatus != SEC_I_CONTINUE_NEEDED) + { + if (_securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) + { + // the server is asking for client credentials, we didn't send one because we were not configured to do so, abort + throw SSLException("Handshake failed: No client credentials configured"); + } + else + { + throw SSLException("Handshake failed", Utility::formatError(_securityStatus)); + } + } + + _extraSecBuffer.pvBuffer = 0; + _extraSecBuffer.cbBuffer = 0; + _needData = true; + + setState(ST_CLIENT_HSK_SEND_TOKEN); +} + + +void SecureSocketImpl::stateClientHandshakeSendToken() +{ + poco_assert_dbg (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer); + + int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); + if (sent < 0) + { + return; + } + else if (sent < _outSecBuffer[0].cbBuffer) + { + _outSecBuffer[0].cbBuffer -= sent; + std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); + return; + } + + if (_securityStatus == SEC_E_OK) + { + // The security context was successfully initialized. + // There is no need for another InitializeSecurityContext (Schannel) call. + setState(ST_DONE); + } + else + { + setState(ST_CLIENT_HSK_LOOP_INIT); + } +} + + +void SecureSocketImpl::stateClientHandshakeSendError() +{ + poco_assert_dbg (FAILED(_securityStatus)); + + sendOutSecBufferAndAdvanceState(ST_ERROR); +} + + +void SecureSocketImpl::stateClientHandshakeLoopInit() +{ + _inSecBuffer.reset(false); + _outSecBuffer.reset(true); + + if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || _securityStatus == SEC_I_CONTINUE_NEEDED) + { + setState(ST_CLIENT_HSK_LOOP_RECV); + } + else + { + setState(ST_CLIENT_HSK_LOOP_PROCESS); + } +} + + +int SecureSocketImpl::getRecordLength(const BYTE* pRec, int len) +{ + // TLS Record Header: + // struct { + // ContentType type; /* 1 byte */ + // ProtocolVersion version; /* 2 bytes */ + // uint16 length; /* must be <= 2^14 */ + // opaque fragment[TLSPlaintext.length]; + // } TLSPlaintext; + + if (len >= 5) + { + Poco::UInt16 recLen = (pRec[3] << 8) + pRec[4]; + if (recLen <= 16384) + { + return recLen + 5; + } + else + { + return -1; + } + } + return 0; +} + + +bool SecureSocketImpl::bufferHasCompleteRecords(const BYTE* pBuffer, int length) +{ + const BYTE* pEnd = pBuffer + length; + int recLen = getRecordLength(pBuffer, length); + while (recLen > 0 && length >= recLen) + { + pBuffer += recLen; + length -= recLen; + recLen = getRecordLength(pBuffer, length); + } + return pBuffer == pEnd; +} + + +void SecureSocketImpl::stateClientHandshakeLoopRecv() +{ + poco_assert (IO_BUFFER_SIZE > _recvBufferOffset); + + int n = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, IO_BUFFER_SIZE - _recvBufferOffset); + if (n == 0) throw SSLException("Connection unexpectedly closed during handshake"); + if (n > 0) + { + _recvBufferOffset += n; + // NOTE: InitializeSecurityContext() will fail with SEC_E_INVALID_TOKEN + // if provided with a buffer containing incomplete records, rather than returning + // SEC_E_INCOMPLETE_MESSAGE. Therefore we have to make sure that we + // receive complete records before passing the buffer to InitializeSecurityContext(). + if (bufferHasCompleteRecords(_recvBuffer.begin(), _recvBufferOffset)) + { + setState(ST_CLIENT_HSK_LOOP_PROCESS); + } + + } +} + + +void SecureSocketImpl::stateClientHandshakeLoopProcess() +{ + _inSecBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset); + _inSecBuffer.setSecBufferEmpty(1); + _outSecBuffer.setSecBufferToken(0, 0, 0); + + PRINT_STATE("stateClientHandshakeLoopProcess before InitializeSecurityContextW: "); + + DWORD outFlags; + TimeStamp ts; + _securityStatus = _securityFunctions.InitializeSecurityContextW( + &_hCreds, + &_hContext, + 0, + _contextFlags, + 0, + 0, + &_inSecBuffer, + 0, + 0, + &_outSecBuffer, + &outFlags, + &ts); + + PRINT_STATE("stateClientHandshakeLoopProcess after InitializeSecurityContextW: "); + + if (_securityStatus == SEC_E_OK) + { + drainExtraBuffer(); + setState(ST_CLIENT_HSK_LOOP_DONE); + } + else if (_securityStatus == SEC_I_CONTINUE_NEEDED) + { + drainExtraBuffer(); + setState(ST_CLIENT_HSK_LOOP_SEND); + } + else if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE) + { + _needData = true; + setState(ST_CLIENT_HSK_LOOP_INIT); + } + else if (FAILED(_securityStatus)) + { + if (outFlags & ISC_RET_EXTENDED_ERROR) + { + setState(ST_CLIENT_HSK_SEND_ERROR); + } + else + { + setState(ST_ERROR); + } + } + else + { + setState(ST_CLIENT_HSK_LOOP_INIT); + } +} + + +void SecureSocketImpl::stateClientHandshakeLoopSend() +{ + sendOutSecBufferAndAdvanceState(ST_CLIENT_HSK_LOOP_INIT); +} + + +void SecureSocketImpl::stateClientHandshakeLoopDone() +{ + poco_assert_dbg(_securityStatus == SEC_E_OK); + + if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + { + setState(ST_CLIENT_HSK_SEND_FINAL); + } + else + { + setState(ST_CLIENT_VERIFY); + } +} + + +void SecureSocketImpl::stateClientHandshakeSendFinal() +{ + sendOutSecBufferAndAdvanceState(ST_CLIENT_VERIFY); +} + + +void SecureSocketImpl::sendOutSecBufferAndAdvanceState(State state) +{ + if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + { + int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); + if (sent < 0) return; + if (sent < _outSecBuffer[0].cbBuffer) + { + _outSecBuffer[0].cbBuffer -= sent; + std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); + } + else + { + _outSecBuffer.release(0); + setState(state); + } + } + else + { + setState(state); + } +} + + +void SecureSocketImpl::drainExtraBuffer() +{ + if (_inSecBuffer[1].BufferType == SECBUFFER_EXTRA && _inSecBuffer[1].pvBuffer && _inSecBuffer[1].cbBuffer > 0) + { + std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - _inSecBuffer[1].cbBuffer), _inSecBuffer[1].cbBuffer); + _recvBufferOffset = _inSecBuffer[1].cbBuffer; + } + else _recvBufferOffset = 0; +} + + +void SecureSocketImpl::stateClientVerify() { poco_assert_dbg(!_pPeerCertificate); poco_assert_dbg(!_peerHostName.empty()); @@ -832,294 +1133,127 @@ void SecureSocketImpl::clientConnectVerify() } -void SecureSocketImpl::initClientContext() +void SecureSocketImpl::stateServerAccepted() { - _pOwnCertificate = loadCertificate(false); - _hCreds = _pContext->credentials(); + initServerCredentials(); + setState(ST_SERVER_HSK_START); } -SECURITY_STATUS SecureSocketImpl::performClientHandshake() +void SecureSocketImpl::stateServerHandshakeStart() { - while (_state != ST_DONE) - { - stateMachine(); - } - return _securityStatus; + _securityStatus = SEC_E_INCOMPLETE_MESSAGE; + _initServerContext = true; + _recvBuffer.setCapacity(IO_BUFFER_SIZE); + _recvBufferOffset = 0; + setState(ST_SERVER_HSK_LOOP_INIT); } -void SecureSocketImpl::performInitialClientHandshake() +void SecureSocketImpl::stateServerHandshakeLoopInit() { - // get initial security token - _outSecBuffer.reset(true); - _outSecBuffer.setSecBufferToken(0, 0, 0); - - TimeStamp ts; - DWORD contextAttributes(0); - std::wstring whostName; - Poco::UnicodeConverter::convert(_peerHostName, whostName); - _securityStatus = _securityFunctions.InitializeSecurityContextW( - &_hCreds, - 0, - const_cast(whostName.c_str()), - _contextFlags, - 0, - 0, - 0, - 0, - &_hContext, - &_outSecBuffer, - &contextAttributes, - &ts); - - if (_securityStatus != SEC_E_OK) + if (_securityStatus == SEC_I_CONTINUE_NEEDED || _securityStatus == SEC_E_INCOMPLETE_MESSAGE || _securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) { - if (_securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) + if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || _securityStatus == SEC_I_CONTINUE_NEEDED) { - // the server is asking for client credentials, we didn't send one because we were not configured to do so, abort - throw SSLException("Handshake failed: No client credentials configured"); + setState(ST_SERVER_HSK_LOOP_RECV); } - else if (_securityStatus != SEC_I_CONTINUE_NEEDED) + else { - throw SSLException("Handshake failed", Utility::formatError(_securityStatus)); + setState(ST_SERVER_HSK_LOOP_PROCESS); } } - - _extraSecBuffer.pvBuffer = 0; - _extraSecBuffer.cbBuffer = 0; - _needData = true; - - if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + else if (_securityStatus == SEC_E_CERT_UNKNOWN || _securityStatus == SEC_E_CERT_EXPIRED) { - setState(ST_CLIENT_HSK_SEND_TOKEN); + setState(ST_SERVER_HSK_LOOP_PROCESS); } else if (_securityStatus == SEC_E_OK) { - setState(ST_DONE); + setState(ST_SERVER_HSK_LOOP_DONE); } } -void SecureSocketImpl::performClientHandshakeSendToken() +void SecureSocketImpl::stateServerHandshakeLoopRecv() { - poco_assert_dbg (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer); - - int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); - if (sent < 0) - { - return; - } - else if (sent < _outSecBuffer[0].cbBuffer) - { - _outSecBuffer[0].cbBuffer -= sent; - std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); - return; - } - - if (_securityStatus == SEC_E_OK) - { - // The security context was successfully initialized. - // There is no need for another InitializeSecurityContext (Schannel) call. - setState(ST_DONE); - } - else - { - setState(ST_CLIENT_HSK_LOOP_INIT); - _securityStatus = SEC_E_INCOMPLETE_MESSAGE; - } -} - - -void SecureSocketImpl::performClientHandshakeSendError() -{ - poco_assert_dbg (FAILED(_securityStatus)); - - sendOutSecBufferAndAdvanceState(ST_ERROR); -} - - -void SecureSocketImpl::performClientHandshakeError() -{ - poco_assert_dbg (FAILED(_securityStatus)); - cleanup(); - setState(ST_ERROR); - throw SSLException("Error during handshake", Utility::formatError(_securityStatus)); -} - - -void SecureSocketImpl::performClientHandshakeExtraBuffer() -{ - if (_inSecBuffer[1].BufferType == SECBUFFER_EXTRA) - { - std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - _inSecBuffer[1].cbBuffer), _inSecBuffer[1].cbBuffer); - _recvBufferOffset = _inSecBuffer[1].cbBuffer; - } - else _recvBufferOffset = 0; -} - - -void SecureSocketImpl::performClientHandshakeLoopDone() -{ - poco_assert_dbg(_securityStatus == SEC_E_OK); - - performClientHandshakeExtraBuffer(); - if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) - { - setState(ST_CLIENT_HSK_SEND_FINAL); - } - else - { - setState(ST_CLIENT_VERIFY); - } -} - - -void SecureSocketImpl::performClientHandshakeSendFinal() -{ - sendOutSecBufferAndAdvanceState(ST_CLIENT_VERIFY); -} - - -void SecureSocketImpl::performClientHandshakeLoopInit() -{ - poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED); - - _inSecBuffer.reset(false); - _outSecBuffer.reset(true); - - if (_needData) - { - if (_recvBuffer.capacity() != IO_BUFFER_SIZE) - { - _recvBuffer.setCapacity(IO_BUFFER_SIZE); - } - setState(ST_CLIENT_HSK_LOOP_RECV); - } - else - { - setState(ST_CLIENT_HSK_LOOP_PROCESS); - } -} - - -void SecureSocketImpl::performClientHandshakeLoopRecv() -{ - poco_assert_dbg (_needData); - poco_assert (IO_BUFFER_SIZE > _recvBufferOffset); + poco_assert_dbg (_recvBuffer.capacity() >= IO_BUFFER_SIZE && IO_BUFFER_SIZE > _recvBufferOffset); int n = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, IO_BUFFER_SIZE - _recvBufferOffset); if (n == 0) throw SSLException("Connection unexpectedly closed during handshake"); if (n > 0) { _recvBufferOffset += n; - setState(ST_CLIENT_HSK_LOOP_PROCESS); + setState(ST_SERVER_HSK_LOOP_PROCESS); } } -void SecureSocketImpl::performClientHandshakeLoopProcess() +void SecureSocketImpl::stateServerHandshakeLoopProcess() { + PRINT_STATE("stateServerHandshakeLoopProcess before AcceptSecurityContext: "); _inSecBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset); - // inbuffer 1 should be empty _inSecBuffer.setSecBufferEmpty(1); - - // outBuffer[0] should be empty _outSecBuffer.setSecBufferToken(0, 0, 0); - _outFlags = 0; - TimeStamp ts; - _securityStatus = _securityFunctions.InitializeSecurityContextW( - &_hCreds, - &_hContext, - 0, - _contextFlags, - 0, - 0, - &_inSecBuffer, - 0, - 0, - &_outSecBuffer, - &_outFlags, - &ts); + TimeStamp tsExpiry; + DWORD outFlags; + _securityStatus = _securityFunctions.AcceptSecurityContext( + &_hCreds, + _initServerContext ? NULL : &_hContext, + &_inSecBuffer, + _contextFlags, + 0, + _initServerContext ? &_hContext : NULL, + &_outSecBuffer, + &outFlags, + &tsExpiry); - if (_securityStatus == SEC_E_OK) + _initServerContext = !SecIsValidHandle(&_hContext); + + PRINT_STATE("stateServerHandshakeLoopProcess after AcceptSecurityContext: "); + + setState(ST_SERVER_HSK_LOOP_INIT); + if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE) { - setState(ST_CLIENT_HSK_LOOP_DONE); + return; } - else if (_securityStatus == SEC_I_CONTINUE_NEEDED) + else if (_securityStatus == SEC_E_OK || _securityStatus == SEC_I_CONTINUE_NEEDED || (FAILED(_securityStatus) && (0 != (outFlags & ISC_RET_EXTENDED_ERROR)))) { - setState(ST_CLIENT_HSK_LOOP_CONTINUE); + drainExtraBuffer(); + if (_outSecBuffer[0].cbBuffer != 0 && _outSecBuffer[0].pvBuffer != 0) + { + setState(ST_SERVER_HSK_LOOP_SEND); + return; + } + } + else if (_securityStatus == SEC_E_CERT_UNKNOWN || _securityStatus == SEC_E_CERT_EXPIRED || _securityStatus == CRYPT_E_REVOKED || _securityStatus == SEC_E_UNTRUSTED_ROOT) + { + throw SSLException("Handshake failure, client has rejected certificate", Utility::formatError(_securityStatus)); } else if (FAILED(_securityStatus)) { - if (_outFlags & ISC_RET_EXTENDED_ERROR) - setState(ST_CLIENT_HSK_SEND_ERROR); - else - setState(ST_ERROR); - } - else - { - setState(ST_CLIENT_HSK_LOOP_INCOMPLETE); + throw SSLException("Handshake failure:", Utility::formatError(_securityStatus)); } } -void SecureSocketImpl::performClientHandshakeLoopContinue() +void SecureSocketImpl::stateServerHandshakeLoopSend() { - performClientHandshakeExtraBuffer(); - setState(ST_CLIENT_HSK_LOOP_SEND); + sendOutSecBufferAndAdvanceState(ST_SERVER_HSK_LOOP_INIT); } -void SecureSocketImpl::performClientHandshakeLoopSend() +void SecureSocketImpl::stateServerHandshakeLoopDone() { - sendOutSecBufferAndAdvanceState(ST_CLIENT_HSK_LOOP_INIT); + _securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext, SECPKG_ATTR_STREAM_SIZES, &_streamSizes); + if (_securityStatus != SEC_E_OK) throw SSLException("Cannot query stream sizes", Utility::formatError(_securityStatus)); + _ioBufferSize = _streamSizes.cbHeader + _streamSizes.cbMaximumMessage + _streamSizes.cbTrailer; + setState(ST_SERVER_VERIFY); } -void SecureSocketImpl::sendOutSecBufferAndAdvanceState(State state) +void SecureSocketImpl::stateServerHandshakeVerify() { - if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) - { - int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); - if (sent < 0) return; - if (sent < _outSecBuffer[0].cbBuffer) - { - _outSecBuffer[0].cbBuffer -= sent; - std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); - } - else - { - _outSecBuffer.release(0); - setState(state); - } - } - else - { - setState(state); - } -} - - -void SecureSocketImpl::performClientHandshakeLoopIncompleteMessage() -{ - _needData = true; - setState(ST_CLIENT_HSK_LOOP_INIT); -} - - -void SecureSocketImpl::initServerContext() -{ - _pOwnCertificate = loadCertificate(true); - _hCreds = _pContext->credentials(); -} - - -void SecureSocketImpl::performServerHandshake() -{ - serverHandshakeLoop(&_hContext, &_hCreds, _clientAuthRequired, true, true); - SECURITY_STATUS securityStatus; if (_clientAuthRequired) { @@ -1140,100 +1274,49 @@ void SecureSocketImpl::performServerHandshake() serverVerifyCertificate(); } } - - securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext,SECPKG_ATTR_STREAM_SIZES, &_streamSizes); - if (securityStatus != SEC_E_OK) throw SSLException("Cannot query stream sizes", Utility::formatError(securityStatus)); - - _ioBufferSize = _streamSizes.cbHeader + _streamSizes.cbMaximumMessage + _streamSizes.cbTrailer; + setState(ST_DONE); } -bool SecureSocketImpl::serverHandshakeLoop(PCtxtHandle phContext, PCredHandle phCred, bool requireClientAuth, bool doInitialRead, bool newContext) +void SecureSocketImpl::initClientCredentials() { - TimeStamp tsExpiry; - int n = 0; - bool doRead = doInitialRead; - bool initContext = newContext; - DWORD outFlags; - SECURITY_STATUS securityStatus = SEC_E_INCOMPLETE_MESSAGE; + _pOwnCertificate = loadCertificate(false); + _hCreds = _pContext->credentials(); +} - while (securityStatus == SEC_I_CONTINUE_NEEDED || securityStatus == SEC_E_INCOMPLETE_MESSAGE || securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) + +void SecureSocketImpl::initServerCredentials() +{ + _pOwnCertificate = loadCertificate(true); + _hCreds = _pContext->credentials(); +} + + +SECURITY_STATUS SecureSocketImpl::doHandshake() +{ + while (_state != ST_DONE && _state != ST_ERROR) { - if (securityStatus == SEC_E_INCOMPLETE_MESSAGE) - { - if (doRead) - { - n = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, IO_BUFFER_SIZE - _recvBufferOffset); - - if (n <= 0) - throw SSLException("Failed to receive data in handshake"); - else - _recvBufferOffset += n; - } - else doRead = true; - } - - AutoSecBufferDesc<2> inBuffer(&_securityFunctions, false); - AutoSecBufferDesc<1> outBuffer(&_securityFunctions, true); - inBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset); - inBuffer.setSecBufferEmpty(1); - outBuffer.setSecBufferToken(0, 0, 0); - - securityStatus = _securityFunctions.AcceptSecurityContext( - phCred, - initContext ? NULL : phContext, - &inBuffer, - _contextFlags, - 0, - initContext ? phContext : NULL, - &outBuffer, - &outFlags, - &tsExpiry); - - initContext = false; - - if (securityStatus == SEC_E_OK || securityStatus == SEC_I_CONTINUE_NEEDED || (FAILED(securityStatus) && (0 != (outFlags & ISC_RET_EXTENDED_ERROR)))) - { - if (outBuffer[0].cbBuffer != 0 && outBuffer[0].pvBuffer != 0) - { - n = sendRawBytes(outBuffer[0].pvBuffer, outBuffer[0].cbBuffer); - outBuffer.release(0); - } - } - - if (securityStatus == SEC_E_OK ) - { - if (inBuffer[1].BufferType == SECBUFFER_EXTRA) - { - std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - inBuffer[1].cbBuffer), inBuffer[1].cbBuffer); - _recvBufferOffset = inBuffer[1].cbBuffer; - } - else - { - _recvBufferOffset = 0; - } - return true; - } - else if (FAILED(securityStatus) && securityStatus != SEC_E_INCOMPLETE_MESSAGE) - { - throw SSLException("Handshake failure:", Utility::formatError(securityStatus)); - } - - if (securityStatus != SEC_E_INCOMPLETE_MESSAGE && securityStatus != SEC_I_INCOMPLETE_CREDENTIALS) - { - if (inBuffer[1].BufferType == SECBUFFER_EXTRA) - { - std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - inBuffer[1].cbBuffer), inBuffer[1].cbBuffer); - _recvBufferOffset = inBuffer[1].cbBuffer; - } - else - { - _recvBufferOffset = 0; - } - } + stateMachine(); } + return _securityStatus; +} - return false; + +int SecureSocketImpl::completeHandshake() +{ + if (_pSocket->getBlocking()) + { + doHandshake(); + return 0; + } + else + { + while (_state != ST_DONE && stateMachine()) + { + // no-op + } + return (_state == ST_DONE) ? 0 : -1; + } } @@ -1489,7 +1572,7 @@ void SecureSocketImpl::serverVerifyCertificate() int SecureSocketImpl::clientShutdown(PCredHandle phCreds, CtxtHandle* phContext) { - if (phContext->dwLower == 0 && phContext->dwUpper == 0) + if (!SecIsValidHandle(phContext)) { return 0; } @@ -1542,7 +1625,7 @@ int SecureSocketImpl::clientShutdown(PCredHandle phCreds, CtxtHandle* phContext) int SecureSocketImpl::serverShutdown(PCredHandle phCreds, CtxtHandle* phContext) { - if (phContext->dwLower == 0 && phContext->dwUpper == 0) + if (!SecIsValidHandle(phContext)) { // handshake has never been done poco_assert_dbg (_needHandshake); @@ -1595,32 +1678,61 @@ int SecureSocketImpl::serverShutdown(PCredHandle phCreds, CtxtHandle* phContext) } -void SecureSocketImpl::stateIllegal() -{ - throw Poco::IllegalStateException("SSL state machine"); -} - - -void SecureSocketImpl::stateConnected() -{ - _peerHostName = _pSocket->peerAddress().host().toString(); - setState(ST_CLIENT_HSK_START); -} - - -void SecureSocketImpl::performClientHandshakeStart() -{ - initClientContext(); - performInitialClientHandshake(); -} - - bool SecureSocketImpl::stateMachine() { return StateMachine::instance().execute(this); } +#ifdef ENABLE_PRINT_STATE + + +#define CASE_PRINT_ENUM(enum) case enum: std::cout << #enum; break; + + +void SecureSocketImpl::printState(const std::string& message) +{ + Poco::Thread* pThread = Poco::Thread::current(); + int threadId = 0; + if (pThread) threadId = pThread->id(); + + static Poco::FastMutex mutex; + Poco::FastMutex::ScopedLock lock(mutex); + + std::cout << message << " thread=" << threadId << " state="; + switch (_state) + { + CASE_PRINT_ENUM(ST_INITIAL) + CASE_PRINT_ENUM(ST_CONNECTING) + CASE_PRINT_ENUM(ST_CLIENT_HSK_START) + CASE_PRINT_ENUM(ST_CLIENT_HSK_SEND_TOKEN) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_INIT) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_RECV) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_PROCESS) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_SEND) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_DONE) + CASE_PRINT_ENUM(ST_CLIENT_HSK_SEND_FINAL) + CASE_PRINT_ENUM(ST_CLIENT_HSK_SEND_ERROR) + CASE_PRINT_ENUM(ST_CLIENT_VERIFY) + CASE_PRINT_ENUM(ST_ACCEPTING) + CASE_PRINT_ENUM(ST_SERVER_HSK_START) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_INIT) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_RECV) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_PROCESS) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_SEND) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_DONE) + CASE_PRINT_ENUM(ST_SERVER_VERIFY) + CASE_PRINT_ENUM(ST_DONE) + CASE_PRINT_ENUM(ST_ERROR) + } + + std::cout << " securityStatus=" << Utility::formatError(_securityStatus) << " hCreds="<< SecIsValidHandle(&_hCreds) << " hContext=" << SecIsValidHandle(&_hContext) << " recvBufOff=" << _recvBufferOffset << std::endl; +} + + +#endif + + StateMachine& StateMachine::instance() { static StateMachine sm; @@ -1632,21 +1744,30 @@ StateMachine::StateMachine(): _states(SecureSocketImpl::ST_MAX) { _states[SecureSocketImpl::ST_INITIAL] = &SecureSocketImpl::stateIllegal; - _states[SecureSocketImpl::ST_CONNECTING] = &SecureSocketImpl::stateConnected; - _states[SecureSocketImpl::ST_CLIENT_HSK_START] = &SecureSocketImpl::performClientHandshakeStart; - _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_TOKEN] = &SecureSocketImpl::performClientHandshakeSendToken; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INIT] = &SecureSocketImpl::performClientHandshakeLoopInit; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_RECV] = &SecureSocketImpl::performClientHandshakeLoopRecv; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_PROCESS] = &SecureSocketImpl::performClientHandshakeLoopProcess; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_SEND] = &SecureSocketImpl::performClientHandshakeLoopSend; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INCOMPLETE] = &SecureSocketImpl::performClientHandshakeLoopIncompleteMessage; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_CONTINUE] = &SecureSocketImpl::performClientHandshakeLoopContinue; - _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_DONE] = &SecureSocketImpl::performClientHandshakeLoopDone; - _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_FINAL] = &SecureSocketImpl::performClientHandshakeSendFinal; - _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_ERROR] = &SecureSocketImpl::performClientHandshakeSendError; - _states[SecureSocketImpl::ST_CLIENT_VERIFY] = &SecureSocketImpl::clientConnectVerify; + + _states[SecureSocketImpl::ST_CONNECTING] = &SecureSocketImpl::stateClientConnected; + _states[SecureSocketImpl::ST_CLIENT_HSK_START] = &SecureSocketImpl::stateClientHandshakeStart; + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_TOKEN] = &SecureSocketImpl::stateClientHandshakeSendToken; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INIT] = &SecureSocketImpl::stateClientHandshakeLoopInit; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_RECV] = &SecureSocketImpl::stateClientHandshakeLoopRecv; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_PROCESS] = &SecureSocketImpl::stateClientHandshakeLoopProcess; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_SEND] = &SecureSocketImpl::stateClientHandshakeLoopSend; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_DONE] = &SecureSocketImpl::stateClientHandshakeLoopDone; + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_FINAL] = &SecureSocketImpl::stateClientHandshakeSendFinal; + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_ERROR] = &SecureSocketImpl::stateClientHandshakeSendError; + _states[SecureSocketImpl::ST_CLIENT_VERIFY] = &SecureSocketImpl::stateClientVerify; + + _states[SecureSocketImpl::ST_ACCEPTING] = &SecureSocketImpl::stateServerAccepted; + _states[SecureSocketImpl::ST_SERVER_HSK_START] = &SecureSocketImpl::stateServerHandshakeStart; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_INIT] = &SecureSocketImpl::stateServerHandshakeLoopInit; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_RECV] = &SecureSocketImpl::stateServerHandshakeLoopRecv; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_PROCESS] = &SecureSocketImpl::stateServerHandshakeLoopProcess; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_SEND] = &SecureSocketImpl::stateServerHandshakeLoopSend; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_DONE] = &SecureSocketImpl::stateServerHandshakeLoopDone; + _states[SecureSocketImpl::ST_SERVER_VERIFY] = &SecureSocketImpl::stateServerHandshakeVerify; + _states[SecureSocketImpl::ST_DONE] = &SecureSocketImpl::stateIllegal; - _states[SecureSocketImpl::ST_ERROR] = &SecureSocketImpl::performClientHandshakeError; + _states[SecureSocketImpl::ST_ERROR] = &SecureSocketImpl::stateError; } diff --git a/NetSSL_Win/src/SecureStreamSocketImpl.cpp b/NetSSL_Win/src/SecureStreamSocketImpl.cpp index add82ca23..876162a61 100644 --- a/NetSSL_Win/src/SecureStreamSocketImpl.cpp +++ b/NetSSL_Win/src/SecureStreamSocketImpl.cpp @@ -209,8 +209,7 @@ void SecureStreamSocketImpl::verifyPeerCertificate(const std::string& hostName) int SecureStreamSocketImpl::completeHandshake() { - _impl.completeHandshake(); - return 0; + return _impl.completeHandshake(); } diff --git a/NetSSL_Win/src/Utility.cpp b/NetSSL_Win/src/Utility.cpp index 5970bf8a2..d8ac4eb9c 100644 --- a/NetSSL_Win/src/Utility.cpp +++ b/NetSSL_Win/src/Utility.cpp @@ -24,9 +24,6 @@ namespace Poco { namespace Net { -Poco::FastMutex Utility::_mutex; - - Context::VerificationMode Utility::convertVerificationMode(const std::string& vMode) { std::string mode = Poco::toLower(vMode); @@ -54,6 +51,7 @@ inline void add(std::map& messageMap, long key, const s std::map Utility::initSSPIErr() { std::map messageMap; + add(messageMap, SEC_E_OK, "OK"); add(messageMap, NTE_BAD_UID, "Bad UID"); add(messageMap, NTE_BAD_HASH, "Bad Hash"); add(messageMap, NTE_BAD_KEY, "Bad Key"); @@ -185,18 +183,15 @@ std::map Utility::initSSPIErr() } -const std::string& Utility::formatError(long errCode) +std::string Utility::formatError(long errCode) { - Poco::FastMutex::ScopedLock lock(_mutex); - - static const std::string def("Internal SSPI error"); static const std::map errs(initSSPIErr()); const std::map::const_iterator it = errs.find(errCode); if (it != errs.end()) return it->second; else - return def; + return "0x" + Poco::NumberFormatter::formatHex(errCode, 8); }