diff --git a/tests/recv_multipart.cpp b/tests/recv_multipart.cpp index 1daa867..8e98788 100644 --- a/tests/recv_multipart.cpp +++ b/tests/recv_multipart.cpp @@ -56,4 +56,89 @@ TEST_CASE("recv_multipart test", "[recv_multipart]") CHECK_THROWS_AS(zmq::recv_multipart(zmq::socket_ref(), std::back_inserter(msgs)), const zmq::error_t &); } } + +TEST_CASE("recv_multipart_n test", "[recv_multipart]") +{ + zmq::context_t context(1); + zmq::socket_t output(context, ZMQ_PAIR); + zmq::socket_t input(context, ZMQ_PAIR); + output.bind("inproc://multipart.test"); + input.connect("inproc://multipart.test"); + + SECTION("send 1 message") + { + input.send(zmq::str_buffer("hello")); + + std::array msgs; + auto ret = zmq::recv_multipart_n(output, msgs.data(), msgs.size()); + REQUIRE(ret); + CHECK(*ret == 1); + CHECK(msgs[0].size() == 5); + } + SECTION("send 1 message 2") + { + input.send(zmq::str_buffer("hello")); + + std::array msgs; + auto ret = zmq::recv_multipart_n(output, msgs.data(), msgs.size()); + REQUIRE(ret); + CHECK(*ret == 1); + CHECK(msgs[0].size() == 5); + CHECK(msgs[1].size() == 0); + } + SECTION("send 2 messages, recv 1") + { + input.send(zmq::str_buffer("hello"), zmq::send_flags::sndmore); + input.send(zmq::str_buffer("world!")); + + std::array msgs; + CHECK_THROWS_AS( + zmq::recv_multipart_n(output, msgs.data(), msgs.size()), + const std::runtime_error&); + } + SECTION("recv 0") + { + input.send(zmq::str_buffer("hello"), zmq::send_flags::sndmore); + input.send(zmq::str_buffer("world!")); + + std::array msgs; + CHECK_THROWS_AS( + zmq::recv_multipart_n(output, msgs.data(), 0), + const std::runtime_error&); + } + SECTION("send 2 messages") + { + input.send(zmq::str_buffer("hello"), zmq::send_flags::sndmore); + input.send(zmq::str_buffer("world!")); + + std::array msgs; + auto ret = zmq::recv_multipart_n(output, msgs.data(), msgs.size()); + REQUIRE(ret); + CHECK(*ret == 2); + CHECK(msgs[0].size() == 5); + CHECK(msgs[1].size() == 6); + } + SECTION("send no messages, dontwait") + { + std::array msgs; + auto ret = zmq::recv_multipart_n(output, msgs.data(), msgs.size(), zmq::recv_flags::dontwait); + CHECK_FALSE(ret); + REQUIRE(msgs[0].size() == 0); + } + SECTION("send 1 partial message, dontwait") + { + input.send(zmq::str_buffer("hello"), zmq::send_flags::sndmore); + + std::array msgs; + auto ret = zmq::recv_multipart_n(output, msgs.data(), msgs.size(), zmq::recv_flags::dontwait); + CHECK_FALSE(ret); + REQUIRE(msgs[0].size() == 0); + } + SECTION("recv with invalid socket") + { + std::array msgs; + CHECK_THROWS_AS(zmq::recv_multipart_n(zmq::socket_ref(), msgs.data(), msgs.size()), const zmq::error_t &); + } +} + #endif diff --git a/zmq_addon.hpp b/zmq_addon.hpp index ebe9d10..fcabda8 100644 --- a/zmq_addon.hpp +++ b/zmq_addon.hpp @@ -40,6 +40,41 @@ namespace zmq #ifdef ZMQ_CPP11 +namespace detail +{ +template +recv_result_t recv_multipart_n(socket_ref s, OutputIt out, size_t n, + recv_flags flags) +{ + size_t msg_count = 0; + message_t msg; + while (true) + { + #ifdef ZMQ_CPP17 + if constexpr (CheckN) + #else + if (CheckN) + #endif + { + if (msg_count >= n) + throw std::runtime_error("Too many message parts in recv_multipart_n"); + } + if (!s.recv(msg, flags)) + { + // zmq ensures atomic delivery of messages + assert(msg_count == 0); + return {}; + } + ++msg_count; + const bool more = msg.more(); + *out++ = std::move(msg); + if (!more) + break; + } + return msg_count; +} +} // namespace detail + /* Receive a multipart message. Writes the zmq::message_t objects to OutputIterator out. @@ -57,23 +92,29 @@ ZMQ_NODISCARD recv_result_t recv_multipart(socket_ref s, OutputIt out, recv_flags flags = recv_flags::none) { - size_t msg_count = 0; - message_t msg; - while (true) - { - if (!s.recv(msg, flags)) - { - // zmq ensures atomic delivery of messages - assert(msg_count == 0); - return {}; - } - ++msg_count; - const bool more = msg.more(); - *out++ = std::move(msg); - if (!more) - break; - } - return msg_count; + return detail::recv_multipart_n(s, std::move(out), 0, flags); +} + +/* Receive a multipart message. + + Writes at most n zmq::message_t objects to OutputIterator out. + If the number of message parts of the incoming message exceeds n + then an exception will be thrown. + + Returns: the number of messages received or nullopt (on EAGAIN). + Throws: if recv throws. Throws std::runtime_error if the number + of message parts exceeds n (exactly n messages will have been written + to out). Any exceptions thrown + by the out iterator will be propagated and the message + may have been only partially received with pending + message parts. It is adviced to close this socket in that event. +*/ +template +ZMQ_NODISCARD +recv_result_t recv_multipart_n(socket_ref s, OutputIt out, size_t n, + recv_flags flags = recv_flags::none) +{ + return detail::recv_multipart_n(s, std::move(out), n, flags); } /* Send a multipart message.