diff --git a/Net/include/Poco/Net/WebSocket.h b/Net/include/Poco/Net/WebSocket.h index 34e08dd21..e5247f5c6 100644 --- a/Net/include/Poco/Net/WebSocket.h +++ b/Net/include/Poco/Net/WebSocket.h @@ -56,7 +56,7 @@ public: WS_SERVER, /// Server-side WebSocket. WS_CLIENT /// Client-side WebSocket. }; - + enum FrameFlags /// Frame header flags. { @@ -75,10 +75,10 @@ public: FRAME_OP_CLOSE = 0x08, /// Close connection. FRAME_OP_PING = 0x09, /// Ping frame. FRAME_OP_PONG = 0x0a, /// Pong frame. - FRAME_OP_BITMASK = 0x0f, /// Bit mask for opcodes. + FRAME_OP_BITMASK = 0x0f, /// Bit mask for opcodes. FRAME_OP_SETRAW = 0x100 /// Set raw flags (for use with sendBytes() and FRAME_OP_CONT). }; - + enum SendFlags /// Combined header flags and opcodes for use with sendFrame(). { @@ -87,7 +87,7 @@ public: FRAME_BINARY = FRAME_FLAG_FIN | FRAME_OP_BINARY /// Use this for sending a single binary payload frame. }; - + enum StatusCodes /// StatusCodes for CLOSE frames sent with shutdown(). { @@ -105,7 +105,7 @@ public: WS_UNEXPECTED_CONDITION = 1011, WS_RESERVED_TLS_FAILURE = 1015 }; - + enum ErrorCodes /// These error codes can be obtained from a WebSocketException /// to determine the exact cause of the error. @@ -127,21 +127,21 @@ public: WS_ERR_INCOMPLETE_FRAME = 11 /// Incomplete frame received. }; - + WebSocket(HTTPServerRequest& request, HTTPServerResponse& response); /// Creates a server-side WebSocket from within a /// HTTPRequestHandler. - /// + /// /// First verifies that the request is a valid WebSocket upgrade /// request. If so, completes the handshake by sending /// a proper 101 response. /// /// Throws an exception if the request is not a proper WebSocket /// upgrade request. - + WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response); /// Creates a client-side WebSocket, using the given - /// HTTPClientSession and HTTPRequest for the initial handshake + /// HTTPClientSession and HTTPRequest for the initial handshake /// (HTTP Upgrade request). /// /// Additional HTTP headers for the initial handshake request @@ -153,7 +153,7 @@ public: WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response, HTTPCredentials& credentials); /// Creates a client-side WebSocket, using the given - /// HTTPClientSession and HTTPRequest for the initial handshake + /// HTTPClientSession and HTTPRequest for the initial handshake /// (HTTP Upgrade request). /// /// The given credentials are used for authentication @@ -165,11 +165,11 @@ public: /// /// The result of the handshake can be obtained from the response /// object. - + WebSocket(const Socket& socket); /// Creates a WebSocket from another Socket, which must be a WebSocket, /// otherwise a Poco::InvalidArgumentException will be thrown. - + virtual ~WebSocket(); /// Destroys the StreamSocket. @@ -188,11 +188,11 @@ public: /// Sends a Close control frame to the server end of /// the connection to initiate an orderly shutdown /// of the connection. - + int sendFrame(const void* buffer, int length, int flags = FRAME_TEXT); /// Sends the contents of the given buffer through /// the socket as a single frame. - /// + /// /// Values from the FrameFlags, FrameOpcodes and SendFlags enumerations /// can be specified in flags. /// @@ -206,10 +206,19 @@ public: /// Receives a frame from the socket and stores it /// in buffer. Up to length bytes are received. If /// the frame's payload is larger, a WebSocketException - /// is thrown and the WebSocket connection must be + /// is thrown and the WebSocket connection must be /// terminated. /// - /// Returns the number of bytes received. + /// The frame's payload size must not exceed the + /// maximum payload size set with setMaxPayloadSize(). + /// If it does, a WebSocketException (WS_ERR_PAYLOAD_TOO_BIG) + /// is thrown and the WebSocket connection must be + /// terminated. + /// + /// A WebSocketException will also be thrown if a malformed + /// or incomplete frame is received. + /// + /// Returns the number of bytes received. /// A return value of 0 means that the peer has /// shut down or closed the connection. /// @@ -219,10 +228,20 @@ public: /// /// The frame flags and opcode (FrameFlags and FrameOpcodes) /// is stored in flags. - + int receiveFrame(Poco::Buffer& buffer, int& flags); /// Receives a frame from the socket and stores it /// after any previous content in buffer. + /// The buffer will be grown as necessary. + /// + /// The frame's payload size must not exceed the + /// maximum payload size set with setMaxPayloadSize(). + /// If it does, a WebSocketException (WS_ERR_PAYLOAD_TOO_BIG) + /// is thrown and the WebSocket connection must be + /// terminated. + /// + /// A WebSocketException will also be thrown if a malformed + /// or incomplete frame is received. /// /// Returns the number of bytes received. /// A return value of 0 means that the peer has @@ -239,16 +258,26 @@ public: /// Returns WS_SERVER if the WebSocket is a server-side /// WebSocket, or WS_CLIENT otherwise. + void setMaxPayloadSize(int maxPayloadSize); + /// Sets the maximum payload size for receiveFrame(). + /// + /// The default is std::numeric_limits::max(). + + int getMaxPayloadSize() const; + /// Returns the maximum payload size for receiveFrame(). + /// + /// The default is std::numeric_limits::max(). + static const std::string WEBSOCKET_VERSION; /// The WebSocket protocol version supported (13). - + protected: static WebSocketImpl* accept(HTTPServerRequest& request, HTTPServerResponse& response); static WebSocketImpl* connect(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response, HTTPCredentials& credentials); static WebSocketImpl* completeHandshake(HTTPClientSession& cs, HTTPResponse& response, const std::string& key); static std::string computeAccept(const std::string& key); static std::string createKey(); - + private: WebSocket(); diff --git a/Net/include/Poco/Net/WebSocketImpl.h b/Net/include/Poco/Net/WebSocketImpl.h index f52822bd7..fa1284d5b 100644 --- a/Net/include/Poco/Net/WebSocketImpl.h +++ b/Net/include/Poco/Net/WebSocketImpl.h @@ -21,6 +21,7 @@ #include "Poco/Net/StreamSocketImpl.h" #include "Poco/Buffer.h" #include "Poco/Random.h" +#include "Poco/Buffer.h" namespace Poco { @@ -37,11 +38,11 @@ class Net_API WebSocketImpl: public StreamSocketImpl public: WebSocketImpl(StreamSocketImpl* pStreamSocketImpl, HTTPSession& session, bool mustMaskPayload); /// Creates a WebSocketImpl. - + // StreamSocketImpl virtual int sendBytes(const void* buffer, int length, int flags); /// Sends a WebSocket protocol frame. - + virtual int receiveBytes(void* buffer, int length, int flags); /// Receives a WebSocket protocol frame. @@ -66,7 +67,7 @@ public: virtual void sendUrgent(unsigned char data); virtual int available(); virtual bool secure() const; - virtual void setSendTimeout(const Poco::Timespan& timeout); + virtual void setSendTimeout(const Poco::Timespan& timeout); virtual Poco::Timespan getSendTimeout(); virtual void setReceiveTimeout(const Poco::Timespan& timeout); virtual Poco::Timespan getReceiveTimeout(); @@ -74,17 +75,27 @@ public: // Internal int frameFlags() const; /// Returns the frame flags of the most recently received frame. - + bool mustMaskPayload() const; /// Returns true if the payload must be masked. + void setMaxPayloadSize(int maxPayloadSize); + /// Sets the maximum payload size for receiveFrame(). + /// + /// The default is std::numeric_limits::max(). + + int getMaxPayloadSize() const; + /// Returns the maximum payload size for receiveFrame(). + /// + /// The default is std::numeric_limits::max(). + protected: enum { FRAME_FLAG_MASK = 0x80, MAX_HEADER_LENGTH = 14 }; - + int receiveHeader(char mask[4], bool& useMask); int receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask); int receiveNBytes(void* buffer, int bytes); @@ -93,8 +104,9 @@ protected: private: WebSocketImpl(); - + StreamSocketImpl* _pStreamSocketImpl; + int _maxPayloadSize; Poco::Buffer _buffer; int _bufferOffset; int _frameFlags; @@ -118,6 +130,12 @@ inline bool WebSocketImpl::mustMaskPayload() const } +inline int WebSocketImpl::getMaxPayloadSize() const +{ + return _maxPayloadSize; +} + + } } // namespace Poco::Net diff --git a/Net/src/WebSocket.cpp b/Net/src/WebSocket.cpp index 992e44ca1..6ed3ffd4e 100644 --- a/Net/src/WebSocket.cpp +++ b/Net/src/WebSocket.cpp @@ -44,7 +44,7 @@ WebSocket::WebSocket(HTTPServerRequest& request, HTTPServerResponse& response): { } - + WebSocket::WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response): StreamSocket(connect(cs, request, response, _defaultCreds)) { @@ -57,7 +57,7 @@ WebSocket::WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& } -WebSocket::WebSocket(const Socket& socket): +WebSocket::WebSocket(const Socket& socket): StreamSocket(socket) { if (!dynamic_cast(impl())) @@ -111,7 +111,7 @@ int WebSocket::receiveFrame(void* buffer, int length, int& flags) return n; } - + int WebSocket::receiveFrame(Poco::Buffer& buffer, int& flags) { int n = static_cast(impl())->receiveBytes(buffer, 0); @@ -126,6 +126,18 @@ WebSocket::Mode WebSocket::mode() const } +void WebSocket::setMaxPayloadSize(int maxPayloadSize) +{ + static_cast(impl())->setMaxPayloadSize(maxPayloadSize); +} + + +int WebSocket::getMaxPayloadSize() const +{ + return static_cast(impl())->getMaxPayloadSize(); +} + + WebSocketImpl* WebSocket::accept(HTTPServerRequest& request, HTTPServerResponse& response) { if (request.hasToken("Connection", "upgrade") && icompare(request.get("Upgrade", ""), "websocket") == 0) @@ -136,14 +148,14 @@ WebSocketImpl* WebSocket::accept(HTTPServerRequest& request, HTTPServerResponse& std::string key = request.get("Sec-WebSocket-Key", ""); Poco::trimInPlace(key); if (key.empty()) throw WebSocketException("Missing Sec-WebSocket-Key in handshake request", WS_ERR_HANDSHAKE_NO_KEY); - + response.setStatusAndReason(HTTPResponse::HTTP_SWITCHING_PROTOCOLS); response.set("Upgrade", "websocket"); response.set("Connection", "Upgrade"); response.set("Sec-WebSocket-Accept", computeAccept(key)); response.setContentLength(0); response.send().flush(); - + HTTPServerRequestImpl& requestImpl = static_cast(request); return new WebSocketImpl(static_cast(requestImpl.detachSocket().impl()), requestImpl.session(), false); } @@ -205,7 +217,7 @@ WebSocketImpl* WebSocket::connect(HTTPClientSession& cs, HTTPRequest& request, H WebSocketImpl* WebSocket::completeHandshake(HTTPClientSession& cs, HTTPResponse& response, const std::string& key) { std::string connection = response.get("Connection", ""); - if (Poco::icompare(connection, "Upgrade") != 0) + if (Poco::icompare(connection, "Upgrade") != 0) throw WebSocketException("No Connection: Upgrade header in handshake response", WS_ERR_NO_HANDSHAKE); std::string upgrade = response.get("Upgrade", ""); if (Poco::icompare(upgrade, "websocket") != 0) diff --git a/Net/src/WebSocketImpl.cpp b/Net/src/WebSocketImpl.cpp index 71a5d4bb5..02707b513 100644 --- a/Net/src/WebSocketImpl.cpp +++ b/Net/src/WebSocketImpl.cpp @@ -21,6 +21,7 @@ #include "Poco/BinaryReader.h" #include "Poco/MemoryStream.h" #include "Poco/Format.h" +#include #include @@ -31,6 +32,7 @@ namespace Net { WebSocketImpl::WebSocketImpl(StreamSocketImpl* pStreamSocketImpl, HTTPSession& session, bool mustMaskPayload): StreamSocketImpl(pStreamSocketImpl->sockfd()), _pStreamSocketImpl(pStreamSocketImpl), + _maxPayloadSize(std::numeric_limits::max()), _buffer(0), _bufferOffset(0), _frameFlags(0), @@ -134,6 +136,7 @@ int WebSocketImpl::receiveHeader(char mask[4], bool& useMask) Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); Poco::UInt64 l; reader >> l; + if (l > _maxPayloadSize) throw WebSocketException("Payload too big", WebSocket::WS_ERR_PAYLOAD_TOO_BIG); payloadLength = static_cast(l); } else if (lengthByte == 126) @@ -148,10 +151,12 @@ int WebSocketImpl::receiveHeader(char mask[4], bool& useMask) Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); Poco::UInt16 l; reader >> l; + if (l > _maxPayloadSize) throw WebSocketException("Payload too big", WebSocket::WS_ERR_PAYLOAD_TOO_BIG); payloadLength = static_cast(l); } else { + if (lengthByte > _maxPayloadSize) throw WebSocketException("Payload too big", WebSocket::WS_ERR_PAYLOAD_TOO_BIG); payloadLength = lengthByte; } @@ -169,6 +174,14 @@ int WebSocketImpl::receiveHeader(char mask[4], bool& useMask) } +void WebSocketImpl::setMaxPayloadSize(int maxPayloadSize) +{ + poco_assert (maxPayloadSize > 0); + + _maxPayloadSize = maxPayloadSize; +} + + int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask) { int received = receiveNBytes(reinterpret_cast(buffer), payloadLength); @@ -205,7 +218,7 @@ int WebSocketImpl::receiveBytes(Poco::Buffer& buffer, int, const Poco::Tim int payloadLength = receiveHeader(mask, useMask); if (payloadLength <= 0) return payloadLength; - int oldSize = static_cast(buffer.size()); + std::size_t oldSize = buffer.size(); buffer.resize(oldSize + payloadLength); return receivePayload(buffer.begin() + oldSize, payloadLength, mask, useMask); }