From 66c1d30bb4dab0b9c3c5258d7409dc94eb8b7382 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnter=20Obiltschnig?= Date: Sat, 8 Feb 2025 12:57:24 +0100 Subject: [PATCH] enh(Net): StreamSocket::sendFile() takes additional count parameter --- Net/include/Poco/Net/SocketImpl.h | 9 +- Net/include/Poco/Net/StreamSocket.h | 5 +- Net/src/SocketImpl.cpp | 63 +++++--- Net/src/StreamSocket.cpp | 4 +- Net/testsuite/src/SocketTest.cpp | 150 ++++++++++++++++-- Net/testsuite/src/SocketTest.h | 2 + .../testsuite/src/SecureStreamSocketTest.cpp | 137 ++++++++++++++-- .../testsuite/src/SecureStreamSocketTest.h | 2 + 8 files changed, 309 insertions(+), 63 deletions(-) diff --git a/Net/include/Poco/Net/SocketImpl.h b/Net/include/Poco/Net/SocketImpl.h index 62170ddfd..4537f1965 100644 --- a/Net/include/Poco/Net/SocketImpl.h +++ b/Net/include/Poco/Net/SocketImpl.h @@ -287,9 +287,12 @@ public: /// The preferred way for a socket to receive urgent data /// is by enabling the SO_OOBINLINE option. - virtual std::streamsize sendFile(Poco::FileInputStream& FileInputStream, std::streamoff offset = 0); + virtual std::streamsize sendFile(Poco::FileInputStream& FileInputStream, std::streamoff offset = 0, std::streamsize count = 0); /// Sends the contents of a file in an optimized way, if possible. /// + /// If count is != 0, sends the given number of bytes, otherwise + /// sends all bytes, starting from the given offset. + /// /// On POSIX systems, this means using sendfile() or sendfile64(). /// On Windows, this means using TransmitFile(). /// @@ -544,11 +547,11 @@ protected: void checkBrokenTimeout(SelectMode mode); - std::streamsize sendFileNative(Poco::FileInputStream& FileInputStream, std::streamoff offset); + std::streamsize sendFileNative(Poco::FileInputStream& FileInputStream, std::streamoff offset, std::streamsize count); /// Implements sendFile() using an OS-specific API like /// sendfile() or TransmitFile(). - std::streamsize sendFileBlockwise(Poco::FileInputStream& FileInputStream, std::streamoff offset); + std::streamsize sendFileBlockwise(Poco::FileInputStream& FileInputStream, std::streamoff offset, std::streamsize count); /// Implements sendFile() by reading the file blockwise and /// calling sendBytes() for each block. diff --git a/Net/include/Poco/Net/StreamSocket.h b/Net/include/Poco/Net/StreamSocket.h index 7b007f352..e81c25a4d 100644 --- a/Net/include/Poco/Net/StreamSocket.h +++ b/Net/include/Poco/Net/StreamSocket.h @@ -268,9 +268,12 @@ public: /// The preferred way for a socket to receive urgent data /// is by enabling the SO_OOBINLINE option. - std::streamsize sendFile(Poco::FileInputStream& FileInputStream, std::streamoff offset = 0); + std::streamsize sendFile(Poco::FileInputStream& FileInputStream, std::streamoff offset = 0, std::streamsize count = 0); /// Sends the contents of a file in an optimized way, if possible. /// + /// If count is != 0, sends the given number of bytes, otherwise + /// sends all bytes, starting from the given offset. + /// /// On POSIX systems, this means using sendfile() or sendfile64(). /// On Windows, this means using TransmitFile(). /// diff --git a/Net/src/SocketImpl.cpp b/Net/src/SocketImpl.cpp index c79757999..055ffcee7 100644 --- a/Net/src/SocketImpl.cpp +++ b/Net/src/SocketImpl.cpp @@ -698,21 +698,21 @@ void SocketImpl::sendUrgent(unsigned char data) -std::streamsize SocketImpl::sendFile(FileInputStream& fileInputStream, std::streamoff offset) +std::streamsize SocketImpl::sendFile(FileInputStream& fileInputStream, std::streamoff offset, std::streamsize count) { if (!getBlocking()) throw NetException("sendFile() not supported for non-blocking sockets"); #ifdef POCO_HAVE_SENDFILE if (secure()) { - return sendFileBlockwise(fileInputStream, offset); + return sendFileBlockwise(fileInputStream, offset, count); } else { - return sendFileNative(fileInputStream, offset); + return sendFileNative(fileInputStream, offset, count); } #else - return sendFileBlockwise(fileInputStream, offset); + return sendFileBlockwise(fileInputStream, offset, count); #endif } @@ -1451,11 +1451,10 @@ void SocketImpl::error(int code, const std::string& arg) #ifdef POCO_OS_FAMILY_WINDOWS -std::streamsize SocketImpl::sendFileNative(FileInputStream& fileInputStream, std::streamoff offset) +std::streamsize SocketImpl::sendFileNative(FileInputStream& fileInputStream, std::streamoff offset, std::streamsize count) { FileIOS::NativeHandle fd = fileInputStream.nativeHandle(); - UInt64 fileSize = fileInputStream.size(); - std::streamoff sentSize = fileSize - offset; + if (count == 0) count = fileInputStream.size() - offset; LARGE_INTEGER offsetHelper; offsetHelper.QuadPart = offset; OVERLAPPED overlapped; @@ -1468,7 +1467,7 @@ std::streamsize SocketImpl::sendFileNative(FileInputStream& fileInputStream, std int err = GetLastError(); error(err); } - bool result = TransmitFile(_sockfd, fd, sentSize, 0, &overlapped, nullptr, 0); + bool result = TransmitFile(_sockfd, fd, count, 0, &overlapped, nullptr, 0); if (!result) { int err = WSAGetLastError(); @@ -1480,7 +1479,7 @@ std::streamsize SocketImpl::sendFileNative(FileInputStream& fileInputStream, std WaitForSingleObject(overlapped.hEvent, INFINITE); } CloseHandle(overlapped.hEvent); - return sentSize; + return count; } @@ -1489,23 +1488,33 @@ std::streamsize SocketImpl::sendFileNative(FileInputStream& fileInputStream, std namespace { - std::streamsize sendFilePosix(poco_socket_t sd, FileIOS::NativeHandle fd, std::streamsize offset, std::streamoff sentSize) + std::streamsize sendFileUnix(poco_socket_t sd, FileIOS::NativeHandle fd, std::streamsize offset, std::streamoff count) { Int64 sent = 0; #ifdef __USE_LARGEFILE64 - sent = sendfile64(sd, fd, (off64_t*) &offset, sentSize); + sent = sendfile64(sd, fd, (off64_t*) &offset, count); #else #if POCO_OS == POCO_OS_LINUX && !defined(POCO_EMSCRIPTEN) - sent = sendfile(sd, fd, (off_t*) &offset, sentSize); + sent = sendfile(sd, fd, (off_t*) &offset, count); #elif POCO_OS == POCO_OS_MAC_OS_X - int result = sendfile(fd, sd, offset, &sentSize, nullptr, 0); + int result = sendfile(fd, sd, offset, &count, NULL, 0); if (result < 0) { sent = -1; } else { - sent = sentSize; + sent = count; + } + #elif POCO_OS == POCO_OS_FREE_BSD + int result = sendfile(fd, sd, offset, &count, NULL, NULL, 0); + if (result < 0) + { + sent = -1; + } + else + { + sent = count; } #else throw Poco::NotImplementedException("sendfile not implemented for this platform"); @@ -1520,12 +1529,11 @@ namespace } -std::streamsize SocketImpl::sendFileNative(FileInputStream& fileInputStream, std::streamoff offset) +std::streamsize SocketImpl::sendFileNative(FileInputStream& fileInputStream, std::streamoff offset, std::streamsize count) { FileIOS::NativeHandle fd = fileInputStream.nativeHandle(); - UInt64 fileSize = fileInputStream.size(); - std::streamoff sentSize = fileSize - offset; - Int64 sent = 0; + if (count == 0) count = fileInputStream.size() - offset; + std::streamsize sent = 0; struct sigaction sa, old_sa; sa.sa_handler = SIG_IGN; sigemptyset(&sa.sa_mask); @@ -1534,7 +1542,7 @@ std::streamsize SocketImpl::sendFileNative(FileInputStream& fileInputStream, std while (sent == 0) { errno = 0; - sent = sendFilePosix(_sockfd, fd, offset, sentSize); + sent = sendFileUnix(_sockfd, fd, offset, count); if (sent < 0) { error(errno); @@ -1553,21 +1561,28 @@ std::streamsize SocketImpl::sendFileNative(FileInputStream& fileInputStream, std #endif // POCO_HAVE_SENDFILE -std::streamsize SocketImpl::sendFileBlockwise(FileInputStream& fileInputStream, std::streamoff offset) +std::streamsize SocketImpl::sendFileBlockwise(FileInputStream& fileInputStream, std::streamoff offset, std::streamsize count) { - fileInputStream.seekg(offset); + fileInputStream.seekg(offset, std::ios_base::beg); Poco::Buffer buffer(8192); + std::size_t bufferSize = buffer.size(); + if (count > 0 && bufferSize > count) bufferSize = count; std::streamsize len = 0; - fileInputStream.read(buffer.begin(), buffer.size()); + fileInputStream.read(buffer.begin(), bufferSize); std::streamsize n = fileInputStream.gcount(); - while (n > 0) + while (n > 0 && (count == 0 || len < count)) { len += n; sendBytes(buffer.begin(), n); + if (count > 0 && len < count) + { + const std::size_t remaining = count - len; + if (bufferSize > remaining) bufferSize = remaining; + } if (fileInputStream) { - fileInputStream.read(buffer.begin(), buffer.size()); + fileInputStream.read(buffer.begin(), bufferSize); n = fileInputStream.gcount(); } else n = 0; diff --git a/Net/src/StreamSocket.cpp b/Net/src/StreamSocket.cpp index 20023ed8a..d71de4c23 100644 --- a/Net/src/StreamSocket.cpp +++ b/Net/src/StreamSocket.cpp @@ -214,9 +214,9 @@ void StreamSocket::sendUrgent(unsigned char data) } -std::streamsize StreamSocket::sendFile(Poco::FileInputStream& fileInputStream, std::streamoff offset) +std::streamsize StreamSocket::sendFile(Poco::FileInputStream& fileInputStream, std::streamoff offset, std::streamsize count) { - return impl()->sendFile(fileInputStream, offset); + return impl()->sendFile(fileInputStream, offset, count); } diff --git a/Net/testsuite/src/SocketTest.cpp b/Net/testsuite/src/SocketTest.cpp index 52b076334..31699ff33 100644 --- a/Net/testsuite/src/SocketTest.cpp +++ b/Net/testsuite/src/SocketTest.cpp @@ -15,6 +15,9 @@ #include "Poco/Net/StreamSocket.h" #include "Poco/Net/ServerSocket.h" #include "Poco/Net/SocketAddress.h" +#include "Poco/Net/TCPServerConnection.h" +#include "Poco/Net/TCPServerConnectionFactory.h" +#include "Poco/Net/TCPServer.h" #include "Poco/Net/NetException.h" #include "Poco/Timespan.h" #include "Poco/Stopwatch.h" @@ -25,6 +28,7 @@ #include "Poco/TemporaryFile.h" #include "Poco/FileStream.h" #include "Poco/Path.h" +#include "Poco/Thread.h" #include @@ -33,6 +37,9 @@ using Poco::Net::StreamSocket; using Poco::Net::ServerSocket; using Poco::Net::SocketAddress; using Poco::Net::ConnectionRefusedException; +using Poco::Net::TCPServerConnection; +using Poco::Net::TCPServerConnectionFactoryImpl; +using Poco::Net::TCPServer; using Poco::Timespan; using Poco::Stopwatch; using Poco::TimeoutException; @@ -44,6 +51,49 @@ using Poco::File; using Poco::delegate; +namespace +{ + class CopyToStringConnection: public TCPServerConnection + { + public: + CopyToStringConnection(const StreamSocket& s): + TCPServerConnection(s) + { + } + + void run() + { + _data.clear(); + StreamSocket& ss = socket(); + try + { + char buffer[256]; + int n = ss.receiveBytes(buffer, sizeof(buffer)); + while (n > 0) + { + _data.append(buffer, n); + n = ss.receiveBytes(buffer, sizeof(buffer)); + } + } + catch (Poco::Exception& exc) + { + std::cerr << "CopyToStringConnection: " << exc.displayText() << std::endl; + } + } + + static const std::string& data() + { + return _data; + } + + private: + static std::string _data; + }; + + std::string CopyToStringConnection::_data; +} + + SocketTest::SocketTest(const std::string& name): CppUnit::TestCase(name) { } @@ -671,9 +721,12 @@ void SocketTest::testUseFd() void SocketTest::testSendFile() { - EchoServer echoServer; + ServerSocket svs(0); + TCPServer srv(new TCPServerConnectionFactoryImpl(), svs); + srv.start(); + StreamSocket ss; - ss.connect(SocketAddress("127.0.0.1", echoServer.port())); + ss.connect(SocketAddress("127.0.0.1", srv.port())); std::string sentData = "Hello, world!"; @@ -683,28 +736,89 @@ void SocketTest::testSendFile() ostr.close(); Poco::FileInputStream istr(file.path()); - std::streamsize sent = 0; - while (sent < file.getSize()) - { - sent += ss.sendFile(istr, sent); - } + ss.sendFile(istr); istr.close(); + ss.close(); - std::string receivedData; - char buffer[1024]; - int n = ss.receiveBytes(buffer, sizeof(buffer)); - while (n > 0) + while (srv.currentConnections() > 0) { - receivedData.append(buffer, n); - if (receivedData.size() < sentData.size()) - n = ss.receiveBytes(buffer, sizeof(buffer)); - else - n = 0; + Poco::Thread::sleep(100); } - assertTrue (receivedData == sentData); + assertTrue (CopyToStringConnection::data() == sentData); +} + +void SocketTest::testSendFileLarge() +{ + ServerSocket svs(0); + TCPServer srv(new TCPServerConnectionFactoryImpl(), svs); + srv.start(); + + StreamSocket ss; + ss.connect(SocketAddress("127.0.0.1", srv.port())); + + std::string sentData; + + Poco::TemporaryFile file; + Poco::FileOutputStream ostr(file.path()); + std::string data("0123456789abcdef"); + for (int i = 0; i < 10000; i++) + { + ostr.write(data.data(), data.size()); + sentData += data; + } + ostr.close(); + + Poco::FileInputStream istr(file.path()); + ss.sendFile(istr); + istr.close(); ss.close(); + + while (srv.currentConnections() > 0) + { + Poco::Thread::sleep(100); + } + + assertTrue (CopyToStringConnection::data() == sentData); +} + + +void SocketTest::testSendFileRange() +{ + ServerSocket svs(0); + TCPServer srv(new TCPServerConnectionFactoryImpl(), svs); + srv.start(); + + StreamSocket ss; + ss.connect(SocketAddress("127.0.0.1", srv.port())); + + std::string sentData; + + Poco::TemporaryFile file; + Poco::FileOutputStream ostr(file.path()); + std::string data("0123456789abcdef"); + for (int i = 0; i < 1024; i++) + { + ostr.write(data.data(), data.size()); + sentData += data; + } + ostr.close(); + + const std::streamoff offset = 4000; + const std::streamsize count = 10000; + + Poco::FileInputStream istr(file.path()); + ss.sendFile(istr, offset, count); + istr.close(); + ss.close(); + + while (srv.currentConnections() > 0) + { + Poco::Thread::sleep(100); + } + + assertTrue (CopyToStringConnection::data() == sentData.substr(offset, count)); } @@ -766,6 +880,8 @@ CppUnit::Test* SocketTest::suite() CppUnit_addTest(pSuite, SocketTest, testUnixLocalAbstract); CppUnit_addTest(pSuite, SocketTest, testUseFd); CppUnit_addTest(pSuite, SocketTest, testSendFile); + CppUnit_addTest(pSuite, SocketTest, testSendFileLarge); + CppUnit_addTest(pSuite, SocketTest, testSendFileRange); return pSuite; } diff --git a/Net/testsuite/src/SocketTest.h b/Net/testsuite/src/SocketTest.h index af026e7d8..2cc4d8f57 100644 --- a/Net/testsuite/src/SocketTest.h +++ b/Net/testsuite/src/SocketTest.h @@ -50,6 +50,8 @@ public: void testUnixLocalAbstract(); void testUseFd(); void testSendFile(); + void testSendFileLarge(); + void testSendFileRange(); void setUp() override; void tearDown() override; diff --git a/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.cpp b/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.cpp index dd04d621e..ddf6006f1 100644 --- a/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.cpp +++ b/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.cpp @@ -75,6 +75,45 @@ namespace } } }; + + class CopyToStringConnection: public TCPServerConnection + { + public: + CopyToStringConnection(const StreamSocket& s): + TCPServerConnection(s) + { + } + + void run() + { + _data.clear(); + StreamSocket& ss = socket(); + try + { + char buffer[256]; + int n = ss.receiveBytes(buffer, sizeof(buffer)); + while (n > 0) + { + _data.append(buffer, n); + n = ss.receiveBytes(buffer, sizeof(buffer)); + } + } + catch (Poco::Exception& exc) + { + std::cerr << "CopyToStringConnection: " << exc.displayText() << std::endl; + } + } + + static const std::string& data() + { + return _data; + } + + private: + static std::string _data; + }; + + std::string CopyToStringConnection::_data; } @@ -206,7 +245,7 @@ void SecureStreamSocketTest::testNB() void SecureStreamSocketTest::testSendFile() { SecureServerSocket svs(0); - TCPServer srv(new TCPServerConnectionFactoryImpl(), svs); + TCPServer srv(new TCPServerConnectionFactoryImpl(), svs); srv.start(); SecureStreamSocket ss; @@ -220,28 +259,92 @@ void SecureStreamSocketTest::testSendFile() ostr.close(); Poco::FileInputStream istr(file.path()); - std::streamsize sent = 0; - while (sent < file.getSize()) - { - sent += ss.sendFile(istr, sent); - } + ss.sendFile(istr); istr.close(); + ss.close(); - std::string receivedData; - char buffer[1024]; - int n = ss.receiveBytes(buffer, sizeof(buffer)); - while (n > 0) + while (srv.currentConnections() > 0) { - receivedData.append(buffer, n); - if (receivedData.size() < sentData.size()) - n = ss.receiveBytes(buffer, sizeof(buffer)); - else - n = 0; + Poco::Thread::sleep(100); } - assertTrue (receivedData == sentData); + assertTrue (CopyToStringConnection::data() == sentData); +} + +void SecureStreamSocketTest::testSendFileLarge() +{ + SecureServerSocket svs(0); + TCPServer srv(new TCPServerConnectionFactoryImpl(), svs); + srv.start(); + + SecureStreamSocket ss; + ss.connect(SocketAddress("127.0.0.1", srv.port())); + + std::string sentData; + + Poco::TemporaryFile file; + Poco::FileOutputStream ostr(file.path()); + std::string data("0123456789abcdef"); + for (int i = 0; i < 10000; i++) + { + ostr.write(data.data(), data.size()); + sentData += data; + } + ostr.close(); + + Poco::FileInputStream istr(file.path()); + ss.sendFile(istr); + istr.close(); ss.close(); + + while (srv.currentConnections() > 0) + { + Poco::Thread::sleep(100); + } + + assertTrue (CopyToStringConnection::data() == sentData); +} + + +void SecureStreamSocketTest::testSendFileRange() +{ + SecureServerSocket svs(0); + TCPServer srv(new TCPServerConnectionFactoryImpl(), svs); + srv.start(); + + SecureStreamSocket ss; + ss.connect(SocketAddress("127.0.0.1", srv.port())); + + std::string fileData; + + Poco::TemporaryFile file; + Poco::FileOutputStream ostr(file.path()); + std::string data("0123456789abcdef"); + for (int i = 0; i < 1024; i++) + { + ostr.write(data.data(), data.size()); + fileData += data; + } + ostr.close(); + + const std::streamoff offset = 4000; + const std::streamsize count = 10000; + + Poco::FileInputStream istr(file.path()); + ss.sendFile(istr, offset, count); + istr.close(); + ss.close(); + + while (srv.currentConnections() > 0) + { + Poco::Thread::sleep(100); + } + + std::cout << "\n" << fileData.size() << std::endl; + std::cout << "\n" << CopyToStringConnection::data().size() << std::endl; + + assertTrue (CopyToStringConnection::data() == fileData.substr(offset, count)); } @@ -263,6 +366,8 @@ CppUnit::Test* SecureStreamSocketTest::suite() CppUnit_addTest(pSuite, SecureStreamSocketTest, testPeek); CppUnit_addTest(pSuite, SecureStreamSocketTest, testNB); CppUnit_addTest(pSuite, SecureStreamSocketTest, testSendFile); + CppUnit_addTest(pSuite, SecureStreamSocketTest, testSendFileLarge); + CppUnit_addTest(pSuite, SecureStreamSocketTest, testSendFileRange); return pSuite; } diff --git a/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.h b/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.h index f072e90fc..eb4486e22 100644 --- a/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.h +++ b/NetSSL_OpenSSL/testsuite/src/SecureStreamSocketTest.h @@ -28,6 +28,8 @@ public: void testPeek(); void testNB(); void testSendFile(); + void testSendFileLarge(); + void testSendFileRange(); void setUp(); void tearDown();