diff --git a/Net/include/Poco/Net/WebSocketImpl.h b/Net/include/Poco/Net/WebSocketImpl.h index fed918bdb..439c3008c 100644 --- a/Net/include/Poco/Net/WebSocketImpl.h +++ b/Net/include/Poco/Net/WebSocketImpl.h @@ -1,7 +1,7 @@ // // WebSocketImpl.h // -// $Id: //poco/1.4/Net/include/Poco/Net/WebSocketImpl.h#1 $ +// $Id: //poco/1.4/Net/include/Poco/Net/WebSocketImpl.h#5 $ // // Library: Net // Package: WebSocket @@ -93,6 +93,7 @@ protected: MAX_HEADER_LENGTH = 14 }; + int receiveNBytes(void* buffer, int bytes); virtual ~WebSocketImpl(); private: diff --git a/Net/src/WebSocketImpl.cpp b/Net/src/WebSocketImpl.cpp index 262f3b291..1caf31074 100644 --- a/Net/src/WebSocketImpl.cpp +++ b/Net/src/WebSocketImpl.cpp @@ -1,7 +1,7 @@ // // WebSocketImpl.cpp // -// $Id: //poco/1.4/Net/src/WebSocketImpl.cpp#2 $ +// $Id: //poco/1.4/Net/src/WebSocketImpl.cpp#6 $ // // Library: Net // Package: WebSocket @@ -118,89 +118,92 @@ int WebSocketImpl::sendBytes(const void* buffer, int length, int flags) int WebSocketImpl::receiveBytes(void* buffer, int length, int) { char header[MAX_HEADER_LENGTH]; - int n = _pStreamSocketImpl->receiveBytes(header, 2); - if (n == 1) + int n = receiveNBytes(header, 2); + poco_assert (n == 2); + Poco::UInt8 lengthByte = static_cast(header[1]); + int maskOffset = 0; + if (lengthByte & FRAME_FLAG_MASK) maskOffset += 4; + lengthByte &= 0x7f; + if (lengthByte + 2 + maskOffset < MAX_HEADER_LENGTH) { - n += _pStreamSocketImpl->receiveBytes(header + 1, 1); + n = receiveNBytes(header + 2, lengthByte + maskOffset); } - if (n == 2) + else { - Poco::UInt8 lengthByte = static_cast(header[1]) & 0x7f; - if (lengthByte + 2 < MAX_HEADER_LENGTH) - { - n = _pStreamSocketImpl->receiveBytes(header + 2, lengthByte); - } - else - { - n = _pStreamSocketImpl->receiveBytes(header + 2, MAX_HEADER_LENGTH - 2); - } + n = receiveNBytes(header + 2, MAX_HEADER_LENGTH - 2); } - else throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME); - if (n > 0) + poco_assert (n > 0); + n += 2; + Poco::MemoryInputStream istr(header, n); + Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); + Poco::UInt8 flags; + char mask[4]; + reader >> flags >> lengthByte; + _frameFlags = flags; + int payloadLength = 0; + int payloadOffset = 2; + if ((lengthByte & 0x7f) == 127) { - n += 2; - Poco::MemoryInputStream istr(header, n); - Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); - Poco::UInt8 flags; - Poco::UInt8 lengthByte; - char mask[4]; - reader >> flags >> lengthByte; - _frameFlags = flags; - int payloadLength = 0; - int payloadOffset = 2; - if ((lengthByte & 0x7f) == 127) - { - Poco::UInt64 l; - reader >> l; - if (l > length) throw WebSocketException(Poco::format("Insufficient buffer for payload size %Lu", l), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); - payloadLength = static_cast(l); - payloadOffset += 8; - } - else if ((lengthByte & 0x7f) == 126) - { - Poco::UInt16 l; - reader >> l; - if (l > length) throw WebSocketException(Poco::format("Insufficient buffer for payload size %hu", l), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); - payloadLength = static_cast(l); - payloadOffset += 2; - } - else - { - Poco::UInt8 l = lengthByte & 0x7f; - if (l > length) throw WebSocketException(Poco::format("Insufficient buffer for payload size %u", unsigned(l)), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); - payloadLength = static_cast(l); - } - if (lengthByte & FRAME_FLAG_MASK) - { - reader.readRaw(mask, 4); - payloadOffset += 4; - } - int received = 0; - if (payloadOffset < n) - { - std::memcpy(buffer, header + payloadOffset, n - payloadOffset); - received = n - payloadOffset; - } - while (received < payloadLength) - { - n = _pStreamSocketImpl->receiveBytes(reinterpret_cast(buffer) + received, payloadLength - received); - if (n > 0) - received += n; - else - throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME); - } - if (lengthByte & FRAME_FLAG_MASK) - { - char* p = reinterpret_cast(buffer); - for (int i = 0; i < received; i++) - { - p[i] ^= mask[i % 4]; - } - } - return received; + Poco::UInt64 l; + reader >> l; + if (l > length) throw WebSocketException(Poco::format("Insufficient buffer for payload size %Lu", l), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); + payloadLength = static_cast(l); + payloadOffset += 8; } - return n; + else if ((lengthByte & 0x7f) == 126) + { + Poco::UInt16 l; + reader >> l; + if (l > length) throw WebSocketException(Poco::format("Insufficient buffer for payload size %hu", l), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); + payloadLength = static_cast(l); + payloadOffset += 2; + } + else + { + Poco::UInt8 l = lengthByte & 0x7f; + if (l > length) throw WebSocketException(Poco::format("Insufficient buffer for payload size %u", unsigned(l)), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); + payloadLength = static_cast(l); + } + if (lengthByte & FRAME_FLAG_MASK) + { + reader.readRaw(mask, 4); + payloadOffset += 4; + } + int received = 0; + if (payloadOffset < n) + { + std::memcpy(buffer, header + payloadOffset, n - payloadOffset); + received = n - payloadOffset; + } + if (received < payloadLength) + { + received += receiveNBytes(reinterpret_cast(buffer) + received, payloadLength - received); + } + if (lengthByte & FRAME_FLAG_MASK) + { + char* p = reinterpret_cast(buffer); + for (int i = 0; i < received; i++) + { + p[i] ^= mask[i % 4]; + } + } + return received; +} + + +int WebSocketImpl::receiveNBytes(void* buffer, int bytes) +{ + int received = _pStreamSocketImpl->receiveBytes(reinterpret_cast(buffer), bytes); + while (received < bytes) + { + int n = _pStreamSocketImpl->receiveBytes(reinterpret_cast(buffer) + received, bytes - received); + if (n > 0) + received += n; + else + throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME); + } + return received; } diff --git a/Net/testsuite/src/WebSocketTest.cpp b/Net/testsuite/src/WebSocketTest.cpp index 289faa997..568dbfa5b 100644 --- a/Net/testsuite/src/WebSocketTest.cpp +++ b/Net/testsuite/src/WebSocketTest.cpp @@ -1,7 +1,7 @@ // // WebSocketTest.cpp // -// $Id: //poco/1.4/Net/testsuite/src/WebSocketTest.cpp#2 $ +// $Id: //poco/1.4/Net/testsuite/src/WebSocketTest.cpp#3 $ // // Copyright (c) 2012, Applied Informatics Software Engineering GmbH. // and Contributors. @@ -106,6 +106,16 @@ namespace WebSocketTest::WebSocketTest(const std::string& name): CppUnit::TestCase(name) +{ +} + + +WebSocketTest::~WebSocketTest() +{ +} + + +void WebSocketTest::testWebSocket() { Poco::Net::ServerSocket ss(0); Poco::Net::HTTPServer server(new WebSocketRequestHandlerFactory, ss, new Poco::Net::HTTPServerParams); @@ -117,7 +127,8 @@ WebSocketTest::WebSocketTest(const std::string& name): CppUnit::TestCase(name) HTTPRequest request(HTTPRequest::HTTP_GET, "/ws"); HTTPResponse response; WebSocket ws(cs, request, response); - std::string payload("Hello, world!"); + + std::string payload("x"); ws.sendFrame(payload.data(), payload.size()); char buffer[1024]; int flags; @@ -125,6 +136,33 @@ WebSocketTest::WebSocketTest(const std::string& name): CppUnit::TestCase(name) assert (n == payload.size()); assert (payload.compare(0, payload.size(), buffer, 0, n) == 0); assert (flags == WebSocket::FRAME_TEXT); + + for (int i = 2; i < 20; i++) + { + payload.assign(i, 'x'); + ws.sendFrame(payload.data(), payload.size()); + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + assert (n == payload.size()); + assert (payload.compare(0, payload.size(), buffer, 0, n) == 0); + assert (flags == WebSocket::FRAME_TEXT); + } + + for (int i = 125; i < 129; i++) + { + payload.assign(i, 'x'); + ws.sendFrame(payload.data(), payload.size()); + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + assert (n == payload.size()); + assert (payload.compare(0, payload.size(), buffer, 0, n) == 0); + assert (flags == WebSocket::FRAME_TEXT); + } + + payload = "Hello, world!"; + ws.sendFrame(payload.data(), payload.size()); + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + assert (n == payload.size()); + assert (payload.compare(0, payload.size(), buffer, 0, n) == 0); + assert (flags == WebSocket::FRAME_TEXT); payload = "Hello, universe!"; ws.sendFrame(payload.data(), payload.size(), WebSocket::FRAME_BINARY); @@ -142,16 +180,6 @@ WebSocketTest::WebSocketTest(const std::string& name): CppUnit::TestCase(name) } -WebSocketTest::~WebSocketTest() -{ -} - - -void WebSocketTest::testWebSocket() -{ -} - - void WebSocketTest::setUp() { }