chore(NetSSL_Win): rewrite handshake logic to support non-blocking sockets

This commit is contained in:
Günter Obiltschnig 2024-11-26 19:20:03 +01:00
parent 0eb2e1b3b3
commit 663232b8b8
5 changed files with 565 additions and 419 deletions

View File

@ -18,6 +18,10 @@
#define NetSSL_SecureSocketImpl_INCLUDED #define NetSSL_SecureSocketImpl_INCLUDED
// Temporary debugging aid, to be removed
// #define ENABLE_PRINT_STATE
#include "Poco/Net/SocketImpl.h" #include "Poco/Net/SocketImpl.h"
#include "Poco/Net/NetSSL.h" #include "Poco/Net/NetSSL.h"
#include "Poco/Net/Context.h" #include "Poco/Net/Context.h"
@ -35,6 +39,14 @@
#include <sspi.h> #include <sspi.h>
#ifdef ENABLE_PRINT_STATE
#define PRINT_STATE(m) printState(m)
#else
#define PRINT_STATE(m)
#endif
namespace Poco { namespace Poco {
namespace Net { namespace Net {
@ -224,20 +236,25 @@ protected:
enum State enum State
{ {
ST_INITIAL = 0, ST_INITIAL = 0,
// Client
ST_CONNECTING, ST_CONNECTING,
ST_CLIENT_HSK_START, ST_CLIENT_HSK_START,
ST_CLIENT_HSK_SEND_TOKEN, ST_CLIENT_HSK_SEND_TOKEN,
ST_CLIENT_HSK_LOOP_INIT, ST_CLIENT_HSK_LOOP_INIT,
ST_CLIENT_HSK_LOOP_RECV, ST_CLIENT_HSK_LOOP_RECV,
ST_CLIENT_HSK_LOOP_PROCESS, ST_CLIENT_HSK_LOOP_PROCESS,
ST_CLIENT_HSK_LOOP_CONTINUE,
ST_CLIENT_HSK_LOOP_SEND, ST_CLIENT_HSK_LOOP_SEND,
ST_CLIENT_HSK_LOOP_INCOMPLETE,
ST_CLIENT_HSK_LOOP_DONE, ST_CLIENT_HSK_LOOP_DONE,
ST_CLIENT_HSK_SEND_FINAL, ST_CLIENT_HSK_SEND_FINAL,
ST_CLIENT_HSK_SEND_ERROR, ST_CLIENT_HSK_SEND_ERROR,
ST_CLIENT_VERIFY, ST_CLIENT_VERIFY,
ST_ACCEPTING,
ST_SERVER_HSK_START,
ST_SERVER_HSK_LOOP_INIT,
ST_SERVER_HSK_LOOP_RECV,
ST_SERVER_HSK_LOOP_PROCESS,
ST_SERVER_HSK_LOOP_SEND,
ST_SERVER_HSK_LOOP_DONE,
ST_SERVER_VERIFY,
ST_DONE, ST_DONE,
ST_ERROR, ST_ERROR,
ST_MAX ST_MAX
@ -251,51 +268,64 @@ protected:
int sendRawBytes(const void* buffer, int length, int flags = 0); int sendRawBytes(const void* buffer, int length, int flags = 0);
int receiveRawBytes(void* buffer, int length, int flags = 0); int receiveRawBytes(void* buffer, int length, int flags = 0);
void clientConnectVerify();
void performServerHandshake();
bool serverHandshakeLoop(PCtxtHandle phContext, PCredHandle phCred, bool requireClientAuth, bool doInitialRead, bool newContext);
void clientVerifyCertificate(const std::string& hostName); void clientVerifyCertificate(const std::string& hostName);
void verifyCertificateChainClient(PCCERT_CONTEXT pServerCert); void verifyCertificateChainClient(PCCERT_CONTEXT pServerCert);
void serverVerifyCertificate(); void serverVerifyCertificate();
int serverShutdown(PCredHandle phCreds, CtxtHandle* phContext); int serverShutdown(PCredHandle phCreds, CtxtHandle* phContext);
int clientShutdown(PCredHandle phCreds, CtxtHandle* phContext); int clientShutdown(PCredHandle phCreds, CtxtHandle* phContext);
void initClientContext();
void initServerContext();
PCCERT_CONTEXT loadCertificate(bool mustFindCertificate); PCCERT_CONTEXT loadCertificate(bool mustFindCertificate);
void initCommon(); void initCommon();
void cleanup(); void cleanup();
void performClientHandshakeStart(); void stateIllegal();
void performInitialClientHandshake(); void stateError();
void performClientHandshakeSendToken();
void performClientHandshakeLoopIncompleteMessage();
void performClientHandshakeLoopInit();
void performClientHandshakeLoopRecv();
void performClientHandshakeLoopProcess();
void performClientHandshakeLoopSend();
void performClientHandshakeLoopDone();
void performClientHandshakeSendFinal();
void performClientHandshakeExtraBuffer();
void performClientHandshakeLoopContinue();
void performClientHandshakeError();
void performClientHandshakeSendError();
void sendOutSecBufferAndAdvanceState(State state);
SECURITY_STATUS performClientHandshake(); void stateClientConnected();
void stateClientHandshakeStart();
void stateClientHandshakeSendToken();
void stateClientHandshakeLoopInit();
void stateClientHandshakeLoopRecv();
void stateClientHandshakeLoopProcess();
void stateClientHandshakeLoopSend();
void stateClientHandshakeLoopDone();
void stateClientHandshakeSendFinal();
void stateClientHandshakeSendError();
void stateClientVerify();
void stateServerAccepted();
void stateServerHandshakeStart();
void stateServerHandshakeLoopInit();
void stateServerHandshakeLoopRecv();
void stateServerHandshakeLoopProcess();
void stateServerHandshakeLoopSend();
void stateServerHandshakeLoopDone();
void stateServerHandshakeVerify();
void sendOutSecBufferAndAdvanceState(State state);
void drainExtraBuffer();
static int getRecordLength(const BYTE* pBuffer, int length);
static bool bufferHasCompleteRecords(const BYTE* pBuffer, int length);
void initClientCredentials();
void initServerCredentials();
SECURITY_STATUS doHandshake();
int completeHandshake();
SECURITY_STATUS decodeMessage(BYTE* pBuffer, DWORD bufSize, AutoSecBufferDesc<4>& msg, SecBuffer*& pData, SecBuffer*& pExtra); SECURITY_STATUS decodeMessage(BYTE* pBuffer, DWORD bufSize, AutoSecBufferDesc<4>& msg, SecBuffer*& pData, SecBuffer*& pExtra);
SECURITY_STATUS decodeBufferFull(BYTE* pBuffer, DWORD bufSize, char* pOutBuffer, int outLength, int& bytesDecoded); SECURITY_STATUS decodeBufferFull(BYTE* pBuffer, DWORD bufSize, char* pOutBuffer, int outLength, int& bytesDecoded);
void stateIllegal();
void stateConnected();
void acceptSSL(); void acceptSSL();
void connectSSL(bool completeHandshake); void connectSSL(bool completeHandshake);
void completeHandshake();
static int lastError(); static int lastError();
bool stateMachine(); bool stateMachine();
State getState() const; State getState() const;
void setState(State st); void setState(State st);
static bool isLocalHost(const std::string& hostName); static bool isLocalHost(const std::string& hostName);
#ifdef ENABLE_PRINT_STATE
void printState(const std::string& msg);
#endif
private: private:
SecureSocketImpl(const SecureSocketImpl&); SecureSocketImpl(const SecureSocketImpl&);
SecureSocketImpl& operator = (const SecureSocketImpl&); SecureSocketImpl& operator = (const SecureSocketImpl&);
@ -331,9 +361,9 @@ private:
SecBuffer _extraSecBuffer; SecBuffer _extraSecBuffer;
SECURITY_STATUS _securityStatus; SECURITY_STATUS _securityStatus;
State _state; State _state;
DWORD _outFlags;
bool _needData; bool _needData;
bool _needHandshake; bool _needHandshake;
bool _initServerContext = false;
friend class SecureStreamSocketImpl; friend class SecureStreamSocketImpl;
friend class StateMachine; friend class StateMachine;
@ -376,6 +406,7 @@ inline SecureSocketImpl::State SecureSocketImpl::getState() const
inline void SecureSocketImpl::setState(SecureSocketImpl::State st) inline void SecureSocketImpl::setState(SecureSocketImpl::State st)
{ {
_state = st; _state = st;
PRINT_STATE("setState: ");
} }

View File

@ -35,7 +35,7 @@ public:
/// Non-case sensitive conversion of a string to a VerificationMode enum. /// Non-case sensitive conversion of a string to a VerificationMode enum.
/// If verMode is illegal an OptionException is thrown. /// If verMode is illegal an OptionException is thrown.
static const std::string& formatError(long errCode); static std::string formatError(long errCode);
/// Converts an winerror.h code into human readable form. /// Converts an winerror.h code into human readable form.
private: private:

File diff suppressed because it is too large Load Diff

View File

@ -209,8 +209,7 @@ void SecureStreamSocketImpl::verifyPeerCertificate(const std::string& hostName)
int SecureStreamSocketImpl::completeHandshake() int SecureStreamSocketImpl::completeHandshake()
{ {
_impl.completeHandshake(); return _impl.completeHandshake();
return 0;
} }

View File

@ -24,9 +24,6 @@ namespace Poco {
namespace Net { namespace Net {
Poco::FastMutex Utility::_mutex;
Context::VerificationMode Utility::convertVerificationMode(const std::string& vMode) Context::VerificationMode Utility::convertVerificationMode(const std::string& vMode)
{ {
std::string mode = Poco::toLower(vMode); std::string mode = Poco::toLower(vMode);
@ -54,6 +51,7 @@ inline void add(std::map<long, const std::string>& messageMap, long key, const s
std::map<long, const std::string> Utility::initSSPIErr() std::map<long, const std::string> Utility::initSSPIErr()
{ {
std::map<long, const std::string> messageMap; std::map<long, const std::string> messageMap;
add(messageMap, SEC_E_OK, "OK");
add(messageMap, NTE_BAD_UID, "Bad UID"); add(messageMap, NTE_BAD_UID, "Bad UID");
add(messageMap, NTE_BAD_HASH, "Bad Hash"); add(messageMap, NTE_BAD_HASH, "Bad Hash");
add(messageMap, NTE_BAD_KEY, "Bad Key"); add(messageMap, NTE_BAD_KEY, "Bad Key");
@ -185,18 +183,15 @@ std::map<long, const std::string> Utility::initSSPIErr()
} }
const std::string& Utility::formatError(long errCode) std::string Utility::formatError(long errCode)
{ {
Poco::FastMutex::ScopedLock lock(_mutex);
static const std::string def("Internal SSPI error");
static const std::map<long, const std::string> errs(initSSPIErr()); static const std::map<long, const std::string> errs(initSSPIErr());
const std::map<long, const std::string>::const_iterator it = errs.find(errCode); const std::map<long, const std::string>::const_iterator it = errs.find(errCode);
if (it != errs.end()) if (it != errs.end())
return it->second; return it->second;
else else
return def; return "0x" + Poco::NumberFormatter::formatHex(errCode, 8);
} }