diff --git a/Net/include/Poco/Net/HTTPServerRequestImpl.h b/Net/include/Poco/Net/HTTPServerRequestImpl.h index 38ea80594..01136bf67 100644 --- a/Net/include/Poco/Net/HTTPServerRequestImpl.h +++ b/Net/include/Poco/Net/HTTPServerRequestImpl.h @@ -81,6 +81,9 @@ public: StreamSocket detachSocket(); /// Returns the underlying socket after detaching /// it from the server session. + + HTTPServerSession& session(); + /// Returns the underlying HTTPServerSession. protected: static const std::string EXPECT; @@ -130,6 +133,12 @@ inline HTTPServerResponse& HTTPServerRequestImpl::response() const } +inline HTTPServerSession& HTTPServerRequestImpl::session() +{ + return _session; +} + + } } // namespace Poco::Net diff --git a/Net/include/Poco/Net/HTTPSession.h b/Net/include/Poco/Net/HTTPSession.h index 0b01c2ca8..2ba04b893 100644 --- a/Net/include/Poco/Net/HTTPSession.h +++ b/Net/include/Poco/Net/HTTPSession.h @@ -25,6 +25,7 @@ #include "Poco/Timespan.h" #include "Poco/Exception.h" #include "Poco/Any.h" +#include "Poco/Buffer.h" #include @@ -100,6 +101,14 @@ public: StreamSocket& socket(); /// Returns a reference to the underlying socket. + + void drainBuffer(Poco::Buffer& buffer); + /// Copies all bytes remaining in the internal buffer to the + /// given Poco::Buffer, resizing it as necessary. + /// + /// This is usually used together with detachSocket() to + /// obtain any data already read from the socket, but not + /// yet processed. protected: HTTPSession(); diff --git a/Net/include/Poco/Net/WebSocketImpl.h b/Net/include/Poco/Net/WebSocketImpl.h index 31857c162..68676faf1 100644 --- a/Net/include/Poco/Net/WebSocketImpl.h +++ b/Net/include/Poco/Net/WebSocketImpl.h @@ -22,19 +22,23 @@ #include "Poco/Net/StreamSocketImpl.h" #include "Poco/Random.h" +#include "Poco/Buffer.h" namespace Poco { namespace Net { +class HTTPSession; + + class Net_API WebSocketImpl: public StreamSocketImpl /// This class implements a WebSocket, according /// to the WebSocket protocol described in RFC 6455. { public: - WebSocketImpl(StreamSocketImpl* pStreamSocketImpl, bool mustMaskPayload); - /// Creates a StreamSocketImpl using the given native socket. + WebSocketImpl(StreamSocketImpl* pStreamSocketImpl, HTTPSession& session, bool mustMaskPayload); + /// Creates a WebSocketImpl. // StreamSocketImpl virtual int sendBytes(const void* buffer, int length, int flags); @@ -79,12 +83,15 @@ protected: }; int receiveNBytes(void* buffer, int bytes); + int receiveSomeBytes(char* buffer, int bytes); virtual ~WebSocketImpl(); private: WebSocketImpl(); StreamSocketImpl* _pStreamSocketImpl; + Poco::Buffer _buffer; + int _bufferOffset; int _frameFlags; bool _mustMaskPayload; Poco::Random _rnd; diff --git a/Net/src/HTTPSession.cpp b/Net/src/HTTPSession.cpp index 030ec7246..c9610922d 100644 --- a/Net/src/HTTPSession.cpp +++ b/Net/src/HTTPSession.cpp @@ -238,4 +238,11 @@ void HTTPSession::attachSessionData(const Poco::Any& data) } +void HTTPSession::drainBuffer(Poco::Buffer& buffer) +{ + buffer.assign(_pCurrent, static_cast(_pEnd - _pCurrent)); + _pCurrent = _pEnd; +} + + } } // namespace Poco::Net diff --git a/Net/src/WebSocket.cpp b/Net/src/WebSocket.cpp index 9418dd8d4..9ff32d1f3 100644 --- a/Net/src/WebSocket.cpp +++ b/Net/src/WebSocket.cpp @@ -19,6 +19,7 @@ #include "Poco/Net/HTTPServerRequestImpl.h" #include "Poco/Net/HTTPServerResponse.h" #include "Poco/Net/HTTPClientSession.h" +#include "Poco/Net/HTTPServerSession.h" #include "Poco/Net/NetException.h" #include "Poco/Buffer.h" #include "Poco/MemoryStream.h" @@ -137,7 +138,9 @@ WebSocketImpl* WebSocket::accept(HTTPServerRequest& request, HTTPServerResponse& response.set("Sec-WebSocket-Accept", computeAccept(key)); response.setContentLength(0); response.send().flush(); - return new WebSocketImpl(static_cast(static_cast(request).detachSocket().impl()), false); + + HTTPServerRequestImpl& requestImpl = static_cast(request); + return new WebSocketImpl(static_cast(requestImpl.detachSocket().impl()), requestImpl.session(), false); } else throw WebSocketException("No WebSocket handshake", WS_ERR_NO_HANDSHAKE); } @@ -205,7 +208,7 @@ WebSocketImpl* WebSocket::completeHandshake(HTTPClientSession& cs, HTTPResponse& std::string accept = response.get("Sec-WebSocket-Accept", ""); if (accept != computeAccept(key)) throw WebSocketException("Invalid or missing Sec-WebSocket-Accept header in handshake response", WS_ERR_HANDSHAKE_ACCEPT); - return new WebSocketImpl(static_cast(cs.detachSocket().impl()), true); + return new WebSocketImpl(static_cast(cs.detachSocket().impl()), cs, true); } diff --git a/Net/src/WebSocketImpl.cpp b/Net/src/WebSocketImpl.cpp index a9e75817f..45ebec242 100644 --- a/Net/src/WebSocketImpl.cpp +++ b/Net/src/WebSocketImpl.cpp @@ -17,6 +17,7 @@ #include "Poco/Net/WebSocketImpl.h" #include "Poco/Net/NetException.h" #include "Poco/Net/WebSocket.h" +#include "Poco/Net/HTTPSession.h" #include "Poco/Buffer.h" #include "Poco/BinaryWriter.h" #include "Poco/BinaryReader.h" @@ -29,14 +30,17 @@ namespace Poco { namespace Net { -WebSocketImpl::WebSocketImpl(StreamSocketImpl* pStreamSocketImpl, bool mustMaskPayload): +WebSocketImpl::WebSocketImpl(StreamSocketImpl* pStreamSocketImpl, HTTPSession& session, bool mustMaskPayload): StreamSocketImpl(pStreamSocketImpl->sockfd()), _pStreamSocketImpl(pStreamSocketImpl), + _buffer(0), + _bufferOffset(0), _frameFlags(0), _mustMaskPayload(mustMaskPayload) { poco_check_ptr(pStreamSocketImpl); _pStreamSocketImpl->duplicate(); + session.drainBuffer(_buffer); } @@ -192,12 +196,12 @@ int WebSocketImpl::receiveBytes(void* buffer, int length, int) int WebSocketImpl::receiveNBytes(void* buffer, int bytes) { - int received = _pStreamSocketImpl->receiveBytes(reinterpret_cast(buffer), bytes); + int received = receiveSomeBytes(reinterpret_cast(buffer), bytes); if (received > 0) { while (received < bytes) { - int n = _pStreamSocketImpl->receiveBytes(reinterpret_cast(buffer) + received, bytes - received); + int n = receiveSomeBytes(reinterpret_cast(buffer) + received, bytes - received); if (n > 0) received += n; else @@ -208,6 +212,23 @@ int WebSocketImpl::receiveNBytes(void* buffer, int bytes) } +int WebSocketImpl::receiveSomeBytes(char* buffer, int bytes) +{ + int n = _buffer.size() - _bufferOffset; + if (n > 0) + { + if (bytes < n) n = bytes; + std::memcpy(buffer, _buffer.begin() + _bufferOffset, n); + _bufferOffset += n; + return n; + } + else + { + return _pStreamSocketImpl->receiveBytes(buffer, bytes); + } +} + + SocketImpl* WebSocketImpl::acceptConnection(SocketAddress& clientAddr) { throw Poco::InvalidAccessException("Cannot acceptConnection() on a WebSocketImpl");