From 5d69308bfd5efcc3ee89de34f3217ba776d04441 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnter=20Obiltschnig?= Date: Sat, 16 Nov 2024 16:48:48 +0100 Subject: [PATCH] fix(Net): Non-blocking sockets support #4773 --- Net/src/SocketImpl.cpp | 88 +++++++++++++++++++++++++------- Net/src/StreamSocketImpl.cpp | 2 +- Net/testsuite/src/SocketTest.cpp | 22 ++++++++ Net/testsuite/src/SocketTest.h | 1 + 4 files changed, 93 insertions(+), 20 deletions(-) diff --git a/Net/src/SocketImpl.cpp b/Net/src/SocketImpl.cpp index deaf92f1a..579dfc656 100644 --- a/Net/src/SocketImpl.cpp +++ b/Net/src/SocketImpl.cpp @@ -361,8 +361,10 @@ void SocketImpl::checkBrokenTimeout(SelectMode mode) int SocketImpl::sendBytes(const void* buffer, int length, int flags) { - checkBrokenTimeout(SELECT_WRITE); - + if (_blocking) + { + checkBrokenTimeout(SELECT_WRITE); + } int rc; do { @@ -370,15 +372,26 @@ int SocketImpl::sendBytes(const void* buffer, int length, int flags) rc = ::send(_sockfd, reinterpret_cast(buffer), length, flags); } while (_blocking && rc < 0 && lastError() == POCO_EINTR); - if (rc < 0) error(); + if (rc < 0) + { + int err = lastError(); + if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK)) + ; + else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT) + throw TimeoutException(err); + else + error(err); + } return rc; } int SocketImpl::sendBytes(const SocketBufVec& buffers, int flags) { - checkBrokenTimeout(SELECT_WRITE); - + if (_blocking) + { + checkBrokenTimeout(SELECT_WRITE); + } int rc = 0; do { @@ -395,15 +408,26 @@ int SocketImpl::sendBytes(const SocketBufVec& buffers, int flags) #endif } while (_blocking && rc < 0 && lastError() == POCO_EINTR); - if (rc < 0) error(); + if (rc < 0) + { + int err = lastError(); + if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK)) + ; + else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT) + throw TimeoutException(err); + else + error(err); + } return rc; } int SocketImpl::receiveBytes(void* buffer, int length, int flags) { - checkBrokenTimeout(SELECT_READ); - + if (_blocking) + { + checkBrokenTimeout(SELECT_READ); + } int rc; do { @@ -414,7 +438,7 @@ int SocketImpl::receiveBytes(void* buffer, int length, int flags) if (rc < 0) { int err = lastError(); - if (err == POCO_EAGAIN && !_blocking) + if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK)) ; else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT) throw TimeoutException(err); @@ -427,8 +451,10 @@ int SocketImpl::receiveBytes(void* buffer, int length, int flags) int SocketImpl::receiveBytes(SocketBufVec& buffers, int flags) { - checkBrokenTimeout(SELECT_READ); - + if (_blocking) + { + checkBrokenTimeout(SELECT_READ); + } int rc = 0; do { @@ -448,7 +474,7 @@ int SocketImpl::receiveBytes(SocketBufVec& buffers, int flags) if (rc < 0) { int err = lastError(); - if (err == POCO_EAGAIN && !_blocking) + if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK)) ; else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT) throw TimeoutException(err); @@ -476,7 +502,7 @@ int SocketImpl::receiveBytes(Poco::Buffer& buffer, int flags, const Poco:: if (rc < 0) { int err = lastError(); - if (err == POCO_EAGAIN && !_blocking) + if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK)) ; else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT) throw TimeoutException(err); @@ -502,7 +528,16 @@ int SocketImpl::sendTo(const void* buffer, int length, const SocketAddress& addr #endif } while (_blocking && rc < 0 && lastError() == POCO_EINTR); - if (rc < 0) error(); + if (rc < 0) + { + int err = lastError(); + if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK)) + ; + else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT) + throw TimeoutException(err); + else + error(err); + } return rc; } @@ -534,7 +569,16 @@ int SocketImpl::sendTo(const SocketBufVec& buffers, const SocketAddress& address #endif } while (_blocking && rc < 0 && lastError() == POCO_EINTR); - if (rc < 0) error(); + if (rc < 0) + { + int err = lastError(); + if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK)) + ; + else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT) + throw TimeoutException(err); + else + error(err); + } return rc; } @@ -556,7 +600,10 @@ int SocketImpl::receiveFrom(void* buffer, int length, SocketAddress& address, in int SocketImpl::receiveFrom(void* buffer, int length, struct sockaddr** ppSA, poco_socklen_t** ppSALen, int flags) { - checkBrokenTimeout(SELECT_READ); + if (_blocking) + { + checkBrokenTimeout(SELECT_READ); + } int rc; do { @@ -567,7 +614,7 @@ int SocketImpl::receiveFrom(void* buffer, int length, struct sockaddr** ppSA, po if (rc < 0) { int err = lastError(); - if (err == POCO_EAGAIN && !_blocking) + if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK)) ; else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT) throw TimeoutException(err); @@ -595,7 +642,10 @@ int SocketImpl::receiveFrom(SocketBufVec& buffers, SocketAddress& address, int f int SocketImpl::receiveFrom(SocketBufVec& buffers, struct sockaddr** pSA, poco_socklen_t** ppSALen, int flags) { - checkBrokenTimeout(SELECT_READ); + if (_blocking) + { + checkBrokenTimeout(SELECT_READ); + } int rc = 0; do { @@ -624,7 +674,7 @@ int SocketImpl::receiveFrom(SocketBufVec& buffers, struct sockaddr** pSA, poco_s if (rc < 0) { int err = lastError(); - if (err == POCO_EAGAIN && !_blocking) + if (!_blocking && (err == POCO_EAGAIN || err == POCO_EWOULDBLOCK)) ; else if (err == POCO_EAGAIN || err == POCO_ETIMEDOUT) throw TimeoutException(err); diff --git a/Net/src/StreamSocketImpl.cpp b/Net/src/StreamSocketImpl.cpp index 427c440f0..2174f832d 100644 --- a/Net/src/StreamSocketImpl.cpp +++ b/Net/src/StreamSocketImpl.cpp @@ -61,7 +61,7 @@ int StreamSocketImpl::sendBytes(const void* buffer, int length, int flags) while (remaining > 0) { int n = SocketImpl::sendBytes(p, remaining, flags); - poco_assert_dbg (n >= 0); + poco_assert_dbg (!blocking || n >= 0); p += n; sent += n; remaining -= n; diff --git a/Net/testsuite/src/SocketTest.cpp b/Net/testsuite/src/SocketTest.cpp index 3a232a611..b45bd3dde 100644 --- a/Net/testsuite/src/SocketTest.cpp +++ b/Net/testsuite/src/SocketTest.cpp @@ -67,6 +67,27 @@ void SocketTest::testEcho() } +void SocketTest::testPeek() +{ + EchoServer echoServer; + StreamSocket ss; + ss.connect(SocketAddress("127.0.0.1", echoServer.port())); + int n = ss.sendBytes("hello, world!", 13); + assertTrue (n == 13); + char buffer[256]; + n = ss.receiveBytes(buffer, 5, MSG_PEEK); + assertTrue (n == 5); + assertTrue (std::string(buffer, n) == "hello"); + n = ss.receiveBytes(buffer, sizeof(buffer), MSG_PEEK); + assertTrue (n == 13); + assertTrue (std::string(buffer, n) == "hello, world!"); + n = ss.receiveBytes(buffer, sizeof(buffer)); + assertTrue (n == 13); + assertTrue (std::string(buffer, n) == "hello, world!"); + ss.close(); +} + + void SocketTest::testMoveStreamSocket() { EchoServer echoServer; @@ -679,6 +700,7 @@ CppUnit::Test* SocketTest::suite() CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("SocketTest"); CppUnit_addTest(pSuite, SocketTest, testEcho); + CppUnit_addTest(pSuite, SocketTest, testPeek); CppUnit_addTest(pSuite, SocketTest, testMoveStreamSocket); CppUnit_addTest(pSuite, SocketTest, testPoll); CppUnit_addTest(pSuite, SocketTest, testAvailable); diff --git a/Net/testsuite/src/SocketTest.h b/Net/testsuite/src/SocketTest.h index 8ac15e850..2bfe8526b 100644 --- a/Net/testsuite/src/SocketTest.h +++ b/Net/testsuite/src/SocketTest.h @@ -25,6 +25,7 @@ public: ~SocketTest() override; void testEcho(); + void testPeek(); void testMoveStreamSocket(); void testPoll(); void testAvailable();