fixed a crash on WinCE; code cleanup

This commit is contained in:
Günter Obiltschnig
2014-09-17 11:17:04 +02:00
parent ee25a49e9b
commit 0a115be8ac
7 changed files with 136 additions and 100 deletions

View File

@@ -73,7 +73,7 @@ public:
const_cast<AutoSecBufferDesc*>(&desc)->initBuffers(); const_cast<AutoSecBufferDesc*>(&desc)->initBuffers();
} }
AutoSecBufferDesc& operator=(const AutoSecBufferDesc& desc) AutoSecBufferDesc& operator = (const AutoSecBufferDesc& desc)
{ {
if (&desc != this) if (&desc != this)
{ {
@@ -96,13 +96,22 @@ public:
~AutoSecBufferDesc() ~AutoSecBufferDesc()
/// Destroys the AutoSecBufferDesc /// Destroys the AutoSecBufferDesc
{
release();
}
void release()
{ {
if (_autoRelease) if (_autoRelease)
{ {
for (int i = 0; i < numBufs; ++i) for (int i = 0; i < numBufs; ++i)
{ {
_pSec->FreeContextBuffer(_buffers[i].pvBuffer); if (_buffers[i].pvBuffer)
{
_pSec->FreeContextBuffer(_buffers[i].pvBuffer);
}
} }
_autoRelease = false;
} }
} }

View File

@@ -27,6 +27,13 @@
#include "Poco/Net/InvalidCertificateHandler.h" #include "Poco/Net/InvalidCertificateHandler.h"
#include "Poco/BasicEvent.h" #include "Poco/BasicEvent.h"
#include "Poco/SharedPtr.h" #include "Poco/SharedPtr.h"
#include <wincrypt.h>
#include <schannel.h>
#ifndef SECURITY_WIN32
#define SECURITY_WIN32
#endif
#include <security.h>
#include <sspi.h>
namespace Poco { namespace Poco {
@@ -184,6 +191,9 @@ public:
static const std::string CFG_SERVER_PREFIX; static const std::string CFG_SERVER_PREFIX;
static const std::string CFG_CLIENT_PREFIX; static const std::string CFG_CLIENT_PREFIX;
protected:
SecurityFunctionTableW& securityFunctions();
private: private:
SSLManager(); SSLManager();
/// Creates the SSLManager. /// Creates the SSLManager.
@@ -200,6 +210,15 @@ private:
void initCertificateHandler(bool server); void initCertificateHandler(bool server);
/// Inits the certificate handler. /// Inits the certificate handler.
void loadSecurityLibrary();
/// Loads the Windows security DLL.
void unloadSecurityLibrary();
/// Unloads the Windows security DLL.
HMODULE _hSecurityModule;
SecurityFunctionTableW _securityFunctions;
CertificateHandlerFactoryMgr _certHandlerFactoryMgr; CertificateHandlerFactoryMgr _certHandlerFactoryMgr;
Context::Ptr _ptrDefaultServerContext; Context::Ptr _ptrDefaultServerContext;
InvalidCertificateHandlerPtr _ptrServerCertificateHandler; InvalidCertificateHandlerPtr _ptrServerCertificateHandler;
@@ -232,6 +251,7 @@ private:
static const std::string CFG_REQUIRE_TLSV1_2; static const std::string CFG_REQUIRE_TLSV1_2;
friend class Poco::SingletonHolder<SSLManager>; friend class Poco::SingletonHolder<SSLManager>;
friend class SecureSocketImpl;
}; };
@@ -244,6 +264,12 @@ inline CertificateHandlerFactoryMgr& SSLManager::certificateHandlerFactoryMgr()
} }
inline SecurityFunctionTableW& SSLManager::securityFunctions()
{
return _securityFunctions;
}
} } // namespace Poco::Net } } // namespace Poco::Net

View File

@@ -241,8 +241,7 @@ private:
DWORD _clientFlags; DWORD _clientFlags;
DWORD _serverFlags; DWORD _serverFlags;
HMODULE _hSecurityModule; SecurityFunctionTableW& _securityFunctions;
SecurityFunctionTableW _securityFunctions;
BYTE* _pReceiveBuffer; BYTE* _pReceiveBuffer;
DWORD _receiveBufferSize; DWORD _receiveBufferSize;

View File

@@ -51,19 +51,16 @@ const std::string SSLManager::CFG_REQUIRE_TLSV1_1("requireTLSv1_1");
const std::string SSLManager::CFG_REQUIRE_TLSV1_2("requireTLSv1_2"); const std::string SSLManager::CFG_REQUIRE_TLSV1_2("requireTLSv1_2");
SSLManager::SSLManager() SSLManager::SSLManager():
_hSecurityModule(0)
{ {
loadSecurityLibrary();
} }
SSLManager::~SSLManager() SSLManager::~SSLManager()
{ {
ClientVerificationError.clear(); shutdown();
ServerVerificationError.clear();
_ptrServerCertificateHandler = 0;
_ptrDefaultServerContext = 0;
_ptrClientCertificateHandler = 0;
_ptrDefaultClientContext = 0;
} }
@@ -237,8 +234,90 @@ void SSLManager::shutdown()
{ {
ClientVerificationError.clear(); ClientVerificationError.clear();
ServerVerificationError.clear(); ServerVerificationError.clear();
_ptrDefaultServerContext = 0; _ptrServerCertificateHandler = 0;
_ptrDefaultClientContext = 0; _ptrDefaultServerContext = 0;
_ptrClientCertificateHandler = 0;
_ptrDefaultClientContext = 0;
unloadSecurityLibrary();
}
void SSLManager::loadSecurityLibrary()
{
if (_hSecurityModule) return;
OSVERSIONINFO VerInfo;
std::wstring dllPath;
// Find out which security DLL to use, depending on
// whether we are on Win2k, NT or Win9x
VerInfo.dwOSVersionInfoSize = sizeof(OSVERSIONINFO);
if (!GetVersionEx(&VerInfo))
throw Poco::SystemException("Cannot determine OS version");
#if defined(_WIN32_WCE)
dllPath = L"Secur32.dll";
#else
if (VerInfo.dwPlatformId == VER_PLATFORM_WIN32_NT
&& VerInfo.dwMajorVersion == 4)
{
dllPath = L"Security.dll";
}
else if (VerInfo.dwPlatformId == VER_PLATFORM_WIN32_WINDOWS ||
VerInfo.dwPlatformId == VER_PLATFORM_WIN32_NT )
{
dllPath = L"Secur32.dll";
}
else
{
throw Poco::SystemException("Cannot determine which security DLL to use");
}
#endif
//
// Load Security DLL
//
_hSecurityModule = LoadLibraryW(dllPath.c_str());
if(_hSecurityModule == 0)
{
throw Poco::SystemException("Failed to load security DLL");
}
#if defined(_WIN32_WCE)
INIT_SECURITY_INTERFACE pInitSecurityInterface = (INIT_SECURITY_INTERFACE)GetProcAddressW( _hSecurityModule, L"InitSecurityInterfaceW");
#else
INIT_SECURITY_INTERFACE pInitSecurityInterface = (INIT_SECURITY_INTERFACE)GetProcAddress( _hSecurityModule, "InitSecurityInterfaceW");
#endif
if (!pInitSecurityInterface)
{
FreeLibrary(_hSecurityModule);
_hSecurityModule = 0;
throw Poco::SystemException("Failed to initialize security DLL (no init function)");
}
PSecurityFunctionTable pSecurityFunc = pInitSecurityInterface();
if (!pSecurityFunc)
{
FreeLibrary(_hSecurityModule);
_hSecurityModule = 0;
throw Poco::SystemException("Failed to initialize security DLL (no function table)");
}
CopyMemory(&_securityFunctions, pSecurityFunc, sizeof(_securityFunctions));
}
void SSLManager::unloadSecurityLibrary()
{
if (_hSecurityModule)
{
FreeLibrary(_hSecurityModule);
_hSecurityModule = 0;
}
} }

View File

@@ -84,8 +84,7 @@ SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr<SocketImpl> pSocketImpl, Contex
_hContext(), _hContext(),
_clientFlags(0), _clientFlags(0),
_serverFlags(0), _serverFlags(0),
_hSecurityModule(0), _securityFunctions(SSLManager::instance().securityFunctions()),
_securityFunctions(),
_pReceiveBuffer(0), _pReceiveBuffer(0),
_receiveBufferSize(0), _receiveBufferSize(0),
_pIOBuffer(0), _pIOBuffer(0),
@@ -122,8 +121,6 @@ SecureSocketImpl::~SecureSocketImpl()
void SecureSocketImpl::initCommon() void SecureSocketImpl::initCommon()
{ {
loadSecurityLibrary();
DWORD commonFlags = ISC_REQ_SEQUENCE_DETECT DWORD commonFlags = ISC_REQ_SEQUENCE_DETECT
| ISC_REQ_REPLAY_DETECT | ISC_REQ_REPLAY_DETECT
| ISC_REQ_CONFIDENTIALITY | ISC_REQ_CONFIDENTIALITY
@@ -188,11 +185,9 @@ void SecureSocketImpl::cleanup()
_hCertificateStore = 0; _hCertificateStore = 0;
} }
if (_hSecurityModule) // must release buffers before unloading library
{ _outSecBuffer.release();
FreeLibrary(_hSecurityModule); _inSecBuffer.release();
_hSecurityModule = 0;
}
delete [] _pReceiveBuffer; delete [] _pReceiveBuffer;
_pReceiveBuffer = 0; _pReceiveBuffer = 0;
@@ -664,6 +659,9 @@ void SecureSocketImpl::setPeerHostName(const std::string& peerHostName)
PCCERT_CONTEXT SecureSocketImpl::loadCertificate(const std::string& certStore, const std::string& certName, bool useMachineStore, bool mustFindCertificate) PCCERT_CONTEXT SecureSocketImpl::loadCertificate(const std::string& certStore, const std::string& certName, bool useMachineStore, bool mustFindCertificate)
{ {
if (mustFindCertificate && certName.empty())
throw SSLException("Certificate required, but no certificate name provided");
PCCERT_CONTEXT pCert = 0; PCCERT_CONTEXT pCert = 0;
std::wstring wcertStore; std::wstring wcertStore;
@@ -768,84 +766,11 @@ void SecureSocketImpl::acquireSchannelContext(Mode mode, PCCERT_CONTEXT pCertCon
if (status != SEC_E_OK) if (status != SEC_E_OK)
{ {
cleanup();
throw SSLException("Failed to acquire Schannel credentials", Utility::formatError(status)); throw SSLException("Failed to acquire Schannel credentials", Utility::formatError(status));
} }
} }
bool SecureSocketImpl::loadSecurityLibrary()
{
if (_hSecurityModule) return true;
OSVERSIONINFO VerInfo;
std::wstring dllPath;
// Find out which security DLL to use, depending on
// whether we are on Win2k, NT or Win9x
VerInfo.dwOSVersionInfoSize = sizeof(OSVERSIONINFO);
if (!GetVersionEx(&VerInfo))
{
return false;
}
#if defined(_WIN32_WCE)
dllPath = L"Secur32.dll";
#else
if (VerInfo.dwPlatformId == VER_PLATFORM_WIN32_NT
&& VerInfo.dwMajorVersion == 4)
{
dllPath = L"Security.dll";
}
else if (VerInfo.dwPlatformId == VER_PLATFORM_WIN32_WINDOWS ||
VerInfo.dwPlatformId == VER_PLATFORM_WIN32_NT )
{
dllPath = L"Secur32.dll";
}
else
{
return false;
}
#endif
//
// Load Security DLL
//
_hSecurityModule = LoadLibraryW(dllPath.c_str());
if(_hSecurityModule == 0)
{
return false;
}
#if defined(_WIN32_WCE)
INIT_SECURITY_INTERFACE pInitSecurityInterface = (INIT_SECURITY_INTERFACE)GetProcAddressW( _hSecurityModule, L"InitSecurityInterfaceW");
#else
INIT_SECURITY_INTERFACE pInitSecurityInterface = (INIT_SECURITY_INTERFACE)GetProcAddress( _hSecurityModule, "InitSecurityInterfaceW");
#endif
if (!pInitSecurityInterface)
{
FreeLibrary(_hSecurityModule);
_hSecurityModule = 0;
return false;
}
PSecurityFunctionTable pSecurityFunc = pInitSecurityInterface();
if (!pSecurityFunc)
{
FreeLibrary(_hSecurityModule);
_hSecurityModule = 0;
return false;
}
CopyMemory(&_securityFunctions, pSecurityFunc, sizeof(_securityFunctions));
return true;
}
void SecureSocketImpl::connectSSL(bool completeHandshake) void SecureSocketImpl::connectSSL(bool completeHandshake)
{ {
poco_assert_dbg(_pPeerCertificate == 0); poco_assert_dbg(_pPeerCertificate == 0);
@@ -1012,7 +937,6 @@ SECURITY_STATUS SecureSocketImpl::performClientHandshakeLoop()
} }
void SecureSocketImpl::performClientHandshakeLoopExtError() void SecureSocketImpl::performClientHandshakeLoopExtError()
{ {
poco_assert_dbg (FAILED(_securityStatus)); poco_assert_dbg (FAILED(_securityStatus));
@@ -1030,6 +954,7 @@ void SecureSocketImpl::performClientHandshakeLoopError()
throw SSLException("Error during handshake", Utility::formatError(_securityStatus)); throw SSLException("Error during handshake", Utility::formatError(_securityStatus));
} }
void SecureSocketImpl::performClientHandshakeSendOutBuffer() void SecureSocketImpl::performClientHandshakeSendOutBuffer()
{ {
if (_outSecBuffer[0].cbBuffer != 0 && _outSecBuffer[0].pvBuffer != NULL) if (_outSecBuffer[0].cbBuffer != 0 && _outSecBuffer[0].pvBuffer != NULL)
@@ -1038,7 +963,6 @@ void SecureSocketImpl::performClientHandshakeSendOutBuffer()
if (numBytes != _outSecBuffer[0].cbBuffer) if (numBytes != _outSecBuffer[0].cbBuffer)
{ {
cleanup();
throw SSLException("Socket error during handshake"); throw SSLException("Socket error during handshake");
} }
_bytesReadSum = 0; _bytesReadSum = 0;
@@ -1349,7 +1273,6 @@ void SecureSocketImpl::clientVerifyCertificate(PCCERT_CONTEXT pServerCert, const
} }
void SecureSocketImpl::verifyCertificateChainClient(PCCERT_CONTEXT pServerCert, PCCERT_CHAIN_CONTEXT pChainContext) void SecureSocketImpl::verifyCertificateChainClient(PCCERT_CONTEXT pServerCert, PCCERT_CHAIN_CONTEXT pChainContext)
{ {
X509Certificate cert(pServerCert, true); X509Certificate cert(pServerCert, true);

View File

@@ -19,8 +19,7 @@
CppUnit::Test* NetSSLTestSuite::suite() CppUnit::Test* NetSSLTestSuite::suite()
{ {
CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("OpenSSLTestSuite"); CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("NetSSLTestSuite");
pSuite->addTest(HTTPSClientTestSuite::suite()); pSuite->addTest(HTTPSClientTestSuite::suite());
pSuite->addTest(TCPServerTestSuite::suite()); pSuite->addTest(TCPServerTestSuite::suite());

View File

@@ -18,6 +18,7 @@
#include "Poco/Net/HTTPStreamFactory.h" #include "Poco/Net/HTTPStreamFactory.h"
#include "Poco/Net/HTTPSStreamFactory.h" #include "Poco/Net/HTTPSStreamFactory.h"
#include <cstdlib> #include <cstdlib>
#include <iostream>
class NetSSLApp: public Poco::Util::Application class NetSSLApp: public Poco::Util::Application