- fixed SF# 594: Websocket fails with small masked payloads

This commit is contained in:
Guenter Obiltschnig
2012-11-10 11:47:26 +01:00
parent 45fa903880
commit e36800c76d
3 changed files with 122 additions and 90 deletions

View File

@@ -1,7 +1,7 @@
// //
// WebSocketImpl.h // WebSocketImpl.h
// //
// $Id: //poco/1.4/Net/include/Poco/Net/WebSocketImpl.h#1 $ // $Id: //poco/1.4/Net/include/Poco/Net/WebSocketImpl.h#5 $
// //
// Library: Net // Library: Net
// Package: WebSocket // Package: WebSocket
@@ -93,6 +93,7 @@ protected:
MAX_HEADER_LENGTH = 14 MAX_HEADER_LENGTH = 14
}; };
int receiveNBytes(void* buffer, int bytes);
virtual ~WebSocketImpl(); virtual ~WebSocketImpl();
private: private:

View File

@@ -1,7 +1,7 @@
// //
// WebSocketImpl.cpp // WebSocketImpl.cpp
// //
// $Id: //poco/1.4/Net/src/WebSocketImpl.cpp#2 $ // $Id: //poco/1.4/Net/src/WebSocketImpl.cpp#6 $
// //
// Library: Net // Library: Net
// Package: WebSocket // Package: WebSocket
@@ -118,32 +118,26 @@ int WebSocketImpl::sendBytes(const void* buffer, int length, int flags)
int WebSocketImpl::receiveBytes(void* buffer, int length, int) int WebSocketImpl::receiveBytes(void* buffer, int length, int)
{ {
char header[MAX_HEADER_LENGTH]; char header[MAX_HEADER_LENGTH];
int n = _pStreamSocketImpl->receiveBytes(header, 2); int n = receiveNBytes(header, 2);
if (n == 1) poco_assert (n == 2);
Poco::UInt8 lengthByte = static_cast<Poco::UInt8>(header[1]);
int maskOffset = 0;
if (lengthByte & FRAME_FLAG_MASK) maskOffset += 4;
lengthByte &= 0x7f;
if (lengthByte + 2 + maskOffset < MAX_HEADER_LENGTH)
{ {
n += _pStreamSocketImpl->receiveBytes(header + 1, 1); n = receiveNBytes(header + 2, lengthByte + maskOffset);
}
if (n == 2)
{
Poco::UInt8 lengthByte = static_cast<Poco::UInt8>(header[1]) & 0x7f;
if (lengthByte + 2 < MAX_HEADER_LENGTH)
{
n = _pStreamSocketImpl->receiveBytes(header + 2, lengthByte);
} }
else else
{ {
n = _pStreamSocketImpl->receiveBytes(header + 2, MAX_HEADER_LENGTH - 2); n = receiveNBytes(header + 2, MAX_HEADER_LENGTH - 2);
} }
}
else throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME);
if (n > 0) poco_assert (n > 0);
{
n += 2; n += 2;
Poco::MemoryInputStream istr(header, n); Poco::MemoryInputStream istr(header, n);
Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER);
Poco::UInt8 flags; Poco::UInt8 flags;
Poco::UInt8 lengthByte;
char mask[4]; char mask[4];
reader >> flags >> lengthByte; reader >> flags >> lengthByte;
_frameFlags = flags; _frameFlags = flags;
@@ -182,13 +176,9 @@ int WebSocketImpl::receiveBytes(void* buffer, int length, int)
std::memcpy(buffer, header + payloadOffset, n - payloadOffset); std::memcpy(buffer, header + payloadOffset, n - payloadOffset);
received = n - payloadOffset; received = n - payloadOffset;
} }
while (received < payloadLength) if (received < payloadLength)
{ {
n = _pStreamSocketImpl->receiveBytes(reinterpret_cast<char*>(buffer) + received, payloadLength - received); received += receiveNBytes(reinterpret_cast<char*>(buffer) + received, payloadLength - received);
if (n > 0)
received += n;
else
throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME);
} }
if (lengthByte & FRAME_FLAG_MASK) if (lengthByte & FRAME_FLAG_MASK)
{ {
@@ -200,7 +190,20 @@ int WebSocketImpl::receiveBytes(void* buffer, int length, int)
} }
return received; return received;
} }
return n;
int WebSocketImpl::receiveNBytes(void* buffer, int bytes)
{
int received = _pStreamSocketImpl->receiveBytes(reinterpret_cast<char*>(buffer), bytes);
while (received < bytes)
{
int n = _pStreamSocketImpl->receiveBytes(reinterpret_cast<char*>(buffer) + received, bytes - received);
if (n > 0)
received += n;
else
throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME);
}
return received;
} }

View File

@@ -1,7 +1,7 @@
// //
// WebSocketTest.cpp // WebSocketTest.cpp
// //
// $Id: //poco/1.4/Net/testsuite/src/WebSocketTest.cpp#2 $ // $Id: //poco/1.4/Net/testsuite/src/WebSocketTest.cpp#3 $
// //
// Copyright (c) 2012, Applied Informatics Software Engineering GmbH. // Copyright (c) 2012, Applied Informatics Software Engineering GmbH.
// and Contributors. // and Contributors.
@@ -106,6 +106,16 @@ namespace
WebSocketTest::WebSocketTest(const std::string& name): CppUnit::TestCase(name) WebSocketTest::WebSocketTest(const std::string& name): CppUnit::TestCase(name)
{
}
WebSocketTest::~WebSocketTest()
{
}
void WebSocketTest::testWebSocket()
{ {
Poco::Net::ServerSocket ss(0); Poco::Net::ServerSocket ss(0);
Poco::Net::HTTPServer server(new WebSocketRequestHandlerFactory, ss, new Poco::Net::HTTPServerParams); Poco::Net::HTTPServer server(new WebSocketRequestHandlerFactory, ss, new Poco::Net::HTTPServerParams);
@@ -117,7 +127,8 @@ WebSocketTest::WebSocketTest(const std::string& name): CppUnit::TestCase(name)
HTTPRequest request(HTTPRequest::HTTP_GET, "/ws"); HTTPRequest request(HTTPRequest::HTTP_GET, "/ws");
HTTPResponse response; HTTPResponse response;
WebSocket ws(cs, request, response); WebSocket ws(cs, request, response);
std::string payload("Hello, world!");
std::string payload("x");
ws.sendFrame(payload.data(), payload.size()); ws.sendFrame(payload.data(), payload.size());
char buffer[1024]; char buffer[1024];
int flags; int flags;
@@ -126,6 +137,33 @@ WebSocketTest::WebSocketTest(const std::string& name): CppUnit::TestCase(name)
assert (payload.compare(0, payload.size(), buffer, 0, n) == 0); assert (payload.compare(0, payload.size(), buffer, 0, n) == 0);
assert (flags == WebSocket::FRAME_TEXT); assert (flags == WebSocket::FRAME_TEXT);
for (int i = 2; i < 20; i++)
{
payload.assign(i, 'x');
ws.sendFrame(payload.data(), payload.size());
n = ws.receiveFrame(buffer, sizeof(buffer), flags);
assert (n == payload.size());
assert (payload.compare(0, payload.size(), buffer, 0, n) == 0);
assert (flags == WebSocket::FRAME_TEXT);
}
for (int i = 125; i < 129; i++)
{
payload.assign(i, 'x');
ws.sendFrame(payload.data(), payload.size());
n = ws.receiveFrame(buffer, sizeof(buffer), flags);
assert (n == payload.size());
assert (payload.compare(0, payload.size(), buffer, 0, n) == 0);
assert (flags == WebSocket::FRAME_TEXT);
}
payload = "Hello, world!";
ws.sendFrame(payload.data(), payload.size());
n = ws.receiveFrame(buffer, sizeof(buffer), flags);
assert (n == payload.size());
assert (payload.compare(0, payload.size(), buffer, 0, n) == 0);
assert (flags == WebSocket::FRAME_TEXT);
payload = "Hello, universe!"; payload = "Hello, universe!";
ws.sendFrame(payload.data(), payload.size(), WebSocket::FRAME_BINARY); ws.sendFrame(payload.data(), payload.size(), WebSocket::FRAME_BINARY);
n = ws.receiveFrame(buffer, sizeof(buffer), flags); n = ws.receiveFrame(buffer, sizeof(buffer), flags);
@@ -142,16 +180,6 @@ WebSocketTest::WebSocketTest(const std::string& name): CppUnit::TestCase(name)
} }
WebSocketTest::~WebSocketTest()
{
}
void WebSocketTest::testWebSocket()
{
}
void WebSocketTest::setUp() void WebSocketTest::setUp()
{ {
} }