// // Context.cpp // // Library: NetSSL_Win // Package: SSLCore // Module: Context // // Copyright (c) 2006-2014, Applied Informatics Software Engineering GmbH. // and Contributors. // // SPDX-License-Identifier: BSL-1.0 // #include "Poco/Net/Context.h" #include "Poco/Net/SSLManager.h" #include "Poco/Net/SSLException.h" #include "Poco/Net/Utility.h" #include "Poco/UnicodeConverter.h" #include "Poco/Format.h" #include "Poco/File.h" #include "Poco/FileStream.h" #include "Poco/MemoryStream.h" #include "Poco/Base64Decoder.h" #include "Poco/Buffer.h" #include #include #include #include namespace Poco { namespace Net { const std::string Context::CERT_STORE_MY("MY"); const std::string Context::CERT_STORE_ROOT("ROOT"); const std::string Context::CERT_STORE_TRUST("TRUST"); const std::string Context::CERT_STORE_CA("CA"); const std::string Context::CERT_STORE_USERDS("USERDS"); Context::Context(Usage usage, const std::string& certificateInfoOrPath, VerificationMode verMode, int options, const std::string& certStore): _usage(usage), _mode(verMode), _options(options), _disabledProtocols(0), _extendedCertificateVerification(true), _certInfoOrPath(certificateInfoOrPath), _certStoreName(certStore), _hMemCertStore(nullptr), _hCollectionCertStore(nullptr), _hRootCertStore(nullptr), _hCertStore(nullptr), _pCert(nullptr), _securityFunctions(SSLManager::instance().securityFunctions()) { init(); } Context::~Context() { if (_pCert) { CertFreeCertificateContext(_pCert); } if (_hCertStore) { CertCloseStore(_hCertStore, 0); } CertCloseStore(_hCollectionCertStore, 0); CertCloseStore(_hMemCertStore, 0); if (_hRootCertStore) { CertCloseStore(_hRootCertStore, 0); } if (_hCreds.dwLower != 0 && _hCreds.dwUpper != 0) { _securityFunctions.FreeCredentialsHandle(&_hCreds); } } void Context::init() { _hCreds.dwLower = 0; _hCreds.dwUpper = 0; _hMemCertStore = CertOpenStore( CERT_STORE_PROV_MEMORY, // The memory provider type 0, // The encoding type is not needed 0, // Use the default provider 0, // Accept the default dwFlags nullptr); // pvPara is not used if (!_hMemCertStore) throw SSLException("Failed to create memory certificate store", GetLastError()); _hCollectionCertStore = CertOpenStore( CERT_STORE_PROV_COLLECTION, // A collection store 0, // Encoding type; not used with a collection store 0, // Use the default provider 0, // No flags nullptr); // Not needed if (!_hCollectionCertStore) throw SSLException("Failed to create collection store", GetLastError()); if (!CertAddStoreToCollection(_hCollectionCertStore, _hMemCertStore, CERT_PHYSICAL_STORE_ADD_ENABLE_FLAG, 1)) throw SSLException("Failed to add memory certificate store to collection store", GetLastError()); if (_options & OPT_TRUST_ROOTS_WIN_CERT_STORE) { // add root certificates std::wstring rootStore; Poco::UnicodeConverter::convert(CERT_STORE_ROOT, rootStore); _hRootCertStore = CertOpenSystemStoreW(0, rootStore.c_str()); if (!_hRootCertStore) throw SSLException("Failed to open root certificate store", GetLastError()); if (!CertAddStoreToCollection(_hCollectionCertStore, _hRootCertStore, CERT_PHYSICAL_STORE_ADD_ENABLE_FLAG, 1)) throw SSLException("Failed to add root certificate store to collection store", GetLastError()); } } void Context::enableExtendedCertificateVerification(bool flag) { _extendedCertificateVerification = flag; } void Context::addTrustedCert(const Poco::Net::X509Certificate& cert) { Poco::FastMutex::ScopedLock lock(_mutex); if (!CertAddCertificateContextToStore(_hMemCertStore, cert.system(), CERT_STORE_ADD_REPLACE_EXISTING, nullptr)) throw CertificateException("Failed to add certificate to store", GetLastError()); } Poco::Net::X509Certificate Context::certificate() { if (_pCert) return Poco::Net::X509Certificate(_pCert, true); if (_certInfoOrPath.empty()) throw NoCertificateException("Certificate requested, but no certificate name or path provided"); if (_options & OPT_LOAD_CERT_FROM_FILE) { importCertificate(); } else { loadCertificate(); } return Poco::Net::X509Certificate(_pCert, true); } void Context::loadCertificate() { std::wstring wcertStoreName; Poco::UnicodeConverter::convert(_certStoreName, wcertStoreName); if (!_hCertStore) { if (_options & OPT_USE_MACHINE_STORE) _hCertStore = CertOpenStore( CERT_STORE_PROV_SYSTEM, 0, 0, CERT_SYSTEM_STORE_LOCAL_MACHINE | CERT_STORE_OPEN_EXISTING_FLAG, wcertStoreName.c_str() ); else _hCertStore = CertOpenSystemStoreW(0, wcertStoreName.c_str()); } if (!_hCertStore) throw CertificateException("Failed to open certificate store", _certStoreName, GetLastError()); // Find the certificate either using name or hash. if(_options & OPT_USE_CERT_HASH) { // Sanity check for the hash value. if(_certInfoOrPath.size() < 40 || _certInfoOrPath.size() % 2) throw CertificateException(Poco::format("Invalid certificate hash %s", _certInfoOrPath)); // Convert hex to binary. BYTE buffer[256] = {}; int bufferSize = 0; int byte = 0; std::string szHex; for(int counter = 0; counter < _certInfoOrPath.size() / 2; counter++) { szHex = _certInfoOrPath.substr(2 * counter, 2); if( sscanf(szHex.c_str(), "%02x", OUT &byte ) != 1) throw CertificateException(Poco::format("Invalid certificate hash %s", _certInfoOrPath)); buffer[counter] = (BYTE) byte; bufferSize += 1; } CRYPT_HASH_BLOB chBlob; chBlob.cbData = bufferSize; chBlob.pbData = buffer; _pCert = CertFindCertificateInStore(_hCertStore, PKCS_7_ASN_ENCODING | X509_ASN_ENCODING, 0, CERT_FIND_HASH, &chBlob, nullptr); if (!_pCert) throw NoCertificateException(Poco::format("Failed to find certificate %s in store %s", _certInfoOrPath, _certStoreName)); } else { CERT_RDN_ATTR cert_rdn_attr; char cmnName[] = szOID_COMMON_NAME; cert_rdn_attr.pszObjId = cmnName; cert_rdn_attr.dwValueType = CERT_RDN_ANY_TYPE; cert_rdn_attr.Value.cbData = (DWORD) _certInfoOrPath.size(); cert_rdn_attr.Value.pbData = (BYTE *) _certInfoOrPath.c_str(); CERT_RDN cert_rdn; cert_rdn.cRDNAttr = 1; cert_rdn.rgRDNAttr = &cert_rdn_attr; _pCert = CertFindCertificateInStore(_hCertStore, X509_ASN_ENCODING, 0, CERT_FIND_SUBJECT_ATTR, &cert_rdn, nullptr); if (!_pCert) throw NoCertificateException(Poco::format("Failed to find certificate %s in store %s", _certInfoOrPath, _certStoreName)); } } void Context::importCertificate() { Poco::File certFile(_certInfoOrPath); if (!certFile.exists()) throw Poco::FileNotFoundException(_certInfoOrPath); Poco::File::FileSize size = certFile.getSize(); if (size > 4096) throw Poco::DataFormatException("PKCS #12 certificate file too large", _certInfoOrPath); Poco::Buffer buffer(static_cast(size)); Poco::FileInputStream istr(_certInfoOrPath); istr.read(buffer.begin(), buffer.size()); if (istr.gcount() != size) throw Poco::IOException("error reading PKCS #12 certificate file"); importCertificate(buffer.begin(), buffer.size()); } void Context::importCertificate(const char* pBuffer, std::size_t size) { std::string password; SSLManager::instance().PrivateKeyPassphraseRequired.notify(&SSLManager::instance(), password); std::wstring wpassword; Poco::UnicodeConverter::toUTF16(password, wpassword); // clear UTF-8 password std::fill(const_cast(password.data()), const_cast(password.data() + password.size()), 'X'); CRYPT_DATA_BLOB blob; blob.cbData = static_cast(size); blob.pbData = reinterpret_cast(const_cast(pBuffer)); HCERTSTORE hTempStore = PFXImportCertStore(&blob, wpassword.data(), PKCS12_ALLOW_OVERWRITE_KEY | PKCS12_INCLUDE_EXTENDED_PROPERTIES); // clear UTF-16 password std::fill(const_cast(wpassword.data()), const_cast(wpassword.data() + password.size()), L'X'); if (hTempStore) { PCCERT_CONTEXT pCert = nullptr; pCert = CertEnumCertificatesInStore(hTempStore, pCert); while (pCert) { PCCERT_CONTEXT pStoreCert = nullptr; BOOL res = CertAddCertificateContextToStore(_hMemCertStore, pCert, CERT_STORE_ADD_REPLACE_EXISTING, &pStoreCert); if (res) { if (!_pCert) { _pCert = pStoreCert; } else { CertFreeCertificateContext(pStoreCert); pStoreCert = nullptr; } } pCert = CertEnumCertificatesInStore(hTempStore, pCert); } CertCloseStore(hTempStore, 0); } else throw CertificateException("failed to import certificate", GetLastError()); } CredHandle& Context::credentials() { Poco::FastMutex::ScopedLock lock(_mutex); if (_hCreds.dwLower == 0 && _hCreds.dwUpper == 0) { acquireSchannelCredentials(_hCreds); } return _hCreds; } void Context::acquireSchannelCredentials(CredHandle& credHandle) const { SCHANNEL_CRED schannelCred; ZeroMemory(&schannelCred, sizeof(schannelCred)); schannelCred.dwVersion = SCHANNEL_CRED_VERSION; if (_pCert) { schannelCred.cCreds = 1; // how many cred are stored in &pCertContext schannelCred.paCred = const_cast(&_pCert); } schannelCred.grbitEnabledProtocols = proto(); // Windows NT and Windows Me/98/95: revocation checking not supported via flags if (_options & Context::OPT_PERFORM_REVOCATION_CHECK) schannelCred.dwFlags |= SCH_CRED_REVOCATION_CHECK_CHAIN; else schannelCred.dwFlags |= SCH_CRED_IGNORE_NO_REVOCATION_CHECK | SCH_CRED_IGNORE_REVOCATION_OFFLINE; if (isForServerUse()) { if (_mode >= Context::VERIFY_STRICT) schannelCred.dwFlags |= SCH_CRED_NO_SYSTEM_MAPPER; if (_mode == Context::VERIFY_NONE) schannelCred.dwFlags |= SCH_CRED_MANUAL_CRED_VALIDATION; } else { if (_mode >= Context::VERIFY_STRICT) schannelCred.dwFlags |= SCH_CRED_NO_DEFAULT_CREDS; else schannelCred.dwFlags |= SCH_CRED_USE_DEFAULT_CREDS; if (_mode == Context::VERIFY_NONE) schannelCred.dwFlags |= SCH_CRED_MANUAL_CRED_VALIDATION | SCH_CRED_NO_SERVERNAME_CHECK; if (!_extendedCertificateVerification) schannelCred.dwFlags |= SCH_CRED_NO_SERVERNAME_CHECK; } #if defined(SCH_USE_STRONG_CRYPTO) if (_options & Context::OPT_USE_STRONG_CRYPTO) schannelCred.dwFlags |= SCH_USE_STRONG_CRYPTO; #endif schannelCred.hRootStore = schannelCred.hRootStore = isForServerUse() ? _hCollectionCertStore : nullptr; TimeStamp tsExpiry; tsExpiry.LowPart = tsExpiry.HighPart = 0; ::SEC_WCHAR name[] = UNISP_NAME_W; SECURITY_STATUS status = _securityFunctions.AcquireCredentialsHandleW( nullptr, name, isForServerUse() ? SECPKG_CRED_INBOUND : SECPKG_CRED_OUTBOUND, nullptr, &schannelCred, nullptr, nullptr, &credHandle, &tsExpiry); if (status != SEC_E_OK) { throw SSLException("Failed to acquire Schannel credentials", Utility::formatError(status)); } } DWORD Context::proto() const { DWORD result = 0; switch (_usage) { case Context::TLS_CLIENT_USE: case Context::CLIENT_USE: result = SP_PROT_SSL3_CLIENT | SP_PROT_TLS1_CLIENT #if defined(SP_PROT_TLS1_1) | SP_PROT_TLS1_1_CLIENT #endif #if defined(SP_PROT_TLS1_2) | SP_PROT_TLS1_2_CLIENT #endif #if defined(SP_PROT_TLS1_3) | SP_PROT_TLS1_3_CLIENT #endif ; break; case Context::TLS_SERVER_USE: case Context::SERVER_USE: result = SP_PROT_SSL3_SERVER | SP_PROT_TLS1_SERVER #if defined(SP_PROT_TLS1_1) | SP_PROT_TLS1_1_SERVER #endif #if defined(SP_PROT_TLS1_2) | SP_PROT_TLS1_2_SERVER #endif #if defined(SP_PROT_TLS1_3) | SP_PROT_TLS1_3_SERVER #endif ; break; case Context::TLSV1_CLIENT_USE: result = SP_PROT_TLS1_CLIENT #if defined(SP_PROT_TLS1_1) | SP_PROT_TLS1_1_CLIENT #endif #if defined(SP_PROT_TLS1_2) | SP_PROT_TLS1_2_CLIENT #endif #if defined(SP_PROT_TLS1_3) | SP_PROT_TLS1_3_CLIENT #endif ; break; case Context::TLSV1_SERVER_USE: result = SP_PROT_TLS1_SERVER #if defined(SP_PROT_TLS1_1) | SP_PROT_TLS1_1_SERVER #endif #if defined(SP_PROT_TLS1_2) | SP_PROT_TLS1_2_SERVER #endif #if defined(SP_PROT_TLS1_3) | SP_PROT_TLS1_3_SERVER #endif ; break; #if defined(SP_PROT_TLS1_1) case Context::TLSV1_1_CLIENT_USE: result = SP_PROT_TLS1_1_CLIENT #if defined(SP_PROT_TLS1_2) | SP_PROT_TLS1_2_CLIENT #endif #if defined(SP_PROT_TLS1_3) | SP_PROT_TLS1_3_CLIENT #endif ; break; case Context::TLSV1_1_SERVER_USE: result = SP_PROT_TLS1_1_SERVER #if defined(SP_PROT_TLS1_2) | SP_PROT_TLS1_2_SERVER #endif #if defined(SP_PROT_TLS1_3) | SP_PROT_TLS1_3_SERVER #endif ; break; #endif #if defined(SP_PROT_TLS1_2) case Context::TLSV1_2_CLIENT_USE: result = SP_PROT_TLS1_2_CLIENT #if defined(SP_PROT_TLS1_3) | SP_PROT_TLS1_3_CLIENT #endif ; break; case Context::TLSV1_2_SERVER_USE: result = SP_PROT_TLS1_2_SERVER #if defined(SP_PROT_TLS1_3) | SP_PROT_TLS1_3_SERVER #endif ; break; #endif #if defined(SP_PROT_TLS1_3) case Context::TLSV1_3_CLIENT_USE: result = SP_PROT_TLS1_3_CLIENT; break; case Context::TLSV1_3_SERVER_USE: result = SP_PROT_TLS1_3_SERVER; break; #endif default: throw Poco::InvalidArgumentException("Unsupported SSL/TLS protocol version"); } return result & enabledProtocols(); } DWORD Context::enabledProtocols() const { DWORD result = 0; if (!(_disabledProtocols & PROTO_SSLV3)) { result |= SP_PROT_SSL3_CLIENT | SP_PROT_SSL3_SERVER; } if (!(_disabledProtocols & PROTO_TLSV1)) { result |= SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_SERVER; } if (!(_disabledProtocols & PROTO_TLSV1_1)) { #ifdef SP_PROT_TLS1_1_CLIENT result |= SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_1_SERVER; #endif } if (!(_disabledProtocols & PROTO_TLSV1_2)) { #ifdef SP_PROT_TLS1_2_CLIENT result |= SP_PROT_TLS1_2_CLIENT | SP_PROT_TLS1_2_SERVER; #endif } if (!(_disabledProtocols & PROTO_TLSV1_3)) { #ifdef SP_PROT_TLS1_3_CLIENT result |= SP_PROT_TLS1_3_CLIENT | SP_PROT_TLS1_3_SERVER; #endif } return result; } void Context::disableProtocols(int protocols) { _disabledProtocols = protocols; } void Context::requireMinimumProtocol(Protocols protocol) { _disabledProtocols = 0; switch (protocol) { case PROTO_SSLV3: _disabledProtocols = PROTO_SSLV2; break; case PROTO_TLSV1: _disabledProtocols = PROTO_SSLV2 | PROTO_SSLV3; break; case PROTO_TLSV1_1: _disabledProtocols = PROTO_SSLV2 | PROTO_SSLV3 | PROTO_TLSV1; break; case PROTO_TLSV1_2: _disabledProtocols = PROTO_SSLV2 | PROTO_SSLV3 | PROTO_TLSV1 | PROTO_TLSV1_1; break; case PROTO_TLSV1_3: _disabledProtocols = PROTO_SSLV2 | PROTO_SSLV3 | PROTO_TLSV1 | PROTO_TLSV1_1 | PROTO_TLSV1_2; break; } } } } // namespace Poco::Net