From cedd086362111c2b7664e7b35aec151ffac2372c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnter=20Obiltschnig?= Date: Sat, 16 Nov 2024 16:49:50 +0100 Subject: [PATCH] feat(Net): Non-blocking WebSocket #4773 --- Net/include/Poco/Net/WebSocket.h | 10 +- Net/include/Poco/Net/WebSocketImpl.h | 43 ++- Net/src/WebSocket.cpp | 36 +-- Net/src/WebSocketImpl.cpp | 373 ++++++++++++++++++++++----- Net/testsuite/src/WebSocketTest.cpp | 80 +++++- Net/testsuite/src/WebSocketTest.h | 1 + 6 files changed, 436 insertions(+), 107 deletions(-) diff --git a/Net/include/Poco/Net/WebSocket.h b/Net/include/Poco/Net/WebSocket.h index e6cf92746..03fce43cf 100644 --- a/Net/include/Poco/Net/WebSocket.h +++ b/Net/include/Poco/Net/WebSocket.h @@ -221,15 +221,21 @@ public: #endif //POCO_NEW_STATE_ON_MOVE - void shutdown(); + int shutdown(); /// Sends a Close control frame to the server end of /// the connection to initiate an orderly shutdown /// of the connection. + /// + /// Returns the number of bytes sent or -1 if the socket + /// is non-blocking and the frame cannot be sent at this time. - void shutdown(Poco::UInt16 statusCode, const std::string& statusMessage = ""); + int shutdown(Poco::UInt16 statusCode, const std::string& statusMessage = ""); /// Sends a Close control frame to the server end of /// the connection to initiate an orderly shutdown /// of the connection. + /// + /// Returns the number of bytes sent or -1 if the socket + /// is non-blocking and the frame cannot be sent at this time. int sendFrame(const void* buffer, int length, int flags = FRAME_TEXT); /// Sends the contents of the given buffer through diff --git a/Net/include/Poco/Net/WebSocketImpl.h b/Net/include/Poco/Net/WebSocketImpl.h index 71f6fd92b..59651dd83 100644 --- a/Net/include/Poco/Net/WebSocketImpl.h +++ b/Net/include/Poco/Net/WebSocketImpl.h @@ -67,10 +67,16 @@ public: virtual void sendUrgent(unsigned char data); virtual int available(); virtual bool secure() const; + virtual void setSendBufferSize(int size); + virtual int getSendBufferSize(); + virtual void setReceiveBufferSize(int size); + virtual int getReceiveBufferSize(); virtual void setSendTimeout(const Poco::Timespan& timeout); virtual Poco::Timespan getSendTimeout(); virtual void setReceiveTimeout(const Poco::Timespan& timeout); virtual Poco::Timespan getReceiveTimeout(); + virtual void setBlocking(bool flag); + virtual bool getBlocking() const; // Internal int frameFlags() const; @@ -93,13 +99,35 @@ protected: enum { FRAME_FLAG_MASK = 0x80, - MAX_HEADER_LENGTH = 14 + MAX_HEADER_LENGTH = 14, + MASK_LENGTH = 4 }; - int receiveHeader(char mask[4], bool& useMask); - int receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask); - int receiveNBytes(void* buffer, int bytes); - int receiveSomeBytes(char* buffer, int bytes); + struct ReceiveState + { + int frameFlags = 0; + bool useMask = false; + char mask[MASK_LENGTH]; + int headerLength = 0; + int payloadLength = 0; + int remainingPayloadLength = 0; + Poco::Buffer payload{0}; + }; + + struct SendState + { + int length = 0; + int remainingPayloadOffset = 0; + int remainingPayloadLength = 0; + Poco::Buffer payload{0}; + }; + + int peekHeader(ReceiveState& receiveState); + void skipHeader(int headerLength); + int receivePayload(char *buffer, int payloadLength, char mask[MASK_LENGTH], bool useMask); + int receiveNBytes(void* buffer, int length); + int receiveSomeBytes(char* buffer, int length); + int peekSomeBytes(char* buffer, int length); virtual ~WebSocketImpl(); private: @@ -109,8 +137,9 @@ private: int _maxPayloadSize; Poco::Buffer _buffer; int _bufferOffset; - int _frameFlags; bool _mustMaskPayload; + ReceiveState _receiveState; + SendState _sendState; Poco::Random _rnd; }; @@ -120,7 +149,7 @@ private: // inline int WebSocketImpl::frameFlags() const { - return _frameFlags; + return _receiveState.frameFlags; } diff --git a/Net/src/WebSocket.cpp b/Net/src/WebSocket.cpp index 45c9e0e14..80f3dad00 100644 --- a/Net/src/WebSocket.cpp +++ b/Net/src/WebSocket.cpp @@ -107,48 +107,20 @@ WebSocket& WebSocket::operator = (const Socket& socket) } -#ifdef POCO_NEW_STATE_ON_MOVE - -WebSocket& WebSocket::operator = (Socket&& socket) +int WebSocket::shutdown() { - if (dynamic_cast(socket.impl())) - Socket::operator = (std::move(socket)); - else - throw InvalidArgumentException("Cannot assign incompatible socket"); - return *this; + return shutdown(WS_NORMAL_CLOSE); } -WebSocket& WebSocket::operator = (WebSocket&& socket) -{ - Socket::operator = (std::move(socket)); - return *this; -} - -#endif // POCO_NEW_STATE_ON_MOVE - - -WebSocket& WebSocket::operator = (const WebSocket& socket) -{ - Socket::operator = (socket); - return *this; -} - - -void WebSocket::shutdown() -{ - shutdown(WS_NORMAL_CLOSE); -} - - -void WebSocket::shutdown(Poco::UInt16 statusCode, const std::string& statusMessage) +int WebSocket::shutdown(Poco::UInt16 statusCode, const std::string& statusMessage) { Poco::Buffer buffer(statusMessage.size() + 2); Poco::MemoryOutputStream ostr(buffer.begin(), buffer.size()); Poco::BinaryWriter writer(ostr, Poco::BinaryWriter::NETWORK_BYTE_ORDER); writer << statusCode; writer.writeRaw(statusMessage); - sendFrame(buffer.begin(), static_cast(ostr.charsWritten()), FRAME_FLAG_FIN | FRAME_OP_CLOSE); + return sendFrame(buffer.begin(), static_cast(ostr.charsWritten()), FRAME_FLAG_FIN | FRAME_OP_CLOSE); } diff --git a/Net/src/WebSocketImpl.cpp b/Net/src/WebSocketImpl.cpp index f7976cc98..23759d54c 100644 --- a/Net/src/WebSocketImpl.cpp +++ b/Net/src/WebSocketImpl.cpp @@ -37,7 +37,6 @@ WebSocketImpl::WebSocketImpl(StreamSocketImpl* pStreamSocketImpl, HTTPSession& s _maxPayloadSize(std::numeric_limits::max()), _buffer(0), _bufferOffset(0), - _frameFlags(0), _mustMaskPayload(mustMaskPayload) { poco_check_ptr(pStreamSocketImpl); @@ -62,7 +61,34 @@ WebSocketImpl::~WebSocketImpl() int WebSocketImpl::sendBytes(const void* buffer, int length, int flags) { - Poco::Buffer frame(length + MAX_HEADER_LENGTH); + if (_sendState.remainingPayloadLength > 0) + { + if (length != _sendState.length) + { + throw InvalidArgumentException("Pending send buffer length mismatch"); + } + int sent = _pStreamSocketImpl->sendBytes(_sendState.payload.begin() + _sendState.remainingPayloadOffset, _sendState.remainingPayloadLength); + if (sent >= 0) + { + if (sent < _sendState.remainingPayloadLength) + { + _sendState.remainingPayloadOffset += sent; + _sendState.remainingPayloadLength -= sent; + return -1; + } + else + { + _sendState.length = 0; + _sendState.remainingPayloadOffset = 0; + _sendState.remainingPayloadLength = 0; + return length; + } + } + else return -1; + } + + Poco::Buffer& frame(_sendState.payload); + frame.resize(length + MAX_HEADER_LENGTH, false); Poco::MemoryOutputStream ostr(frame.begin(), frame.size()); Poco::BinaryWriter writer(ostr, Poco::BinaryWriter::NETWORK_BYTE_ORDER); @@ -94,85 +120,137 @@ int WebSocketImpl::sendBytes(const void* buffer, int length, int flags) const Poco::UInt32 mask = _rnd.next(); const char* m = reinterpret_cast(&mask); const char* b = reinterpret_cast(buffer); - writer.writeRaw(m, 4); + writer.writeRaw(m, MASK_LENGTH); char* p = frame.begin() + ostr.charsWritten(); for (int i = 0; i < length; i++) { - p[i] = b[i] ^ m[i % 4]; + p[i] = b[i] ^ m[i % MASK_LENGTH]; } } else { std::memcpy(frame.begin() + ostr.charsWritten(), buffer, length); } - _pStreamSocketImpl->sendBytes(frame.begin(), length + static_cast(ostr.charsWritten())); - return length; + + int frameLength = length + static_cast(ostr.charsWritten()); + int sent = _pStreamSocketImpl->sendBytes(frame.begin(), frameLength); + if (sent >= 0) + { + if (sent < frameLength) + { + _sendState.length = length; + _sendState.remainingPayloadOffset = sent; + _sendState.remainingPayloadLength = frameLength - sent; + return -1; + } + else + { + _sendState.length = 0; + _sendState.remainingPayloadOffset = 0; + _sendState.remainingPayloadLength = 0; + return length; + } + } + else + { + _sendState.length = length; + _sendState.remainingPayloadOffset = 0; + _sendState.remainingPayloadLength = frameLength; + return -1; + } } -int WebSocketImpl::receiveHeader(char mask[4], bool& useMask) +int WebSocketImpl::peekHeader(ReceiveState& receiveState) { char header[MAX_HEADER_LENGTH]; - int n = receiveNBytes(header, 2); - if (n <= 0) - { - _frameFlags = 0; - return n; - } - poco_assert (n == 2); + + receiveState.frameFlags = 0; + receiveState.useMask = false; + receiveState.headerLength = 0; + receiveState.payloadLength = 0; + receiveState.remainingPayloadLength = 0; + + int n = peekSomeBytes(header, MAX_HEADER_LENGTH); + if (n == 0) + return 0; + else if (n < 2) + return -1; + Poco::UInt8 flags = static_cast(header[0]); - _frameFlags = flags; + receiveState.frameFlags = flags; Poco::UInt8 lengthByte = static_cast(header[1]); - useMask = ((lengthByte & FRAME_FLAG_MASK) != 0); - int payloadLength; + receiveState.useMask = ((lengthByte & FRAME_FLAG_MASK) != 0); + int maskOffset = 0; lengthByte &= 0x7f; if (lengthByte == 127) { - n = receiveNBytes(header + 2, 8); - if (n <= 0) + if (n < 10) { - _frameFlags = 0; - return n; + receiveState.frameFlags = 0; + return -1; } Poco::MemoryInputStream istr(header + 2, 8); Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); Poco::UInt64 l; reader >> l; if (l > _maxPayloadSize) throw WebSocketException("Payload too big", WebSocket::WS_ERR_PAYLOAD_TOO_BIG); - payloadLength = static_cast(l); + receiveState.payloadLength = static_cast(l); + maskOffset = 10; } else if (lengthByte == 126) { - n = receiveNBytes(header + 2, 2); - if (n <= 0) + if (n < 4) { - _frameFlags = 0; - return n; + receiveState.frameFlags = 0; + return -1; } Poco::MemoryInputStream istr(header + 2, 2); Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); Poco::UInt16 l; reader >> l; if (l > _maxPayloadSize) throw WebSocketException("Payload too big", WebSocket::WS_ERR_PAYLOAD_TOO_BIG); - payloadLength = static_cast(l); + receiveState.payloadLength = static_cast(l); + maskOffset = 4; } else { if (lengthByte > _maxPayloadSize) throw WebSocketException("Payload too big", WebSocket::WS_ERR_PAYLOAD_TOO_BIG); - payloadLength = lengthByte; + receiveState.payloadLength = lengthByte; + maskOffset = 2; } - if (useMask) + if (receiveState.useMask) { - n = receiveNBytes(mask, 4); - if (n <= 0) + if (n < maskOffset + MASK_LENGTH) { - _frameFlags = 0; - return n; + receiveState.frameFlags = 0; + return -1; } + std::memcpy(receiveState.mask, header + maskOffset, MASK_LENGTH); + receiveState.headerLength = maskOffset + MASK_LENGTH; + } + else + { + receiveState.headerLength = maskOffset; } - return payloadLength; + receiveState.remainingPayloadLength = receiveState.payloadLength; + + return receiveState.payloadLength; +} + + +void WebSocketImpl::skipHeader(int headerLength) +{ + poco_assert_dbg (headerLength <= MAX_HEADER_LENGTH); + + if (headerLength > 0) + { + char header[MAX_HEADER_LENGTH]; + int n = receiveNBytes(header, headerLength); + poco_assert_dbg (n == headerLength); + } } @@ -184,16 +262,14 @@ void WebSocketImpl::setMaxPayloadSize(int maxPayloadSize) } -int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask) +int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[MASK_LENGTH], bool useMask) { int received = receiveNBytes(reinterpret_cast(buffer), payloadLength); - if (received <= 0) throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME); - - if (useMask) + if (received > 0 && useMask) { for (int i = 0; i < received; i++) { - buffer[i] ^= mask[i % 4]; + buffer[i] ^= mask[i % MASK_LENGTH]; } } return received; @@ -202,63 +278,196 @@ int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[4], int WebSocketImpl::receiveBytes(void* buffer, int length, int) { - char mask[4]; - bool useMask; - _frameFlags = 0; - int payloadLength = receiveHeader(mask, useMask); - if (payloadLength <= 0) + if (getBlocking()) + { + int payloadLength = -1; + while (payloadLength < 0) + { + payloadLength = peekHeader(_receiveState); + } + if (payloadLength <= 0) + { + skipHeader(_receiveState.headerLength); + return payloadLength; + } + else if (payloadLength > length) + { + throw WebSocketException(Poco::format("Insufficient buffer for payload size %d", payloadLength), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); + } + + skipHeader(_receiveState.headerLength); + + if (receivePayload(reinterpret_cast(buffer), payloadLength, _receiveState.mask, _receiveState.useMask) != payloadLength) + throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME); + return payloadLength; - if (payloadLength > length) - throw WebSocketException(Poco::format("Insufficient buffer for payload size %d", payloadLength), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); - return receivePayload(reinterpret_cast(buffer), payloadLength, mask, useMask); + } + else + { + if (_receiveState.remainingPayloadLength == 0) + { + int payloadLength = peekHeader(_receiveState); + if (payloadLength <= 0) + { + skipHeader(_receiveState.headerLength); + return payloadLength; + } + else if (payloadLength > length) + { + throw WebSocketException(Poco::format("Insufficient buffer for payload size %d", payloadLength), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); + } + + skipHeader(_receiveState.headerLength); + + _receiveState.payload.resize(payloadLength, false); + } + else if (_receiveState.payloadLength > length) + { + throw WebSocketException(Poco::format("Insufficient buffer for payload size %d", _receiveState.payloadLength), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); + } + int payloadOffset = _receiveState.payloadLength - _receiveState.remainingPayloadLength; + int n = receivePayload(_receiveState.payload.begin() + payloadOffset, _receiveState.remainingPayloadLength, _receiveState.mask, _receiveState.useMask); + if (n > 0) + { + _receiveState.remainingPayloadLength -= n; + if (_receiveState.remainingPayloadLength == 0) + { + std::memcpy(buffer, _receiveState.payload.begin(), _receiveState.payloadLength); + return _receiveState.payloadLength; + } + else + { + return -1; + } + } + else if (n == 0) + { + throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME); + } + else + { + return -1; + } + } } int WebSocketImpl::receiveBytes(Poco::Buffer& buffer, int, const Poco::Timespan&) { - char mask[4]; - bool useMask; - _frameFlags = 0; - int payloadLength = receiveHeader(mask, useMask); - if (payloadLength <= 0) + if (getBlocking()) + { + int payloadLength = -1; + while (payloadLength < 0) + { + payloadLength = peekHeader(_receiveState); + } + if (payloadLength <= 0) + return payloadLength; + + skipHeader(_receiveState.headerLength); + + std::size_t oldSize = buffer.size(); + buffer.resize(oldSize + payloadLength); + + if (receivePayload(buffer.begin() + oldSize, payloadLength, _receiveState.mask, _receiveState.useMask) != payloadLength) + throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME); + return payloadLength; - std::size_t oldSize = buffer.size(); - buffer.resize(oldSize + payloadLength); - return receivePayload(buffer.begin() + oldSize, payloadLength, mask, useMask); + } + else + { + if (_receiveState.remainingPayloadLength == 0) + { + int payloadLength = peekHeader(_receiveState); + if (payloadLength <= 0) + return payloadLength; + + skipHeader(_receiveState.headerLength); + + _receiveState.payload.resize(payloadLength, false); + } + int payloadOffset = _receiveState.payloadLength - _receiveState.remainingPayloadLength; + int n = receivePayload(_receiveState.payload.begin() + payloadOffset, _receiveState.remainingPayloadLength, _receiveState.mask, _receiveState.useMask); + if (n > 0) + { + _receiveState.remainingPayloadLength -= n; + if (_receiveState.remainingPayloadLength == 0) + { + std::size_t oldSize = buffer.size(); + buffer.resize(oldSize + _receiveState.payloadLength); + + std::memcpy(buffer.begin() + oldSize, _receiveState.payload.begin(), _receiveState.payloadLength); + return _receiveState.payloadLength; + } + else + { + return -1; + } + } + else if (n == 0) + { + throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME); + } + else + { + return -1; + } + } } -int WebSocketImpl::receiveNBytes(void* buffer, int bytes) +int WebSocketImpl::receiveNBytes(void* buffer, int length) { - int received = receiveSomeBytes(reinterpret_cast(buffer), bytes); + int received = receiveSomeBytes(reinterpret_cast(buffer), length); if (received > 0) { - while (received < bytes) + while (received < length) { - int n = receiveSomeBytes(reinterpret_cast(buffer) + received, bytes - received); + int n = receiveSomeBytes(reinterpret_cast(buffer) + received, length - received); if (n > 0) received += n; else - throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME); + break; } } return received; } -int WebSocketImpl::receiveSomeBytes(char* buffer, int bytes) +int WebSocketImpl::receiveSomeBytes(char* buffer, int length) { int n = static_cast(_buffer.size()) - _bufferOffset; if (n > 0) { - if (bytes < n) n = bytes; + if (length < n) n = length; std::memcpy(buffer, _buffer.begin() + _bufferOffset, n); - _bufferOffset += n; + _bufferOffset += length; return n; } else { - return _pStreamSocketImpl->receiveBytes(buffer, bytes); + return _pStreamSocketImpl->receiveBytes(buffer, length); + } +} + + +int WebSocketImpl::peekSomeBytes(char* buffer, int length) +{ + int n = static_cast(_buffer.size()) - _bufferOffset; + if (n > 0) + { + if (length < n) n = length; + std::memcpy(buffer, _buffer.begin() + _bufferOffset, n); + if (length > n) + { + int rc = _pStreamSocketImpl->receiveBytes(buffer + n, length - n, MSG_PEEK); + if (rc > 0) n += rc; + } + return n; + } + else + { + return _pStreamSocketImpl->receiveBytes(buffer, length, MSG_PEEK); } } @@ -366,6 +575,30 @@ bool WebSocketImpl::secure() const } +void WebSocketImpl::setSendBufferSize(int size) +{ + _pStreamSocketImpl->setSendBufferSize(size); +} + + +int WebSocketImpl::getSendBufferSize() +{ + return _pStreamSocketImpl->getSendBufferSize(); +} + + +void WebSocketImpl::setReceiveBufferSize(int size) +{ + _pStreamSocketImpl->setReceiveBufferSize(size); +} + + +int WebSocketImpl::getReceiveBufferSize() +{ + return _pStreamSocketImpl->getReceiveBufferSize(); +} + + void WebSocketImpl::setSendTimeout(const Poco::Timespan& timeout) { _pStreamSocketImpl->setSendTimeout(timeout); @@ -390,6 +623,18 @@ Poco::Timespan WebSocketImpl::getReceiveTimeout() } +void WebSocketImpl::setBlocking(bool flag) +{ + _pStreamSocketImpl->setBlocking(flag); +} + + +bool WebSocketImpl::getBlocking() const +{ + return _pStreamSocketImpl->getBlocking(); +} + + int WebSocketImpl::available() { int n = static_cast(_buffer.size()) - _bufferOffset; diff --git a/Net/testsuite/src/WebSocketTest.cpp b/Net/testsuite/src/WebSocketTest.cpp index 3cd312ed1..d5ab638bd 100644 --- a/Net/testsuite/src/WebSocketTest.cpp +++ b/Net/testsuite/src/WebSocketTest.cpp @@ -58,9 +58,9 @@ namespace do { n = ws.receiveFrame(buffer.begin(), static_cast(buffer.size()), flags); - ws.sendFrame(buffer.begin(), n, flags); + if (n > 0) ws.sendFrame(buffer.begin(), n, flags); } - while (n > 0 || (flags & WebSocket::FRAME_OP_BITMASK) != WebSocket::FRAME_OP_CLOSE); + while (n > 0 && (flags & WebSocket::FRAME_OP_BITMASK) != WebSocket::FRAME_OP_CLOSE); } catch (WebSocketException& exc) { @@ -199,6 +199,7 @@ void WebSocketTest::testWebSocket() assertTrue (n == 2); assertTrue ((flags & WebSocket::FRAME_OP_BITMASK) == WebSocket::FRAME_OP_CLOSE); + ws.close(); server.stop(); } @@ -234,6 +235,9 @@ void WebSocketTest::testWebSocketLarge() assertTrue (n == payload.size()); assertTrue (payload.compare(0, payload.size(), buffer, n) == 0); + + ws.close(); + server.stop(); } @@ -270,6 +274,9 @@ void WebSocketTest::testOneLargeFrame(int msgSize) n = ws.receiveFrame(pocobuffer, flags); assertTrue (n == payload.size()); assertTrue (payload.compare(0, payload.size(), pocobuffer.begin(), n) == 0); + + ws.close(); + server.stop(); } @@ -280,6 +287,74 @@ void WebSocketTest::testWebSocketLargeInOneFrame() } +void WebSocketTest::testWebSocketNB() +{ + Poco::Net::ServerSocket ss(0); + Poco::Net::HTTPServer server(new WebSocketRequestHandlerFactory(256*1024), ss, new Poco::Net::HTTPServerParams); + server.start(); + + Poco::Thread::sleep(200); + + HTTPClientSession cs("127.0.0.1", ss.address().port()); + HTTPRequest request(HTTPRequest::HTTP_GET, "/ws", HTTPRequest::HTTP_1_1); + HTTPResponse response; + WebSocket ws(cs, request, response); + ws.setBlocking(false); + + int flags; + char buffer[256*1024] = {}; + int n = ws.receiveFrame(buffer, sizeof(buffer), flags); + assertTrue (n < 0); + + std::string payload("x"); + n = ws.sendFrame(payload.data(), (int) payload.size()); + assertTrue (n > 0); + if (ws.poll(1000000, Poco::Net::Socket::SELECT_READ)) + { + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + while (n < 0) + { + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + } + } + assertTrue (n == payload.size()); + assertTrue (payload.compare(0, payload.size(), buffer, n) == 0); + assertTrue (flags == WebSocket::FRAME_TEXT); + + ws.setSendBufferSize(256*1024); + ws.setReceiveBufferSize(256*1024); + + payload.assign(256000, 'z'); + n = ws.sendFrame(payload.data(), (int) payload.size()); + assertTrue (n > 0); + if (ws.poll(1000000, Poco::Net::Socket::SELECT_READ)) + { + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + while (n < 0) + { + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + } + } + assertTrue (n == payload.size()); + assertTrue (payload.compare(0, payload.size(), buffer, n) == 0); + assertTrue (flags == WebSocket::FRAME_TEXT); + + n = ws.shutdown(); + assertTrue (n > 0); + + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + while (n < 0) + { + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + } + assertTrue (n == 2); + assertTrue ((flags & WebSocket::FRAME_OP_BITMASK) == WebSocket::FRAME_OP_CLOSE); + + ws.close(); + server.stop(); +} + + void WebSocketTest::setUp() { } @@ -297,6 +372,7 @@ CppUnit::Test* WebSocketTest::suite() CppUnit_addTest(pSuite, WebSocketTest, testWebSocket); CppUnit_addTest(pSuite, WebSocketTest, testWebSocketLarge); CppUnit_addTest(pSuite, WebSocketTest, testWebSocketLargeInOneFrame); + CppUnit_addTest(pSuite, WebSocketTest, testWebSocketNB); return pSuite; } diff --git a/Net/testsuite/src/WebSocketTest.h b/Net/testsuite/src/WebSocketTest.h index a32a26633..4c453ab54 100644 --- a/Net/testsuite/src/WebSocketTest.h +++ b/Net/testsuite/src/WebSocketTest.h @@ -27,6 +27,7 @@ public: void testWebSocket(); void testWebSocketLarge(); void testWebSocketLargeInOneFrame(); + void testWebSocketNB(); void setUp(); void tearDown();