diff --git a/Net/src/HTTPNTLMCredentials.cpp b/Net/src/HTTPNTLMCredentials.cpp index e3e611236..b28d9c74c 100644 --- a/Net/src/HTTPNTLMCredentials.cpp +++ b/Net/src/HTTPNTLMCredentials.cpp @@ -17,6 +17,7 @@ #include "Poco/Net/HTTPAuthenticationParams.h" #include "Poco/Net/HTTPRequest.h" #include "Poco/Net/HTTPResponse.h" +#include "Poco/Net/NetException.h" #include "Poco/DateTime.h" #include "Poco/NumberFormatter.h" #include "Poco/Exception.h" @@ -115,25 +116,19 @@ std::string HTTPNTLMCredentials::createNTLMMessage(const std::string& responseAu std::string username; splitUsername(_username, username, negotiateMsg.domain); std::vector negotiateBuf = NTLMCredentials::formatNegotiateMessage(negotiateMsg); - - std::ostringstream ostr; - Poco::Base64Encoder base64(ostr); - base64.rdbuf()->setLineLength(0); - base64.write(reinterpret_cast(&negotiateBuf[0]), negotiateBuf.size()); - base64.close(); - return ostr.str(); + return toBase64(negotiateBuf); } else { - Poco::MemoryInputStream istr(responseAuthParams.data(), responseAuthParams.size()); - Poco::Base64Decoder debase64(istr); - std::vector buffer(responseAuthParams.size()); - debase64.read(reinterpret_cast(&buffer[0]), buffer.size()); - std::size_t size = static_cast(debase64.gcount()); - - Poco::Net::NTLMCredentials::ChallengeMessage challengeMsg; - if (NTLMCredentials::parseChallengeMessage(&buffer[0], size, challengeMsg)) + std::vector buffer = fromBase64(responseAuthParams); + NTLMCredentials::ChallengeMessage challengeMsg; + if (NTLMCredentials::parseChallengeMessage(&buffer[0], buffer.size(), challengeMsg)) { + if ((challengeMsg.flags & NTLMCredentials::NTLM_FLAG_NEGOTIATE_NTLM2_KEY) == 0) + { + throw HTTPException("Proxy does not support NTLMv2 authentication"); + } + std::string username; std::string domain; splitUsername(_username, username, domain); @@ -143,22 +138,18 @@ std::string HTTPNTLMCredentials::createNTLMMessage(const std::string& responseAu authenticateMsg.target = challengeMsg.target; authenticateMsg.username = username; - std::vector nonce = NTLMCredentials::createNonce(); + std::vector lmNonce = NTLMCredentials::createNonce(); + std::vector ntlmNonce = NTLMCredentials::createNonce(); Poco::UInt64 timestamp = NTLMCredentials::createTimestamp(); std::vector ntlm2Hash = NTLMCredentials::createNTLMv2Hash(username, challengeMsg.target, _password); - authenticateMsg.lmResponse = NTLMCredentials::createLMv2Response(ntlm2Hash, challengeMsg.challenge, nonce); - authenticateMsg.ntlmResponse = Poco::Net::NTLMCredentials::createNTLMv2Response(ntlm2Hash, challengeMsg.challenge, nonce, challengeMsg.targetInfo, timestamp); - std::vector authenticateBuf = Poco::Net::NTLMCredentials::formatAuthenticateMessage(authenticateMsg); + authenticateMsg.lmResponse = NTLMCredentials::createLMv2Response(ntlm2Hash, challengeMsg.challenge, lmNonce); + authenticateMsg.ntlmResponse = NTLMCredentials::createNTLMv2Response(ntlm2Hash, challengeMsg.challenge, ntlmNonce, challengeMsg.targetInfo, timestamp); - std::ostringstream ostr; - Poco::Base64Encoder base64(ostr); - base64.rdbuf()->setLineLength(0); - base64.write(reinterpret_cast(&authenticateBuf[0]), authenticateBuf.size()); - base64.close(); - return ostr.str(); + std::vector authenticateBuf = NTLMCredentials::formatAuthenticateMessage(authenticateMsg); + return toBase64(authenticateBuf); } - else throw Poco::InvalidArgumentException("Invalid NTLM challenge"); + else throw HTTPException("Invalid NTLM challenge"); } } @@ -186,5 +177,26 @@ void HTTPNTLMCredentials::splitUsername(const std::string& usernameAndDomain, st } +std::string HTTPNTLMCredentials::toBase64(const std::vector& buffer) +{ + std::ostringstream ostr; + Poco::Base64Encoder base64(ostr); + base64.rdbuf()->setLineLength(0); + base64.write(reinterpret_cast(&buffer[0]), buffer.size()); + base64.close(); + return ostr.str(); +} + + +std::vector HTTPNTLMCredentials::fromBase64(const std::string& base64) +{ + Poco::MemoryInputStream istr(base64.data(), base64.size()); + Poco::Base64Decoder debase64(istr); + std::vector buffer(base64.size()); + debase64.read(reinterpret_cast(&buffer[0]), buffer.size()); + buffer.resize(static_cast(debase64.gcount())); + return buffer; +} + } } // namespace Poco::Net