From 3d4be814e8a6cb44e39ddc7c272bdcc1ffad296f Mon Sep 17 00:00:00 2001 From: Gudmundur Adalsteinsson Date: Mon, 22 Apr 2019 21:15:51 +0000 Subject: [PATCH] Problem: send/recv functions lack type-safety Solution: Add functions taking buffers and enum class flags --- README.md | 17 +- tests/CMakeLists.txt | 1 + tests/active_poller.cpp | 22 +-- tests/buffer.cpp | 253 ++++++++++++++++++++++++ tests/poller.cpp | 12 +- tests/socket.cpp | 149 +++++++++++++- zmq.hpp | 425 +++++++++++++++++++++++++++++++++++++++- zmq_addon.hpp | 11 ++ 8 files changed, 866 insertions(+), 24 deletions(-) create mode 100644 tests/buffer.cpp diff --git a/README.md b/README.md index 9d3c253..0489c65 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,21 @@ Supported platforms - Any platform supported by libzmq that provides a sufficiently recent gcc (4.8.1 or newer) or clang (3.3 or newer) - Visual Studio 2012+ x86/x64 +Examples +======== +```c++ +#include +#include +int main() +{ + zmq::context_t ctx; + zmq::socket_t sock(ctx, zmq::socket_type::push); + sock.bind("inproc://test"); + const std::string_view m = "Hello, world"; + sock.send(zmq::buffer(m), zmq::send_flags::dontwait); +} +``` + Contribution policy =================== @@ -74,5 +89,3 @@ cpp zmq (which will also include libzmq for you). find_package(cppzmq) target_link_libraries(*Your Project Name* cppzmq) ``` - - diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 402c0d2..59c20a9 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -17,6 +17,7 @@ find_package(Threads) add_executable( unit_tests + buffer.cpp message.cpp context.cpp socket.cpp diff --git a/tests/active_poller.cpp b/tests/active_poller.cpp index 866d7bb..f7db4a2 100644 --- a/tests/active_poller.cpp +++ b/tests/active_poller.cpp @@ -157,7 +157,7 @@ TEST_CASE("poll basic", "[active_poller]") { server_client_setup s; - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); zmq::active_poller_t active_poller; bool message_received = false; @@ -184,7 +184,7 @@ TEST_CASE("client server", "[active_poller]") zmq::active_poller_t::handler_t handler = [&](short e) { if (0 != (e & ZMQ_POLLIN)) { zmq::message_t zmq_msg; - CHECK_NOTHROW(s.server.recv(&zmq_msg)); // get message + CHECK_NOTHROW(s.server.recv(zmq_msg)); // get message std::string recv_msg(zmq_msg.data(), zmq_msg.size()); CHECK(send_msg == recv_msg); } else if (0 != (e & ~ZMQ_POLLOUT)) { @@ -197,7 +197,7 @@ TEST_CASE("client server", "[active_poller]") CHECK_NOTHROW(active_poller.add(s.server, ZMQ_POLLIN, handler)); // client sends message - CHECK_NOTHROW(s.client.send(zmq::message_t{send_msg})); + CHECK_NOTHROW(s.client.send(zmq::message_t{send_msg}, zmq::send_flags::none)); CHECK(1 == active_poller.wait(std::chrono::milliseconds{-1})); CHECK(events == ZMQ_POLLIN); @@ -236,7 +236,7 @@ TEST_CASE("remove invalid socket throws", "[active_poller]") TEST_CASE("wait on added empty handler", "[active_poller]") { server_client_setup s; - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); zmq::active_poller_t active_poller; zmq::active_poller_t::handler_t handler; CHECK_NOTHROW(active_poller.add(s.server, ZMQ_POLLIN, handler)); @@ -291,7 +291,7 @@ TEST_CASE("poll client server", "[active_poller]") CHECK_NOTHROW(active_poller.add(s.server, ZMQ_POLLIN, s.handler)); // client sends message - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); // wait for message and verify events CHECK_NOTHROW(active_poller.wait(std::chrono::milliseconds{500})); @@ -316,7 +316,7 @@ TEST_CASE("wait one return", "[active_poller]") active_poller.add(s.server, ZMQ_POLLIN, [&count](short) { ++count; })); // client sends message - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); // wait for message and verify events CHECK(1 == active_poller.wait(std::chrono::milliseconds{500})); @@ -326,7 +326,7 @@ TEST_CASE("wait one return", "[active_poller]") TEST_CASE("wait on move constructed active_poller", "[active_poller]") { server_client_setup s; - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); zmq::active_poller_t a; zmq::active_poller_t::handler_t handler; CHECK_NOTHROW(a.add(s.server, ZMQ_POLLIN, handler)); @@ -340,7 +340,7 @@ TEST_CASE("wait on move constructed active_poller", "[active_poller]") TEST_CASE("wait on move assigned active_poller", "[active_poller]") { server_client_setup s; - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); zmq::active_poller_t a; zmq::active_poller_t::handler_t handler; CHECK_NOTHROW(a.add(s.server, ZMQ_POLLIN, handler)); @@ -361,14 +361,14 @@ TEST_CASE("received on move constructed active_poller", "[active_poller]") zmq::active_poller_t a; CHECK_NOTHROW(a.add(s.server, ZMQ_POLLIN, [&count](short) { ++count; })); // client sends message - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); // wait for message and verify it is received CHECK(1 == a.wait(std::chrono::milliseconds{500})); CHECK(1u == count); // Move construct active_poller b zmq::active_poller_t b{std::move(a)}; // client sends message again - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); // wait for message and verify it is received CHECK(1 == b.wait(std::chrono::milliseconds{500})); CHECK(2u == count); @@ -399,7 +399,7 @@ TEST_CASE("remove from handler", "[active_poller]") CHECK(ITER_NO == active_poller.size()); // Clients send messages for (auto &s : setup_list) { - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); } // Wait for all servers to receive a message diff --git a/tests/buffer.cpp b/tests/buffer.cpp new file mode 100644 index 0000000..fce558a --- /dev/null +++ b/tests/buffer.cpp @@ -0,0 +1,253 @@ +#include +#include + +#ifdef ZMQ_CPP17 +static_assert(std::is_nothrow_swappable_v); +static_assert(std::is_nothrow_swappable_v); +static_assert(std::is_trivially_copyable_v); +static_assert(std::is_trivially_copyable_v); +#endif + +#ifdef ZMQ_CPP11 + +using BT = int16_t; + +TEST_CASE("buffer default ctor", "[buffer]") +{ + zmq::mutable_buffer mb; + zmq::const_buffer cb; + CHECK(mb.size() == 0); + CHECK(mb.data() == nullptr); + CHECK(cb.size() == 0); + CHECK(cb.data() == nullptr); +} + +TEST_CASE("buffer data ctor", "[buffer]") +{ + std::vector v(10); + zmq::const_buffer cb(v.data(), v.size() * sizeof(BT)); + CHECK(cb.size() == v.size() * sizeof(BT)); + CHECK(cb.data() == v.data()); + zmq::mutable_buffer mb(v.data(), v.size() * sizeof(BT)); + CHECK(mb.size() == v.size() * sizeof(BT)); + CHECK(mb.data() == v.data()); + zmq::const_buffer from_mut = mb; + CHECK(mb.size() == from_mut.size()); + CHECK(mb.data() == from_mut.data()); + const auto cmb = mb; + static_assert(std::is_same::value, ""); +} + +TEST_CASE("const_buffer operator+", "[buffer]") +{ + std::vector v(10); + zmq::const_buffer cb(v.data(), v.size() * sizeof(BT)); + const size_t shift = 4; + auto shifted = cb + shift; + CHECK(shifted.size() == v.size() * sizeof(BT) - shift); + CHECK(shifted.data() == v.data() + shift / sizeof(BT)); + auto shifted2 = shift + cb; + CHECK(shifted.size() == shifted2.size()); + CHECK(shifted.data() == shifted2.data()); + auto cbinp = cb; + cbinp += shift; + CHECK(shifted.size() == cbinp.size()); + CHECK(shifted.data() == cbinp.data()); +} + +TEST_CASE("mutable_buffer operator+", "[buffer]") +{ + std::vector v(10); + zmq::mutable_buffer mb(v.data(), v.size() * sizeof(BT)); + const size_t shift = 4; + auto shifted = mb + shift; + CHECK(shifted.size() == v.size() * sizeof(BT) - shift); + CHECK(shifted.data() == v.data() + shift / sizeof(BT)); + auto shifted2 = shift + mb; + CHECK(shifted.size() == shifted2.size()); + CHECK(shifted.data() == shifted2.data()); + auto mbinp = mb; + mbinp += shift; + CHECK(shifted.size() == mbinp.size()); + CHECK(shifted.data() == mbinp.data()); +} + +TEST_CASE("mutable_buffer creation basic", "[buffer]") +{ + std::vector v(10); + zmq::mutable_buffer mb(v.data(), v.size() * sizeof(BT)); + zmq::mutable_buffer mb2 = zmq::buffer(v.data(), v.size() * sizeof(BT)); + CHECK(mb.data() == mb2.data()); + CHECK(mb.size() == mb2.size()); + zmq::mutable_buffer mb3 = zmq::buffer(mb); + CHECK(mb.data() == mb3.data()); + CHECK(mb.size() == mb3.size()); + zmq::mutable_buffer mb4 = zmq::buffer(mb, 10 * v.size() * sizeof(BT)); + CHECK(mb.data() == mb4.data()); + CHECK(mb.size() == mb4.size()); + zmq::mutable_buffer mb5 = zmq::buffer(mb, 4); + CHECK(mb.data() == mb5.data()); + CHECK(4 == mb5.size()); +} + +TEST_CASE("const_buffer creation basic", "[buffer]") +{ + const std::vector v(10); + zmq::const_buffer cb(v.data(), v.size() * sizeof(BT)); + zmq::const_buffer cb2 = zmq::buffer(v.data(), v.size() * sizeof(BT)); + CHECK(cb.data() == cb2.data()); + CHECK(cb.size() == cb2.size()); + zmq::const_buffer cb3 = zmq::buffer(cb); + CHECK(cb.data() == cb3.data()); + CHECK(cb.size() == cb3.size()); + zmq::const_buffer cb4 = zmq::buffer(cb, 10 * v.size() * sizeof(BT)); + CHECK(cb.data() == cb4.data()); + CHECK(cb.size() == cb4.size()); + zmq::const_buffer cb5 = zmq::buffer(cb, 4); + CHECK(cb.data() == cb5.data()); + CHECK(4 == cb5.size()); +} + +TEST_CASE("mutable_buffer creation C array", "[buffer]") +{ + BT d[10] = {}; + zmq::mutable_buffer b = zmq::buffer(d); + CHECK(b.size() == 10 * sizeof(BT)); + CHECK(b.data() == static_cast(d)); + zmq::const_buffer b2 = zmq::buffer(d, 4); + CHECK(b2.size() == 4); + CHECK(b2.data() == static_cast(d)); +} + +TEST_CASE("const_buffer creation C array", "[buffer]") +{ + const BT d[10] = {}; + zmq::const_buffer b = zmq::buffer(d); + CHECK(b.size() == 10 * sizeof(BT)); + CHECK(b.data() == static_cast(d)); + zmq::const_buffer b2 = zmq::buffer(d, 4); + CHECK(b2.size() == 4); + CHECK(b2.data() == static_cast(d)); +} + +TEST_CASE("mutable_buffer creation array", "[buffer]") +{ + std::array d = {}; + zmq::mutable_buffer b = zmq::buffer(d); + CHECK(b.size() == d.size() * sizeof(BT)); + CHECK(b.data() == d.data()); + zmq::mutable_buffer b2 = zmq::buffer(d, 4); + CHECK(b2.size() == 4); + CHECK(b2.data() == d.data()); +} + +TEST_CASE("const_buffer creation array", "[buffer]") +{ + const std::array d = {}; + zmq::const_buffer b = zmq::buffer(d); + CHECK(b.size() == d.size() * sizeof(BT)); + CHECK(b.data() == d.data()); + zmq::const_buffer b2 = zmq::buffer(d, 4); + CHECK(b2.size() == 4); + CHECK(b2.data() == d.data()); +} + +TEST_CASE("const_buffer creation array 2", "[buffer]") +{ + std::array d = {{}}; + zmq::const_buffer b = zmq::buffer(d); + CHECK(b.size() == d.size() * sizeof(BT)); + CHECK(b.data() == d.data()); + zmq::const_buffer b2 = zmq::buffer(d, 4); + CHECK(b2.size() == 4); + CHECK(b2.data() == d.data()); +} + +TEST_CASE("mutable_buffer creation vector", "[buffer]") +{ + std::vector d(10); + zmq::mutable_buffer b = zmq::buffer(d); + CHECK(b.size() == d.size() * sizeof(BT)); + CHECK(b.data() == d.data()); + zmq::mutable_buffer b2 = zmq::buffer(d, 4); + CHECK(b2.size() == 4); + CHECK(b2.data() == d.data()); + d.clear(); + b = zmq::buffer(d); + CHECK(b.size() == 0); + CHECK(b.data() == nullptr); +} + +TEST_CASE("const_buffer creation vector", "[buffer]") +{ + std::vector d(10); + zmq::const_buffer b = zmq::buffer(static_cast&>(d)); + CHECK(b.size() == d.size() * sizeof(BT)); + CHECK(b.data() == d.data()); + zmq::const_buffer b2 = zmq::buffer(static_cast&>(d), 4); + CHECK(b2.size() == 4); + CHECK(b2.data() == d.data()); + d.clear(); + b = zmq::buffer(static_cast&>(d)); + CHECK(b.size() == 0); + CHECK(b.data() == nullptr); +} + +TEST_CASE("const_buffer creation string", "[buffer]") +{ + const std::wstring d(10, L'a'); + zmq::const_buffer b = zmq::buffer(d); + CHECK(b.size() == d.size() * sizeof(wchar_t)); + CHECK(b.data() == d.data()); + zmq::const_buffer b2 = zmq::buffer(d, 4); + CHECK(b2.size() == 4); + CHECK(b2.data() == d.data()); +} + +TEST_CASE("mutable_buffer creation string", "[buffer]") +{ + std::wstring d(10, L'a'); + zmq::mutable_buffer b = zmq::buffer(d); + CHECK(b.size() == d.size() * sizeof(wchar_t)); + CHECK(b.data() == d.data()); + zmq::mutable_buffer b2 = zmq::buffer(d, 4); + CHECK(b2.size() == 4); + CHECK(b2.data() == d.data()); +} + +#ifdef ZMQ_CPP17 +TEST_CASE("const_buffer creation string_view", "[buffer]") +{ + std::wstring dstr(10, L'a'); + std::wstring_view d = dstr; + zmq::const_buffer b = zmq::buffer(d); + CHECK(b.size() == d.size() * sizeof(wchar_t)); + CHECK(b.data() == d.data()); + zmq::const_buffer b2 = zmq::buffer(d, 4); + CHECK(b2.size() == 4); + CHECK(b2.data() == d.data()); +} +#endif + +TEST_CASE("buffer of structs", "[buffer]") +{ + struct some_pod + { + int64_t val; + char arr[8]; + }; + struct some_non_pod + { + int64_t val; + char arr[8]; + std::vector s; // not trivially copyable + }; + static_assert(zmq::detail::is_pod_like::value, ""); + static_assert(!zmq::detail::is_pod_like::value, ""); + std::array d; + zmq::mutable_buffer b = zmq::buffer(d); + CHECK(b.size() == d.size() * sizeof(some_pod)); + CHECK(b.data() == d.data()); +} + +#endif diff --git a/tests/poller.cpp b/tests/poller.cpp index dfb8a47..dfe422b 100644 --- a/tests/poller.cpp +++ b/tests/poller.cpp @@ -142,7 +142,7 @@ TEST_CASE("poller poll basic", "[poller]") { common_server_client_setup s; - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); zmq::poller_t poller; std::vector events{1}; @@ -220,7 +220,7 @@ TEST_CASE("poller poll client server", "[poller]") CHECK_NOTHROW(poller.add(s.server, ZMQ_POLLIN, s.server)); // client sends message - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); // wait for message and verify events std::vector events(1); @@ -243,7 +243,7 @@ TEST_CASE("poller wait one return", "[poller]") CHECK_NOTHROW(poller.add(s.server, ZMQ_POLLIN, nullptr)); // client sends message - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); // wait for message and verify events std::vector events(1); @@ -253,7 +253,7 @@ TEST_CASE("poller wait one return", "[poller]") TEST_CASE("poller wait on move constructed", "[poller]") { common_server_client_setup s; - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); zmq::poller_t<> a; CHECK_NOTHROW(a.add(s.server, ZMQ_POLLIN, nullptr)); zmq::poller_t<> b{std::move(a)}; @@ -266,7 +266,7 @@ TEST_CASE("poller wait on move constructed", "[poller]") TEST_CASE("poller wait on move assigned", "[poller]") { common_server_client_setup s; - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); zmq::poller_t<> a; CHECK_NOTHROW(a.add(s.server, ZMQ_POLLIN, nullptr)); zmq::poller_t<> b; @@ -293,7 +293,7 @@ TEST_CASE("poller remove from handler", "[poller]") } // Clients send messages for (auto &s : setup_list) { - CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"})); + CHECK_NOTHROW(s.client.send(zmq::message_t{"Hi"}, zmq::send_flags::none)); } // Wait for all servers to receive a message diff --git a/tests/socket.cpp b/tests/socket.cpp index 4c9cbd1..ed00739 100644 --- a/tests/socket.cpp +++ b/tests/socket.cpp @@ -46,6 +46,14 @@ TEST_CASE("socket swap", "[socket]") using std::swap; swap(socket1, socket2); } +TEST_CASE("rass", "[socket]") +{ +zmq::context_t ctx; + zmq::socket_t sock(ctx, zmq::socket_type::push); + sock.bind("inproc://test"); + const std::string m = "Hello, world"; + sock.send(zmq::buffer(m), zmq::send_flags::dontwait); +} #endif TEST_CASE("socket sends and receives const buffer", "[socket]") @@ -55,13 +63,150 @@ TEST_CASE("socket sends and receives const buffer", "[socket]") zmq::socket_t receiver(context, ZMQ_PAIR); receiver.bind("inproc://test"); sender.connect("inproc://test"); - CHECK(2 == sender.send("Hi", 2)); + const char* str = "Hi"; + + #ifdef ZMQ_CPP11 + CHECK(2 == sender.send(zmq::buffer(str, 2)).size); + char buf[2]; + const auto res = receiver.recv(zmq::buffer(buf)); + CHECK(!res.truncated()); + CHECK(2 == res.size); + #else + CHECK(2 == sender.send(str, 2)); char buf[2]; CHECK(2 == receiver.recv(buf, 2)); - CHECK(0 == memcmp(buf, "Hi", 2)); + #endif + CHECK(0 == memcmp(buf, str, 2)); } #ifdef ZMQ_CPP11 + +TEST_CASE("socket send none sndmore", "[socket]") +{ + zmq::context_t context; + zmq::socket_t s(context, zmq::socket_type::router); + s.bind("inproc://test"); + + std::vector buf(4); + auto res = s.send(zmq::buffer(buf), zmq::send_flags::sndmore); + CHECK(res.size == buf.size()); + CHECK(res.success); + res = s.send(zmq::buffer(buf)); + CHECK(res.size == buf.size()); + CHECK(res.success); +} + +TEST_CASE("socket send dontwait", "[socket]") +{ + zmq::context_t context; + zmq::socket_t s(context, zmq::socket_type::push); + s.bind("inproc://test"); + + std::vector buf(4); + auto res = s.send(zmq::buffer(buf), zmq::send_flags::dontwait); + CHECK(!res.success); + CHECK(res.size == 0); + res = s.send(zmq::buffer(buf), + zmq::send_flags::dontwait | zmq::send_flags::sndmore); + CHECK(!res.success); + CHECK(res.size == 0); + + zmq::message_t msg; + auto resm = s.send(msg, zmq::send_flags::dontwait); + CHECK(!resm.success); + CHECK(resm.size == 0); + CHECK(msg.size() == 0); +} + +TEST_CASE("socket send exception", "[socket]") +{ + zmq::context_t context; + zmq::socket_t s(context, zmq::socket_type::pull); + s.bind("inproc://test"); + + std::vector buf(4); + CHECK_THROWS_AS(s.send(zmq::buffer(buf)), const zmq::error_t &); +} + +TEST_CASE("socket recv none", "[socket]") +{ + zmq::context_t context; + zmq::socket_t s(context, zmq::socket_type::pair); + zmq::socket_t s2(context, zmq::socket_type::pair); + s2.bind("inproc://test"); + s.connect("inproc://test"); + + std::vector sbuf(4); + const auto res_send = s2.send(zmq::buffer(sbuf)); + CHECK(res_send.success); + + std::vector buf(2); + const auto res = s.recv(zmq::buffer(buf)); + CHECK(res.success); + CHECK(res.truncated()); + CHECK(res.untruncated_size == sbuf.size()); + CHECK(res.size == buf.size()); + + const auto res_send2 = s2.send(zmq::buffer(sbuf)); + CHECK(res_send2.success); + std::vector buf2(10); + const auto res2 = s.recv(zmq::buffer(buf2)); + CHECK(res2.success); + CHECK(!res2.truncated()); + CHECK(res2.untruncated_size == sbuf.size()); + CHECK(res2.size == sbuf.size()); +} + +TEST_CASE("socket send recv message_t", "[socket]") +{ + zmq::context_t context; + zmq::socket_t s(context, zmq::socket_type::pair); + zmq::socket_t s2(context, zmq::socket_type::pair); + s2.bind("inproc://test"); + s.connect("inproc://test"); + + zmq::message_t smsg(size_t{10}); + const auto res_send = s2.send(smsg, zmq::send_flags::none); + CHECK(res_send.success); + CHECK(res_send.size == 10); + CHECK(smsg.size() == 0); + + zmq::message_t rmsg; + const auto res = s.recv(rmsg); + CHECK(res.success); + CHECK(res.size == 10); + CHECK(rmsg.size() == res.size); +} + +TEST_CASE("socket recv dontwait", "[socket]") +{ + zmq::context_t context; + zmq::socket_t s(context, zmq::socket_type::pull); + s.bind("inproc://test"); + + std::vector buf(4); + constexpr auto flags = zmq::recv_flags::none | zmq::recv_flags::dontwait; + auto res = s.recv(zmq::buffer(buf), flags); + CHECK(!res.success); + CHECK(res.size == 0); + + zmq::message_t msg; + auto resm = s.recv(msg, flags); + CHECK(!resm.success); + CHECK(resm.size == 0); + CHECK(msg.size() == 0); +} + +TEST_CASE("socket recv exception", "[socket]") +{ + zmq::context_t context; + zmq::socket_t s(context, zmq::socket_type::push); + s.bind("inproc://test"); + + std::vector buf(4); + CHECK_THROWS_AS(s.recv(zmq::buffer(buf)), const zmq::error_t &); +} + TEST_CASE("socket proxy", "[socket]") { zmq::context_t context; diff --git a/zmq.hpp b/zmq.hpp index 3d27975..d6c6170 100644 --- a/zmq.hpp +++ b/zmq.hpp @@ -72,6 +72,9 @@ #include #include +#ifdef ZMQ_CPP11 +#include +#endif #include #include #include @@ -277,6 +280,8 @@ class message_t } #if defined(ZMQ_BUILD_DRAFT_API) && defined(ZMQ_CPP11) + // this function is too greedy, must add + // SFINAE for begin and end support. template explicit message_t(const T &msg_) : message_t(std::begin(msg_), std::end(msg_)) { @@ -615,8 +620,340 @@ inline void swap(context_t &a, context_t &b) ZMQ_NOTHROW { a.swap(b); } +#ifdef ZMQ_CPP11 +struct send_result +{ + size_t size; // message size in bytes + bool success; +}; + +struct recv_result +{ + size_t size; // message size in bytes + bool success; +}; + +struct recv_buffer_result +{ + size_t size; // number of bytes written to buffer + size_t untruncated_size; // untruncated message size in bytes + bool success; + + ZMQ_NODISCARD bool truncated() const noexcept + { + return size != untruncated_size; + } +}; + +enum class send_flags : int +{ + none = 0, + dontwait = ZMQ_DONTWAIT, + sndmore = ZMQ_SNDMORE +}; + +constexpr send_flags operator|(send_flags a, send_flags b) noexcept +{ + return static_cast(static_cast(a) | static_cast(b)); +} + +enum class recv_flags : int +{ + none = 0, + dontwait = ZMQ_DONTWAIT +}; + +constexpr recv_flags operator|(recv_flags a, recv_flags b) noexcept +{ + return static_cast(static_cast(a) | static_cast(b)); +} + +// mutable_buffer, const_buffer and buffer are based on +// the Networking TS specification, draft: +// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2018/n4771.pdf + +class mutable_buffer +{ + public: + mutable_buffer() noexcept : _data(nullptr), _size(0) {} + mutable_buffer(void *p, size_t n) noexcept : _data(p), _size(n) + { + assert(p != nullptr || n == 0); + } + + void *data() const noexcept { return _data; } + size_t size() const noexcept { return _size; } + mutable_buffer &operator+=(size_t n) noexcept + { + // (std::min) is a workaround for when a min macro is defined + const auto shift = (std::min)(n, _size); + _data = static_cast(_data) + shift; + _size -= shift; + return *this; + } + + private: + void *_data; + size_t _size; +}; + +inline mutable_buffer operator+(const mutable_buffer &mb, size_t n) noexcept +{ + return mutable_buffer(static_cast(mb.data()) + (std::min)(n, mb.size()), + mb.size() - (std::min)(n, mb.size())); +} +inline mutable_buffer operator+(size_t n, const mutable_buffer &mb) noexcept +{ + return mb + n; +} + +class const_buffer +{ + public: + const_buffer() noexcept : _data(nullptr), _size(0) {} + const_buffer(const void *p, size_t n) noexcept : _data(p), _size(n) {} + const_buffer(const mutable_buffer &mb) noexcept : + _data(mb.data()), + _size(mb.size()) + { + } + + const void *data() const noexcept { return _data; } + size_t size() const noexcept { return _size; } + const_buffer &operator+=(size_t n) noexcept + { + const auto shift = (std::min)(n, _size); + _data = static_cast(_data) + shift; + _size -= shift; + return *this; + } + + private: + const void *_data; + size_t _size; +}; + +inline const_buffer operator+(const const_buffer &cb, size_t n) noexcept +{ + return const_buffer(static_cast(cb.data()) + + (std::min)(n, cb.size()), + cb.size() - (std::min)(n, cb.size())); +} +inline const_buffer operator+(size_t n, const const_buffer &cb) noexcept +{ + return cb + n; +} + +// buffer creation + +inline mutable_buffer buffer(void* p, size_t n) noexcept +{ + return mutable_buffer(p, n); +} +inline const_buffer buffer(const void* p, size_t n) noexcept +{ + return const_buffer(p, n); +} +inline mutable_buffer buffer(const mutable_buffer& mb) noexcept +{ + return mb; +} +inline mutable_buffer buffer(const mutable_buffer& mb, size_t n) noexcept +{ + return mutable_buffer(mb.data(), (std::min)(mb.size(), n)); +} +inline const_buffer buffer(const const_buffer& cb) noexcept +{ + return cb; +} +inline const_buffer buffer(const const_buffer& cb, size_t n) noexcept +{ + return const_buffer(cb.data(), (std::min)(cb.size(), n)); +} + namespace detail { +// utility functions for containers with data and size +// data is nullptr if the container is empty +template mutable_buffer buffar_mut_ds(T &data) noexcept +{ + return mutable_buffer(data.size() != 0u ? data.data() : nullptr, + data.size() * sizeof(*data.data())); +} +template mutable_buffer buffar_mut_ds(T &data, size_t n_bytes) noexcept +{ + return mutable_buffer(data.size() != 0u ? data.data() : nullptr, + (std::min)(data.size() * sizeof(*data.data()), n_bytes)); +} +template const_buffer buffar_const_ds(const T &data) noexcept +{ + return const_buffer(data.size() != 0u ? data.data() : nullptr, + data.size() * sizeof(*data.data())); +} +template const_buffer buffar_const_ds(const T &data, size_t n_bytes) noexcept +{ + return const_buffer(data.size() != 0u ? data.data() : nullptr, + (std::min)(data.size() * sizeof(*data.data()), n_bytes)); +} +template struct is_pod_like +{ + // NOTE: The networking draft N4771 section 16.11 requires + // T in the buffer functions below to be + // trivially copyable OR standard layout. + // Here we decide to be conservative and require both. + static constexpr bool value = + std::is_trivially_copyable::value && std::is_standard_layout::value; +}; +} // namespace detail + +// C array +template mutable_buffer buffer(T (&data)[N]) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + static_assert(N > 0, "N > 0"); + return mutable_buffer(static_cast(data), N * sizeof(T)); +} +template +mutable_buffer buffer(T (&data)[N], size_t n_bytes) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + static_assert(N > 0, "N > 0"); + return mutable_buffer(static_cast(data), (std::min)(N * sizeof(T), n_bytes)); +} +template const_buffer buffer(const T (&data)[N]) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + static_assert(N > 0, "N > 0"); + return const_buffer(static_cast(data), N * sizeof(T)); +} +template +const_buffer buffer(const T (&data)[N], size_t n_bytes) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + static_assert(N > 0, "N > 0"); + return const_buffer(static_cast(data), + (std::min)(N * sizeof(T), n_bytes)); +} +// std::array +template mutable_buffer buffer(std::array &data) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + static_assert(N > 0, "N > 0"); + return mutable_buffer(data.data(), N * sizeof(T)); +} +template +mutable_buffer buffer(std::array &data, size_t n_bytes) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + static_assert(N > 0, "N > 0"); + return mutable_buffer(data.data(), (std::min)(N * sizeof(T), n_bytes)); +} +template +const_buffer buffer(std::array &data) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + static_assert(N > 0, "N > 0"); + return const_buffer(data.data(), N * sizeof(T)); +} +template +const_buffer buffer(std::array &data, size_t n_bytes) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + static_assert(N > 0, "N > 0"); + return const_buffer(data.data(), (std::min)(N * sizeof(T), n_bytes)); +} +template +const_buffer buffer(const std::array &data) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + static_assert(N > 0, "N > 0"); + return const_buffer(data.data(), N * sizeof(T)); +} +template +const_buffer buffer(const std::array &data, size_t n_bytes) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + static_assert(N > 0, "N > 0"); + return const_buffer(data.data(), (std::min)(N * sizeof(T), n_bytes)); +} +// std::vector +template +mutable_buffer buffer(std::vector &data) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + return detail::buffar_mut_ds(data); +} +template +mutable_buffer buffer(std::vector &data, size_t n_bytes) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + return detail::buffar_mut_ds(data, n_bytes); +} +template +const_buffer buffer(const std::vector &data) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + return detail::buffar_const_ds(data); +} +template +const_buffer buffer(const std::vector &data, size_t n_bytes) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + return detail::buffar_const_ds(data, n_bytes); +} +// std::basic_string +template +mutable_buffer buffer(std::basic_string &data) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + // before C++17 string::data() returned const char* + return mutable_buffer(data.size() != 0u ? &data[0] : nullptr, + data.size() * sizeof(T)); +} +template +mutable_buffer buffer(std::basic_string &data, + size_t n_bytes) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + // before C++17 string::data() returned const char* + return mutable_buffer(data.size() != 0u ? &data[0] : nullptr, + (std::min)(data.size() * sizeof(T), n_bytes)); +} +template +const_buffer buffer(const std::basic_string &data) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + return detail::buffar_const_ds(data); +} +template +const_buffer buffer(const std::basic_string &data, + size_t n_bytes) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + return detail::buffar_const_ds(data, n_bytes); +} + +#ifdef ZMQ_CPP17 +// std::basic_string_view +template +const_buffer buffer(std::basic_string_view data) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + return detail::buffar_const_ds(data); +} +template +const_buffer buffer(std::basic_string_view data, size_t n_bytes) noexcept +{ + static_assert(detail::is_pod_like::value, "T must be POD"); + return detail::buffar_const_ds(data, n_bytes); +} +#endif + +#endif // ZMQ_CPP11 + +namespace detail +{ + class socket_base { public: @@ -688,6 +1025,9 @@ public: bool connected() const ZMQ_NOTHROW { return (_handle != ZMQ_NULLPTR); } +#ifdef ZMQ_CPP11 + ZMQ_DEPRECATED("from 4.3.1, use send taking a const_buffer and send_flags") +#endif size_t send(const void *buf_, size_t len_, int flags_ = 0) { int nbytes = zmq_send(_handle, buf_, len_, flags_); @@ -698,7 +1038,11 @@ public: throw error_t(); } - bool send(message_t &msg_, int flags_ = 0) +#ifdef ZMQ_CPP11 + ZMQ_DEPRECATED("from 4.3.1, use send taking message_t and send_flags") +#endif + bool send(message_t &msg_, + int flags_ = 0) // default until removed { int nbytes = zmq_msg_send(msg_.handle(), _handle, flags_); if (nbytes >= 0) @@ -715,9 +1059,51 @@ public: } #ifdef ZMQ_HAS_RVALUE_REFS - bool send(message_t &&msg_, int flags_ = 0) { return send(msg_, flags_); } +#ifdef ZMQ_CPP11 + ZMQ_DEPRECATED("from 4.3.1, use send taking message_t and send_flags") +#endif + bool send(message_t &&msg_, + int flags_ = 0) // default until removed + { + #ifdef ZMQ_CPP11 + return send(msg_, static_cast(flags_)).success; + #else + return send(msg_, flags_); + #endif + } #endif +#ifdef ZMQ_CPP11 + send_result send(const_buffer buf, send_flags flags = send_flags::none) + { + const int nbytes = + zmq_send(_handle, buf.data(), buf.size(), static_cast(flags)); + if (nbytes >= 0) + return {static_cast(nbytes), true}; + if (zmq_errno() == EAGAIN) + return {size_t{0}, false}; + throw error_t(); + } + + send_result send(message_t &msg, send_flags flags) + { + int nbytes = zmq_msg_send(msg.handle(), _handle, static_cast(flags)); + if (nbytes >= 0) + return {static_cast(nbytes), true}; + if (zmq_errno() == EAGAIN) + return {size_t{0}, false}; + throw error_t(); + } + + send_result send(message_t &&msg, send_flags flags) + { + return send(msg, flags); + } +#endif + +#ifdef ZMQ_CPP11 + ZMQ_DEPRECATED("from 4.3.1, use recv taking a mutable_buffer and recv_flags") +#endif size_t recv(void *buf_, size_t len_, int flags_ = 0) { int nbytes = zmq_recv(_handle, buf_, len_, flags_); @@ -728,7 +1114,14 @@ public: throw error_t(); } - bool recv(message_t *msg_, int flags_ = 0) +#ifdef ZMQ_CPP11 + ZMQ_DEPRECATED("from 4.3.1, use recv taking a reference to message_t and recv_flags") +#endif + bool recv(message_t *msg_, int flags_ +#ifndef ZMQ_CPP11 + = 0 +#endif + ) { int nbytes = zmq_msg_recv(msg_->handle(), _handle, flags_); if (nbytes >= 0) @@ -738,6 +1131,32 @@ public: throw error_t(); } +#ifdef ZMQ_CPP11 + recv_buffer_result recv(mutable_buffer buf, recv_flags flags = recv_flags::none) + { + const int nbytes = + zmq_recv(_handle, buf.data(), buf.size(), static_cast(flags)); + if (nbytes >= 0) + return {(std::min)(static_cast(nbytes), buf.size()), + static_cast(nbytes), true}; + if (zmq_errno() == EAGAIN) + return {size_t{0}, size_t{0}, false}; + throw error_t(); + } + + recv_result recv(message_t &msg, recv_flags flags = recv_flags::none) + { + const int nbytes = zmq_msg_recv(msg.handle(), _handle, static_cast(flags)); + if (nbytes >= 0) { + assert(msg.size() == static_cast(nbytes)); + return {static_cast(nbytes), true}; + } + if (zmq_errno() == EAGAIN) + return {size_t{0}, false}; + throw error_t(); + } +#endif + #if defined(ZMQ_BUILD_DRAFT_API) && ZMQ_VERSION >= ZMQ_MAKE_VERSION(4, 2, 0) void join(const char* group) { diff --git a/zmq_addon.hpp b/zmq_addon.hpp index 041e2fa..e0d9eb9 100644 --- a/zmq_addon.hpp +++ b/zmq_addon.hpp @@ -130,8 +130,13 @@ class multipart_t bool more = true; while (more) { message_t message; + #ifdef ZMQ_CPP11 + if (!socket.recv(message, static_cast(flags)).success) + return false; + #else if (!socket.recv(&message, flags)) return false; + #endif more = message.more(); add(std::move(message)); } @@ -146,8 +151,14 @@ class multipart_t while (more) { message_t message = pop(); more = size() > 0; + #ifdef ZMQ_CPP11 + if (!socket.send(message, + static_cast((more ? ZMQ_SNDMORE : 0) | flags)).success) + return false; + #else if (!socket.send(message, (more ? ZMQ_SNDMORE : 0) | flags)) return false; + #endif } clear(); return true;