added better support for 'Expect: 100-continue' header and '100 Continue' responses

This commit is contained in:
Guenter Obiltschnig 2015-10-06 11:11:16 +02:00
parent 6bb1f4ad62
commit c66849f9ee
11 changed files with 175 additions and 41 deletions

View File

@ -227,6 +227,21 @@ public:
/// to ensure a new connection will be set up /// to ensure a new connection will be set up
/// for the next request. /// for the next request.
virtual bool canContinue(HTTPResponse& response);
/// If the request contains a "Expect: 100-continue" header,
/// (see HTTPRequest::setExpectContinue()) the method can be
/// used to check whether the server has
/// sent a 100 Continue response before continuing with the
/// request, i.e. sending the request body.
///
/// Returns true if the server has responded with 100 Continue,
/// otherwise false. The HTTPResponse object contains the
/// response sent by the server.
///
/// In any case, receiveResponse() must be called as well in
/// order to complete the request. The same HTTPResponse object
/// passed to canContinue() must also be passed to receiveResponse().
void reset(); void reset();
/// Resets the session and closes the socket. /// Resets the session and closes the socket.
/// ///
@ -291,6 +306,7 @@ private:
bool _reconnect; bool _reconnect;
bool _mustReconnect; bool _mustReconnect;
bool _expectResponseBody; bool _expectResponseBody;
bool _responseReceived;
Poco::SharedPtr<std::ostream> _pRequestStream; Poco::SharedPtr<std::ostream> _pRequestStream;
Poco::SharedPtr<std::istream> _pResponseStream; Poco::SharedPtr<std::istream> _pResponseStream;

View File

@ -105,6 +105,14 @@ public:
/// Sets the authentication scheme and information for /// Sets the authentication scheme and information for
/// this request. /// this request.
bool getExpectContinue() const;
/// Returns true if the request contains an
/// "Expect: 100-continue" header.
void setExpectContinue(bool expectContinue);
/// Adds a "Expect: 100-continue" header to the request if
/// expectContinue is true, otherwise removes the Expect header.
bool hasProxyCredentials() const; bool hasProxyCredentials() const;
/// Returns true iff the request contains proxy authentication /// Returns true iff the request contains proxy authentication
/// information in the form of an Proxy-Authorization header. /// information in the form of an Proxy-Authorization header.
@ -143,6 +151,7 @@ public:
static const std::string AUTHORIZATION; static const std::string AUTHORIZATION;
static const std::string PROXY_AUTHORIZATION; static const std::string PROXY_AUTHORIZATION;
static const std::string UPGRADE; static const std::string UPGRADE;
static const std::string EXPECT;
protected: protected:
void getCredentials(const std::string& header, std::string& scheme, std::string& authInfo) const; void getCredentials(const std::string& header, std::string& scheme, std::string& authInfo) const;

View File

@ -56,10 +56,6 @@ public:
/// The stream must be valid until the HTTPServerRequest /// The stream must be valid until the HTTPServerRequest
/// object is destroyed. /// object is destroyed.
virtual bool expectContinue() const = 0;
/// Returns true if the client expects a
/// 100 Continue response.
virtual const SocketAddress& clientAddress() const = 0; virtual const SocketAddress& clientAddress() const = 0;
/// Returns the client's address. /// Returns the client's address.

View File

@ -59,10 +59,6 @@ public:
/// The stream is valid until the HTTPServerRequestImpl /// The stream is valid until the HTTPServerRequestImpl
/// object is destroyed. /// object is destroyed.
bool expectContinue() const;
/// Returns true if the client expects a
/// 100 Continue response.
const SocketAddress& clientAddress() const; const SocketAddress& clientAddress() const;
/// Returns the client's address. /// Returns the client's address.
@ -88,9 +84,6 @@ public:
/// Returns the underlying socket after detaching /// Returns the underlying socket after detaching
/// it from the server session. /// it from the server session.
protected:
static const std::string EXPECT;
private: private:
HTTPServerResponseImpl& _response; HTTPServerResponseImpl& _response;
HTTPServerSession& _session; HTTPServerSession& _session;

View File

@ -46,7 +46,8 @@ HTTPClientSession::HTTPClientSession():
_keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0), _keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0),
_reconnect(false), _reconnect(false),
_mustReconnect(false), _mustReconnect(false),
_expectResponseBody(false) _expectResponseBody(false),
_responseReceived(false)
{ {
} }
@ -58,7 +59,8 @@ HTTPClientSession::HTTPClientSession(const StreamSocket& socket):
_keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0), _keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0),
_reconnect(false), _reconnect(false),
_mustReconnect(false), _mustReconnect(false),
_expectResponseBody(false) _expectResponseBody(false),
_responseReceived(false)
{ {
} }
@ -70,7 +72,8 @@ HTTPClientSession::HTTPClientSession(const SocketAddress& address):
_keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0), _keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0),
_reconnect(false), _reconnect(false),
_mustReconnect(false), _mustReconnect(false),
_expectResponseBody(false) _expectResponseBody(false),
_responseReceived(false)
{ {
} }
@ -82,7 +85,8 @@ HTTPClientSession::HTTPClientSession(const std::string& host, Poco::UInt16 port)
_keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0), _keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0),
_reconnect(false), _reconnect(false),
_mustReconnect(false), _mustReconnect(false),
_expectResponseBody(false) _expectResponseBody(false),
_responseReceived(false)
{ {
} }
@ -94,7 +98,8 @@ HTTPClientSession::HTTPClientSession(const std::string& host, Poco::UInt16 port,
_keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0), _keepAliveTimeout(DEFAULT_KEEP_ALIVE_TIMEOUT, 0),
_reconnect(false), _reconnect(false),
_mustReconnect(false), _mustReconnect(false),
_expectResponseBody(false) _expectResponseBody(false),
_responseReceived(false)
{ {
} }
@ -192,6 +197,7 @@ std::ostream& HTTPClientSession::sendRequest(HTTPRequest& request)
{ {
clearException(); clearException();
_pResponseStream = 0; _pResponseStream = 0;
_responseReceived = false;
bool keepAlive = getKeepAlive(); bool keepAlive = getKeepAlive();
if (((connected() && !keepAlive) || mustReconnect()) && !_host.empty()) if (((connected() && !keepAlive) || mustReconnect()) && !_host.empty())
@ -259,6 +265,8 @@ std::istream& HTTPClientSession::receiveResponse(HTTPResponse& response)
_pRequestStream = 0; _pRequestStream = 0;
if (networkException()) networkException()->rethrow(); if (networkException()) networkException()->rethrow();
if (!_responseReceived)
{
do do
{ {
response.clear(); response.clear();
@ -278,6 +286,7 @@ std::istream& HTTPClientSession::receiveResponse(HTTPResponse& response)
} }
} }
while (response.getStatus() == HTTPResponse::HTTP_CONTINUE); while (response.getStatus() == HTTPResponse::HTTP_CONTINUE);
}
_mustReconnect = getKeepAlive() && !response.getKeepAlive(); _mustReconnect = getKeepAlive() && !response.getKeepAlive();
@ -298,6 +307,34 @@ std::istream& HTTPClientSession::receiveResponse(HTTPResponse& response)
} }
bool HTTPClientSession::canContinue(HTTPResponse& response)
{
poco_assert (!_responseReceived);
_pRequestStream->flush();
if (networkException()) networkException()->rethrow();
response.clear();
HTTPHeaderInputStream his(*this);
try
{
response.read(his);
}
catch (Exception&)
{
close();
if (networkException())
networkException()->rethrow();
else
throw;
throw;
}
_responseReceived = response.getStatus() != HTTPResponse::HTTP_CONTINUE;
return !_responseReceived;
}
void HTTPClientSession::reset() void HTTPClientSession::reset()
{ {
close(); close();

View File

@ -43,6 +43,7 @@ const std::string HTTPRequest::COOKIE = "Cookie";
const std::string HTTPRequest::AUTHORIZATION = "Authorization"; const std::string HTTPRequest::AUTHORIZATION = "Authorization";
const std::string HTTPRequest::PROXY_AUTHORIZATION = "Proxy-Authorization"; const std::string HTTPRequest::PROXY_AUTHORIZATION = "Proxy-Authorization";
const std::string HTTPRequest::UPGRADE = "Upgrade"; const std::string HTTPRequest::UPGRADE = "Upgrade";
const std::string HTTPRequest::EXPECT = "Expect";
HTTPRequest::HTTPRequest(): HTTPRequest::HTTPRequest():
@ -259,4 +260,20 @@ void HTTPRequest::setCredentials(const std::string& header, const std::string& s
} }
bool HTTPRequest::getExpectContinue() const
{
const std::string& expect = get(EXPECT, EMPTY);
return !expect.empty() && icompare(expect, "100-continue") == 0;
}
void HTTPRequest::setExpectContinue(bool expectContinue)
{
if (expectContinue)
set(EXPECT, "100-continue");
else
erase(EXPECT);
}
} } // namespace Poco::Net } } // namespace Poco::Net

View File

@ -81,7 +81,7 @@ void HTTPServerConnection::run()
std::auto_ptr<HTTPRequestHandler> pHandler(_pFactory->createRequestHandler(request)); std::auto_ptr<HTTPRequestHandler> pHandler(_pFactory->createRequestHandler(request));
if (pHandler.get()) if (pHandler.get())
{ {
if (request.expectContinue()) if (request.getExpectContinue() && response.getStatus() == HTTPResponse::HTTP_OK)
response.sendContinue(); response.sendContinue();
pHandler->handleRequest(request, response); pHandler->handleRequest(request, response);

View File

@ -33,9 +33,6 @@ namespace Poco {
namespace Net { namespace Net {
const std::string HTTPServerRequestImpl::EXPECT("Expect");
HTTPServerRequestImpl::HTTPServerRequestImpl(HTTPServerResponseImpl& response, HTTPServerSession& session, HTTPServerParams* pParams): HTTPServerRequestImpl::HTTPServerRequestImpl(HTTPServerResponseImpl& response, HTTPServerSession& session, HTTPServerParams* pParams):
_response(response), _response(response),
_session(session), _session(session),
@ -90,11 +87,4 @@ StreamSocket HTTPServerRequestImpl::detachSocket()
} }
bool HTTPServerRequestImpl::expectContinue() const
{
const std::string& expect = get(EXPECT, EMPTY);
return !expect.empty() && icompare(expect, "100-continue") == 0;
}
} } // namespace Poco::Net } } // namespace Poco::Net

View File

@ -300,6 +300,47 @@ void HTTPClientSessionTest::testBypassProxy()
} }
void HTTPClientSessionTest::testExpectContinue()
{
HTTPTestServer srv;
HTTPClientSession s("localhost", srv.port());
HTTPRequest request(HTTPRequest::HTTP_POST, "/expect");
std::string body("this is a random request body\r\n0\r\n");
request.setContentLength((int) body.length());
request.setExpectContinue(true);
s.sendRequest(request) << body;
HTTPResponse response;
assert (s.canContinue(response));
assert (response.getStatus() == HTTPResponse::HTTP_CONTINUE);
std::istream& rs = s.receiveResponse(response);
assert (response.getStatus() == HTTPResponse::HTTP_OK);
assert (response.getContentLength() == body.length());
std::ostringstream ostr;
StreamCopier::copyStream(rs, ostr);
assert (ostr.str() == body);
}
void HTTPClientSessionTest::testExpectContinueFail()
{
HTTPTestServer srv;
HTTPClientSession s("localhost", srv.port());
HTTPRequest request(HTTPRequest::HTTP_POST, "/fail");
std::string body("this is a random request body\r\n0\r\n");
request.setContentLength((int) body.length());
request.setExpectContinue(true);
s.sendRequest(request) << body;
HTTPResponse response;
assert (!s.canContinue(response));
assert (response.getStatus() == HTTPResponse::HTTP_BAD_REQUEST);
std::istream& rs = s.receiveResponse(response);
assert (response.getStatus() == HTTPResponse::HTTP_BAD_REQUEST);
std::ostringstream ostr;
StreamCopier::copyStream(rs, ostr);
assert (ostr.str().empty());
}
void HTTPClientSessionTest::setUp() void HTTPClientSessionTest::setUp()
{ {
} }
@ -327,6 +368,8 @@ CppUnit::Test* HTTPClientSessionTest::suite()
CppUnit_addTest(pSuite, HTTPClientSessionTest, testProxy); CppUnit_addTest(pSuite, HTTPClientSessionTest, testProxy);
CppUnit_addTest(pSuite, HTTPClientSessionTest, testProxyAuth); CppUnit_addTest(pSuite, HTTPClientSessionTest, testProxyAuth);
CppUnit_addTest(pSuite, HTTPClientSessionTest, testBypassProxy); CppUnit_addTest(pSuite, HTTPClientSessionTest, testBypassProxy);
CppUnit_addTest(pSuite, HTTPClientSessionTest, testExpectContinue);
CppUnit_addTest(pSuite, HTTPClientSessionTest, testExpectContinueFail);
return pSuite; return pSuite;
} }

View File

@ -39,6 +39,8 @@ public:
void testProxy(); void testProxy();
void testProxyAuth(); void testProxyAuth();
void testBypassProxy(); void testBypassProxy();
void testExpectContinue();
void testExpectContinueFail();
void setUp(); void setUp();
void tearDown(); void tearDown();

View File

@ -142,6 +142,37 @@ std::string HTTPTestServer::handleRequest() const
if (_lastRequest.substr(0, 3) == "GET") if (_lastRequest.substr(0, 3) == "GET")
response.append(body); response.append(body);
} }
else if (_lastRequest.substr(0, 12) == "POST /expect")
{
std::string::size_type pos = _lastRequest.find("\r\n\r\n");
pos += 4;
std::string body = _lastRequest.substr(pos);
response.append("HTTP/1.1 100 Continue\r\n\r\n");
response.append("HTTP/1.1 200 OK\r\n");
response.append("Content-Type: text/plain\r\n");
if (_lastRequest.find("Content-Length") != std::string::npos)
{
response.append("Content-Length: ");
response.append(NumberFormatter::format((int) body.size()));
response.append("\r\n");
}
else if (_lastRequest.find("chunked") != std::string::npos)
{
response.append("Transfer-Encoding: chunked\r\n");
}
response.append("Connection: Close\r\n");
response.append("\r\n");
response.append(body);
}
else if (_lastRequest.substr(0, 10) == "POST /fail")
{
std::string::size_type pos = _lastRequest.find("\r\n\r\n");
pos += 4;
std::string body = _lastRequest.substr(pos);
response.append("HTTP/1.1 400 Bad Request\r\n");
response.append("Connection: Close\r\n");
response.append("\r\n");
}
else if (_lastRequest.substr(0, 4) == "POST") else if (_lastRequest.substr(0, 4) == "POST")
{ {
std::string::size_type pos = _lastRequest.find("\r\n\r\n"); std::string::size_type pos = _lastRequest.find("\r\n\r\n");