diff --git a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h index 0d08b2d3a..b3fa77c29 100644 --- a/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h +++ b/NetSSL_Win/include/Poco/Net/SecureSocketImpl.h @@ -207,6 +207,12 @@ public: int available() const; /// Returns the number of bytes available in the buffer. + SocketImpl* socket(); + /// Returns the underlying SocketImpl. + + const SocketImpl* socket() const; + /// Returns the underlying SocketImpl. + protected: enum { @@ -294,6 +300,8 @@ private: Poco::Buffer _overflowBuffer; Poco::Buffer _sendBuffer; + DWORD _sendBufferOffset; + DWORD _sendBufferPending; Poco::Buffer _recvBuffer; DWORD _recvBufferOffset; DWORD _ioBufferSize; @@ -316,6 +324,18 @@ private: // // inlines // +inline SocketImpl* SecureSocketImpl::socket() +{ + return _pSocket.get(); +} + + +inline const SocketImpl* SecureSocketImpl::socket() const +{ + return _pSocket.get(); +} + + inline poco_socket_t SecureSocketImpl::sockfd() { return _pSocket->sockfd(); diff --git a/NetSSL_Win/include/Poco/Net/SecureStreamSocketImpl.h b/NetSSL_Win/include/Poco/Net/SecureStreamSocketImpl.h index 96ef86e41..243f0a28a 100644 --- a/NetSSL_Win/include/Poco/Net/SecureStreamSocketImpl.h +++ b/NetSSL_Win/include/Poco/Net/SecureStreamSocketImpl.h @@ -197,6 +197,12 @@ public: /// Returns true iff a reused session was negotiated during /// the handshake. + // SocketImpl + virtual void setBlocking(bool flag) override; + virtual bool getBlocking() const override; + virtual void setRawOption(int level, int option, const void* value, poco_socklen_t length) override; + virtual void getRawOption(int level, int option, void* value, poco_socklen_t& length) override; + protected: void acceptSSL(); /// Performs a SSL server-side handshake. diff --git a/NetSSL_Win/src/SecureSocketImpl.cpp b/NetSSL_Win/src/SecureSocketImpl.cpp index 044a7a8c3..fb62c3ea4 100644 --- a/NetSSL_Win/src/SecureSocketImpl.cpp +++ b/NetSSL_Win/src/SecureSocketImpl.cpp @@ -71,6 +71,8 @@ SecureSocketImpl::SecureSocketImpl(Poco::AutoPtr pSocketImpl, Contex _contextFlags(0), _overflowBuffer(0), _sendBuffer(0), + _sendBufferOffset(0), + _sendBufferPending(0), _recvBuffer(IO_BUFFER_SIZE), _recvBufferOffset(0), _ioBufferSize(0), @@ -362,6 +364,17 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) if (_state == ST_ERROR) return 0; + if (_sendBufferPending > 0) + { + int sent = sendRawBytes(_sendBuffer.begin() + _sendBufferOffset, _sendBufferPending, flags); + if (sent > 0) + { + _sendBufferOffset += sent; + _sendBufferPending -= sent; + } + return _sendBufferPending == 0 ? length : -1; + } + if (_state != ST_DONE) { bool establish = _pSocket->getBlocking(); @@ -379,14 +392,13 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) } } - int rc = 0; int dataToSend = length; int dataSent = 0; const char* pBuffer = reinterpret_cast(buffer); - if (_sendBuffer.capacity() != _ioBufferSize) - _sendBuffer.setCapacity(_ioBufferSize); - + _sendBuffer.resize(length + 1024); + _sendBufferOffset = 0; + _sendBufferPending = 0; while (dataToSend > 0) { AutoSecBufferDesc<4> msg(&_securityFunctions, false); @@ -396,11 +408,16 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) SecBuffer* pDataBuffer = 0; SecBuffer* pExtraBuffer = 0; - std::memcpy(_sendBuffer.begin() + _streamSizes.cbHeader, pBuffer + dataSent, dataSize); + if (_sendBuffer.size() < _sendBufferPending + dataSize + _streamSizes.cbHeader + _streamSizes.cbTrailer) + { + _sendBuffer.resize(_sendBufferPending + dataSize + _streamSizes.cbHeader + _streamSizes.cbTrailer, true); + } - msg.setSecBufferStreamHeader(0, _sendBuffer.begin(), _streamSizes.cbHeader); - msg.setSecBufferData(1, _sendBuffer.begin() + _streamSizes.cbHeader, dataSize); - msg.setSecBufferStreamTrailer(2, _sendBuffer.begin() + _streamSizes.cbHeader + dataSize, _streamSizes.cbTrailer); + std::memcpy(_sendBuffer.begin() + _sendBufferPending + _streamSizes.cbHeader, pBuffer + dataSent, dataSize); + + msg.setSecBufferStreamHeader(0, _sendBuffer.begin() + _sendBufferPending, _streamSizes.cbHeader); + msg.setSecBufferData(1, _sendBuffer.begin() + _sendBufferPending + _streamSizes.cbHeader, dataSize); + msg.setSecBufferStreamTrailer(2, _sendBuffer.begin() + _sendBufferPending + _streamSizes.cbHeader + dataSize, _streamSizes.cbTrailer); msg.setSecBufferEmpty(3); SECURITY_STATUS securityStatus = _securityFunctions.EncryptMessage(&_hContext, 0, &msg, 0); @@ -408,24 +425,31 @@ int SecureSocketImpl::sendBytes(const void* buffer, int length, int flags) if (FAILED(securityStatus) && securityStatus != SEC_E_CONTEXT_EXPIRED) throw SSLException("Failed to encrypt message", Utility::formatError(securityStatus)); - int outBufferLen = msg[0].cbBuffer + msg[1].cbBuffer + msg[2].cbBuffer; - - int sent = sendRawBytes(_sendBuffer.begin(), outBufferLen, flags); - if (_pSocket->getBlocking() && sent == -1) - { - if (dataSent == 0) - return -1; - else - return dataSent; - } - if (sent != outBufferLen) - throw SSLException("Failed to send encrypted message"); + poco_assert_dbg (_streamSizes.cbHeader == msg[0].cbBuffer); + poco_assert_dbg (_streamSizes.cbTrailer == msg[2].cbBuffer); + poco_assert_dbg (dataSize == msg[1].cbBuffer); + _sendBufferPending += msg[0].cbBuffer + msg[1].cbBuffer + msg[2].cbBuffer; dataToSend -= dataSize; dataSent += dataSize; - rc += sent; + + poco_assert (_sendBufferPending <= _sendBuffer.size()); + } + + int sent = sendRawBytes(_sendBuffer.begin(), _sendBufferPending, flags); + if (sent >= 0) + { + _sendBufferOffset += sent; + _sendBufferPending -= sent; + if (_sendBufferPending > 0) + return -1; + else + return dataSent; + } + else + { + return -1; } - return dataSent; } @@ -463,14 +487,20 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) { rc = length; std::memcpy(buffer, _overflowBuffer.begin(), rc); - std::memmove(_overflowBuffer.begin(), _overflowBuffer.begin() + rc, overflowSize - rc); - _overflowBuffer.resize(overflowSize - rc); + if ((flags & MSG_PEEK) == 0) + { + std::memmove(_overflowBuffer.begin(), _overflowBuffer.begin() + rc, overflowSize - rc); + _overflowBuffer.resize(overflowSize - rc); + } } else { rc = static_cast(overflowSize); std::memcpy(buffer, _overflowBuffer.begin(), rc); - _overflowBuffer.resize(0); + if ((flags & MSG_PEEK) == 0) + { + _overflowBuffer.resize(0); + } } } else @@ -521,13 +551,21 @@ int SecureSocketImpl::receiveBytes(void* buffer, int length, int flags) rc = bytesDecoded; if (rc > length) rc = length; + + if (flags & MSG_PEEK) + { + // prepend returned data to overflow buffer + std::size_t oldSize = _overflowBuffer.size(); + _overflowBuffer.resize(oldSize + rc); + std::memmove(_overflowBuffer.begin() + rc, _overflowBuffer.begin(), oldSize); + std::memcpy(_overflowBuffer.begin(), buffer, rc); + } + return rc; } if (securityStatus == SEC_E_INCOMPLETE_MESSAGE) { - if (!_pSocket->getBlocking()) - return -1; continue; } diff --git a/NetSSL_Win/src/SecureStreamSocketImpl.cpp b/NetSSL_Win/src/SecureStreamSocketImpl.cpp index 04edfbf19..e90a03059 100644 --- a/NetSSL_Win/src/SecureStreamSocketImpl.cpp +++ b/NetSSL_Win/src/SecureStreamSocketImpl.cpp @@ -213,4 +213,28 @@ int SecureStreamSocketImpl::completeHandshake() } +void SecureStreamSocketImpl::setBlocking(bool flag) +{ + _impl.socket()->setBlocking(flag); +} + + +bool SecureStreamSocketImpl::getBlocking() const +{ + return _impl.socket()->getBlocking(); +} + + +void SecureStreamSocketImpl::setRawOption(int level, int option, const void* value, poco_socklen_t length) +{ + _impl.socket()->setRawOption(level, option, value, length); +} + + +void SecureStreamSocketImpl::getRawOption(int level, int option, void* value, poco_socklen_t& length) +{ + _impl.socket()->getRawOption(level, option, value, length); +} + + } } // namespace Poco::Net diff --git a/NetSSL_Win/testsuite/src/NetSSLTestSuite.cpp b/NetSSL_Win/testsuite/src/NetSSLTestSuite.cpp index 49b9eab08..449dfbc6f 100644 --- a/NetSSL_Win/testsuite/src/NetSSLTestSuite.cpp +++ b/NetSSL_Win/testsuite/src/NetSSLTestSuite.cpp @@ -9,19 +9,22 @@ #include "NetSSLTestSuite.h" - +#include "SecureStreamSocketTestSuite.h" #include "HTTPSClientTestSuite.h" #include "TCPServerTestSuite.h" #include "HTTPSServerTestSuite.h" +#include "WebSocketTestSuite.h" CppUnit::Test* NetSSLTestSuite::suite() { CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("NetSSLTestSuite"); + pSuite->addTest(SecureStreamSocketTestSuite::suite()); pSuite->addTest(HTTPSClientTestSuite::suite()); pSuite->addTest(TCPServerTestSuite::suite()); pSuite->addTest(HTTPSServerTestSuite::suite()); + pSuite->addTest(WebSocketTestSuite::suite()); return pSuite; } diff --git a/NetSSL_Win/testsuite/src/SecureStreamSocketTest.cpp b/NetSSL_Win/testsuite/src/SecureStreamSocketTest.cpp new file mode 100644 index 000000000..bf4a6a9b1 --- /dev/null +++ b/NetSSL_Win/testsuite/src/SecureStreamSocketTest.cpp @@ -0,0 +1,222 @@ +// +// SecureStreamSocketTest.cpp +// +// Copyright (c) 2006, Applied Informatics Software Engineering GmbH. +// and Contributors. +// +// SPDX-License-Identifier: BSL-1.0 +// + + +#include "SecureStreamSocketTest.h" +#include "CppUnit/TestCaller.h" +#include "CppUnit/TestSuite.h" +#include "Poco/Net/TCPServer.h" +#include "Poco/Net/TCPServerConnection.h" +#include "Poco/Net/TCPServerConnectionFactory.h" +#include "Poco/Net/TCPServerParams.h" +#include "Poco/Net/SecureStreamSocket.h" +#include "Poco/Net/SecureServerSocket.h" +#include "Poco/Net/Context.h" +#include "Poco/Net/RejectCertificateHandler.h" +#include "Poco/Net/AcceptCertificateHandler.h" +#include "Poco/Net/Session.h" +#include "Poco/Net/SSLManager.h" +#include "Poco/Util/Application.h" +#include "Poco/Util/AbstractConfiguration.h" +#include "Poco/Thread.h" +#include + + +using Poco::Net::TCPServer; +using Poco::Net::TCPServerConnection; +using Poco::Net::TCPServerConnectionFactory; +using Poco::Net::TCPServerConnectionFactoryImpl; +using Poco::Net::TCPServerParams; +using Poco::Net::StreamSocket; +using Poco::Net::SecureStreamSocket; +using Poco::Net::SecureServerSocket; +using Poco::Net::SocketAddress; +using Poco::Net::Context; +using Poco::Net::Session; +using Poco::Net::SSLManager; +using Poco::Thread; +using Poco::Util::Application; + + +namespace +{ + class EchoConnection: public TCPServerConnection + { + public: + EchoConnection(const StreamSocket& s): TCPServerConnection(s) + { + } + + void run() + { + StreamSocket& ss = socket(); + try + { + char buffer[256]; + int n = ss.receiveBytes(buffer, sizeof(buffer)); + while (n > 0) + { + ss.sendBytes(buffer, n); + n = ss.receiveBytes(buffer, sizeof(buffer)); + } + } + catch (Poco::Exception& exc) + { + std::cerr << "EchoConnection: " << exc.displayText() << std::endl; + } + } + }; +} + + +SecureStreamSocketTest::SecureStreamSocketTest(const std::string& name): CppUnit::TestCase(name) +{ +} + + +SecureStreamSocketTest::~SecureStreamSocketTest() +{ +} + + +void SecureStreamSocketTest::testSendReceive() +{ + SecureServerSocket svs(0); + TCPServer srv(new TCPServerConnectionFactoryImpl(), svs); + srv.start(); + + SocketAddress sa("127.0.0.1", svs.address().port()); + SecureStreamSocket ss1(sa); + std::string data("hello, world"); + ss1.sendBytes(data.data(), static_cast(data.size())); + char buffer[8192]; + int n = ss1.receiveBytes(buffer, sizeof(buffer)); + assertTrue (n > 0); + assertTrue (std::string(buffer, n) == data); + + const std::vector sizes = {67, 467, 7883, 19937}; + for (const auto n: sizes) + { + data.assign(n, 'X'); + ss1.sendBytes(data.data(), static_cast(data.size())); + std::string received; + while (received.size() < n) + { + int rc = ss1.receiveBytes(buffer, sizeof(buffer)); + if (rc > 0) + { + received.append(buffer, rc); + } + else if (n == 0) + { + break; + } + } + assertTrue (received == data); + } + + ss1.close(); +} + + +void SecureStreamSocketTest::testPeek() +{ + SecureServerSocket svs(0); + TCPServer srv(new TCPServerConnectionFactoryImpl(), svs); + srv.start(); + + SocketAddress sa("127.0.0.1", svs.address().port()); + SecureStreamSocket ss(sa); + + int n = ss.sendBytes("hello, world!", 13); + assertTrue (n == 13); + char buffer[256]; + n = ss.receiveBytes(buffer, 5, MSG_PEEK); + assertTrue (n == 5); + assertTrue (std::string(buffer, n) == "hello"); + n = ss.receiveBytes(buffer, sizeof(buffer), MSG_PEEK); + assertTrue (n == 13); + assertTrue (std::string(buffer, n) == "hello, world!"); + n = ss.receiveBytes(buffer, 7); + assertTrue (n == 7); + assertTrue (std::string(buffer, n) == "hello, "); + n = ss.receiveBytes(buffer, 6); + assertTrue (n == 6); + assertTrue (std::string(buffer, n) == "world!"); + ss.close(); +} + + +void SecureStreamSocketTest::testNB() +{ + SecureServerSocket svs(0); + TCPServer srv(new TCPServerConnectionFactoryImpl(), svs); + srv.start(); + + SocketAddress sa("127.0.0.1", svs.address().port()); + SecureStreamSocket ss1(sa); + ss1.setBlocking(false); + ss1.setSendBufferSize(32000); + + char buffer[8192]; + const std::vector sizes = {67, 467, 7883, 19937}; + for (const auto n: sizes) + { + std::string data(n, 'X'); + int rc = ss1.sendBytes(data.data(), static_cast(data.size())); + assertTrue (rc == n); + + rc = -1; + while (rc < 0) + { + rc = ss1.receiveBytes(buffer, sizeof(buffer), MSG_PEEK); + } + assertTrue (rc > 0 && rc <= n); + assertTrue (data.compare(0, rc, buffer, rc) == 0); + + std::string received; + while (received.size() < n) + { + int rc = ss1.receiveBytes(buffer, sizeof(buffer)); + if (rc > 0) + { + received.append(buffer, rc); + } + else if (n == 0) + { + break; + } + } + assertTrue (received == data); + } + + ss1.close(); +} + + +void SecureStreamSocketTest::setUp() +{ +} + + +void SecureStreamSocketTest::tearDown() +{ +} + + +CppUnit::Test* SecureStreamSocketTest::suite() +{ + CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("SecureStreamSocketTest"); + + CppUnit_addTest(pSuite, SecureStreamSocketTest, testSendReceive); + CppUnit_addTest(pSuite, SecureStreamSocketTest, testPeek); + CppUnit_addTest(pSuite, SecureStreamSocketTest, testNB); + + return pSuite; +} diff --git a/NetSSL_Win/testsuite/src/SecureStreamSocketTest.h b/NetSSL_Win/testsuite/src/SecureStreamSocketTest.h new file mode 100644 index 000000000..62ac9e1c7 --- /dev/null +++ b/NetSSL_Win/testsuite/src/SecureStreamSocketTest.h @@ -0,0 +1,40 @@ +// +// SecureStreamSocketTest.h +// +// Definition of the SecureStreamSocketTest class. +// +// Copyright (c) 2006, Applied Informatics Software Engineering GmbH. +// and Contributors. +// +// SPDX-License-Identifier: BSL-1.0 +// + + +#ifndef SecureStreamSocketTest_INCLUDED +#define SecureStreamSocketTest_INCLUDED + + +#include "Poco/Net/Net.h" +#include "CppUnit/TestCase.h" + + +class SecureStreamSocketTest: public CppUnit::TestCase +{ +public: + SecureStreamSocketTest(const std::string& name); + ~SecureStreamSocketTest(); + + void testSendReceive(); + void testPeek(); + void testNB(); + + void setUp(); + void tearDown(); + + static CppUnit::Test* suite(); + +private: +}; + + +#endif // SecureStreamSocketTest diff --git a/NetSSL_Win/testsuite/src/SecureStreamSocketTestSuite.cpp b/NetSSL_Win/testsuite/src/SecureStreamSocketTestSuite.cpp new file mode 100644 index 000000000..1aa23be2a --- /dev/null +++ b/NetSSL_Win/testsuite/src/SecureStreamSocketTestSuite.cpp @@ -0,0 +1,22 @@ +// +// SecureStreamSocketTestSuite.cpp +// +// Copyright (c) 2006, Applied Informatics Software Engineering GmbH. +// and Contributors. +// +// SPDX-License-Identifier: BSL-1.0 +// + + +#include "SecureStreamSocketTestSuite.h" +#include "SecureStreamSocketTest.h" + + +CppUnit::Test* SecureStreamSocketTestSuite::suite() +{ + CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("SecureStreamSocketTestSuite"); + + pSuite->addTest(SecureStreamSocketTest::suite()); + + return pSuite; +} diff --git a/NetSSL_Win/testsuite/src/SecureStreamSocketTestSuite.h b/NetSSL_Win/testsuite/src/SecureStreamSocketTestSuite.h new file mode 100644 index 000000000..0bbf4cc3c --- /dev/null +++ b/NetSSL_Win/testsuite/src/SecureStreamSocketTestSuite.h @@ -0,0 +1,27 @@ +// +// SecureStreamSocketTestSuite.h +// +// Definition of the SecureStreamSocketTestSuite class. +// +// Copyright (c) 2006, Applied Informatics Software Engineering GmbH. +// and Contributors. +// +// SPDX-License-Identifier: BSL-1.0 +// + + +#ifndef SecureStreamSocketTestSuite_INCLUDED +#define SecureStreamSocketTestSuite_INCLUDED + + +#include "CppUnit/TestSuite.h" + + +class SecureStreamSocketTestSuite +{ +public: + static CppUnit::Test* suite(); +}; + + +#endif // SecureStreamSocketTestSuite diff --git a/NetSSL_Win/testsuite/src/WebSocketTestSuite.cpp b/NetSSL_Win/testsuite/src/WebSocketTestSuite.cpp new file mode 100644 index 000000000..1043d8b0c --- /dev/null +++ b/NetSSL_Win/testsuite/src/WebSocketTestSuite.cpp @@ -0,0 +1,22 @@ +// +// WebSocketTestSuite.cpp +// +// Copyright (c) 2012, Applied Informatics Software Engineering GmbH. +// and Contributors. +// +// SPDX-License-Identifier: BSL-1.0 +// + + +#include "WebSocketTestSuite.h" +#include "WebSocketTest.h" + + +CppUnit::Test* WebSocketTestSuite::suite() +{ + CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("WebSocketTestSuite"); + + pSuite->addTest(WebSocketTest::suite()); + + return pSuite; +} diff --git a/NetSSL_Win/testsuite/src/WebSocketTestSuite.h b/NetSSL_Win/testsuite/src/WebSocketTestSuite.h new file mode 100644 index 000000000..b3d33b2c8 --- /dev/null +++ b/NetSSL_Win/testsuite/src/WebSocketTestSuite.h @@ -0,0 +1,27 @@ +// +// WebSocketTestSuite.h +// +// Definition of the WebSocketTestSuite class. +// +// Copyright (c) 2012, Applied Informatics Software Engineering GmbH. +// and Contributors. +// +// SPDX-License-Identifier: BSL-1.0 +// + + +#ifndef WebSocketTestSuite_INCLUDED +#define WebSocketTestSuite_INCLUDED + + +#include "CppUnit/TestSuite.h" + + +class WebSocketTestSuite +{ +public: + static CppUnit::Test* suite(); +}; + + +#endif // WebSocketTestSuite_INCLUDED