fix(NetSSL_Win): shutdown behavior

This commit is contained in:
Günter Obiltschnig 2024-11-23 15:38:13 +01:00
parent 25ce194273
commit e911884753
4 changed files with 48 additions and 24 deletions

View File

@ -236,6 +236,12 @@ protected:
ST_ERROR ST_ERROR
}; };
enum TLSShutdown
{
TLS_SHUTDOWN_SENT = 1,
TLS_SHUTDOWN_RECEIVED = 2
};
int sendRawBytes(const void* buffer, int length, int flags = 0); int sendRawBytes(const void* buffer, int length, int flags = 0);
int receiveRawBytes(void* buffer, int length, int flags = 0); int receiveRawBytes(void* buffer, int length, int flags = 0);
void clientConnectVerify(); void clientConnectVerify();
@ -245,8 +251,8 @@ protected:
void clientVerifyCertificate(const std::string& hostName); void clientVerifyCertificate(const std::string& hostName);
void verifyCertificateChainClient(PCCERT_CONTEXT pServerCert); void verifyCertificateChainClient(PCCERT_CONTEXT pServerCert);
void serverVerifyCertificate(); void serverVerifyCertificate();
LONG serverDisconnect(PCredHandle phCreds, CtxtHandle* phContext); int serverShutdown(PCredHandle phCreds, CtxtHandle* phContext);
LONG clientDisconnect(PCredHandle phCreds, CtxtHandle* phContext); int clientShutdown(PCredHandle phCreds, CtxtHandle* phContext);
bool loadSecurityLibrary(); bool loadSecurityLibrary();
void initClientContext(); void initClientContext();
void initServerContext(); void initServerContext();
@ -286,6 +292,7 @@ private:
Poco::AutoPtr<SocketImpl> _pSocket; Poco::AutoPtr<SocketImpl> _pSocket;
Context::Ptr _pContext; Context::Ptr _pContext;
Mode _mode; Mode _mode;
int _shutdownFlags;
std::string _peerHostName; std::string _peerHostName;
bool _useMachineStore; bool _useMachineStore;
bool _clientAuthRequired; bool _clientAuthRequired;

View File

@ -138,7 +138,7 @@ std::istream* HTTPSStreamFactory::open(const URI& uri)
{ {
return new HTTPResponseStream(rs, pSession); return new HTTPResponseStream(rs, pSession);
} }
else if (res.getStatus() == HTTPResponse::HTTP_USEPROXY && !retry) else if (res.getStatus() == HTTPResponse::HTTP_USE_PROXY && !retry)
{ {
// The requested resource MUST be accessed through the proxy // The requested resource MUST be accessed through the proxy
// given by the Location field. The Location field gives the // given by the Location field. The Location field gives the

View File

@ -62,6 +62,7 @@ SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr<SocketImpl> pSocketImpl, Contex
_pSocket(pSocketImpl), _pSocket(pSocketImpl),
_pContext(pContext), _pContext(pContext),
_mode(pContext->isForServerUse() ? MODE_SERVER : MODE_CLIENT), _mode(pContext->isForServerUse() ? MODE_SERVER : MODE_CLIENT),
_shutdownFlags(0),
_clientAuthRequired(pContext->verificationMode() >= Context::VERIFY_STRICT), _clientAuthRequired(pContext->verificationMode() >= Context::VERIFY_STRICT),
_securityFunctions(SSLManager::instance().securityFunctions()), _securityFunctions(SSLManager::instance().securityFunctions()),
_pOwnCertificate(0), _pOwnCertificate(0),
@ -267,13 +268,24 @@ void SecureSocketImpl::listen(int backlog)
int SecureSocketImpl::shutdown() int SecureSocketImpl::shutdown()
{ {
if (_mode == MODE_SERVER) int rc = 0;
serverDisconnect(&_hCreds, &_hContext); if ((_shutdownFlags & TLS_SHUTDOWN_SENT) == 0)
else {
clientDisconnect(&_hCreds, &_hContext); if (_mode == MODE_SERVER)
{
_pSocket->shutdown(); rc = serverShutdown(&_hCreds, &_hContext);
return 0; }
else
{
rc = clientShutdown(&_hCreds, &_hContext);
}
if (rc >= 0)
{
_pSocket->shutdownSend();
_shutdownFlags |= TLS_SHUTDOWN_SENT;
}
}
return (_shutdownFlags & TLS_SHUTDOWN_RECEIVED) ? 1 : 0;
} }
@ -281,10 +293,7 @@ void SecureSocketImpl::close()
{ {
try try
{ {
if (_mode == MODE_SERVER) shutdown();
serverDisconnect(&_hCreds, &_hContext);
else
clientDisconnect(&_hCreds, &_hContext);
} }
catch (Poco::Exception&) catch (Poco::Exception&)
{ {
@ -577,13 +586,12 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags)
if (securityStatus == SEC_I_CONTEXT_EXPIRED) if (securityStatus == SEC_I_CONTEXT_EXPIRED)
{ {
SetLastError(securityStatus); _shutdownFlags |= TLS_SHUTDOWN_RECEIVED;
break; break;
} }
if (securityStatus != SEC_E_OK && securityStatus != SEC_I_RENEGOTIATE && securityStatus != SEC_I_CONTEXT_EXPIRED) if (securityStatus != SEC_E_OK && securityStatus != SEC_I_RENEGOTIATE)
{ {
SetLastError(securityStatus);
break; break;
} }
@ -626,6 +634,7 @@ SECURITY_STATUS SecureSocketImpl::decodeMessage(BYTE* pBuffer, DWORD bufSize, Au
pExtraBuffer = 0; pExtraBuffer = 0;
SECURITY_STATUS securityStatus = _securityFunctions.DecryptMessage(&_hContext, &msg, 0, 0); SECURITY_STATUS securityStatus = _securityFunctions.DecryptMessage(&_hContext, &msg, 0, 0);
// TODO: when decrypting the close_notify alert, returns SEC_E_DECRYPT_FAILURE
if (securityStatus == SEC_E_OK || securityStatus == SEC_I_RENEGOTIATE) if (securityStatus == SEC_E_OK || securityStatus == SEC_I_RENEGOTIATE)
{ {
@ -1493,11 +1502,11 @@ void SecureSocketImpl::serverVerifyCertificate()
} }
LONG SecureSocketImpl::clientDisconnect(PCredHandle phCreds, CtxtHandle* phContext) int SecureSocketImpl::clientShutdown(PCredHandle phCreds, CtxtHandle* phContext)
{ {
if (phContext->dwLower == 0 && phContext->dwUpper == 0) if (phContext->dwLower == 0 && phContext->dwUpper == 0)
{ {
return SEC_E_OK; return 0;
} }
AutoSecBufferDesc<1> tokBuffer(&_securityFunctions, false); AutoSecBufferDesc<1> tokBuffer(&_securityFunctions, false);
@ -1506,7 +1515,7 @@ LONG SecureSocketImpl::clientDisconnect(PCredHandle phCreds, CtxtHandle* phConte
tokBuffer.setSecBufferToken(0, &tokenType, sizeof(tokenType)); tokBuffer.setSecBufferToken(0, &tokenType, sizeof(tokenType));
DWORD status = _securityFunctions.ApplyControlToken(phContext, &tokBuffer); DWORD status = _securityFunctions.ApplyControlToken(phContext, &tokBuffer);
if (FAILED(status)) return status; if (FAILED(status)) throw SSLException(Utility::formatError(status));
DWORD sspiFlags = ISC_REQ_SEQUENCE_DETECT DWORD sspiFlags = ISC_REQ_SEQUENCE_DETECT
| ISC_REQ_REPLAY_DETECT | ISC_REQ_REPLAY_DETECT
@ -1534,11 +1543,19 @@ LONG SecureSocketImpl::clientDisconnect(PCredHandle phCreds, CtxtHandle* phConte
&sspiOutFlags, &sspiOutFlags,
&expiry); &expiry);
return status; if (FAILED(status)) throw SSLException(Utility::formatError(status));
if (outBuffer[0].pvBuffer && outBuffer[0].cbBuffer)
{
int sent = sendRawBytes(outBuffer[0].pvBuffer, outBuffer[0].cbBuffer);
return sent;
}
return 0;
} }
LONG SecureSocketImpl::serverDisconnect(PCredHandle phCreds, CtxtHandle* phContext) int SecureSocketImpl::serverShutdown(PCredHandle phCreds, CtxtHandle* phContext)
{ {
if (phContext->dwLower == 0 && phContext->dwUpper == 0) if (phContext->dwLower == 0 && phContext->dwUpper == 0)
{ {

View File

@ -149,13 +149,13 @@ void SecureStreamSocketImpl::shutdownReceive()
} }
void SecureStreamSocketImpl::shutdownSend() int SecureStreamSocketImpl::shutdownSend()
{ {
return _impl.shutdown(); return _impl.shutdown();
} }
void SecureStreamSocketImpl::shutdown() int SecureStreamSocketImpl::shutdown()
{ {
return _impl.shutdown(); return _impl.shutdown();
} }