WebSocket: allow setting the maximum payload size for receiving frames

This commit is contained in:
Günter Obiltschnig 2019-05-16 08:08:43 +02:00
parent 7d64cc153b
commit 16f63eed0a
4 changed files with 104 additions and 32 deletions

View File

@ -209,6 +209,15 @@ public:
/// is thrown and the WebSocket connection must be /// is thrown and the WebSocket connection must be
/// terminated. /// terminated.
/// ///
/// The frame's payload size must not exceed the
/// maximum payload size set with setMaxPayloadSize().
/// If it does, a WebSocketException (WS_ERR_PAYLOAD_TOO_BIG)
/// is thrown and the WebSocket connection must be
/// terminated.
///
/// A WebSocketException will also be thrown if a malformed
/// or incomplete frame is received.
///
/// Returns the number of bytes received. /// Returns the number of bytes received.
/// A return value of 0 means that the peer has /// A return value of 0 means that the peer has
/// shut down or closed the connection. /// shut down or closed the connection.
@ -223,6 +232,16 @@ public:
int receiveFrame(Poco::Buffer<char>& buffer, int& flags); int receiveFrame(Poco::Buffer<char>& buffer, int& flags);
/// Receives a frame from the socket and stores it /// Receives a frame from the socket and stores it
/// after any previous content in buffer. /// after any previous content in buffer.
/// The buffer will be grown as necessary.
///
/// The frame's payload size must not exceed the
/// maximum payload size set with setMaxPayloadSize().
/// If it does, a WebSocketException (WS_ERR_PAYLOAD_TOO_BIG)
/// is thrown and the WebSocket connection must be
/// terminated.
///
/// A WebSocketException will also be thrown if a malformed
/// or incomplete frame is received.
/// ///
/// Returns the number of bytes received. /// Returns the number of bytes received.
/// A return value of 0 means that the peer has /// A return value of 0 means that the peer has
@ -239,6 +258,16 @@ public:
/// Returns WS_SERVER if the WebSocket is a server-side /// Returns WS_SERVER if the WebSocket is a server-side
/// WebSocket, or WS_CLIENT otherwise. /// WebSocket, or WS_CLIENT otherwise.
void setMaxPayloadSize(int maxPayloadSize);
/// Sets the maximum payload size for receiveFrame().
///
/// The default is std::numeric_limits<int>::max().
int getMaxPayloadSize() const;
/// Returns the maximum payload size for receiveFrame().
///
/// The default is std::numeric_limits<int>::max().
static const std::string WEBSOCKET_VERSION; static const std::string WEBSOCKET_VERSION;
/// The WebSocket protocol version supported (13). /// The WebSocket protocol version supported (13).

View File

@ -21,6 +21,7 @@
#include "Poco/Net/StreamSocketImpl.h" #include "Poco/Net/StreamSocketImpl.h"
#include "Poco/Buffer.h" #include "Poco/Buffer.h"
#include "Poco/Random.h" #include "Poco/Random.h"
#include "Poco/Buffer.h"
namespace Poco { namespace Poco {
@ -78,6 +79,16 @@ public:
bool mustMaskPayload() const; bool mustMaskPayload() const;
/// Returns true if the payload must be masked. /// Returns true if the payload must be masked.
void setMaxPayloadSize(int maxPayloadSize);
/// Sets the maximum payload size for receiveFrame().
///
/// The default is std::numeric_limits<int>::max().
int getMaxPayloadSize() const;
/// Returns the maximum payload size for receiveFrame().
///
/// The default is std::numeric_limits<int>::max().
protected: protected:
enum enum
{ {
@ -95,6 +106,7 @@ private:
WebSocketImpl(); WebSocketImpl();
StreamSocketImpl* _pStreamSocketImpl; StreamSocketImpl* _pStreamSocketImpl;
int _maxPayloadSize;
Poco::Buffer<char> _buffer; Poco::Buffer<char> _buffer;
int _bufferOffset; int _bufferOffset;
int _frameFlags; int _frameFlags;
@ -118,6 +130,12 @@ inline bool WebSocketImpl::mustMaskPayload() const
} }
inline int WebSocketImpl::getMaxPayloadSize() const
{
return _maxPayloadSize;
}
} } // namespace Poco::Net } } // namespace Poco::Net

View File

@ -126,6 +126,18 @@ WebSocket::Mode WebSocket::mode() const
} }
void WebSocket::setMaxPayloadSize(int maxPayloadSize)
{
static_cast<WebSocketImpl*>(impl())->setMaxPayloadSize(maxPayloadSize);
}
int WebSocket::getMaxPayloadSize() const
{
return static_cast<WebSocketImpl*>(impl())->getMaxPayloadSize();
}
WebSocketImpl* WebSocket::accept(HTTPServerRequest& request, HTTPServerResponse& response) WebSocketImpl* WebSocket::accept(HTTPServerRequest& request, HTTPServerResponse& response)
{ {
if (request.hasToken("Connection", "upgrade") && icompare(request.get("Upgrade", ""), "websocket") == 0) if (request.hasToken("Connection", "upgrade") && icompare(request.get("Upgrade", ""), "websocket") == 0)

View File

@ -21,6 +21,7 @@
#include "Poco/BinaryReader.h" #include "Poco/BinaryReader.h"
#include "Poco/MemoryStream.h" #include "Poco/MemoryStream.h"
#include "Poco/Format.h" #include "Poco/Format.h"
#include <limits>
#include <cstring> #include <cstring>
@ -31,6 +32,7 @@ namespace Net {
WebSocketImpl::WebSocketImpl(StreamSocketImpl* pStreamSocketImpl, HTTPSession& session, bool mustMaskPayload): WebSocketImpl::WebSocketImpl(StreamSocketImpl* pStreamSocketImpl, HTTPSession& session, bool mustMaskPayload):
StreamSocketImpl(pStreamSocketImpl->sockfd()), StreamSocketImpl(pStreamSocketImpl->sockfd()),
_pStreamSocketImpl(pStreamSocketImpl), _pStreamSocketImpl(pStreamSocketImpl),
_maxPayloadSize(std::numeric_limits<int>::max()),
_buffer(0), _buffer(0),
_bufferOffset(0), _bufferOffset(0),
_frameFlags(0), _frameFlags(0),
@ -134,6 +136,7 @@ int WebSocketImpl::receiveHeader(char mask[4], bool& useMask)
Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER);
Poco::UInt64 l; Poco::UInt64 l;
reader >> l; reader >> l;
if (l > _maxPayloadSize) throw WebSocketException("Payload too big", WebSocket::WS_ERR_PAYLOAD_TOO_BIG);
payloadLength = static_cast<int>(l); payloadLength = static_cast<int>(l);
} }
else if (lengthByte == 126) else if (lengthByte == 126)
@ -148,10 +151,12 @@ int WebSocketImpl::receiveHeader(char mask[4], bool& useMask)
Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER);
Poco::UInt16 l; Poco::UInt16 l;
reader >> l; reader >> l;
if (l > _maxPayloadSize) throw WebSocketException("Payload too big", WebSocket::WS_ERR_PAYLOAD_TOO_BIG);
payloadLength = static_cast<int>(l); payloadLength = static_cast<int>(l);
} }
else else
{ {
if (lengthByte > _maxPayloadSize) throw WebSocketException("Payload too big", WebSocket::WS_ERR_PAYLOAD_TOO_BIG);
payloadLength = lengthByte; payloadLength = lengthByte;
} }
@ -169,6 +174,14 @@ int WebSocketImpl::receiveHeader(char mask[4], bool& useMask)
} }
void WebSocketImpl::setMaxPayloadSize(int maxPayloadSize)
{
poco_assert (maxPayloadSize > 0);
_maxPayloadSize = maxPayloadSize;
}
int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask) int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask)
{ {
int received = receiveNBytes(reinterpret_cast<char*>(buffer), payloadLength); int received = receiveNBytes(reinterpret_cast<char*>(buffer), payloadLength);
@ -205,7 +218,7 @@ int WebSocketImpl::receiveBytes(Poco::Buffer<char>& buffer, int, const Poco::Tim
int payloadLength = receiveHeader(mask, useMask); int payloadLength = receiveHeader(mask, useMask);
if (payloadLength <= 0) if (payloadLength <= 0)
return payloadLength; return payloadLength;
int oldSize = static_cast<int>(buffer.size()); std::size_t oldSize = buffer.size();
buffer.resize(oldSize + payloadLength); buffer.resize(oldSize + payloadLength);
return receivePayload(buffer.begin() + oldSize, payloadLength, mask, useMask); return receivePayload(buffer.begin() + oldSize, payloadLength, mask, useMask);
} }