diff --git a/Net/include/Poco/Net/HTTPClientSession.h b/Net/include/Poco/Net/HTTPClientSession.h index e7514c0ea..a1fb388f3 100644 --- a/Net/include/Poco/Net/HTTPClientSession.h +++ b/Net/include/Poco/Net/HTTPClientSession.h @@ -227,6 +227,21 @@ public: /// to ensure a new connection will be set up /// for the next request. + virtual bool canContinue(HTTPResponse& response); + /// If the request contains a "Expect: 100-continue" header, + /// (see HTTPRequest::setExpectContinue()) the method can be + /// used to check whether the server has + /// sent a 100 Continue response before continuing with the + /// request, i.e. sending the request body. + /// + /// Returns true if the server has responded with 100 Continue, + /// otherwise false. The HTTPResponse object contains the + /// response sent by the server. + /// + /// In any case, receiveResponse() must be called as well in + /// order to complete the request. The same HTTPResponse object + /// passed to canContinue() must also be passed to receiveResponse(). + void reset(); /// Resets the session and closes the socket. /// @@ -291,6 +306,7 @@ private: bool _reconnect; bool _mustReconnect; bool _expectResponseBody; + bool _responseReceived; Poco::SharedPtr _pRequestStream; Poco::SharedPtr _pResponseStream; diff --git a/Net/include/Poco/Net/HTTPRequest.h b/Net/include/Poco/Net/HTTPRequest.h index 1cf865507..d21507149 100644 --- a/Net/include/Poco/Net/HTTPRequest.h +++ b/Net/include/Poco/Net/HTTPRequest.h @@ -104,7 +104,15 @@ public: void setCredentials(const std::string& scheme, const std::string& authInfo); /// Sets the authentication scheme and information for /// this request. + + bool getExpectContinue() const; + /// Returns true if the request contains an + /// "Expect: 100-continue" header. + void setExpectContinue(bool expectContinue); + /// Adds a "Expect: 100-continue" header to the request if + /// expectContinue is true, otherwise removes the Expect header. + bool hasProxyCredentials() const; /// Returns true iff the request contains proxy authentication /// information in the form of an Proxy-Authorization header. @@ -143,6 +151,7 @@ public: static const std::string AUTHORIZATION; static const std::string PROXY_AUTHORIZATION; static const std::string UPGRADE; + static const std::string EXPECT; protected: void getCredentials(const std::string& header, std::string& scheme, std::string& authInfo) const; diff --git a/Net/include/Poco/Net/HTTPServerRequest.h b/Net/include/Poco/Net/HTTPServerRequest.h index 44a7e95e9..4ffe4ac1d 100644 --- a/Net/include/Poco/Net/HTTPServerRequest.h +++ b/Net/include/Poco/Net/HTTPServerRequest.h @@ -56,10 +56,6 @@ public: /// The stream must be valid until the HTTPServerRequest /// object is destroyed. - virtual bool expectContinue() const = 0; - /// Returns true if the client expects a - /// 100 Continue response. - virtual const SocketAddress& clientAddress() const = 0; /// Returns the client's address. diff --git a/Net/include/Poco/Net/HTTPServerRequestImpl.h b/Net/include/Poco/Net/HTTPServerRequestImpl.h index d200b6b59..0abf0b53e 100644 --- a/Net/include/Poco/Net/HTTPServerRequestImpl.h +++ b/Net/include/Poco/Net/HTTPServerRequestImpl.h @@ -59,10 +59,6 @@ public: /// The stream is valid until the HTTPServerRequestImpl /// object is destroyed. - bool expectContinue() const; - /// Returns true if the client expects a - /// 100 Continue response. - const SocketAddress& clientAddress() const; /// Returns the client's address. @@ -87,9 +83,6 @@ public: StreamSocket detachSocket(); /// Returns the underlying socket after detaching /// it from the server session. - -protected: - static const std::string EXPECT; private: HTTPServerResponseImpl& _response; diff --git a/Net/src/HTTPClientSession.cpp b/Net/src/HTTPClientSession.cpp index a68ec0528..44c744de9 100644 --- a/Net/src/HTTPClientSession.cpp +++ b/Net/src/HTTPClientSession.cpp @@ -46,7 +46,8 @@ HTTPClientSession::HTTPClientSession(): _keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0), _reconnect(false), _mustReconnect(false), - _expectResponseBody(false) + _expectResponseBody(false), + _responseReceived(false) { } @@ -58,7 +59,8 @@ HTTPClientSession::HTTPClientSession(const StreamSocket& socket): _keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0), _reconnect(false), _mustReconnect(false), - _expectResponseBody(false) + _expectResponseBody(false), + _responseReceived(false) { } @@ -70,7 +72,8 @@ HTTPClientSession::HTTPClientSession(const SocketAddress& address): _keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0), _reconnect(false), _mustReconnect(false), - _expectResponseBody(false) + _expectResponseBody(false), + _responseReceived(false) { } @@ -82,7 +85,8 @@ HTTPClientSession::HTTPClientSession(const std::string& host, Poco::UInt16 port) _keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0), _reconnect(false), _mustReconnect(false), - _expectResponseBody(false) + _expectResponseBody(false), + _responseReceived(false) { } @@ -94,7 +98,8 @@ HTTPClientSession::HTTPClientSession(const std::string& host, Poco::UInt16 port, _keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0), _reconnect(false), _mustReconnect(false), - _expectResponseBody(false) + _expectResponseBody(false), + _responseReceived(false) { } @@ -192,6 +197,7 @@ std::ostream& HTTPClientSession::sendRequest(HTTPRequest& request) { clearException(); _pResponseStream = 0; + _responseReceived = false; bool keepAlive = getKeepAlive(); if (((connected() && !keepAlive) || mustReconnect()) && !_host.empty()) @@ -259,25 +265,28 @@ std::istream& HTTPClientSession::receiveResponse(HTTPResponse& response) _pRequestStream = 0; if (networkException()) networkException()->rethrow(); - do + if (!_responseReceived) { - response.clear(); - HTTPHeaderInputStream his(*this); - try + do { - response.read(his); - } - catch (Exception&) - { - close(); - if (networkException()) - networkException()->rethrow(); - else + response.clear(); + HTTPHeaderInputStream his(*this); + try + { + response.read(his); + } + catch (Exception&) + { + close(); + if (networkException()) + networkException()->rethrow(); + else + throw; throw; - throw; + } } + while (response.getStatus() == HTTPResponse::HTTP_CONTINUE); } - while (response.getStatus() == HTTPResponse::HTTP_CONTINUE); _mustReconnect = getKeepAlive() && !response.getKeepAlive(); @@ -298,6 +307,34 @@ std::istream& HTTPClientSession::receiveResponse(HTTPResponse& response) } +bool HTTPClientSession::canContinue(HTTPResponse& response) +{ + poco_assert (!_responseReceived); + + _pRequestStream->flush(); + + if (networkException()) networkException()->rethrow(); + + response.clear(); + HTTPHeaderInputStream his(*this); + try + { + response.read(his); + } + catch (Exception&) + { + close(); + if (networkException()) + networkException()->rethrow(); + else + throw; + throw; + } + _responseReceived = response.getStatus() != HTTPResponse::HTTP_CONTINUE; + return !_responseReceived; +} + + void HTTPClientSession::reset() { close(); diff --git a/Net/src/HTTPRequest.cpp b/Net/src/HTTPRequest.cpp index 61d7c6651..f1e63ab73 100644 --- a/Net/src/HTTPRequest.cpp +++ b/Net/src/HTTPRequest.cpp @@ -43,6 +43,7 @@ const std::string HTTPRequest::COOKIE = "Cookie"; const std::string HTTPRequest::AUTHORIZATION = "Authorization"; const std::string HTTPRequest::PROXY_AUTHORIZATION = "Proxy-Authorization"; const std::string HTTPRequest::UPGRADE = "Upgrade"; +const std::string HTTPRequest::EXPECT = "Expect"; HTTPRequest::HTTPRequest(): @@ -259,4 +260,20 @@ void HTTPRequest::setCredentials(const std::string& header, const std::string& s } +bool HTTPRequest::getExpectContinue() const +{ + const std::string& expect = get(EXPECT, EMPTY); + return !expect.empty() && icompare(expect, "100-continue") == 0; +} + + +void HTTPRequest::setExpectContinue(bool expectContinue) +{ + if (expectContinue) + set(EXPECT, "100-continue"); + else + erase(EXPECT); +} + + } } // namespace Poco::Net diff --git a/Net/src/HTTPServerConnection.cpp b/Net/src/HTTPServerConnection.cpp index 970e881c5..4166f26c8 100644 --- a/Net/src/HTTPServerConnection.cpp +++ b/Net/src/HTTPServerConnection.cpp @@ -81,7 +81,7 @@ void HTTPServerConnection::run() std::auto_ptr pHandler(_pFactory->createRequestHandler(request)); if (pHandler.get()) { - if (request.expectContinue()) + if (request.getExpectContinue() && response.getStatus() == HTTPResponse::HTTP_OK) response.sendContinue(); pHandler->handleRequest(request, response); diff --git a/Net/src/HTTPServerRequestImpl.cpp b/Net/src/HTTPServerRequestImpl.cpp index 729ab3b72..c363cd7ba 100644 --- a/Net/src/HTTPServerRequestImpl.cpp +++ b/Net/src/HTTPServerRequestImpl.cpp @@ -33,9 +33,6 @@ namespace Poco { namespace Net { -const std::string HTTPServerRequestImpl::EXPECT("Expect"); - - HTTPServerRequestImpl::HTTPServerRequestImpl(HTTPServerResponseImpl& response, HTTPServerSession& session, HTTPServerParams* pParams): _response(response), _session(session), @@ -90,11 +87,4 @@ StreamSocket HTTPServerRequestImpl::detachSocket() } -bool HTTPServerRequestImpl::expectContinue() const -{ - const std::string& expect = get(EXPECT, EMPTY); - return !expect.empty() && icompare(expect, "100-continue") == 0; -} - - } } // namespace Poco::Net diff --git a/Net/testsuite/src/HTTPClientSessionTest.cpp b/Net/testsuite/src/HTTPClientSessionTest.cpp index 484852f39..2e4ecb0dc 100644 --- a/Net/testsuite/src/HTTPClientSessionTest.cpp +++ b/Net/testsuite/src/HTTPClientSessionTest.cpp @@ -300,6 +300,47 @@ void HTTPClientSessionTest::testBypassProxy() } +void HTTPClientSessionTest::testExpectContinue() +{ + HTTPTestServer srv; + HTTPClientSession s("localhost", srv.port()); + HTTPRequest request(HTTPRequest::HTTP_POST, "/expect"); + std::string body("this is a random request body\r\n0\r\n"); + request.setContentLength((int) body.length()); + request.setExpectContinue(true); + s.sendRequest(request) << body; + HTTPResponse response; + assert (s.canContinue(response)); + assert (response.getStatus() == HTTPResponse::HTTP_CONTINUE); + std::istream& rs = s.receiveResponse(response); + assert (response.getStatus() == HTTPResponse::HTTP_OK); + assert (response.getContentLength() == body.length()); + std::ostringstream ostr; + StreamCopier::copyStream(rs, ostr); + assert (ostr.str() == body); +} + + +void HTTPClientSessionTest::testExpectContinueFail() +{ + HTTPTestServer srv; + HTTPClientSession s("localhost", srv.port()); + HTTPRequest request(HTTPRequest::HTTP_POST, "/fail"); + std::string body("this is a random request body\r\n0\r\n"); + request.setContentLength((int) body.length()); + request.setExpectContinue(true); + s.sendRequest(request) << body; + HTTPResponse response; + assert (!s.canContinue(response)); + assert (response.getStatus() == HTTPResponse::HTTP_BAD_REQUEST); + std::istream& rs = s.receiveResponse(response); + assert (response.getStatus() == HTTPResponse::HTTP_BAD_REQUEST); + std::ostringstream ostr; + StreamCopier::copyStream(rs, ostr); + assert (ostr.str().empty()); +} + + void HTTPClientSessionTest::setUp() { } @@ -327,6 +368,8 @@ CppUnit::Test* HTTPClientSessionTest::suite() CppUnit_addTest(pSuite, HTTPClientSessionTest, testProxy); CppUnit_addTest(pSuite, HTTPClientSessionTest, testProxyAuth); CppUnit_addTest(pSuite, HTTPClientSessionTest, testBypassProxy); + CppUnit_addTest(pSuite, HTTPClientSessionTest, testExpectContinue); + CppUnit_addTest(pSuite, HTTPClientSessionTest, testExpectContinueFail); return pSuite; } diff --git a/Net/testsuite/src/HTTPClientSessionTest.h b/Net/testsuite/src/HTTPClientSessionTest.h index 099b23269..09824f7cf 100644 --- a/Net/testsuite/src/HTTPClientSessionTest.h +++ b/Net/testsuite/src/HTTPClientSessionTest.h @@ -39,6 +39,8 @@ public: void testProxy(); void testProxyAuth(); void testBypassProxy(); + void testExpectContinue(); + void testExpectContinueFail(); void setUp(); void tearDown(); diff --git a/Net/testsuite/src/HTTPTestServer.cpp b/Net/testsuite/src/HTTPTestServer.cpp index c7f570727..67f3e0e2b 100644 --- a/Net/testsuite/src/HTTPTestServer.cpp +++ b/Net/testsuite/src/HTTPTestServer.cpp @@ -142,6 +142,37 @@ std::string HTTPTestServer::handleRequest() const if (_lastRequest.substr(0, 3) == "GET") response.append(body); } + else if (_lastRequest.substr(0, 12) == "POST /expect") + { + std::string::size_type pos = _lastRequest.find("\r\n\r\n"); + pos += 4; + std::string body = _lastRequest.substr(pos); + response.append("HTTP/1.1 100 Continue\r\n\r\n"); + response.append("HTTP/1.1 200 OK\r\n"); + response.append("Content-Type: text/plain\r\n"); + if (_lastRequest.find("Content-Length") != std::string::npos) + { + response.append("Content-Length: "); + response.append(NumberFormatter::format((int) body.size())); + response.append("\r\n"); + } + else if (_lastRequest.find("chunked") != std::string::npos) + { + response.append("Transfer-Encoding: chunked\r\n"); + } + response.append("Connection: Close\r\n"); + response.append("\r\n"); + response.append(body); + } + else if (_lastRequest.substr(0, 10) == "POST /fail") + { + std::string::size_type pos = _lastRequest.find("\r\n\r\n"); + pos += 4; + std::string body = _lastRequest.substr(pos); + response.append("HTTP/1.1 400 Bad Request\r\n"); + response.append("Connection: Close\r\n"); + response.append("\r\n"); + } else if (_lastRequest.substr(0, 4) == "POST") { std::string::size_type pos = _lastRequest.find("\r\n\r\n");