diff --git a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h index 9fc60b4b3..b1502cf01 100644 --- a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h +++ b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h @@ -236,6 +236,12 @@ protected: ST_ERROR }; + enum TLSShutdown + { + TLS_SHUTDOWN_SENT = 1, + TLS_SHUTDOWN_RECEIVED = 2 + }; + int sendRawBytes(const void* buffer, int length, int flags = 0); int receiveRawBytes(void* buffer, int length, int flags = 0); void clientConnectVerify(); @@ -245,8 +251,8 @@ protected: void clientVerifyCertificate(const std::string& hostName); void verifyCertificateChainClient(PCCERT_CONTEXT pServerCert); void serverVerifyCertificate(); - LONG serverDisconnect(PCredHandle phCreds, CtxtHandle* phContext); - LONG clientDisconnect(PCredHandle phCreds, CtxtHandle* phContext); + int serverShutdown(PCredHandle phCreds, CtxtHandle* phContext); + int clientShutdown(PCredHandle phCreds, CtxtHandle* phContext); bool loadSecurityLibrary(); void initClientContext(); void initServerContext(); @@ -286,6 +292,7 @@ private: Poco::AutoPtr _pSocket; Context::Ptr _pContext; Mode _mode; + int _shutdownFlags; std::string _peerHostName; bool _useMachineStore; bool _clientAuthRequired; diff --git a/NetSSL_Win/src/HTTPSStreamFactory.cpp b/NetSSL_Win/src/HTTPSStreamFactory.cpp index ec1d43f9e..b2cb86821 100644 --- a/NetSSL_Win/src/HTTPSStreamFactory.cpp +++ b/NetSSL_Win/src/HTTPSStreamFactory.cpp @@ -138,7 +138,7 @@ std::istream* HTTPSStreamFactory::open(const URI& uri) { return new HTTPResponseStream(rs, pSession); } - else if (res.getStatus() == HTTPResponse::HTTP_USEPROXY && !retry) + else if (res.getStatus() == HTTPResponse::HTTP_USE_PROXY && !retry) { // The requested resource MUST be accessed through the proxy // given by the Location field. The Location field gives the diff --git a/NetSSL_Win/src/SecureSocketImpl.cpp b/NetSSL_Win/src/SecureSocketImpl.cpp index c9ea76b83..4e9a3b98a 100644 --- a/NetSSL_Win/src/SecureSocketImpl.cpp +++ b/NetSSL_Win/src/SecureSocketImpl.cpp @@ -62,6 +62,7 @@ SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr pSocketImpl, Contex _pSocket(pSocketImpl), _pContext(pContext), _mode(pContext->isForServerUse() ? MODE_SERVER : MODE_CLIENT), + _shutdownFlags(0), _clientAuthRequired(pContext->verificationMode() >= Context::VERIFY_STRICT), _securityFunctions(SSLManager::instance().securityFunctions()), _pOwnCertificate(0), @@ -267,13 +268,24 @@ void SecureSocketImpl::listen(int backlog) int SecureSocketImpl::shutdown() { - if (_mode == MODE_SERVER) - serverDisconnect(&_hCreds, &_hContext); - else - clientDisconnect(&_hCreds, &_hContext); - - _pSocket->shutdown(); - return 0; + int rc = 0; + if ((_shutdownFlags & TLS_SHUTDOWN_SENT) == 0) + { + if (_mode == MODE_SERVER) + { + rc = serverShutdown(&_hCreds, &_hContext); + } + else + { + rc = clientShutdown(&_hCreds, &_hContext); + } + if (rc >= 0) + { + _pSocket->shutdownSend(); + _shutdownFlags |= TLS_SHUTDOWN_SENT; + } + } + return (_shutdownFlags & TLS_SHUTDOWN_RECEIVED) ? 1 : 0; } @@ -281,10 +293,7 @@ void SecureSocketImpl::close() { try { - if (_mode == MODE_SERVER) - serverDisconnect(&_hCreds, &_hContext); - else - clientDisconnect(&_hCreds, &_hContext); + shutdown(); } catch (Poco::Exception&) { @@ -577,13 +586,12 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) if (securityStatus == SEC_I_CONTEXT_EXPIRED) { - SetLastError(securityStatus); + _shutdownFlags |= TLS_SHUTDOWN_RECEIVED; break; } - if (securityStatus != SEC_E_OK && securityStatus != SEC_I_RENEGOTIATE && securityStatus != SEC_I_CONTEXT_EXPIRED) + if (securityStatus != SEC_E_OK && securityStatus != SEC_I_RENEGOTIATE) { - SetLastError(securityStatus); break; } @@ -626,6 +634,7 @@ SECURITY_STATUS SecureSocketImpl::decodeMessage(BYTE* pBuffer, DWORD bufSize, Au pExtraBuffer = 0; SECURITY_STATUS securityStatus = _securityFunctions.DecryptMessage(&_hContext, &msg, 0, 0); + // TODO: when decrypting the close_notify alert, returns SEC_E_DECRYPT_FAILURE if (securityStatus == SEC_E_OK || securityStatus == SEC_I_RENEGOTIATE) { @@ -1493,11 +1502,11 @@ void SecureSocketImpl::serverVerifyCertificate() } -LONG SecureSocketImpl::clientDisconnect(PCredHandle phCreds, CtxtHandle* phContext) +int SecureSocketImpl::clientShutdown(PCredHandle phCreds, CtxtHandle* phContext) { if (phContext->dwLower == 0 && phContext->dwUpper == 0) { - return SEC_E_OK; + return 0; } AutoSecBufferDesc<1> tokBuffer(&_securityFunctions, false); @@ -1506,7 +1515,7 @@ LONG SecureSocketImpl::clientDisconnect(PCredHandle phCreds, CtxtHandle* phConte tokBuffer.setSecBufferToken(0, &tokenType, sizeof(tokenType)); DWORD status = _securityFunctions.ApplyControlToken(phContext, &tokBuffer); - if (FAILED(status)) return status; + if (FAILED(status)) throw SSLException(Utility::formatError(status)); DWORD sspiFlags = ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT @@ -1534,11 +1543,19 @@ LONG SecureSocketImpl::clientDisconnect(PCredHandle phCreds, CtxtHandle* phConte &sspiOutFlags, &expiry); - return status; + if (FAILED(status)) throw SSLException(Utility::formatError(status)); + + if (outBuffer[0].pvBuffer && outBuffer[0].cbBuffer) + { + int sent = sendRawBytes(outBuffer[0].pvBuffer, outBuffer[0].cbBuffer); + return sent; + } + + return 0; } -LONG SecureSocketImpl::serverDisconnect(PCredHandle phCreds, CtxtHandle* phContext) +int SecureSocketImpl::serverShutdown(PCredHandle phCreds, CtxtHandle* phContext) { if (phContext->dwLower == 0 && phContext->dwUpper == 0) { diff --git a/NetSSL_Win/src/SecureStreamSocketImpl.cpp b/NetSSL_Win/src/SecureStreamSocketImpl.cpp index ca7048c9b..add82ca23 100644 --- a/NetSSL_Win/src/SecureStreamSocketImpl.cpp +++ b/NetSSL_Win/src/SecureStreamSocketImpl.cpp @@ -149,13 +149,13 @@ void SecureStreamSocketImpl::shutdownReceive() } -void SecureStreamSocketImpl::shutdownSend() +int SecureStreamSocketImpl::shutdownSend() { return _impl.shutdown(); } -void SecureStreamSocketImpl::shutdown() +int SecureStreamSocketImpl::shutdown() { return _impl.shutdown(); }