mirror of
https://github.com/pocoproject/poco.git
synced 2024-12-13 10:32:57 +01:00
chore(NetSSL_Win): rewrite handshake logic to support non-blocking sockets
This commit is contained in:
parent
0eb2e1b3b3
commit
663232b8b8
@ -18,6 +18,10 @@
|
|||||||
#define NetSSL_SecureSocketImpl_INCLUDED
|
#define NetSSL_SecureSocketImpl_INCLUDED
|
||||||
|
|
||||||
|
|
||||||
|
// Temporary debugging aid, to be removed
|
||||||
|
// #define ENABLE_PRINT_STATE
|
||||||
|
|
||||||
|
|
||||||
#include "Poco/Net/SocketImpl.h"
|
#include "Poco/Net/SocketImpl.h"
|
||||||
#include "Poco/Net/NetSSL.h"
|
#include "Poco/Net/NetSSL.h"
|
||||||
#include "Poco/Net/Context.h"
|
#include "Poco/Net/Context.h"
|
||||||
@ -35,6 +39,14 @@
|
|||||||
#include <sspi.h>
|
#include <sspi.h>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef ENABLE_PRINT_STATE
|
||||||
|
#define PRINT_STATE(m) printState(m)
|
||||||
|
#else
|
||||||
|
#define PRINT_STATE(m)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
namespace Poco {
|
namespace Poco {
|
||||||
namespace Net {
|
namespace Net {
|
||||||
|
|
||||||
@ -224,20 +236,25 @@ protected:
|
|||||||
enum State
|
enum State
|
||||||
{
|
{
|
||||||
ST_INITIAL = 0,
|
ST_INITIAL = 0,
|
||||||
// Client
|
|
||||||
ST_CONNECTING,
|
ST_CONNECTING,
|
||||||
ST_CLIENT_HSK_START,
|
ST_CLIENT_HSK_START,
|
||||||
ST_CLIENT_HSK_SEND_TOKEN,
|
ST_CLIENT_HSK_SEND_TOKEN,
|
||||||
ST_CLIENT_HSK_LOOP_INIT,
|
ST_CLIENT_HSK_LOOP_INIT,
|
||||||
ST_CLIENT_HSK_LOOP_RECV,
|
ST_CLIENT_HSK_LOOP_RECV,
|
||||||
ST_CLIENT_HSK_LOOP_PROCESS,
|
ST_CLIENT_HSK_LOOP_PROCESS,
|
||||||
ST_CLIENT_HSK_LOOP_CONTINUE,
|
|
||||||
ST_CLIENT_HSK_LOOP_SEND,
|
ST_CLIENT_HSK_LOOP_SEND,
|
||||||
ST_CLIENT_HSK_LOOP_INCOMPLETE,
|
|
||||||
ST_CLIENT_HSK_LOOP_DONE,
|
ST_CLIENT_HSK_LOOP_DONE,
|
||||||
ST_CLIENT_HSK_SEND_FINAL,
|
ST_CLIENT_HSK_SEND_FINAL,
|
||||||
ST_CLIENT_HSK_SEND_ERROR,
|
ST_CLIENT_HSK_SEND_ERROR,
|
||||||
ST_CLIENT_VERIFY,
|
ST_CLIENT_VERIFY,
|
||||||
|
ST_ACCEPTING,
|
||||||
|
ST_SERVER_HSK_START,
|
||||||
|
ST_SERVER_HSK_LOOP_INIT,
|
||||||
|
ST_SERVER_HSK_LOOP_RECV,
|
||||||
|
ST_SERVER_HSK_LOOP_PROCESS,
|
||||||
|
ST_SERVER_HSK_LOOP_SEND,
|
||||||
|
ST_SERVER_HSK_LOOP_DONE,
|
||||||
|
ST_SERVER_VERIFY,
|
||||||
ST_DONE,
|
ST_DONE,
|
||||||
ST_ERROR,
|
ST_ERROR,
|
||||||
ST_MAX
|
ST_MAX
|
||||||
@ -251,51 +268,64 @@ protected:
|
|||||||
|
|
||||||
int sendRawBytes(const void* buffer, int length, int flags = 0);
|
int sendRawBytes(const void* buffer, int length, int flags = 0);
|
||||||
int receiveRawBytes(void* buffer, int length, int flags = 0);
|
int receiveRawBytes(void* buffer, int length, int flags = 0);
|
||||||
void clientConnectVerify();
|
|
||||||
void performServerHandshake();
|
|
||||||
bool serverHandshakeLoop(PCtxtHandle phContext, PCredHandle phCred, bool requireClientAuth, bool doInitialRead, bool newContext);
|
|
||||||
void clientVerifyCertificate(const std::string& hostName);
|
void clientVerifyCertificate(const std::string& hostName);
|
||||||
void verifyCertificateChainClient(PCCERT_CONTEXT pServerCert);
|
void verifyCertificateChainClient(PCCERT_CONTEXT pServerCert);
|
||||||
void serverVerifyCertificate();
|
void serverVerifyCertificate();
|
||||||
int serverShutdown(PCredHandle phCreds, CtxtHandle* phContext);
|
int serverShutdown(PCredHandle phCreds, CtxtHandle* phContext);
|
||||||
int clientShutdown(PCredHandle phCreds, CtxtHandle* phContext);
|
int clientShutdown(PCredHandle phCreds, CtxtHandle* phContext);
|
||||||
void initClientContext();
|
|
||||||
void initServerContext();
|
|
||||||
PCCERT_CONTEXT loadCertificate(bool mustFindCertificate);
|
PCCERT_CONTEXT loadCertificate(bool mustFindCertificate);
|
||||||
void initCommon();
|
void initCommon();
|
||||||
void cleanup();
|
void cleanup();
|
||||||
|
|
||||||
void performClientHandshakeStart();
|
void stateIllegal();
|
||||||
void performInitialClientHandshake();
|
void stateError();
|
||||||
void performClientHandshakeSendToken();
|
|
||||||
void performClientHandshakeLoopIncompleteMessage();
|
|
||||||
void performClientHandshakeLoopInit();
|
|
||||||
void performClientHandshakeLoopRecv();
|
|
||||||
void performClientHandshakeLoopProcess();
|
|
||||||
void performClientHandshakeLoopSend();
|
|
||||||
void performClientHandshakeLoopDone();
|
|
||||||
void performClientHandshakeSendFinal();
|
|
||||||
void performClientHandshakeExtraBuffer();
|
|
||||||
void performClientHandshakeLoopContinue();
|
|
||||||
void performClientHandshakeError();
|
|
||||||
void performClientHandshakeSendError();
|
|
||||||
void sendOutSecBufferAndAdvanceState(State state);
|
|
||||||
|
|
||||||
SECURITY_STATUS performClientHandshake();
|
void stateClientConnected();
|
||||||
|
void stateClientHandshakeStart();
|
||||||
|
void stateClientHandshakeSendToken();
|
||||||
|
void stateClientHandshakeLoopInit();
|
||||||
|
void stateClientHandshakeLoopRecv();
|
||||||
|
void stateClientHandshakeLoopProcess();
|
||||||
|
void stateClientHandshakeLoopSend();
|
||||||
|
void stateClientHandshakeLoopDone();
|
||||||
|
void stateClientHandshakeSendFinal();
|
||||||
|
void stateClientHandshakeSendError();
|
||||||
|
void stateClientVerify();
|
||||||
|
|
||||||
|
void stateServerAccepted();
|
||||||
|
void stateServerHandshakeStart();
|
||||||
|
void stateServerHandshakeLoopInit();
|
||||||
|
void stateServerHandshakeLoopRecv();
|
||||||
|
void stateServerHandshakeLoopProcess();
|
||||||
|
void stateServerHandshakeLoopSend();
|
||||||
|
void stateServerHandshakeLoopDone();
|
||||||
|
void stateServerHandshakeVerify();
|
||||||
|
|
||||||
|
void sendOutSecBufferAndAdvanceState(State state);
|
||||||
|
void drainExtraBuffer();
|
||||||
|
static int getRecordLength(const BYTE* pBuffer, int length);
|
||||||
|
static bool bufferHasCompleteRecords(const BYTE* pBuffer, int length);
|
||||||
|
|
||||||
|
void initClientCredentials();
|
||||||
|
void initServerCredentials();
|
||||||
|
SECURITY_STATUS doHandshake();
|
||||||
|
int completeHandshake();
|
||||||
|
|
||||||
SECURITY_STATUS decodeMessage(BYTE* pBuffer, DWORD bufSize, AutoSecBufferDesc<4>& msg, SecBuffer*& pData, SecBuffer*& pExtra);
|
SECURITY_STATUS decodeMessage(BYTE* pBuffer, DWORD bufSize, AutoSecBufferDesc<4>& msg, SecBuffer*& pData, SecBuffer*& pExtra);
|
||||||
SECURITY_STATUS decodeBufferFull(BYTE* pBuffer, DWORD bufSize, char* pOutBuffer, int outLength, int& bytesDecoded);
|
SECURITY_STATUS decodeBufferFull(BYTE* pBuffer, DWORD bufSize, char* pOutBuffer, int outLength, int& bytesDecoded);
|
||||||
void stateIllegal();
|
|
||||||
void stateConnected();
|
|
||||||
void acceptSSL();
|
void acceptSSL();
|
||||||
void connectSSL(bool completeHandshake);
|
void connectSSL(bool completeHandshake);
|
||||||
void completeHandshake();
|
|
||||||
static int lastError();
|
static int lastError();
|
||||||
bool stateMachine();
|
bool stateMachine();
|
||||||
State getState() const;
|
State getState() const;
|
||||||
void setState(State st);
|
void setState(State st);
|
||||||
static bool isLocalHost(const std::string& hostName);
|
static bool isLocalHost(const std::string& hostName);
|
||||||
|
|
||||||
|
#ifdef ENABLE_PRINT_STATE
|
||||||
|
void printState(const std::string& msg);
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SecureSocketImpl(const SecureSocketImpl&);
|
SecureSocketImpl(const SecureSocketImpl&);
|
||||||
SecureSocketImpl& operator = (const SecureSocketImpl&);
|
SecureSocketImpl& operator = (const SecureSocketImpl&);
|
||||||
@ -331,9 +361,9 @@ private:
|
|||||||
SecBuffer _extraSecBuffer;
|
SecBuffer _extraSecBuffer;
|
||||||
SECURITY_STATUS _securityStatus;
|
SECURITY_STATUS _securityStatus;
|
||||||
State _state;
|
State _state;
|
||||||
DWORD _outFlags;
|
|
||||||
bool _needData;
|
bool _needData;
|
||||||
bool _needHandshake;
|
bool _needHandshake;
|
||||||
|
bool _initServerContext = false;
|
||||||
|
|
||||||
friend class SecureStreamSocketImpl;
|
friend class SecureStreamSocketImpl;
|
||||||
friend class StateMachine;
|
friend class StateMachine;
|
||||||
@ -376,6 +406,7 @@ inline SecureSocketImpl::State SecureSocketImpl::getState() const
|
|||||||
inline void SecureSocketImpl::setState(SecureSocketImpl::State st)
|
inline void SecureSocketImpl::setState(SecureSocketImpl::State st)
|
||||||
{
|
{
|
||||||
_state = st;
|
_state = st;
|
||||||
|
PRINT_STATE("setState: ");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ public:
|
|||||||
/// Non-case sensitive conversion of a string to a VerificationMode enum.
|
/// Non-case sensitive conversion of a string to a VerificationMode enum.
|
||||||
/// If verMode is illegal an OptionException is thrown.
|
/// If verMode is illegal an OptionException is thrown.
|
||||||
|
|
||||||
static const std::string& formatError(long errCode);
|
static std::string formatError(long errCode);
|
||||||
/// Converts an winerror.h code into human readable form.
|
/// Converts an winerror.h code into human readable form.
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -209,8 +209,7 @@ void SecureStreamSocketImpl::verifyPeerCertificate(const std::string& hostName)
|
|||||||
|
|
||||||
int SecureStreamSocketImpl::completeHandshake()
|
int SecureStreamSocketImpl::completeHandshake()
|
||||||
{
|
{
|
||||||
_impl.completeHandshake();
|
return _impl.completeHandshake();
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,9 +24,6 @@ namespace Poco {
|
|||||||
namespace Net {
|
namespace Net {
|
||||||
|
|
||||||
|
|
||||||
Poco::FastMutex Utility::_mutex;
|
|
||||||
|
|
||||||
|
|
||||||
Context::VerificationMode Utility::convertVerificationMode(const std::string& vMode)
|
Context::VerificationMode Utility::convertVerificationMode(const std::string& vMode)
|
||||||
{
|
{
|
||||||
std::string mode = Poco::toLower(vMode);
|
std::string mode = Poco::toLower(vMode);
|
||||||
@ -54,6 +51,7 @@ inline void add(std::map<long, const std::string>& messageMap, long key, const s
|
|||||||
std::map<long, const std::string> Utility::initSSPIErr()
|
std::map<long, const std::string> Utility::initSSPIErr()
|
||||||
{
|
{
|
||||||
std::map<long, const std::string> messageMap;
|
std::map<long, const std::string> messageMap;
|
||||||
|
add(messageMap, SEC_E_OK, "OK");
|
||||||
add(messageMap, NTE_BAD_UID, "Bad UID");
|
add(messageMap, NTE_BAD_UID, "Bad UID");
|
||||||
add(messageMap, NTE_BAD_HASH, "Bad Hash");
|
add(messageMap, NTE_BAD_HASH, "Bad Hash");
|
||||||
add(messageMap, NTE_BAD_KEY, "Bad Key");
|
add(messageMap, NTE_BAD_KEY, "Bad Key");
|
||||||
@ -185,18 +183,15 @@ std::map<long, const std::string> Utility::initSSPIErr()
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
const std::string& Utility::formatError(long errCode)
|
std::string Utility::formatError(long errCode)
|
||||||
{
|
{
|
||||||
Poco::FastMutex::ScopedLock lock(_mutex);
|
|
||||||
|
|
||||||
static const std::string def("Internal SSPI error");
|
|
||||||
static const std::map<long, const std::string> errs(initSSPIErr());
|
static const std::map<long, const std::string> errs(initSSPIErr());
|
||||||
|
|
||||||
const std::map<long, const std::string>::const_iterator it = errs.find(errCode);
|
const std::map<long, const std::string>::const_iterator it = errs.find(errCode);
|
||||||
if (it != errs.end())
|
if (it != errs.end())
|
||||||
return it->second;
|
return it->second;
|
||||||
else
|
else
|
||||||
return def;
|
return "0x" + Poco::NumberFormatter::formatHex(errCode, 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user