diff --git a/Net/include/Poco/Net/SocketImpl.h b/Net/include/Poco/Net/SocketImpl.h index 2a71877e1..a2d4c8e62 100644 --- a/Net/include/Poco/Net/SocketImpl.h +++ b/Net/include/Poco/Net/SocketImpl.h @@ -170,12 +170,22 @@ public: virtual void shutdownReceive(); /// Shuts down the receiving part of the socket connection. - virtual void shutdownSend(); + virtual int shutdownSend(); /// Shuts down the sending part of the socket connection. + /// + /// Returns 0 for a non-blocking socket. May return + /// a negative value for a non-blocking socket in case + /// of a TLS connection. In that case, the operation should + /// be retried once the underlying socket becomes writable. - virtual void shutdown(); + virtual int shutdown(); /// Shuts down both the receiving and the sending part /// of the socket connection. + /// + /// Returns 0 for a non-blocking socket. May return + /// a negative value for a non-blocking socket in case + /// of a TLS connection. In that case, the operation should + /// be retried once the underlying socket becomes writable. virtual int sendBytes(const void* buffer, int length, int flags = 0); /// Sends the contents of the given buffer through diff --git a/Net/include/Poco/Net/StreamSocket.h b/Net/include/Poco/Net/StreamSocket.h index 662fb6dd3..fa5706316 100644 --- a/Net/include/Poco/Net/StreamSocket.h +++ b/Net/include/Poco/Net/StreamSocket.h @@ -157,12 +157,22 @@ public: void shutdownReceive(); /// Shuts down the receiving part of the socket connection. - void shutdownSend(); + int shutdownSend(); /// Shuts down the sending part of the socket connection. + /// + /// Returns 0 for a non-blocking socket. May return + /// a negative value for a non-blocking socket in case + /// of a TLS connection. In that case, the operation should + /// be retried once the underlying socket becomes writable. - void shutdown(); + int shutdown(); /// Shuts down both the receiving and the sending part /// of the socket connection. + /// + /// Returns 0 for a non-blocking socket. May return + /// a negative value for a non-blocking socket in case + /// of a TLS connection. In that case, the operation should + /// be retried once the underlying socket becomes writable. int sendBytes(const void* buffer, int length, int flags = 0); /// Sends the contents of the given buffer through diff --git a/Net/include/Poco/Net/WebSocketImpl.h b/Net/include/Poco/Net/WebSocketImpl.h index 4e53d3871..ee8ede76d 100644 --- a/Net/include/Poco/Net/WebSocketImpl.h +++ b/Net/include/Poco/Net/WebSocketImpl.h @@ -69,8 +69,8 @@ public: virtual void listen(int backlog = 64); virtual void close(); virtual void shutdownReceive(); - virtual void shutdownSend(); - virtual void shutdown(); + virtual int shutdownSend(); + virtual int shutdown(); virtual int sendTo(const void* buffer, int length, const SocketAddress& address, int flags = 0); virtual int receiveFrom(void* buffer, int length, SocketAddress& address, int flags = 0); virtual void sendUrgent(unsigned char data); diff --git a/Net/samples/HTTPTimeServer/src/HTTPTimeServer.cpp b/Net/samples/HTTPTimeServer/src/HTTPTimeServer.cpp index 3289c3d1e..9f4ee7667 100644 --- a/Net/samples/HTTPTimeServer/src/HTTPTimeServer.cpp +++ b/Net/samples/HTTPTimeServer/src/HTTPTimeServer.cpp @@ -67,10 +67,11 @@ public: response.setChunkedTransferEncoding(true); response.setContentType("text/html"); + response.set("Clear-Site-Data", "\"cookies\""); std::ostream& ostr = response.send(); ostr << "HTTPTimeServer powered by POCO C++ Libraries"; - ostr << ""; + ostr << ""; ostr << "

"; ostr << dt; ostr << "

"; diff --git a/Net/src/SocketImpl.cpp b/Net/src/SocketImpl.cpp index 579dfc656..b0c828fb3 100644 --- a/Net/src/SocketImpl.cpp +++ b/Net/src/SocketImpl.cpp @@ -327,21 +327,23 @@ void SocketImpl::shutdownReceive() } -void SocketImpl::shutdownSend() +int SocketImpl::shutdownSend() { if (_sockfd == POCO_INVALID_SOCKET) throw InvalidSocketException(); int rc = ::shutdown(_sockfd, 1); if (rc != 0) error(); + return 0; } -void SocketImpl::shutdown() +int SocketImpl::shutdown() { if (_sockfd == POCO_INVALID_SOCKET) throw InvalidSocketException(); int rc = ::shutdown(_sockfd, 2); if (rc != 0) error(); + return 0; } diff --git a/Net/src/StreamSocket.cpp b/Net/src/StreamSocket.cpp index 806c56516..dc8e84858 100644 --- a/Net/src/StreamSocket.cpp +++ b/Net/src/StreamSocket.cpp @@ -146,15 +146,15 @@ void StreamSocket::shutdownReceive() } -void StreamSocket::shutdownSend() +int StreamSocket::shutdownSend() { - impl()->shutdownSend(); + return impl()->shutdownSend(); } -void StreamSocket::shutdown() +int StreamSocket::shutdown() { - impl()->shutdown(); + return impl()->shutdown(); } diff --git a/Net/src/WebSocketImpl.cpp b/Net/src/WebSocketImpl.cpp index 23759d54c..fa2a27fb2 100644 --- a/Net/src/WebSocketImpl.cpp +++ b/Net/src/WebSocketImpl.cpp @@ -539,15 +539,15 @@ void WebSocketImpl::shutdownReceive() } -void WebSocketImpl::shutdownSend() +int WebSocketImpl::shutdownSend() { - _pStreamSocketImpl->shutdownSend(); + return _pStreamSocketImpl->shutdownSend(); } -void WebSocketImpl::shutdown() +int WebSocketImpl::shutdown() { - _pStreamSocketImpl->shutdown(); + return _pStreamSocketImpl->shutdown(); } diff --git a/NetSSL_OpenSSL/include/Poco/Net/SecureSocketImpl.h b/NetSSL_OpenSSL/include/Poco/Net/SecureSocketImpl.h index 10e9c1aee..477718618 100644 --- a/NetSSL_OpenSSL/include/Poco/Net/SecureSocketImpl.h +++ b/NetSSL_OpenSSL/include/Poco/Net/SecureSocketImpl.h @@ -148,10 +148,11 @@ public: /// number of connections that can be queued /// for this socket. - void shutdown(); + int shutdown(); /// Shuts down the connection by attempting /// an orderly SSL shutdown, then actually - /// shutting down the TCP connection. + /// shutting down the TCP connection in the + /// send direction. void close(); /// Close the socket. @@ -294,7 +295,6 @@ private: bool _needHandshake; std::string _peerHostName; Session::Ptr _pSession; - bool _bidirectShutdown = true; mutable MutexT _mutex; friend class SecureStreamSocketImpl; diff --git a/NetSSL_OpenSSL/include/Poco/Net/SecureStreamSocketImpl.h b/NetSSL_OpenSSL/include/Poco/Net/SecureStreamSocketImpl.h index 942a36c45..083399166 100644 --- a/NetSSL_OpenSSL/include/Poco/Net/SecureStreamSocketImpl.h +++ b/NetSSL_OpenSSL/include/Poco/Net/SecureStreamSocketImpl.h @@ -116,14 +116,25 @@ public: /// Since SSL does not support a half shutdown, this does /// nothing. - void shutdownSend() override; + int shutdownSend() override; /// Shuts down the receiving part of the socket connection. /// - /// Since SSL does not support a half shutdown, this does - /// nothing. + /// Sends a close notify shutdown alert message to the peer + /// (if not sent yet), then calls shutdownSend() on the + /// underlying socket. + /// + /// Returns 0 if the message has been sent. + /// Returns 1 if the message has been sent, but the peer + /// has not yet sent its shutdown alert message. + /// In case of a non-blocking socket, returns < 0 if the + /// message cannot be sent at the moment. In this case, + /// the call to shutdownSend() must be retried after the + /// underlying socket becomes writable again. - void shutdown() override; + int shutdown() override; /// Shuts down the SSL connection. + /// + /// Same as shutdownSend(). void abort(); /// Aborts the connection by closing the underlying diff --git a/NetSSL_OpenSSL/src/SecureSocketImpl.cpp b/NetSSL_OpenSSL/src/SecureSocketImpl.cpp index 2d70fd57b..ce03396e8 100644 --- a/NetSSL_OpenSSL/src/SecureSocketImpl.cpp +++ b/NetSSL_OpenSSL/src/SecureSocketImpl.cpp @@ -253,69 +253,39 @@ void SecureSocketImpl::listen(int backlog) } -void SecureSocketImpl::shutdown() +int SecureSocketImpl::shutdown() { if (_pSSL) { UnLockT l(_mutex); - // Don't shut down the socket more than once. int shutdownState = ::SSL_get_shutdown(_pSSL); bool shutdownSent = (shutdownState & SSL_SENT_SHUTDOWN) == SSL_SENT_SHUTDOWN; if (!shutdownSent) { - // A proper clean shutdown would require us to - // retry the shutdown if we get a zero return - // value, until SSL_shutdown() returns 1. - // However, this will lead to problems with - // most web browsers, so we just set the shutdown - // flag by calling SSL_shutdown() once and be - // done with it. -#if OPENSSL_VERSION_NUMBER >= 0x30000000L - int rc = 0; - if (!_bidirectShutdown) - rc = ::SSL_shutdown(_pSSL); - else - { - Poco::Timespan recvTimeout = _pSocket->getReceiveTimeout(); - Poco::Timespan pollTimeout(0, 100000); - Poco::Timestamp tsNow; - do - { - rc = ::SSL_shutdown(_pSSL); - if (rc == 1) break; - if (rc < 0) - { - int err = ::SSL_get_error(_pSSL, rc); - if (err == SSL_ERROR_WANT_READ) - _pSocket->poll(pollTimeout, Poco::Net::Socket::SELECT_READ); - else if (err == SSL_ERROR_WANT_WRITE) - _pSocket->poll(pollTimeout, Poco::Net::Socket::SELECT_WRITE); - else - { - int socketError = SocketImpl::lastError(); - long lastError = ::ERR_get_error(); - if ((err == SSL_ERROR_SSL) && (socketError == 0) && (lastError == 0x0A000123)) - rc = 0; - break; - } - } - else _pSocket->poll(pollTimeout, Poco::Net::Socket::SELECT_READ); - } while (!tsNow.isElapsed(recvTimeout.totalMicroseconds())); - } -#else int rc = ::SSL_shutdown(_pSSL); -#endif - if (rc < 0) handleError(rc); + if (rc < 0) + { + if (SocketImpl::lastError() == POCO_EWOULDBLOCK) + rc = SecureStreamSocket::ERR_SSL_WANT_WRITE; + else + rc = handleError(rc); + } l.unlock(); - if (_pSocket->getBlocking()) + if (rc >= 0) { - _pSocket->shutdown(); + _pSocket->shutdownSend(); } + return rc; + } + else + { + return (shutdownState & SSL_RECEIVED_SHUTDOWN) == SSL_RECEIVED_SHUTDOWN; } } + return 1; } @@ -407,7 +377,6 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) if (tsStart.isElapsed(recvTimeout.totalMicroseconds())) throw Poco::TimeoutException(); }; - _bidirectShutdown = false; if (rc <= 0) { return handleError(rc); diff --git a/NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp b/NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp index 289279e37..3523831e8 100644 --- a/NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp +++ b/NetSSL_OpenSSL/src/SecureStreamSocketImpl.cpp @@ -156,14 +156,15 @@ void SecureStreamSocketImpl::shutdownReceive() } -void SecureStreamSocketImpl::shutdownSend() +int SecureStreamSocketImpl::shutdownSend() { + return _impl.shutdown(); } -void SecureStreamSocketImpl::shutdown() +int SecureStreamSocketImpl::shutdown() { - _impl.shutdown(); + return _impl.shutdown(); } diff --git a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h index b3fa77c29..268ab2fe6 100644 --- a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h +++ b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h @@ -18,6 +18,10 @@ #define NetSSL_SecureSocketImpl_INCLUDED +// Temporary debugging aid, to be removed +// #define ENABLE_PRINT_STATE + + #include "Poco/Net/SocketImpl.h" #include "Poco/Net/NetSSL.h" #include "Poco/Net/Context.h" @@ -35,6 +39,14 @@ #include + +#ifdef ENABLE_PRINT_STATE +#define PRINT_STATE(m) printState(m) +#else +#define PRINT_STATE(m) +#endif + + namespace Poco { namespace Net { @@ -154,10 +166,11 @@ public: /// number of connections that can be queued /// for this socket. - void shutdown(); + int shutdown(); /// Shuts down the connection by attempting /// an orderly SSL shutdown, then actually - /// shutting down the TCP connection. + /// shutting down the TCP connection in the + /// send direction. void close(); /// Close the socket. @@ -224,60 +237,95 @@ protected: { ST_INITIAL = 0, ST_CONNECTING, - ST_CLIENTHANDSHAKESTART, - ST_CLIENTHANDSHAKECONDREAD, - ST_CLIENTHANDSHAKEINCOMPLETE, - ST_CLIENTHANDSHAKEOK, - ST_CLIENTHANDSHAKEEXTERROR, - ST_CLIENTHANDSHAKECONTINUE, - ST_VERIFY, + ST_CLIENT_HSK_START, + ST_CLIENT_HSK_SEND_TOKEN, + ST_CLIENT_HSK_LOOP_INIT, + ST_CLIENT_HSK_LOOP_RECV, + ST_CLIENT_HSK_LOOP_PROCESS, + ST_CLIENT_HSK_LOOP_SEND, + ST_CLIENT_HSK_LOOP_DONE, + ST_CLIENT_HSK_SEND_FINAL, + ST_CLIENT_HSK_SEND_ERROR, + ST_CLIENT_VERIFY, + ST_ACCEPTING, + ST_SERVER_HSK_START, + ST_SERVER_HSK_LOOP_INIT, + ST_SERVER_HSK_LOOP_RECV, + ST_SERVER_HSK_LOOP_PROCESS, + ST_SERVER_HSK_LOOP_SEND, + ST_SERVER_HSK_LOOP_DONE, + ST_SERVER_VERIFY, ST_DONE, - ST_ERROR + ST_ERROR, + ST_MAX + }; + + enum TLSShutdown + { + TLS_SHUTDOWN_SENT = 1, + TLS_SHUTDOWN_RECEIVED = 2 }; int sendRawBytes(const void* buffer, int length, int flags = 0); int receiveRawBytes(void* buffer, int length, int flags = 0); - void clientConnectVerify(); - void sendInitialTokenOutBuffer(); - void performServerHandshake(); - bool serverHandshakeLoop(PCtxtHandle phContext, PCredHandle phCred, bool requireClientAuth, bool doInitialRead, bool newContext); void clientVerifyCertificate(const std::string& hostName); void verifyCertificateChainClient(PCCERT_CONTEXT pServerCert); void serverVerifyCertificate(); - LONG serverDisconnect(PCredHandle phCreds, CtxtHandle* phContext); - LONG clientDisconnect(PCredHandle phCreds, CtxtHandle* phContext); - bool loadSecurityLibrary(); - void initClientContext(); - void initServerContext(); + int serverShutdown(PCredHandle phCreds, CtxtHandle* phContext); + int clientShutdown(PCredHandle phCreds, CtxtHandle* phContext); PCCERT_CONTEXT loadCertificate(bool mustFindCertificate); void initCommon(); void cleanup(); - void performClientHandshake(); - void performInitialClientHandshake(); - SECURITY_STATUS performClientHandshakeLoop(); - void performClientHandshakeLoopIncompleteMessage(); - void performClientHandshakeLoopCondReceive(); - void performClientHandshakeLoopReceive(); - void performClientHandshakeLoopOK(); - void performClientHandshakeLoopInit(); - void performClientHandshakeExtraBuffer(); - void performClientHandshakeSendOutBuffer(); - void performClientHandshakeLoopContinueNeeded(); - void performClientHandshakeLoopError(); - void performClientHandshakeLoopExtError(); + + void stateIllegal(); + void stateError(); + + void stateClientConnected(); + void stateClientHandshakeStart(); + void stateClientHandshakeSendToken(); + void stateClientHandshakeLoopInit(); + void stateClientHandshakeLoopRecv(); + void stateClientHandshakeLoopProcess(); + void stateClientHandshakeLoopSend(); + void stateClientHandshakeLoopDone(); + void stateClientHandshakeSendFinal(); + void stateClientHandshakeSendError(); + void stateClientVerify(); + + void stateServerAccepted(); + void stateServerHandshakeStart(); + void stateServerHandshakeLoopInit(); + void stateServerHandshakeLoopRecv(); + void stateServerHandshakeLoopProcess(); + void stateServerHandshakeLoopSend(); + void stateServerHandshakeLoopDone(); + void stateServerHandshakeVerify(); + + void sendOutSecBufferAndAdvanceState(State state); + void drainExtraBuffer(); + static int getRecordLength(const BYTE* pBuffer, int length); + static bool bufferHasCompleteRecords(const BYTE* pBuffer, int length); + + void initClientCredentials(); + void initServerCredentials(); + SECURITY_STATUS doHandshake(); + int completeHandshake(); + SECURITY_STATUS decodeMessage(BYTE* pBuffer, DWORD bufSize, AutoSecBufferDesc<4>& msg, SecBuffer*& pData, SecBuffer*& pExtra); SECURITY_STATUS decodeBufferFull(BYTE* pBuffer, DWORD bufSize, char* pOutBuffer, int outLength, int& bytesDecoded); - void stateIllegal(); - void stateConnected(); + void acceptSSL(); void connectSSL(bool completeHandshake); - void completeHandshake(); static int lastError(); - void stateMachine(); + bool stateMachine(); State getState() const; void setState(State st); static bool isLocalHost(const std::string& hostName); +#ifdef ENABLE_PRINT_STATE + void printState(const std::string& msg); +#endif + private: SecureSocketImpl(const SecureSocketImpl&); SecureSocketImpl& operator = (const SecureSocketImpl&); @@ -285,6 +333,7 @@ private: Poco::AutoPtr _pSocket; Context::Ptr _pContext; Mode _mode; + int _shutdownFlags; std::string _peerHostName; bool _useMachineStore; bool _clientAuthRequired; @@ -312,9 +361,9 @@ private: SecBuffer _extraSecBuffer; SECURITY_STATUS _securityStatus; State _state; - DWORD _outFlags; bool _needData; bool _needHandshake; + bool _initServerContext = false; friend class SecureStreamSocketImpl; friend class StateMachine; @@ -357,6 +406,7 @@ inline SecureSocketImpl::State SecureSocketImpl::getState() const inline void SecureSocketImpl::setState(SecureSocketImpl::State st) { _state = st; + PRINT_STATE("setState: "); } diff --git a/NetSSL_Win/include/Poco/Net/SecureStreamSocketImpl.h b/NetSSL_Win/include/Poco/Net/SecureStreamSocketImpl.h index 243f0a28a..6760d6b7d 100644 --- a/NetSSL_Win/include/Poco/Net/SecureStreamSocketImpl.h +++ b/NetSSL_Win/include/Poco/Net/SecureStreamSocketImpl.h @@ -117,14 +117,25 @@ public: /// Since SSL does not support a half shutdown, this does /// nothing. - void shutdownSend(); + int shutdownSend(); /// Shuts down the receiving part of the socket connection. /// - /// Since SSL does not support a half shutdown, this does - /// nothing. + /// Sends a close notify shutdown alert message to the peer + /// (if not sent yet), then calls shutdownSend() on the + /// underlying socket. + /// + /// Returns 0 if the message has been sent. + /// Returns 1 if the message has been sent, but the peer + /// has not yet sent its shutdown alert message. + /// In case of a non-blocking socket, returns < 0 if the + /// message cannot be sent at the moment. In this case, + /// the call to shutdownSend() must be retried after the + /// underlying socket becomes writable again. - void shutdown(); + int shutdown(); /// Shuts down the SSL connection. + /// + /// Same as shutdownSend(). void abort(); /// Aborts the connection by closing the underlying diff --git a/NetSSL_Win/include/Poco/Net/Utility.h b/NetSSL_Win/include/Poco/Net/Utility.h index 0e5e837a2..7a3501fd7 100644 --- a/NetSSL_Win/include/Poco/Net/Utility.h +++ b/NetSSL_Win/include/Poco/Net/Utility.h @@ -35,7 +35,7 @@ public: /// Non-case sensitive conversion of a string to a VerificationMode enum. /// If verMode is illegal an OptionException is thrown. - static const std::string& formatError(long errCode); + static std::string formatError(long errCode); /// Converts an winerror.h code into human readable form. private: diff --git a/NetSSL_Win/src/HTTPSStreamFactory.cpp b/NetSSL_Win/src/HTTPSStreamFactory.cpp index ec1d43f9e..b2cb86821 100644 --- a/NetSSL_Win/src/HTTPSStreamFactory.cpp +++ b/NetSSL_Win/src/HTTPSStreamFactory.cpp @@ -138,7 +138,7 @@ std::istream* HTTPSStreamFactory::open(const URI& uri) { return new HTTPResponseStream(rs, pSession); } - else if (res.getStatus() == HTTPResponse::HTTP_USEPROXY && !retry) + else if (res.getStatus() == HTTPResponse::HTTP_USE_PROXY && !retry) { // The requested resource MUST be accessed through the proxy // given by the Location field. The Location field gives the diff --git a/NetSSL_Win/src/SecureSocketImpl.cpp b/NetSSL_Win/src/SecureSocketImpl.cpp index fb62c3ea4..c6d38d861 100644 --- a/NetSSL_Win/src/SecureSocketImpl.cpp +++ b/NetSSL_Win/src/SecureSocketImpl.cpp @@ -25,6 +25,12 @@ #include +#ifdef ENABLE_PRINT_STATE +#include "Poco/Mutex.h" +#include +#endif + + namespace Poco { namespace Net { @@ -32,7 +38,6 @@ namespace Net { class StateMachine { public: - typedef bool (StateMachine::*ConditionMethod)(SOCKET sockfd); typedef void (SecureSocketImpl::*StateImpl)(void); StateMachine(); @@ -40,21 +45,13 @@ public: static StateMachine& instance(); - // Conditions - bool readable(SOCKET sockfd); - bool writable(SOCKET sockfd); - bool readOrWritable(SOCKET sockfd); - bool none(SOCKET sockfd); - void select(fd_set* fdRead, fd_set* fdWrite, SOCKET sockfd); - - void execute(SecureSocketImpl* pSock); + bool execute(SecureSocketImpl* pSock); private: - StateMachine(const StateMachine&); - StateMachine& operator = (const StateMachine&); + StateMachine(const StateMachine&) = delete; + StateMachine& operator = (const StateMachine&) = delete; - typedef std::pair ConditionState; - std::vector _states; + std::vector _states; }; @@ -62,12 +59,11 @@ SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr pSocketImpl, Contex _pSocket(pSocketImpl), _pContext(pContext), _mode(pContext->isForServerUse() ? MODE_SERVER : MODE_CLIENT), + _shutdownFlags(0), _clientAuthRequired(pContext->verificationMode() >= Context::VERIFY_STRICT), _securityFunctions(SSLManager::instance().securityFunctions()), _pOwnCertificate(0), _pPeerCertificate(0), - _hCreds(), - _hContext(), _contextFlags(0), _overflowBuffer(0), _sendBuffer(0), @@ -85,11 +81,8 @@ SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr pSocketImpl, Contex _needData(true), _needHandshake(false) { - _hCreds.dwLower = 0; - _hCreds.dwUpper = 0; - - _hContext.dwLower = 0; - _hContext.dwUpper = 0; + SecInvalidateHandle(&_hCreds); + SecInvalidateHandle(&_hContext); _streamSizes.cbBlockSize = 0; _streamSizes.cbHeader = 0; @@ -150,14 +143,12 @@ void SecureSocketImpl::cleanup() { _peerHostName.clear(); - _hCreds.dwLower = 0; - _hCreds.dwUpper = 0; + SecInvalidateHandle(&_hCreds); - if (_hContext.dwLower != 0 && _hContext.dwUpper != 0) + if (SecIsValidHandle(&_hContext)) { _securityFunctions.DeleteSecurityContext(&_hContext); - _hContext.dwLower = 0; - _hContext.dwUpper = 0; + SecInvalidateHandle(&_hContext); } if (_pOwnCertificate) @@ -191,19 +182,19 @@ SocketImpl* SecureSocketImpl::acceptConnection(SocketAddress& clientAddr) void SecureSocketImpl::connect(const SocketAddress& address, bool performHandshake) { - _state = ST_ERROR; + setState(ST_ERROR); _pSocket->connect(address); connectSSL(performHandshake); - _state = ST_DONE; + setState(ST_DONE); } void SecureSocketImpl::connect(const SocketAddress& address, const Poco::Timespan& timeout, bool performHandshake) { - _state = ST_ERROR; + setState(ST_ERROR); _pSocket->connect(address, timeout); connectSSL(performHandshake); - _state = ST_DONE; + setState(ST_DONE); } @@ -211,12 +202,12 @@ void SecureSocketImpl::connectNB(const SocketAddress& address) { try { - _state = ST_CONNECTING; + setState(ST_CONNECTING); _pSocket->connectNB(address); } catch (...) { - _state = ST_ERROR; + setState(ST_ERROR); } } @@ -257,32 +248,46 @@ void SecureSocketImpl::listen(int backlog) { _mode = MODE_SERVER; - if (_hCreds.dwLower == 0 && _hCreds.dwUpper == 0) + if (!SecIsValidHandle(&_hCreds)) { - initServerContext(); + initServerCredentials(); } _pSocket->listen(backlog); } -void SecureSocketImpl::shutdown() +int SecureSocketImpl::shutdown() { - if (_mode == MODE_SERVER) - serverDisconnect(&_hCreds, &_hContext); - else - clientDisconnect(&_hCreds, &_hContext); - - _pSocket->shutdown(); + int rc = 0; + if ((_shutdownFlags & TLS_SHUTDOWN_SENT) == 0) + { + if (_mode == MODE_SERVER) + { + rc = serverShutdown(&_hCreds, &_hContext); + } + else + { + rc = clientShutdown(&_hCreds, &_hContext); + } + if (rc >= 0) + { + _pSocket->shutdownSend(); + _shutdownFlags |= TLS_SHUTDOWN_SENT; + } + } + return (_shutdownFlags & TLS_SHUTDOWN_RECEIVED) ? 1 : 0; } void SecureSocketImpl::close() { - if (_mode == MODE_SERVER) - serverDisconnect(&_hCreds, &_hContext); - else - clientDisconnect(&_hCreds, &_hContext); - + try + { + shutdown(); + } + catch (Poco::Exception&) + { + } _pSocket->close(); cleanup(); } @@ -303,8 +308,7 @@ int SecureSocketImpl::available() const void SecureSocketImpl::acceptSSL() { - _state = ST_DONE; - initServerContext(); + setState(ST_ACCEPTING); _needHandshake = true; } @@ -327,11 +331,14 @@ void SecureSocketImpl::verifyPeerCertificate(const std::string& hostName) return; } - if (_mode == MODE_SERVER) + { serverVerifyCertificate(); + } else + { clientVerifyCertificate(hostName); + } } @@ -358,12 +365,28 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) { if (_needHandshake) { - completeHandshake(); + doHandshake(); _needHandshake = false; } if (_state == ST_ERROR) return 0; + if (_state != ST_DONE) + { + if (_pSocket->getBlocking()) + { + doHandshake(); + } + else + { + while (_state != ST_DONE && stateMachine()) + { + // no-op + } + if (_state != ST_DONE) return -1; + } + } + if (_sendBufferPending > 0) { int sent = sendRawBytes(_sendBuffer.begin() + _sendBufferOffset, _sendBufferPending, flags); @@ -375,23 +398,6 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) return _sendBufferPending == 0 ? length : -1; } - if (_state != ST_DONE) - { - bool establish = _pSocket->getBlocking(); - if (establish) - { - while (_state != ST_DONE) - { - stateMachine(); - } - } - else - { - stateMachine(); - return -1; - } - } - int dataToSend = length; int dataSent = 0; const char* pBuffer = reinterpret_cast(buffer); @@ -457,25 +463,24 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) { if (_needHandshake) { - completeHandshake(); + doHandshake(); _needHandshake = false; } if (_state == ST_ERROR) return 0; if (_state != ST_DONE) { - bool establish = _pSocket->getBlocking(); - if (establish) + if (_pSocket->getBlocking()) { - while (_state != ST_DONE) - { - stateMachine(); - } + doHandshake(); } else { - stateMachine(); - return -1; + while (_state != ST_DONE && stateMachine()) + { + // no-op + } + if (_state != ST_DONE) return -1; } } @@ -571,24 +576,23 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) if (securityStatus == SEC_I_CONTEXT_EXPIRED) { - SetLastError(securityStatus); + _shutdownFlags |= TLS_SHUTDOWN_RECEIVED; break; } - if (securityStatus != SEC_E_OK && securityStatus != SEC_I_RENEGOTIATE && securityStatus != SEC_I_CONTEXT_EXPIRED) + if (securityStatus != SEC_E_OK && securityStatus != SEC_I_RENEGOTIATE) { - SetLastError(securityStatus); break; } if (securityStatus == SEC_I_RENEGOTIATE) { _needData = false; - _state = ST_CLIENTHANDSHAKECONDREAD; + setState(ST_CLIENT_HSK_LOOP_INIT); if (!_pSocket->getBlocking()) - return -1; + return bytesDecoded > 0 ? bytesDecoded : -1; - securityStatus = performClientHandshakeLoop(); + securityStatus = doHandshake(); if (securityStatus != SEC_E_OK) break; @@ -620,6 +624,8 @@ SECURITY_STATUS SecureSocketImpl::decodeMessage(BYTE* pBuffer, DWORD bufSize, Au pExtraBuffer = 0; SECURITY_STATUS securityStatus = _securityFunctions.DecryptMessage(&_hContext, &msg, 0, 0); + // TODO: when decrypting the close_notify alert, returns SEC_E_DECRYPT_FAILURE + // instead of SEC_I_CONTEXT_EXPIRED if (securityStatus == SEC_E_OK || securityStatus == SEC_I_RENEGOTIATE) { @@ -719,9 +725,7 @@ SECURITY_STATUS SecureSocketImpl::decodeBufferFull(BYTE* pBuffer, DWORD bufSize, if (securityStatus == SEC_I_RENEGOTIATE) { _needData = false; - securityStatus = performClientHandshakeLoop(); - if (securityStatus != SEC_E_OK) - break; + return securityStatus; } } while (securityStatus == SEC_E_OK && pBuffer); @@ -776,10 +780,10 @@ void SecureSocketImpl::connectSSL(bool completeHandshake) _peerHostName = _pSocket->peerAddress().host().toString(); } - initClientContext(); + setState(ST_CONNECTING); if (completeHandshake) { - performClientHandshake(); + doHandshake(); _needHandshake = false; } else @@ -789,67 +793,38 @@ void SecureSocketImpl::connectSSL(bool completeHandshake) } -void SecureSocketImpl::completeHandshake() +void SecureSocketImpl::stateIllegal() { - if (_mode == MODE_SERVER) - performServerHandshake(); - else - performClientHandshake(); + throw Poco::IllegalStateException("SSL state machine"); } -void SecureSocketImpl::clientConnectVerify() +void SecureSocketImpl::stateError() { - poco_assert_dbg(!_pPeerCertificate); - poco_assert_dbg(!_peerHostName.empty()); + poco_assert_dbg (FAILED(_securityStatus)); - try - { - SECURITY_STATUS securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (PVOID) &_pPeerCertificate); - if (securityStatus != SEC_E_OK) - throw SSLException("Failed to obtain peer certificate", Utility::formatError(securityStatus)); - - clientVerifyCertificate(_peerHostName); - - securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext, SECPKG_ATTR_STREAM_SIZES, &_streamSizes); - if (securityStatus != SEC_E_OK) - throw SSLException("Failed to query stream sizes", Utility::formatError(securityStatus)); - - _ioBufferSize = _streamSizes.cbHeader + _streamSizes.cbMaximumMessage + _streamSizes.cbTrailer; - _state = ST_DONE; - } - catch (...) - { - if (_pPeerCertificate) - { - CertFreeCertificateContext(_pPeerCertificate); - _pPeerCertificate = 0; - } - throw; - } + cleanup(); + setState(ST_ERROR); + throw SSLException("Error during handshake", Utility::formatError(_securityStatus)); } -void SecureSocketImpl::initClientContext() +void SecureSocketImpl::stateClientConnected() { - _pOwnCertificate = loadCertificate(false); - _hCreds = _pContext->credentials(); + _peerHostName = _pSocket->peerAddress().host().toString(); + setState(ST_CLIENT_HSK_START); } -void SecureSocketImpl::performClientHandshake() +void SecureSocketImpl::stateClientHandshakeStart() { - performInitialClientHandshake(); - performClientHandshakeLoop(); - clientConnectVerify(); -} + initClientCredentials(); - -void SecureSocketImpl::performInitialClientHandshake() -{ // get initial security token _outSecBuffer.reset(true); _outSecBuffer.setSecBufferToken(0, 0, 0); + _recvBuffer.setCapacity(IO_BUFFER_SIZE); + _recvBufferOffset = 0; TimeStamp ts; DWORD contextAttributes(0); @@ -869,194 +844,151 @@ void SecureSocketImpl::performInitialClientHandshake() &contextAttributes, &ts); - if (_securityStatus != SEC_E_OK) + if (_securityStatus != SEC_E_OK && _securityStatus != SEC_I_CONTINUE_NEEDED) { if (_securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) { // the server is asking for client credentials, we didn't send one because we were not configured to do so, abort throw SSLException("Handshake failed: No client credentials configured"); } - else if (_securityStatus != SEC_I_CONTINUE_NEEDED) + else { throw SSLException("Handshake failed", Utility::formatError(_securityStatus)); } } - // incomplete credentials: more calls to InitializeSecurityContext needed - // send the token - sendInitialTokenOutBuffer(); + _extraSecBuffer.pvBuffer = 0; + _extraSecBuffer.cbBuffer = 0; + _needData = true; + + setState(ST_CLIENT_HSK_SEND_TOKEN); +} + + +void SecureSocketImpl::stateClientHandshakeSendToken() +{ + poco_assert_dbg (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer); + + int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); + if (sent < 0) + { + return; + } + else if (sent < _outSecBuffer[0].cbBuffer) + { + _outSecBuffer[0].cbBuffer -= sent; + std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); + return; + } if (_securityStatus == SEC_E_OK) { // The security context was successfully initialized. // There is no need for another InitializeSecurityContext (Schannel) call. - _state = ST_DONE; - return; + setState(ST_DONE); } - - //SEC_I_CONTINUE_NEEDED was returned: - // Wait for a return token. The returned token is then passed in - // another call to InitializeSecurityContext (Schannel). The output token can be empty. - - _extraSecBuffer.pvBuffer = 0; - _extraSecBuffer.cbBuffer = 0; - _needData = true; - _state = ST_CLIENTHANDSHAKECONDREAD; - _securityStatus = SEC_E_INCOMPLETE_MESSAGE; -} - - -void SecureSocketImpl::sendInitialTokenOutBuffer() -{ - // send the token - if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + else { - int numBytes = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); - if (numBytes != _outSecBuffer[0].cbBuffer) - throw SSLException("Failed to send token to the server"); + setState(ST_CLIENT_HSK_LOOP_INIT); } } -SECURITY_STATUS SecureSocketImpl::performClientHandshakeLoop() -{ - _recvBufferOffset = 0; - _securityStatus = SEC_E_INCOMPLETE_MESSAGE; - - while (_securityStatus == SEC_I_CONTINUE_NEEDED || _securityStatus == SEC_E_INCOMPLETE_MESSAGE || _securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) - { - performClientHandshakeLoopCondReceive(); - - if (_securityStatus == SEC_E_OK) - { - performClientHandshakeLoopOK(); - } - else if (_securityStatus == SEC_I_CONTINUE_NEEDED) - { - performClientHandshakeLoopContinueNeeded(); - } - else if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE) - { - performClientHandshakeLoopIncompleteMessage(); - } - else if (FAILED(_securityStatus)) - { - if (_outFlags & ISC_RET_EXTENDED_ERROR) - { - performClientHandshakeLoopExtError(); - } - else - { - performClientHandshakeLoopError(); - } - } - else - { - performClientHandshakeLoopIncompleteMessage(); - } - } - - if (FAILED(_securityStatus)) - { - performClientHandshakeLoopError(); - } - - return _securityStatus; -} - - -void SecureSocketImpl::performClientHandshakeLoopExtError() +void SecureSocketImpl::stateClientHandshakeSendError() { poco_assert_dbg (FAILED(_securityStatus)); - performClientHandshakeSendOutBuffer(); - performClientHandshakeLoopError(); + sendOutSecBufferAndAdvanceState(ST_ERROR); } -void SecureSocketImpl::performClientHandshakeLoopError() -{ - poco_assert_dbg (FAILED(_securityStatus)); - cleanup(); - _state = ST_ERROR; - throw SSLException("Error during handshake", Utility::formatError(_securityStatus)); -} - - -void SecureSocketImpl::performClientHandshakeSendOutBuffer() -{ - if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) - { - int numBytes = sendRawBytes(static_cast(_outSecBuffer[0].pvBuffer), _outSecBuffer[0].cbBuffer); - if (numBytes != _outSecBuffer[0].cbBuffer) - throw SSLException("Socket error during handshake"); - - _outSecBuffer.release(0); - } -} - - -void SecureSocketImpl::performClientHandshakeExtraBuffer() -{ - if (_inSecBuffer[1].BufferType == SECBUFFER_EXTRA) - { - std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - _inSecBuffer[1].cbBuffer), _inSecBuffer[1].cbBuffer); - _recvBufferOffset = _inSecBuffer[1].cbBuffer; - } - else _recvBufferOffset = 0; -} - - -void SecureSocketImpl::performClientHandshakeLoopOK() -{ - poco_assert_dbg(_securityStatus == SEC_E_OK); - - performClientHandshakeSendOutBuffer(); - performClientHandshakeExtraBuffer(); - _state = ST_VERIFY; -} - - -void SecureSocketImpl::performClientHandshakeLoopInit() +void SecureSocketImpl::stateClientHandshakeLoopInit() { _inSecBuffer.reset(false); _outSecBuffer.reset(true); + + if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || _securityStatus == SEC_I_CONTINUE_NEEDED) + { + setState(ST_CLIENT_HSK_LOOP_RECV); + } + else + { + setState(ST_CLIENT_HSK_LOOP_PROCESS); + } } -void SecureSocketImpl::performClientHandshakeLoopReceive() +int SecureSocketImpl::getRecordLength(const BYTE* pRec, int len) +{ + // TLS Record Header: + // struct { + // ContentType type; /* 1 byte */ + // ProtocolVersion version; /* 2 bytes */ + // uint16 length; /* must be <= 2^14 */ + // opaque fragment[TLSPlaintext.length]; + // } TLSPlaintext; + + if (len >= 5) + { + Poco::UInt16 recLen = (pRec[3] << 8) + pRec[4]; + if (recLen <= 16384) + { + return recLen + 5; + } + else + { + return -1; + } + } + return 0; +} + + +bool SecureSocketImpl::bufferHasCompleteRecords(const BYTE* pBuffer, int length) +{ + const BYTE* pEnd = pBuffer + length; + int recLen = getRecordLength(pBuffer, length); + while (recLen > 0 && length >= recLen) + { + pBuffer += recLen; + length -= recLen; + recLen = getRecordLength(pBuffer, length); + } + return pBuffer == pEnd; +} + + +void SecureSocketImpl::stateClientHandshakeLoopRecv() { - poco_assert_dbg (_needData); poco_assert (IO_BUFFER_SIZE > _recvBufferOffset); int n = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, IO_BUFFER_SIZE - _recvBufferOffset); - if (n <= 0) throw SSLException("Error during handshake: failed to read data"); - - _recvBufferOffset += n; + if (n == 0) throw SSLException("Connection unexpectedly closed during handshake"); + if (n > 0) + { + _recvBufferOffset += n; + // NOTE: InitializeSecurityContext() will fail with SEC_E_INVALID_TOKEN + // if provided with a buffer containing incomplete records, rather than returning + // SEC_E_INCOMPLETE_MESSAGE. Therefore we have to make sure that we + // receive complete records before passing the buffer to InitializeSecurityContext(). + if (bufferHasCompleteRecords(_recvBuffer.begin(), _recvBufferOffset)) + { + setState(ST_CLIENT_HSK_LOOP_PROCESS); + } + + } } -void SecureSocketImpl::performClientHandshakeLoopCondReceive() +void SecureSocketImpl::stateClientHandshakeLoopProcess() { - poco_assert_dbg (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || SEC_I_CONTINUE_NEEDED); - - performClientHandshakeLoopInit(); - if (_needData) - { - if (_recvBuffer.capacity() != IO_BUFFER_SIZE) - _recvBuffer.setCapacity(IO_BUFFER_SIZE); - performClientHandshakeLoopReceive(); - } - else _needData = true; - _inSecBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset); - // inbuffer 1 should be empty _inSecBuffer.setSecBufferEmpty(1); - - // outBuffer[0] should be empty _outSecBuffer.setSecBufferToken(0, 0, 0); - _outFlags = 0; + PRINT_STATE("stateClientHandshakeLoopProcess before InitializeSecurityContextW: "); + + DWORD outFlags; TimeStamp ts; _securityStatus = _securityFunctions.InitializeSecurityContextW( &_hCreds, @@ -1069,57 +1001,259 @@ void SecureSocketImpl::performClientHandshakeLoopCondReceive() 0, 0, &_outSecBuffer, - &_outFlags, + &outFlags, &ts); + PRINT_STATE("stateClientHandshakeLoopProcess after InitializeSecurityContextW: "); + if (_securityStatus == SEC_E_OK) { - _state = ST_CLIENTHANDSHAKEOK; + drainExtraBuffer(); + setState(ST_CLIENT_HSK_LOOP_DONE); } else if (_securityStatus == SEC_I_CONTINUE_NEEDED) { - _state = ST_CLIENTHANDSHAKECONTINUE; + drainExtraBuffer(); + setState(ST_CLIENT_HSK_LOOP_SEND); + } + else if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE) + { + _needData = true; + setState(ST_CLIENT_HSK_LOOP_INIT); } else if (FAILED(_securityStatus)) { - if (_outFlags & ISC_RET_EXTENDED_ERROR) - _state = ST_CLIENTHANDSHAKEEXTERROR; + if (outFlags & ISC_RET_EXTENDED_ERROR) + { + setState(ST_CLIENT_HSK_SEND_ERROR); + } else - _state = ST_ERROR; + { + setState(ST_ERROR); + } } else { - _state = ST_CLIENTHANDSHAKEINCOMPLETE; + setState(ST_CLIENT_HSK_LOOP_INIT); } } -void SecureSocketImpl::performClientHandshakeLoopContinueNeeded() +void SecureSocketImpl::stateClientHandshakeLoopSend() { - performClientHandshakeSendOutBuffer(); - performClientHandshakeExtraBuffer(); - _state = ST_CLIENTHANDSHAKECONDREAD; + sendOutSecBufferAndAdvanceState(ST_CLIENT_HSK_LOOP_INIT); } -void SecureSocketImpl::performClientHandshakeLoopIncompleteMessage() +void SecureSocketImpl::stateClientHandshakeLoopDone() { - _needData = true; - _state = ST_CLIENTHANDSHAKECONDREAD; + poco_assert_dbg(_securityStatus == SEC_E_OK); + + if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + { + setState(ST_CLIENT_HSK_SEND_FINAL); + } + else + { + setState(ST_CLIENT_VERIFY); + } } -void SecureSocketImpl::initServerContext() +void SecureSocketImpl::stateClientHandshakeSendFinal() { - _pOwnCertificate = loadCertificate(true); - _hCreds = _pContext->credentials(); + sendOutSecBufferAndAdvanceState(ST_CLIENT_VERIFY); } -void SecureSocketImpl::performServerHandshake() +void SecureSocketImpl::sendOutSecBufferAndAdvanceState(State state) { - serverHandshakeLoop(&_hContext, &_hCreds, _clientAuthRequired, true, true); + if (_outSecBuffer[0].cbBuffer && _outSecBuffer[0].pvBuffer) + { + int sent = sendRawBytes(_outSecBuffer[0].pvBuffer, _outSecBuffer[0].cbBuffer); + if (sent < 0) return; + if (sent < _outSecBuffer[0].cbBuffer) + { + _outSecBuffer[0].cbBuffer -= sent; + std::memmove(_outSecBuffer[0].pvBuffer, reinterpret_cast(_outSecBuffer[0].pvBuffer) + sent, _outSecBuffer[0].cbBuffer); + } + else + { + _outSecBuffer.release(0); + setState(state); + } + } + else + { + setState(state); + } +} + +void SecureSocketImpl::drainExtraBuffer() +{ + if (_inSecBuffer[1].BufferType == SECBUFFER_EXTRA && _inSecBuffer[1].pvBuffer && _inSecBuffer[1].cbBuffer > 0) + { + std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - _inSecBuffer[1].cbBuffer), _inSecBuffer[1].cbBuffer); + _recvBufferOffset = _inSecBuffer[1].cbBuffer; + } + else _recvBufferOffset = 0; +} + + +void SecureSocketImpl::stateClientVerify() +{ + poco_assert_dbg(!_pPeerCertificate); + poco_assert_dbg(!_peerHostName.empty()); + + try + { + SECURITY_STATUS securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (PVOID) &_pPeerCertificate); + if (securityStatus != SEC_E_OK) + throw SSLException("Failed to obtain peer certificate", Utility::formatError(securityStatus)); + + clientVerifyCertificate(_peerHostName); + + securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext, SECPKG_ATTR_STREAM_SIZES, &_streamSizes); + if (securityStatus != SEC_E_OK) + throw SSLException("Failed to query stream sizes", Utility::formatError(securityStatus)); + + _ioBufferSize = _streamSizes.cbHeader + _streamSizes.cbMaximumMessage + _streamSizes.cbTrailer; + setState(ST_DONE); + } + catch (...) + { + if (_pPeerCertificate) + { + CertFreeCertificateContext(_pPeerCertificate); + _pPeerCertificate = 0; + } + throw; + } +} + + +void SecureSocketImpl::stateServerAccepted() +{ + initServerCredentials(); + setState(ST_SERVER_HSK_START); +} + + +void SecureSocketImpl::stateServerHandshakeStart() +{ + _securityStatus = SEC_E_INCOMPLETE_MESSAGE; + _initServerContext = true; + _recvBuffer.setCapacity(IO_BUFFER_SIZE); + _recvBufferOffset = 0; + setState(ST_SERVER_HSK_LOOP_INIT); +} + + +void SecureSocketImpl::stateServerHandshakeLoopInit() +{ + if (_securityStatus == SEC_I_CONTINUE_NEEDED || _securityStatus == SEC_E_INCOMPLETE_MESSAGE || _securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) + { + if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE || _securityStatus == SEC_I_CONTINUE_NEEDED) + { + setState(ST_SERVER_HSK_LOOP_RECV); + } + else + { + setState(ST_SERVER_HSK_LOOP_PROCESS); + } + } + else if (_securityStatus == SEC_E_CERT_UNKNOWN || _securityStatus == SEC_E_CERT_EXPIRED) + { + setState(ST_SERVER_HSK_LOOP_PROCESS); + } + else if (_securityStatus == SEC_E_OK) + { + setState(ST_SERVER_HSK_LOOP_DONE); + } +} + + +void SecureSocketImpl::stateServerHandshakeLoopRecv() +{ + poco_assert_dbg (_recvBuffer.capacity() >= IO_BUFFER_SIZE && IO_BUFFER_SIZE > _recvBufferOffset); + + int n = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, IO_BUFFER_SIZE - _recvBufferOffset); + if (n == 0) throw SSLException("Connection unexpectedly closed during handshake"); + if (n > 0) + { + _recvBufferOffset += n; + setState(ST_SERVER_HSK_LOOP_PROCESS); + } +} + + +void SecureSocketImpl::stateServerHandshakeLoopProcess() +{ + PRINT_STATE("stateServerHandshakeLoopProcess before AcceptSecurityContext: "); + _inSecBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset); + _inSecBuffer.setSecBufferEmpty(1); + _outSecBuffer.setSecBufferToken(0, 0, 0); + + TimeStamp tsExpiry; + DWORD outFlags; + _securityStatus = _securityFunctions.AcceptSecurityContext( + &_hCreds, + _initServerContext ? NULL : &_hContext, + &_inSecBuffer, + _contextFlags, + 0, + _initServerContext ? &_hContext : NULL, + &_outSecBuffer, + &outFlags, + &tsExpiry); + + _initServerContext = !SecIsValidHandle(&_hContext); + + PRINT_STATE("stateServerHandshakeLoopProcess after AcceptSecurityContext: "); + + setState(ST_SERVER_HSK_LOOP_INIT); + if (_securityStatus == SEC_E_INCOMPLETE_MESSAGE) + { + return; + } + else if (_securityStatus == SEC_E_OK || _securityStatus == SEC_I_CONTINUE_NEEDED || (FAILED(_securityStatus) && (0 != (outFlags & ISC_RET_EXTENDED_ERROR)))) + { + drainExtraBuffer(); + if (_outSecBuffer[0].cbBuffer != 0 && _outSecBuffer[0].pvBuffer != 0) + { + setState(ST_SERVER_HSK_LOOP_SEND); + return; + } + } + else if (_securityStatus == SEC_E_CERT_UNKNOWN || _securityStatus == SEC_E_CERT_EXPIRED || _securityStatus == CRYPT_E_REVOKED || _securityStatus == SEC_E_UNTRUSTED_ROOT) + { + throw SSLException("Handshake failure, client has rejected certificate", Utility::formatError(_securityStatus)); + } + else if (FAILED(_securityStatus)) + { + throw SSLException("Handshake failure:", Utility::formatError(_securityStatus)); + } +} + + +void SecureSocketImpl::stateServerHandshakeLoopSend() +{ + sendOutSecBufferAndAdvanceState(ST_SERVER_HSK_LOOP_INIT); +} + + +void SecureSocketImpl::stateServerHandshakeLoopDone() +{ + _securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext, SECPKG_ATTR_STREAM_SIZES, &_streamSizes); + if (_securityStatus != SEC_E_OK) throw SSLException("Cannot query stream sizes", Utility::formatError(_securityStatus)); + _ioBufferSize = _streamSizes.cbHeader + _streamSizes.cbMaximumMessage + _streamSizes.cbTrailer; + setState(ST_SERVER_VERIFY); +} + + +void SecureSocketImpl::stateServerHandshakeVerify() +{ SECURITY_STATUS securityStatus; if (_clientAuthRequired) { @@ -1140,100 +1274,49 @@ void SecureSocketImpl::performServerHandshake() serverVerifyCertificate(); } } - - securityStatus = _securityFunctions.QueryContextAttributesW(&_hContext,SECPKG_ATTR_STREAM_SIZES, &_streamSizes); - if (securityStatus != SEC_E_OK) throw SSLException("Cannot query stream sizes", Utility::formatError(securityStatus)); - - _ioBufferSize = _streamSizes.cbHeader + _streamSizes.cbMaximumMessage + _streamSizes.cbTrailer; + setState(ST_DONE); } -bool SecureSocketImpl::serverHandshakeLoop(PCtxtHandle phContext, PCredHandle phCred, bool requireClientAuth, bool doInitialRead, bool newContext) +void SecureSocketImpl::initClientCredentials() { - TimeStamp tsExpiry; - int n = 0; - bool doRead = doInitialRead; - bool initContext = newContext; - DWORD outFlags; - SECURITY_STATUS securityStatus = SEC_E_INCOMPLETE_MESSAGE; + _pOwnCertificate = loadCertificate(false); + _hCreds = _pContext->credentials(); +} - while (securityStatus == SEC_I_CONTINUE_NEEDED || securityStatus == SEC_E_INCOMPLETE_MESSAGE || securityStatus == SEC_I_INCOMPLETE_CREDENTIALS) + +void SecureSocketImpl::initServerCredentials() +{ + _pOwnCertificate = loadCertificate(true); + _hCreds = _pContext->credentials(); +} + + +SECURITY_STATUS SecureSocketImpl::doHandshake() +{ + while (_state != ST_DONE && _state != ST_ERROR) { - if (securityStatus == SEC_E_INCOMPLETE_MESSAGE) - { - if (doRead) - { - n = receiveRawBytes(_recvBuffer.begin() + _recvBufferOffset, IO_BUFFER_SIZE - _recvBufferOffset); - - if (n <= 0) - throw SSLException("Failed to receive data in handshake"); - else - _recvBufferOffset += n; - } - else doRead = true; - } - - AutoSecBufferDesc<2> inBuffer(&_securityFunctions, false); - AutoSecBufferDesc<1> outBuffer(&_securityFunctions, true); - inBuffer.setSecBufferToken(0, _recvBuffer.begin(), _recvBufferOffset); - inBuffer.setSecBufferEmpty(1); - outBuffer.setSecBufferToken(0, 0, 0); - - securityStatus = _securityFunctions.AcceptSecurityContext( - phCred, - initContext ? NULL : phContext, - &inBuffer, - _contextFlags, - 0, - initContext ? phContext : NULL, - &outBuffer, - &outFlags, - &tsExpiry); - - initContext = false; - - if (securityStatus == SEC_E_OK || securityStatus == SEC_I_CONTINUE_NEEDED || (FAILED(securityStatus) && (0 != (outFlags & ISC_RET_EXTENDED_ERROR)))) - { - if (outBuffer[0].cbBuffer != 0 && outBuffer[0].pvBuffer != 0) - { - n = sendRawBytes(outBuffer[0].pvBuffer, outBuffer[0].cbBuffer); - outBuffer.release(0); - } - } - - if (securityStatus == SEC_E_OK ) - { - if (inBuffer[1].BufferType == SECBUFFER_EXTRA) - { - std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - inBuffer[1].cbBuffer), inBuffer[1].cbBuffer); - _recvBufferOffset = inBuffer[1].cbBuffer; - } - else - { - _recvBufferOffset = 0; - } - return true; - } - else if (FAILED(securityStatus) && securityStatus != SEC_E_INCOMPLETE_MESSAGE) - { - throw SSLException("Handshake failure:", Utility::formatError(securityStatus)); - } - - if (securityStatus != SEC_E_INCOMPLETE_MESSAGE && securityStatus != SEC_I_INCOMPLETE_CREDENTIALS) - { - if (inBuffer[1].BufferType == SECBUFFER_EXTRA) - { - std::memmove(_recvBuffer.begin(), _recvBuffer.begin() + (_recvBufferOffset - inBuffer[1].cbBuffer), inBuffer[1].cbBuffer); - _recvBufferOffset = inBuffer[1].cbBuffer; - } - else - { - _recvBufferOffset = 0; - } - } + stateMachine(); } + return _securityStatus; +} - return false; + +int SecureSocketImpl::completeHandshake() +{ + if (_pSocket->getBlocking()) + { + doHandshake(); + return 0; + } + else + { + while (_state != ST_DONE && stateMachine()) + { + // no-op + } + return (_state == ST_DONE) ? 0 : -1; + } } @@ -1487,11 +1570,11 @@ void SecureSocketImpl::serverVerifyCertificate() } -LONG SecureSocketImpl::clientDisconnect(PCredHandle phCreds, CtxtHandle* phContext) +int SecureSocketImpl::clientShutdown(PCredHandle phCreds, CtxtHandle* phContext) { - if (phContext->dwLower == 0 && phContext->dwUpper == 0) + if (!SecIsValidHandle(phContext)) { - return SEC_E_OK; + return 0; } AutoSecBufferDesc<1> tokBuffer(&_securityFunctions, false); @@ -1500,7 +1583,7 @@ LONG SecureSocketImpl::clientDisconnect(PCredHandle phCreds, CtxtHandle* phConte tokBuffer.setSecBufferToken(0, &tokenType, sizeof(tokenType)); DWORD status = _securityFunctions.ApplyControlToken(phContext, &tokBuffer); - if (FAILED(status)) return status; + if (FAILED(status)) throw SSLException(Utility::formatError(status)); DWORD sspiFlags = ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT @@ -1528,13 +1611,21 @@ LONG SecureSocketImpl::clientDisconnect(PCredHandle phCreds, CtxtHandle* phConte &sspiOutFlags, &expiry); - return status; + if (FAILED(status)) throw SSLException(Utility::formatError(status)); + + if (outBuffer[0].pvBuffer && outBuffer[0].cbBuffer) + { + int sent = sendRawBytes(outBuffer[0].pvBuffer, outBuffer[0].cbBuffer); + return sent; + } + + return 0; } -LONG SecureSocketImpl::serverDisconnect(PCredHandle phCreds, CtxtHandle* phContext) +int SecureSocketImpl::serverShutdown(PCredHandle phCreds, CtxtHandle* phContext) { - if (phContext->dwLower == 0 && phContext->dwUpper == 0) + if (!SecIsValidHandle(phContext)) { // handshake has never been done poco_assert_dbg (_needHandshake); @@ -1587,24 +1678,59 @@ LONG SecureSocketImpl::serverDisconnect(PCredHandle phCreds, CtxtHandle* phConte } -void SecureSocketImpl::stateIllegal() +bool SecureSocketImpl::stateMachine() { - throw Poco::IllegalStateException("SSL state machine"); + return StateMachine::instance().execute(this); } -void SecureSocketImpl::stateConnected() +#ifdef ENABLE_PRINT_STATE + + +#define CASE_PRINT_ENUM(enum) case enum: std::cout << #enum; break; + + +void SecureSocketImpl::printState(const std::string& message) { - _peerHostName = _pSocket->peerAddress().host().toString(); - initClientContext(); - performInitialClientHandshake(); + Poco::Thread* pThread = Poco::Thread::current(); + int threadId = 0; + if (pThread) threadId = pThread->id(); + + static Poco::FastMutex mutex; + Poco::FastMutex::ScopedLock lock(mutex); + + std::cout << message << " thread=" << threadId << " state="; + switch (_state) + { + CASE_PRINT_ENUM(ST_INITIAL) + CASE_PRINT_ENUM(ST_CONNECTING) + CASE_PRINT_ENUM(ST_CLIENT_HSK_START) + CASE_PRINT_ENUM(ST_CLIENT_HSK_SEND_TOKEN) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_INIT) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_RECV) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_PROCESS) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_SEND) + CASE_PRINT_ENUM(ST_CLIENT_HSK_LOOP_DONE) + CASE_PRINT_ENUM(ST_CLIENT_HSK_SEND_FINAL) + CASE_PRINT_ENUM(ST_CLIENT_HSK_SEND_ERROR) + CASE_PRINT_ENUM(ST_CLIENT_VERIFY) + CASE_PRINT_ENUM(ST_ACCEPTING) + CASE_PRINT_ENUM(ST_SERVER_HSK_START) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_INIT) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_RECV) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_PROCESS) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_SEND) + CASE_PRINT_ENUM(ST_SERVER_HSK_LOOP_DONE) + CASE_PRINT_ENUM(ST_SERVER_VERIFY) + CASE_PRINT_ENUM(ST_DONE) + CASE_PRINT_ENUM(ST_ERROR) + } + + std::cout << " securityStatus=" << Utility::formatError(_securityStatus) << " hCreds="<< SecIsValidHandle(&_hCreds) << " hContext=" << SecIsValidHandle(&_hContext) << " recvBufOff=" << _recvBufferOffset << std::endl; } -void SecureSocketImpl::stateMachine() -{ - StateMachine::instance().execute(this); -} +#endif StateMachine& StateMachine::instance() @@ -1614,108 +1740,46 @@ StateMachine& StateMachine::instance() } -bool StateMachine::readable(SOCKET sockfd) -{ - fd_set fdRead; - FD_ZERO(&fdRead); - FD_SET(sockfd, &fdRead); - select(&fdRead, 0, sockfd); - return (FD_ISSET(sockfd, &fdRead) != 0); -} - - -bool StateMachine::writable(SOCKET sockfd) -{ - fd_set fdWrite; - FD_ZERO(&fdWrite); - FD_SET(sockfd, &fdWrite); - select(0, &fdWrite, sockfd); - return (FD_ISSET(sockfd, &fdWrite) != 0); -} - - -bool StateMachine::readOrWritable(SOCKET sockfd) -{ - fd_set fdRead, fdWrite; - FD_ZERO(&fdRead); - FD_SET(sockfd, &fdRead); - fdWrite = fdRead; - select(&fdRead, &fdWrite, sockfd); - return (FD_ISSET(sockfd, &fdRead) != 0 || FD_ISSET(sockfd, &fdWrite) != 0); -} - - -bool StateMachine::none(SOCKET sockfd) -{ - return true; -} - - -void StateMachine::select(fd_set* fdRead, fd_set* fdWrite, SOCKET sockfd) -{ - Poco::Timespan remainingTime(((Poco::Timestamp::TimeDiff)SecureSocketImpl::TIMEOUT_MILLISECS)*1000); - int rc(0); - do - { - struct timeval tv; - tv.tv_sec = (long) remainingTime.totalSeconds(); - tv.tv_usec = (long) remainingTime.useconds(); - Poco::Timestamp start; - rc = ::select(int(sockfd) + 1, fdRead, fdWrite, 0, &tv); - if (rc < 0 && SecureSocketImpl::lastError() == POCO_EINTR) - { - Poco::Timestamp end; - Poco::Timespan waited = end - start; - if (waited < remainingTime) - remainingTime -= waited; - else - remainingTime = 0; - } - } - while (rc < 0 && SecureSocketImpl::lastError() == POCO_EINTR); -} - - StateMachine::StateMachine(): - _states() -{ - //ST_INITIAL: 0, -> this one is illegal, you must call connectNB before - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal)); - //ST_CONNECTING: connectNB was called, check if the socket is already available for writing - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::stateConnected)); - //ST_ESTABLISHTUNNELRECEIVED: we got the response, now start the handshake - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performInitialClientHandshake)); - //ST_CLIENTHANDSHAKECONDREAD: condread - _states.push_back(std::make_pair(&StateMachine::readable, &SecureSocketImpl::performClientHandshakeLoopCondReceive)); - //ST_CLIENTHANDSHAKEINCOMPLETE, - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopIncompleteMessage)); - //ST_CLIENTHANDSHAKEOK, - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopOK)); - //ST_CLIENTHANDSHAKEEXTERROR, - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopExtError)); - //ST_CLIENTHANDSHAKECONTINUE, - _states.push_back(std::make_pair(&StateMachine::writable, &SecureSocketImpl::performClientHandshakeLoopContinueNeeded)); - //ST_VERIFY, - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::clientConnectVerify)); - //ST_DONE, - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::stateIllegal)); - //ST_ERROR - _states.push_back(std::make_pair(&StateMachine::none, &SecureSocketImpl::performClientHandshakeLoopError)); + _states(SecureSocketImpl::ST_MAX) +{ + _states[SecureSocketImpl::ST_INITIAL] = &SecureSocketImpl::stateIllegal; + + _states[SecureSocketImpl::ST_CONNECTING] = &SecureSocketImpl::stateClientConnected; + _states[SecureSocketImpl::ST_CLIENT_HSK_START] = &SecureSocketImpl::stateClientHandshakeStart; + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_TOKEN] = &SecureSocketImpl::stateClientHandshakeSendToken; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_INIT] = &SecureSocketImpl::stateClientHandshakeLoopInit; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_RECV] = &SecureSocketImpl::stateClientHandshakeLoopRecv; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_PROCESS] = &SecureSocketImpl::stateClientHandshakeLoopProcess; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_SEND] = &SecureSocketImpl::stateClientHandshakeLoopSend; + _states[SecureSocketImpl::ST_CLIENT_HSK_LOOP_DONE] = &SecureSocketImpl::stateClientHandshakeLoopDone; + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_FINAL] = &SecureSocketImpl::stateClientHandshakeSendFinal; + _states[SecureSocketImpl::ST_CLIENT_HSK_SEND_ERROR] = &SecureSocketImpl::stateClientHandshakeSendError; + _states[SecureSocketImpl::ST_CLIENT_VERIFY] = &SecureSocketImpl::stateClientVerify; + + _states[SecureSocketImpl::ST_ACCEPTING] = &SecureSocketImpl::stateServerAccepted; + _states[SecureSocketImpl::ST_SERVER_HSK_START] = &SecureSocketImpl::stateServerHandshakeStart; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_INIT] = &SecureSocketImpl::stateServerHandshakeLoopInit; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_RECV] = &SecureSocketImpl::stateServerHandshakeLoopRecv; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_PROCESS] = &SecureSocketImpl::stateServerHandshakeLoopProcess; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_SEND] = &SecureSocketImpl::stateServerHandshakeLoopSend; + _states[SecureSocketImpl::ST_SERVER_HSK_LOOP_DONE] = &SecureSocketImpl::stateServerHandshakeLoopDone; + _states[SecureSocketImpl::ST_SERVER_VERIFY] = &SecureSocketImpl::stateServerHandshakeVerify; + + _states[SecureSocketImpl::ST_DONE] = &SecureSocketImpl::stateIllegal; + _states[SecureSocketImpl::ST_ERROR] = &SecureSocketImpl::stateError; } -void StateMachine::execute(SecureSocketImpl* pSock) +bool StateMachine::execute(SecureSocketImpl* pSock) { try { poco_assert_dbg (pSock); - ConditionState& state = _states[pSock->getState()]; - ConditionMethod& meth = state.first; - if ((this->*state.first)(pSock->sockfd())) - { - (pSock->*(state.second))(); - (pSock->getState() == SecureSocketImpl::ST_DONE); - } + auto curState = pSock->getState(); + StateImpl& stateImpl = _states[curState]; + (pSock->*(stateImpl))(); + return pSock->getState() != curState; } catch (...) { diff --git a/NetSSL_Win/src/SecureStreamSocketImpl.cpp b/NetSSL_Win/src/SecureStreamSocketImpl.cpp index e90a03059..876162a61 100644 --- a/NetSSL_Win/src/SecureStreamSocketImpl.cpp +++ b/NetSSL_Win/src/SecureStreamSocketImpl.cpp @@ -149,14 +149,15 @@ void SecureStreamSocketImpl::shutdownReceive() } -void SecureStreamSocketImpl::shutdownSend() +int SecureStreamSocketImpl::shutdownSend() { + return _impl.shutdown(); } -void SecureStreamSocketImpl::shutdown() +int SecureStreamSocketImpl::shutdown() { - _impl.shutdown(); + return _impl.shutdown(); } @@ -208,8 +209,7 @@ void SecureStreamSocketImpl::verifyPeerCertificate(const std::string& hostName) int SecureStreamSocketImpl::completeHandshake() { - _impl.completeHandshake(); - return 0; + return _impl.completeHandshake(); } diff --git a/NetSSL_Win/src/Utility.cpp b/NetSSL_Win/src/Utility.cpp index 5970bf8a2..d8ac4eb9c 100644 --- a/NetSSL_Win/src/Utility.cpp +++ b/NetSSL_Win/src/Utility.cpp @@ -24,9 +24,6 @@ namespace Poco { namespace Net { -Poco::FastMutex Utility::_mutex; - - Context::VerificationMode Utility::convertVerificationMode(const std::string& vMode) { std::string mode = Poco::toLower(vMode); @@ -54,6 +51,7 @@ inline void add(std::map& messageMap, long key, const s std::map Utility::initSSPIErr() { std::map messageMap; + add(messageMap, SEC_E_OK, "OK"); add(messageMap, NTE_BAD_UID, "Bad UID"); add(messageMap, NTE_BAD_HASH, "Bad Hash"); add(messageMap, NTE_BAD_KEY, "Bad Key"); @@ -185,18 +183,15 @@ std::map Utility::initSSPIErr() } -const std::string& Utility::formatError(long errCode) +std::string Utility::formatError(long errCode) { - Poco::FastMutex::ScopedLock lock(_mutex); - - static const std::string def("Internal SSPI error"); static const std::map errs(initSSPIErr()); const std::map::const_iterator it = errs.find(errCode); if (it != errs.end()) return it->second; else - return def; + return "0x" + Poco::NumberFormatter::formatHex(errCode, 8); }