WebSocket: allow setting the maximum payload size for receiving frames

This commit is contained in:
Günter Obiltschnig 2019-05-16 08:08:43 +02:00
parent 7d64cc153b
commit 16f63eed0a
4 changed files with 104 additions and 32 deletions

View File

@ -56,7 +56,7 @@ public:
WS_SERVER, /// Server-side WebSocket. WS_SERVER, /// Server-side WebSocket.
WS_CLIENT /// Client-side WebSocket. WS_CLIENT /// Client-side WebSocket.
}; };
enum FrameFlags enum FrameFlags
/// Frame header flags. /// Frame header flags.
{ {
@ -75,10 +75,10 @@ public:
FRAME_OP_CLOSE = 0x08, /// Close connection. FRAME_OP_CLOSE = 0x08, /// Close connection.
FRAME_OP_PING = 0x09, /// Ping frame. FRAME_OP_PING = 0x09, /// Ping frame.
FRAME_OP_PONG = 0x0a, /// Pong frame. FRAME_OP_PONG = 0x0a, /// Pong frame.
FRAME_OP_BITMASK = 0x0f, /// Bit mask for opcodes. FRAME_OP_BITMASK = 0x0f, /// Bit mask for opcodes.
FRAME_OP_SETRAW = 0x100 /// Set raw flags (for use with sendBytes() and FRAME_OP_CONT). FRAME_OP_SETRAW = 0x100 /// Set raw flags (for use with sendBytes() and FRAME_OP_CONT).
}; };
enum SendFlags enum SendFlags
/// Combined header flags and opcodes for use with sendFrame(). /// Combined header flags and opcodes for use with sendFrame().
{ {
@ -87,7 +87,7 @@ public:
FRAME_BINARY = FRAME_FLAG_FIN | FRAME_OP_BINARY FRAME_BINARY = FRAME_FLAG_FIN | FRAME_OP_BINARY
/// Use this for sending a single binary payload frame. /// Use this for sending a single binary payload frame.
}; };
enum StatusCodes enum StatusCodes
/// StatusCodes for CLOSE frames sent with shutdown(). /// StatusCodes for CLOSE frames sent with shutdown().
{ {
@ -105,7 +105,7 @@ public:
WS_UNEXPECTED_CONDITION = 1011, WS_UNEXPECTED_CONDITION = 1011,
WS_RESERVED_TLS_FAILURE = 1015 WS_RESERVED_TLS_FAILURE = 1015
}; };
enum ErrorCodes enum ErrorCodes
/// These error codes can be obtained from a WebSocketException /// These error codes can be obtained from a WebSocketException
/// to determine the exact cause of the error. /// to determine the exact cause of the error.
@ -127,21 +127,21 @@ public:
WS_ERR_INCOMPLETE_FRAME = 11 WS_ERR_INCOMPLETE_FRAME = 11
/// Incomplete frame received. /// Incomplete frame received.
}; };
WebSocket(HTTPServerRequest& request, HTTPServerResponse& response); WebSocket(HTTPServerRequest& request, HTTPServerResponse& response);
/// Creates a server-side WebSocket from within a /// Creates a server-side WebSocket from within a
/// HTTPRequestHandler. /// HTTPRequestHandler.
/// ///
/// First verifies that the request is a valid WebSocket upgrade /// First verifies that the request is a valid WebSocket upgrade
/// request. If so, completes the handshake by sending /// request. If so, completes the handshake by sending
/// a proper 101 response. /// a proper 101 response.
/// ///
/// Throws an exception if the request is not a proper WebSocket /// Throws an exception if the request is not a proper WebSocket
/// upgrade request. /// upgrade request.
WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response); WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response);
/// Creates a client-side WebSocket, using the given /// Creates a client-side WebSocket, using the given
/// HTTPClientSession and HTTPRequest for the initial handshake /// HTTPClientSession and HTTPRequest for the initial handshake
/// (HTTP Upgrade request). /// (HTTP Upgrade request).
/// ///
/// Additional HTTP headers for the initial handshake request /// Additional HTTP headers for the initial handshake request
@ -153,7 +153,7 @@ public:
WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response, HTTPCredentials& credentials); WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response, HTTPCredentials& credentials);
/// Creates a client-side WebSocket, using the given /// Creates a client-side WebSocket, using the given
/// HTTPClientSession and HTTPRequest for the initial handshake /// HTTPClientSession and HTTPRequest for the initial handshake
/// (HTTP Upgrade request). /// (HTTP Upgrade request).
/// ///
/// The given credentials are used for authentication /// The given credentials are used for authentication
@ -165,11 +165,11 @@ public:
/// ///
/// The result of the handshake can be obtained from the response /// The result of the handshake can be obtained from the response
/// object. /// object.
WebSocket(const Socket& socket); WebSocket(const Socket& socket);
/// Creates a WebSocket from another Socket, which must be a WebSocket, /// Creates a WebSocket from another Socket, which must be a WebSocket,
/// otherwise a Poco::InvalidArgumentException will be thrown. /// otherwise a Poco::InvalidArgumentException will be thrown.
virtual ~WebSocket(); virtual ~WebSocket();
/// Destroys the StreamSocket. /// Destroys the StreamSocket.
@ -188,11 +188,11 @@ public:
/// Sends a Close control frame to the server end of /// Sends a Close control frame to the server end of
/// the connection to initiate an orderly shutdown /// the connection to initiate an orderly shutdown
/// of the connection. /// of the connection.
int sendFrame(const void* buffer, int length, int flags = FRAME_TEXT); int sendFrame(const void* buffer, int length, int flags = FRAME_TEXT);
/// Sends the contents of the given buffer through /// Sends the contents of the given buffer through
/// the socket as a single frame. /// the socket as a single frame.
/// ///
/// Values from the FrameFlags, FrameOpcodes and SendFlags enumerations /// Values from the FrameFlags, FrameOpcodes and SendFlags enumerations
/// can be specified in flags. /// can be specified in flags.
/// ///
@ -206,10 +206,19 @@ public:
/// Receives a frame from the socket and stores it /// Receives a frame from the socket and stores it
/// in buffer. Up to length bytes are received. If /// in buffer. Up to length bytes are received. If
/// the frame's payload is larger, a WebSocketException /// the frame's payload is larger, a WebSocketException
/// is thrown and the WebSocket connection must be /// is thrown and the WebSocket connection must be
/// terminated. /// terminated.
/// ///
/// Returns the number of bytes received. /// The frame's payload size must not exceed the
/// maximum payload size set with setMaxPayloadSize().
/// If it does, a WebSocketException (WS_ERR_PAYLOAD_TOO_BIG)
/// is thrown and the WebSocket connection must be
/// terminated.
///
/// A WebSocketException will also be thrown if a malformed
/// or incomplete frame is received.
///
/// Returns the number of bytes received.
/// A return value of 0 means that the peer has /// A return value of 0 means that the peer has
/// shut down or closed the connection. /// shut down or closed the connection.
/// ///
@ -219,10 +228,20 @@ public:
/// ///
/// The frame flags and opcode (FrameFlags and FrameOpcodes) /// The frame flags and opcode (FrameFlags and FrameOpcodes)
/// is stored in flags. /// is stored in flags.
int receiveFrame(Poco::Buffer<char>& buffer, int& flags); int receiveFrame(Poco::Buffer<char>& buffer, int& flags);
/// Receives a frame from the socket and stores it /// Receives a frame from the socket and stores it
/// after any previous content in buffer. /// after any previous content in buffer.
/// The buffer will be grown as necessary.
///
/// The frame's payload size must not exceed the
/// maximum payload size set with setMaxPayloadSize().
/// If it does, a WebSocketException (WS_ERR_PAYLOAD_TOO_BIG)
/// is thrown and the WebSocket connection must be
/// terminated.
///
/// A WebSocketException will also be thrown if a malformed
/// or incomplete frame is received.
/// ///
/// Returns the number of bytes received. /// Returns the number of bytes received.
/// A return value of 0 means that the peer has /// A return value of 0 means that the peer has
@ -239,16 +258,26 @@ public:
/// Returns WS_SERVER if the WebSocket is a server-side /// Returns WS_SERVER if the WebSocket is a server-side
/// WebSocket, or WS_CLIENT otherwise. /// WebSocket, or WS_CLIENT otherwise.
void setMaxPayloadSize(int maxPayloadSize);
/// Sets the maximum payload size for receiveFrame().
///
/// The default is std::numeric_limits<int>::max().
int getMaxPayloadSize() const;
/// Returns the maximum payload size for receiveFrame().
///
/// The default is std::numeric_limits<int>::max().
static const std::string WEBSOCKET_VERSION; static const std::string WEBSOCKET_VERSION;
/// The WebSocket protocol version supported (13). /// The WebSocket protocol version supported (13).
protected: protected:
static WebSocketImpl* accept(HTTPServerRequest& request, HTTPServerResponse& response); static WebSocketImpl* accept(HTTPServerRequest& request, HTTPServerResponse& response);
static WebSocketImpl* connect(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response, HTTPCredentials& credentials); static WebSocketImpl* connect(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response, HTTPCredentials& credentials);
static WebSocketImpl* completeHandshake(HTTPClientSession& cs, HTTPResponse& response, const std::string& key); static WebSocketImpl* completeHandshake(HTTPClientSession& cs, HTTPResponse& response, const std::string& key);
static std::string computeAccept(const std::string& key); static std::string computeAccept(const std::string& key);
static std::string createKey(); static std::string createKey();
private: private:
WebSocket(); WebSocket();

View File

@ -21,6 +21,7 @@
#include "Poco/Net/StreamSocketImpl.h" #include "Poco/Net/StreamSocketImpl.h"
#include "Poco/Buffer.h" #include "Poco/Buffer.h"
#include "Poco/Random.h" #include "Poco/Random.h"
#include "Poco/Buffer.h"
namespace Poco { namespace Poco {
@ -37,11 +38,11 @@ class Net_API WebSocketImpl: public StreamSocketImpl
public: public:
WebSocketImpl(StreamSocketImpl* pStreamSocketImpl, HTTPSession& session, bool mustMaskPayload); WebSocketImpl(StreamSocketImpl* pStreamSocketImpl, HTTPSession& session, bool mustMaskPayload);
/// Creates a WebSocketImpl. /// Creates a WebSocketImpl.
// StreamSocketImpl // StreamSocketImpl
virtual int sendBytes(const void* buffer, int length, int flags); virtual int sendBytes(const void* buffer, int length, int flags);
/// Sends a WebSocket protocol frame. /// Sends a WebSocket protocol frame.
virtual int receiveBytes(void* buffer, int length, int flags); virtual int receiveBytes(void* buffer, int length, int flags);
/// Receives a WebSocket protocol frame. /// Receives a WebSocket protocol frame.
@ -66,7 +67,7 @@ public:
virtual void sendUrgent(unsigned char data); virtual void sendUrgent(unsigned char data);
virtual int available(); virtual int available();
virtual bool secure() const; virtual bool secure() const;
virtual void setSendTimeout(const Poco::Timespan& timeout); virtual void setSendTimeout(const Poco::Timespan& timeout);
virtual Poco::Timespan getSendTimeout(); virtual Poco::Timespan getSendTimeout();
virtual void setReceiveTimeout(const Poco::Timespan& timeout); virtual void setReceiveTimeout(const Poco::Timespan& timeout);
virtual Poco::Timespan getReceiveTimeout(); virtual Poco::Timespan getReceiveTimeout();
@ -74,17 +75,27 @@ public:
// Internal // Internal
int frameFlags() const; int frameFlags() const;
/// Returns the frame flags of the most recently received frame. /// Returns the frame flags of the most recently received frame.
bool mustMaskPayload() const; bool mustMaskPayload() const;
/// Returns true if the payload must be masked. /// Returns true if the payload must be masked.
void setMaxPayloadSize(int maxPayloadSize);
/// Sets the maximum payload size for receiveFrame().
///
/// The default is std::numeric_limits<int>::max().
int getMaxPayloadSize() const;
/// Returns the maximum payload size for receiveFrame().
///
/// The default is std::numeric_limits<int>::max().
protected: protected:
enum enum
{ {
FRAME_FLAG_MASK = 0x80, FRAME_FLAG_MASK = 0x80,
MAX_HEADER_LENGTH = 14 MAX_HEADER_LENGTH = 14
}; };
int receiveHeader(char mask[4], bool& useMask); int receiveHeader(char mask[4], bool& useMask);
int receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask); int receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask);
int receiveNBytes(void* buffer, int bytes); int receiveNBytes(void* buffer, int bytes);
@ -93,8 +104,9 @@ protected:
private: private:
WebSocketImpl(); WebSocketImpl();
StreamSocketImpl* _pStreamSocketImpl; StreamSocketImpl* _pStreamSocketImpl;
int _maxPayloadSize;
Poco::Buffer<char> _buffer; Poco::Buffer<char> _buffer;
int _bufferOffset; int _bufferOffset;
int _frameFlags; int _frameFlags;
@ -118,6 +130,12 @@ inline bool WebSocketImpl::mustMaskPayload() const
} }
inline int WebSocketImpl::getMaxPayloadSize() const
{
return _maxPayloadSize;
}
} } // namespace Poco::Net } } // namespace Poco::Net

View File

@ -44,7 +44,7 @@ WebSocket::WebSocket(HTTPServerRequest& request, HTTPServerResponse& response):
{ {
} }
WebSocket::WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response): WebSocket::WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response):
StreamSocket(connect(cs, request, response, _defaultCreds)) StreamSocket(connect(cs, request, response, _defaultCreds))
{ {
@ -57,7 +57,7 @@ WebSocket::WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse&
} }
WebSocket::WebSocket(const Socket& socket): WebSocket::WebSocket(const Socket& socket):
StreamSocket(socket) StreamSocket(socket)
{ {
if (!dynamic_cast<WebSocketImpl*>(impl())) if (!dynamic_cast<WebSocketImpl*>(impl()))
@ -111,7 +111,7 @@ int WebSocket::receiveFrame(void* buffer, int length, int& flags)
return n; return n;
} }
int WebSocket::receiveFrame(Poco::Buffer<char>& buffer, int& flags) int WebSocket::receiveFrame(Poco::Buffer<char>& buffer, int& flags)
{ {
int n = static_cast<WebSocketImpl*>(impl())->receiveBytes(buffer, 0); int n = static_cast<WebSocketImpl*>(impl())->receiveBytes(buffer, 0);
@ -126,6 +126,18 @@ WebSocket::Mode WebSocket::mode() const
} }
void WebSocket::setMaxPayloadSize(int maxPayloadSize)
{
static_cast<WebSocketImpl*>(impl())->setMaxPayloadSize(maxPayloadSize);
}
int WebSocket::getMaxPayloadSize() const
{
return static_cast<WebSocketImpl*>(impl())->getMaxPayloadSize();
}
WebSocketImpl* WebSocket::accept(HTTPServerRequest& request, HTTPServerResponse& response) WebSocketImpl* WebSocket::accept(HTTPServerRequest& request, HTTPServerResponse& response)
{ {
if (request.hasToken("Connection", "upgrade") && icompare(request.get("Upgrade", ""), "websocket") == 0) if (request.hasToken("Connection", "upgrade") && icompare(request.get("Upgrade", ""), "websocket") == 0)
@ -136,14 +148,14 @@ WebSocketImpl* WebSocket::accept(HTTPServerRequest& request, HTTPServerResponse&
std::string key = request.get("Sec-WebSocket-Key", ""); std::string key = request.get("Sec-WebSocket-Key", "");
Poco::trimInPlace(key); Poco::trimInPlace(key);
if (key.empty()) throw WebSocketException("Missing Sec-WebSocket-Key in handshake request", WS_ERR_HANDSHAKE_NO_KEY); if (key.empty()) throw WebSocketException("Missing Sec-WebSocket-Key in handshake request", WS_ERR_HANDSHAKE_NO_KEY);
response.setStatusAndReason(HTTPResponse::HTTP_SWITCHING_PROTOCOLS); response.setStatusAndReason(HTTPResponse::HTTP_SWITCHING_PROTOCOLS);
response.set("Upgrade", "websocket"); response.set("Upgrade", "websocket");
response.set("Connection", "Upgrade"); response.set("Connection", "Upgrade");
response.set("Sec-WebSocket-Accept", computeAccept(key)); response.set("Sec-WebSocket-Accept", computeAccept(key));
response.setContentLength(0); response.setContentLength(0);
response.send().flush(); response.send().flush();
HTTPServerRequestImpl& requestImpl = static_cast<HTTPServerRequestImpl&>(request); HTTPServerRequestImpl& requestImpl = static_cast<HTTPServerRequestImpl&>(request);
return new WebSocketImpl(static_cast<StreamSocketImpl*>(requestImpl.detachSocket().impl()), requestImpl.session(), false); return new WebSocketImpl(static_cast<StreamSocketImpl*>(requestImpl.detachSocket().impl()), requestImpl.session(), false);
} }
@ -205,7 +217,7 @@ WebSocketImpl* WebSocket::connect(HTTPClientSession& cs, HTTPRequest& request, H
WebSocketImpl* WebSocket::completeHandshake(HTTPClientSession& cs, HTTPResponse& response, const std::string& key) WebSocketImpl* WebSocket::completeHandshake(HTTPClientSession& cs, HTTPResponse& response, const std::string& key)
{ {
std::string connection = response.get("Connection", ""); std::string connection = response.get("Connection", "");
if (Poco::icompare(connection, "Upgrade") != 0) if (Poco::icompare(connection, "Upgrade") != 0)
throw WebSocketException("No Connection: Upgrade header in handshake response", WS_ERR_NO_HANDSHAKE); throw WebSocketException("No Connection: Upgrade header in handshake response", WS_ERR_NO_HANDSHAKE);
std::string upgrade = response.get("Upgrade", ""); std::string upgrade = response.get("Upgrade", "");
if (Poco::icompare(upgrade, "websocket") != 0) if (Poco::icompare(upgrade, "websocket") != 0)

View File

@ -21,6 +21,7 @@
#include "Poco/BinaryReader.h" #include "Poco/BinaryReader.h"
#include "Poco/MemoryStream.h" #include "Poco/MemoryStream.h"
#include "Poco/Format.h" #include "Poco/Format.h"
#include <limits>
#include <cstring> #include <cstring>
@ -31,6 +32,7 @@ namespace Net {
WebSocketImpl::WebSocketImpl(StreamSocketImpl* pStreamSocketImpl, HTTPSession& session, bool mustMaskPayload): WebSocketImpl::WebSocketImpl(StreamSocketImpl* pStreamSocketImpl, HTTPSession& session, bool mustMaskPayload):
StreamSocketImpl(pStreamSocketImpl->sockfd()), StreamSocketImpl(pStreamSocketImpl->sockfd()),
_pStreamSocketImpl(pStreamSocketImpl), _pStreamSocketImpl(pStreamSocketImpl),
_maxPayloadSize(std::numeric_limits<int>::max()),
_buffer(0), _buffer(0),
_bufferOffset(0), _bufferOffset(0),
_frameFlags(0), _frameFlags(0),
@ -134,6 +136,7 @@ int WebSocketImpl::receiveHeader(char mask[4], bool& useMask)
Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER);
Poco::UInt64 l; Poco::UInt64 l;
reader >> l; reader >> l;
if (l > _maxPayloadSize) throw WebSocketException("Payload too big", WebSocket::WS_ERR_PAYLOAD_TOO_BIG);
payloadLength = static_cast<int>(l); payloadLength = static_cast<int>(l);
} }
else if (lengthByte == 126) else if (lengthByte == 126)
@ -148,10 +151,12 @@ int WebSocketImpl::receiveHeader(char mask[4], bool& useMask)
Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER);
Poco::UInt16 l; Poco::UInt16 l;
reader >> l; reader >> l;
if (l > _maxPayloadSize) throw WebSocketException("Payload too big", WebSocket::WS_ERR_PAYLOAD_TOO_BIG);
payloadLength = static_cast<int>(l); payloadLength = static_cast<int>(l);
} }
else else
{ {
if (lengthByte > _maxPayloadSize) throw WebSocketException("Payload too big", WebSocket::WS_ERR_PAYLOAD_TOO_BIG);
payloadLength = lengthByte; payloadLength = lengthByte;
} }
@ -169,6 +174,14 @@ int WebSocketImpl::receiveHeader(char mask[4], bool& useMask)
} }
void WebSocketImpl::setMaxPayloadSize(int maxPayloadSize)
{
poco_assert (maxPayloadSize > 0);
_maxPayloadSize = maxPayloadSize;
}
int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask) int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask)
{ {
int received = receiveNBytes(reinterpret_cast<char*>(buffer), payloadLength); int received = receiveNBytes(reinterpret_cast<char*>(buffer), payloadLength);
@ -205,7 +218,7 @@ int WebSocketImpl::receiveBytes(Poco::Buffer<char>& buffer, int, const Poco::Tim
int payloadLength = receiveHeader(mask, useMask); int payloadLength = receiveHeader(mask, useMask);
if (payloadLength <= 0) if (payloadLength <= 0)
return payloadLength; return payloadLength;
int oldSize = static_cast<int>(buffer.size()); std::size_t oldSize = buffer.size();
buffer.resize(oldSize + payloadLength); buffer.resize(oldSize + payloadLength);
return receivePayload(buffer.begin() + oldSize, payloadLength, mask, useMask); return receivePayload(buffer.begin() + oldSize, payloadLength, mask, useMask);
} }