diff --git a/tests/socket.cpp b/tests/socket.cpp index 92e7893..7eab2bd 100644 --- a/tests/socket.cpp +++ b/tests/socket.cpp @@ -86,11 +86,12 @@ TEST_CASE("socket sends and receives const buffer", "[socket]") const char* str = "Hi"; #ifdef ZMQ_CPP11 - CHECK(2 == sender.send(zmq::buffer(str, 2)).size); + CHECK(2 == *sender.send(zmq::buffer(str, 2))); char buf[2]; const auto res = receiver.recv(zmq::buffer(buf)); - CHECK(!res.truncated()); - CHECK(2 == res.size); + CHECK(res); + CHECK(!res->truncated()); + CHECK(2 == res->size); #else CHECK(2 == sender.send(str, 2)); char buf[2]; @@ -109,11 +110,11 @@ TEST_CASE("socket send none sndmore", "[socket]") std::vector buf(4); auto res = s.send(zmq::buffer(buf), zmq::send_flags::sndmore); - CHECK(res.size == buf.size()); - CHECK(res.success); + CHECK(res); + CHECK(*res == buf.size()); res = s.send(zmq::buffer(buf)); - CHECK(res.size == buf.size()); - CHECK(res.success); + CHECK(res); + CHECK(*res == buf.size()); } TEST_CASE("socket send dontwait", "[socket]") @@ -124,17 +125,14 @@ TEST_CASE("socket send dontwait", "[socket]") std::vector buf(4); auto res = s.send(zmq::buffer(buf), zmq::send_flags::dontwait); - CHECK(!res.success); - CHECK(res.size == 0); + CHECK(!res); res = s.send(zmq::buffer(buf), zmq::send_flags::dontwait | zmq::send_flags::sndmore); - CHECK(!res.success); - CHECK(res.size == 0); + CHECK(!res); zmq::message_t msg; auto resm = s.send(msg, zmq::send_flags::dontwait); - CHECK(!resm.success); - CHECK(resm.size == 0); + CHECK(!resm); CHECK(msg.size() == 0); } @@ -158,23 +156,24 @@ TEST_CASE("socket recv none", "[socket]") std::vector sbuf(4); const auto res_send = s2.send(zmq::buffer(sbuf)); - CHECK(res_send.success); + CHECK(res_send); + CHECK(res_send.has_value()); std::vector buf(2); const auto res = s.recv(zmq::buffer(buf)); - CHECK(res.success); - CHECK(res.truncated()); - CHECK(res.untruncated_size == sbuf.size()); - CHECK(res.size == buf.size()); + CHECK(res.has_value()); + CHECK(res->truncated()); + CHECK(res->untruncated_size == sbuf.size()); + CHECK(res->size == buf.size()); const auto res_send2 = s2.send(zmq::buffer(sbuf)); - CHECK(res_send2.success); + CHECK(res_send2.has_value()); std::vector buf2(10); const auto res2 = s.recv(zmq::buffer(buf2)); - CHECK(res2.success); - CHECK(!res2.truncated()); - CHECK(res2.untruncated_size == sbuf.size()); - CHECK(res2.size == sbuf.size()); + CHECK(res2.has_value()); + CHECK(!res2->truncated()); + CHECK(res2->untruncated_size == sbuf.size()); + CHECK(res2->size == sbuf.size()); } TEST_CASE("socket send recv message_t", "[socket]") @@ -187,15 +186,16 @@ TEST_CASE("socket send recv message_t", "[socket]") zmq::message_t smsg(size_t{10}); const auto res_send = s2.send(smsg, zmq::send_flags::none); - CHECK(res_send.success); - CHECK(res_send.size == 10); + CHECK(res_send); + CHECK(*res_send == 10); CHECK(smsg.size() == 0); zmq::message_t rmsg; const auto res = s.recv(rmsg); - CHECK(res.success); - CHECK(res.size == 10); - CHECK(rmsg.size() == res.size); + CHECK(res); + CHECK(*res == 10); + CHECK(res.value() == 10); + CHECK(rmsg.size() == *res); } TEST_CASE("socket recv dontwait", "[socket]") @@ -207,13 +207,12 @@ TEST_CASE("socket recv dontwait", "[socket]") std::vector buf(4); constexpr auto flags = zmq::recv_flags::none | zmq::recv_flags::dontwait; auto res = s.recv(zmq::buffer(buf), flags); - CHECK(!res.success); - CHECK(res.size == 0); + CHECK(!res); zmq::message_t msg; auto resm = s.recv(msg, flags); - CHECK(!resm.success); - CHECK(resm.size == 0); + CHECK(!resm); + CHECK_THROWS_AS(resm.value(), const std::exception &); CHECK(msg.size() == 0); } diff --git a/zmq.hpp b/zmq.hpp index 2f50669..935b890 100644 --- a/zmq.hpp +++ b/zmq.hpp @@ -72,16 +72,21 @@ #include #include -#ifdef ZMQ_CPP11 -#include -#endif #include #include #include -#include #include #include #include +#ifdef ZMQ_CPP11 +#include +#include +#include +#include +#endif +#ifdef ZMQ_CPP17 +#include +#endif /* Version macros for compile-time API version detection */ #define CPPZMQ_VERSION_MAJOR 4 @@ -92,12 +97,6 @@ ZMQ_MAKE_VERSION(CPPZMQ_VERSION_MAJOR, CPPZMQ_VERSION_MINOR, \ CPPZMQ_VERSION_PATCH) -#ifdef ZMQ_CPP11 -#include -#include -#include -#endif - // Detect whether the compiler supports C++11 rvalue references. #if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) \ && defined(__GXX_EXPERIMENTAL_CXX0X__)) @@ -621,23 +620,11 @@ inline void swap(context_t &a, context_t &b) ZMQ_NOTHROW { } #ifdef ZMQ_CPP11 -struct send_result -{ - size_t size; // message size in bytes - bool success; -}; -struct recv_result -{ - size_t size; // message size in bytes - bool success; -}; - -struct recv_buffer_result +struct recv_buffer_size { size_t size; // number of bytes written to buffer size_t untruncated_size; // untruncated message size in bytes - bool success; ZMQ_NODISCARD bool truncated() const noexcept { @@ -647,6 +634,71 @@ struct recv_buffer_result namespace detail { + +#ifdef ZMQ_CPP17 +using send_result_t = std::optional; +using recv_result_t = std::optional; +using recv_buffer_result_t = std::optional; +#else +// A C++11 type emulating the most basic +// operations of std::optional for trivial types +template class trivial_optional +{ + public: + static_assert(std::is_trivial::value, "T must be trivial"); + using value_type = T; + + trivial_optional() = default; + trivial_optional(T value) noexcept : _value(value), _has_value(true) {} + + const T *operator->() const noexcept + { + assert(_has_value); + return &_value; + } + T *operator->() noexcept + { + assert(_has_value); + return &_value; + } + + const T &operator*() const noexcept + { + assert(_has_value); + return _value; + } + T &operator*() noexcept + { + assert(_has_value); + return _value; + } + + T &value() + { + if (!_has_value) + throw std::exception(); + return _value; + } + const T &value() const + { + if (!_has_value) + throw std::exception(); + return _value; + } + + explicit operator bool() const noexcept { return _has_value; } + bool has_value() const noexcept { return _has_value; } + + private: + T _value{}; + bool _has_value{false}; +}; + +using send_result_t = trivial_optional; +using recv_result_t = trivial_optional; +using recv_buffer_result_t = trivial_optional; +#endif + template constexpr T enum_bit_or(T a, T b) noexcept { @@ -1111,7 +1163,7 @@ public: int flags_ = 0) // default until removed { #ifdef ZMQ_CPP11 - return send(msg_, static_cast(flags_)).success; + return send(msg_, static_cast(flags_)).has_value(); #else return send(msg_, flags_); #endif @@ -1119,28 +1171,28 @@ public: #endif #ifdef ZMQ_CPP11 - send_result send(const_buffer buf, send_flags flags = send_flags::none) + detail::send_result_t send(const_buffer buf, send_flags flags = send_flags::none) { const int nbytes = zmq_send(_handle, buf.data(), buf.size(), static_cast(flags)); if (nbytes >= 0) - return {static_cast(nbytes), true}; + return static_cast(nbytes); if (zmq_errno() == EAGAIN) - return {size_t{0}, false}; + return {}; throw error_t(); } - send_result send(message_t &msg, send_flags flags) + detail::send_result_t send(message_t &msg, send_flags flags) { int nbytes = zmq_msg_send(msg.handle(), _handle, static_cast(flags)); if (nbytes >= 0) - return {static_cast(nbytes), true}; + return static_cast(nbytes); if (zmq_errno() == EAGAIN) - return {size_t{0}, false}; + return {}; throw error_t(); } - send_result send(message_t &&msg, send_flags flags) + detail::send_result_t send(message_t &&msg, send_flags flags) { return send(msg, flags); } @@ -1177,27 +1229,29 @@ public: } #ifdef ZMQ_CPP11 - recv_buffer_result recv(mutable_buffer buf, recv_flags flags = recv_flags::none) + detail::recv_buffer_result_t recv(mutable_buffer buf, + recv_flags flags = recv_flags::none) { const int nbytes = zmq_recv(_handle, buf.data(), buf.size(), static_cast(flags)); - if (nbytes >= 0) - return {(std::min)(static_cast(nbytes), buf.size()), - static_cast(nbytes), true}; + if (nbytes >= 0) { + return recv_buffer_size{(std::min)(static_cast(nbytes), buf.size()), + static_cast(nbytes)}; + } if (zmq_errno() == EAGAIN) - return {size_t{0}, size_t{0}, false}; + return {}; throw error_t(); } - recv_result recv(message_t &msg, recv_flags flags = recv_flags::none) + detail::recv_result_t recv(message_t &msg, recv_flags flags = recv_flags::none) { const int nbytes = zmq_msg_recv(msg.handle(), _handle, static_cast(flags)); if (nbytes >= 0) { assert(msg.size() == static_cast(nbytes)); - return {static_cast(nbytes), true}; + return static_cast(nbytes); } if (zmq_errno() == EAGAIN) - return {size_t{0}, false}; + return {}; throw error_t(); } #endif diff --git a/zmq_addon.hpp b/zmq_addon.hpp index e0d9eb9..9d62c7a 100644 --- a/zmq_addon.hpp +++ b/zmq_addon.hpp @@ -131,7 +131,7 @@ class multipart_t while (more) { message_t message; #ifdef ZMQ_CPP11 - if (!socket.recv(message, static_cast(flags)).success) + if (!socket.recv(message, static_cast(flags))) return false; #else if (!socket.recv(&message, flags)) @@ -153,7 +153,7 @@ class multipart_t more = size() > 0; #ifdef ZMQ_CPP11 if (!socket.send(message, - static_cast((more ? ZMQ_SNDMORE : 0) | flags)).success) + static_cast((more ? ZMQ_SNDMORE : 0) | flags))) return false; #else if (!socket.send(message, (more ? ZMQ_SNDMORE : 0) | flags))