diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 5c91ac4..41879aa 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -25,6 +25,8 @@ add_executable( poller.cpp active_poller.cpp multipart.cpp + recv_multipart.cpp + send_multipart.cpp monitor.cpp utilities.cpp ) diff --git a/tests/recv_multipart.cpp b/tests/recv_multipart.cpp new file mode 100644 index 0000000..8f56132 --- /dev/null +++ b/tests/recv_multipart.cpp @@ -0,0 +1,60 @@ +#include +#include + +#ifdef ZMQ_CPP11 + +TEST_CASE("recv_multipart test", "[recv_multipart]") +{ + zmq::context_t context(1); + zmq::socket_t output(context, ZMQ_PAIR); + zmq::socket_t input(context, ZMQ_PAIR); + output.bind("inproc://multipart.test"); + input.connect("inproc://multipart.test"); + + SECTION("send 1 message") + { + input.send(zmq::str_buffer("hello")); + + std::vector msgs; + auto ret = zmq::recv_multipart(output, std::back_inserter(msgs)); + REQUIRE(ret); + CHECK(*ret == 1); + REQUIRE(msgs.size() == 1); + CHECK(msgs[0].size() == 5); + } + SECTION("send 2 messages") + { + input.send(zmq::str_buffer("hello"), zmq::send_flags::sndmore); + input.send(zmq::str_buffer("world!")); + + std::vector msgs; + auto ret = zmq::recv_multipart(output, std::back_inserter(msgs)); + REQUIRE(ret); + CHECK(*ret == 2); + REQUIRE(msgs.size() == 2); + CHECK(msgs[0].size() == 5); + CHECK(msgs[1].size() == 6); + } + SECTION("send no messages, dontwait") + { + std::vector msgs; + auto ret = zmq::recv_multipart(output, std::back_inserter(msgs), zmq::recv_flags::dontwait); + CHECK_FALSE(ret); + REQUIRE(msgs.size() == 0); + } + SECTION("send 1 partial message, dontwait") + { + input.send(zmq::str_buffer("hello"), zmq::send_flags::sndmore); + + std::vector msgs; + auto ret = zmq::recv_multipart(output, std::back_inserter(msgs), zmq::recv_flags::dontwait); + CHECK_FALSE(ret); + REQUIRE(msgs.size() == 0); + } + SECTION("recv with invalid socket") + { + std::vector msgs; + CHECK_THROWS_AS(zmq::recv_multipart(zmq::socket_ref(), std::back_inserter(msgs)), const zmq::error_t &); + } +} +#endif diff --git a/tests/send_multipart.cpp b/tests/send_multipart.cpp new file mode 100644 index 0000000..e3253cb --- /dev/null +++ b/tests/send_multipart.cpp @@ -0,0 +1,125 @@ +#include +#include +#include + +#ifdef ZMQ_CPP11 + +TEST_CASE("send_multipart test", "[send_multipart]") +{ + zmq::context_t context(1); + zmq::socket_t output(context, ZMQ_PAIR); + zmq::socket_t input(context, ZMQ_PAIR); + output.bind("inproc://multipart.test"); + input.connect("inproc://multipart.test"); + + SECTION("send 0 messages") + { + std::vector imsgs; + auto iret = zmq::send_multipart(input, imsgs); + REQUIRE(iret); + CHECK(*iret == 0); + } + SECTION("send 1 message") + { + std::array imsgs = {zmq::message_t(3)}; + auto iret = zmq::send_multipart(input, imsgs); + REQUIRE(iret); + CHECK(*iret == 1); + + std::vector omsgs; + auto oret = zmq::recv_multipart(output, std::back_inserter(omsgs)); + REQUIRE(oret); + CHECK(*oret == 1); + REQUIRE(omsgs.size() == 1); + CHECK(omsgs[0].size() == 3); + } + SECTION("send 2 messages") + { + std::array imsgs = {zmq::message_t(3), zmq::message_t(4)}; + auto iret = zmq::send_multipart(input, imsgs); + REQUIRE(iret); + CHECK(*iret == 2); + + std::vector omsgs; + auto oret = zmq::recv_multipart(output, std::back_inserter(omsgs)); + REQUIRE(oret); + CHECK(*oret == 2); + REQUIRE(omsgs.size() == 2); + CHECK(omsgs[0].size() == 3); + CHECK(omsgs[1].size() == 4); + } + SECTION("send 2 messages, const_buffer") + { + std::array imsgs = {zmq::str_buffer("foo"), zmq::str_buffer("bar!")}; + auto iret = zmq::send_multipart(input, imsgs); + REQUIRE(iret); + CHECK(*iret == 2); + + std::vector omsgs; + auto oret = zmq::recv_multipart(output, std::back_inserter(omsgs)); + REQUIRE(oret); + CHECK(*oret == 2); + REQUIRE(omsgs.size() == 2); + CHECK(omsgs[0].size() == 3); + CHECK(omsgs[1].size() == 4); + } + SECTION("send 2 messages, mutable_buffer") + { + char buf[4] = {}; + std::array imsgs = {zmq::buffer(buf, 3), zmq::buffer(buf)}; + auto iret = zmq::send_multipart(input, imsgs); + REQUIRE(iret); + CHECK(*iret == 2); + + std::vector omsgs; + auto oret = zmq::recv_multipart(output, std::back_inserter(omsgs)); + REQUIRE(oret); + CHECK(*oret == 2); + REQUIRE(omsgs.size() == 2); + CHECK(omsgs[0].size() == 3); + CHECK(omsgs[1].size() == 4); + } + SECTION("send 2 messages, dontwait") + { + zmq::socket_t push(context, ZMQ_PUSH); + push.bind("inproc://multipart.test.push"); + + std::array imsgs = {zmq::message_t(3), zmq::message_t(4)}; + auto iret = zmq::send_multipart(push, imsgs, zmq::send_flags::dontwait); + REQUIRE_FALSE(iret); + } + // TODO send with EAGAIN + SECTION("send, misc. containers") + { + std::vector msgs_vec; + msgs_vec.emplace_back(3); + msgs_vec.emplace_back(4); + auto iret = zmq::send_multipart(input, msgs_vec); + REQUIRE(iret); + CHECK(*iret == 2); + + std::forward_list msgs_list; + msgs_list.emplace_front(4); + msgs_list.emplace_front(3); + iret = zmq::send_multipart(input, msgs_list); + REQUIRE(iret); + CHECK(*iret == 2); + + // init. list + const auto msgs_il = {zmq::str_buffer("foo"), zmq::str_buffer("bar!")}; + iret = zmq::send_multipart(input, msgs_il); + REQUIRE(iret); + CHECK(*iret == 2); + // rvalue + iret = zmq::send_multipart(input, + std::initializer_list{zmq::str_buffer("foo"), zmq::str_buffer("bar!")}); + REQUIRE(iret); + CHECK(*iret == 2); + } + SECTION("send with invalid socket") + { + std::vector msgs(1); + CHECK_THROWS_AS(zmq::send_multipart(zmq::socket_ref(), msgs), const zmq::error_t &); + } +} +#endif diff --git a/zmq.hpp b/zmq.hpp index f4a3d70..6950238 100644 --- a/zmq.hpp +++ b/zmq.hpp @@ -971,6 +971,15 @@ inline const_buffer buffer(const const_buffer& cb, size_t n) noexcept namespace detail { + +template +struct is_buffer +{ + static constexpr bool value = + std::is_same::value || + std::is_same::value; +}; + template struct is_pod_like { // NOTE: The networking draft N4771 section 16.11 requires diff --git a/zmq_addon.hpp b/zmq_addon.hpp index b265161..9821a83 100644 --- a/zmq_addon.hpp +++ b/zmq_addon.hpp @@ -37,6 +37,78 @@ namespace zmq { + +#ifdef ZMQ_CPP11 + +/* Receive a multipart message. + + Writes the zmq::message_t objects to OutputIterator out. + The out iterator must handle an unspecified amount of write, + e.g. using std::back_inserter. + + Returns: the number of messages received or nullopt (on EAGAIN). + Throws: if recv throws. +*/ +template +ZMQ_NODISCARD detail::recv_result_t recv_multipart(socket_ref s, OutputIt out, + recv_flags flags = recv_flags::none) +{ + size_t msg_count = 0; + message_t msg; + while (true) + { + if (!s.recv(msg, flags)) + { + // zmq ensures atomic delivery of messages + assert(msg_count == 0); + return {}; + } + ++msg_count; + const bool more = msg.more(); + *out++ = std::move(msg); + if (!more) + break; + } + return msg_count; +} + +/* Send a multipart message. + + The range must be a ForwardRange of zmq::message_t, + zmq::const_buffer or zmq::mutable_buffer. + The flags may be zmq::send_flags::sndmore if there are + more message parts to be sent after the call to this function. + + Returns: the number of messages sent or nullopt (on EAGAIN). + Throws: if send throws. +*/ +template::value + && (std::is_same, message_t>::value + || detail::is_buffer>::value) + >::type> +detail::send_result_t send_multipart(socket_ref s, Range&& msgs, + send_flags flags = send_flags::none) +{ + auto it = msgs.begin(); + auto last = msgs.end(); + const size_t msg_count = static_cast(std::distance(it, last)); + for (; it != last; ++it) + { + const auto mf = flags | (std::next(it) == last ? send_flags::none : send_flags::sndmore); + if (!s.send(*it, mf)) + { + // zmq ensures atomic delivery of messages + assert(it == msgs.begin()); + return {}; + } + } + return msg_count; +} + +#endif + #ifdef ZMQ_HAS_RVALUE_REFS /*