diff --git a/Net/include/Poco/Net/WebSocketImpl.h b/Net/include/Poco/Net/WebSocketImpl.h index ee8ede76d..e1ef2c830 100644 --- a/Net/include/Poco/Net/WebSocketImpl.h +++ b/Net/include/Poco/Net/WebSocketImpl.h @@ -121,6 +121,7 @@ protected: int payloadLength = 0; int remainingPayloadLength = 0; Poco::Buffer payload{0}; + int maskOffset = 0; }; struct SendState @@ -133,7 +134,7 @@ protected: int peekHeader(ReceiveState& receiveState); void skipHeader(int headerLength); - int receivePayload(char *buffer, int payloadLength, char mask[MASK_LENGTH], bool useMask); + int receivePayload(char *buffer, int payloadLength, char mask[MASK_LENGTH], bool useMask, int maskOffset); int receiveNBytes(void* buffer, int length); int receiveSomeBytes(char* buffer, int length); int peekSomeBytes(char* buffer, int length); diff --git a/Net/src/WebSocketImpl.cpp b/Net/src/WebSocketImpl.cpp index fa2a27fb2..af06d3f66 100644 --- a/Net/src/WebSocketImpl.cpp +++ b/Net/src/WebSocketImpl.cpp @@ -262,14 +262,14 @@ void WebSocketImpl::setMaxPayloadSize(int maxPayloadSize) } -int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[MASK_LENGTH], bool useMask) +int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[MASK_LENGTH], bool useMask, int maskOffset) { int received = receiveNBytes(reinterpret_cast(buffer), payloadLength); if (received > 0 && useMask) { for (int i = 0; i < received; i++) { - buffer[i] ^= mask[i % MASK_LENGTH]; + buffer[i] ^= mask[(i + maskOffset) % MASK_LENGTH]; } } return received; @@ -297,7 +297,7 @@ int WebSocketImpl::receiveBytes(void* buffer, int length, int) skipHeader(_receiveState.headerLength); - if (receivePayload(reinterpret_cast(buffer), payloadLength, _receiveState.mask, _receiveState.useMask) != payloadLength) + if (receivePayload(reinterpret_cast(buffer), payloadLength, _receiveState.mask, _receiveState.useMask, 0) != payloadLength) throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME); return payloadLength; @@ -326,17 +326,19 @@ int WebSocketImpl::receiveBytes(void* buffer, int length, int) throw WebSocketException(Poco::format("Insufficient buffer for payload size %d", _receiveState.payloadLength), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); } int payloadOffset = _receiveState.payloadLength - _receiveState.remainingPayloadLength; - int n = receivePayload(_receiveState.payload.begin() + payloadOffset, _receiveState.remainingPayloadLength, _receiveState.mask, _receiveState.useMask); + int n = receivePayload(_receiveState.payload.begin() + payloadOffset, _receiveState.remainingPayloadLength, _receiveState.mask, _receiveState.useMask, _receiveState.maskOffset); if (n > 0) { _receiveState.remainingPayloadLength -= n; if (_receiveState.remainingPayloadLength == 0) { + _receiveState.maskOffset = 0; std::memcpy(buffer, _receiveState.payload.begin(), _receiveState.payloadLength); return _receiveState.payloadLength; } else { + _receiveState.maskOffset += n; return -1; } } @@ -369,7 +371,7 @@ int WebSocketImpl::receiveBytes(Poco::Buffer& buffer, int, const Poco::Tim std::size_t oldSize = buffer.size(); buffer.resize(oldSize + payloadLength); - if (receivePayload(buffer.begin() + oldSize, payloadLength, _receiveState.mask, _receiveState.useMask) != payloadLength) + if (receivePayload(buffer.begin() + oldSize, payloadLength, _receiveState.mask, _receiveState.useMask, 0) != payloadLength) throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME); return payloadLength; @@ -387,12 +389,13 @@ int WebSocketImpl::receiveBytes(Poco::Buffer& buffer, int, const Poco::Tim _receiveState.payload.resize(payloadLength, false); } int payloadOffset = _receiveState.payloadLength - _receiveState.remainingPayloadLength; - int n = receivePayload(_receiveState.payload.begin() + payloadOffset, _receiveState.remainingPayloadLength, _receiveState.mask, _receiveState.useMask); + int n = receivePayload(_receiveState.payload.begin() + payloadOffset, _receiveState.remainingPayloadLength, _receiveState.mask, _receiveState.useMask, _receiveState.maskOffset); if (n > 0) { _receiveState.remainingPayloadLength -= n; if (_receiveState.remainingPayloadLength == 0) { + _receiveState.maskOffset = 0; std::size_t oldSize = buffer.size(); buffer.resize(oldSize + _receiveState.payloadLength); @@ -401,6 +404,7 @@ int WebSocketImpl::receiveBytes(Poco::Buffer& buffer, int, const Poco::Tim } else { + _receiveState.maskOffset += n; return -1; } }