diff --git a/tests/codec_multipart.cpp b/tests/codec_multipart.cpp index 7be90d4..43d444d 100644 --- a/tests/codec_multipart.cpp +++ b/tests/codec_multipart.cpp @@ -69,7 +69,7 @@ TEST_CASE("multipart codec decode bad data overflow", "[codec_multipart]") CHECK_THROWS_AS( multipart_t::decode(wrong_size), - std::out_of_range); + const std::out_of_range&); } TEST_CASE("multipart codec decode bad data extra data", "[codec_multipart]") @@ -83,7 +83,7 @@ TEST_CASE("multipart codec decode bad data extra data", "[codec_multipart]") CHECK_THROWS_AS( multipart_t::decode(wrong_size), - std::out_of_range); + const std::out_of_range&); } diff --git a/zmq_addon.hpp b/zmq_addon.hpp index ea93e44..327899c 100644 --- a/zmq_addon.hpp +++ b/zmq_addon.hpp @@ -67,6 +67,39 @@ recv_multipart_n(socket_ref s, OutputIt out, size_t n, recv_flags flags) } return msg_count; } + +inline bool is_little_endian() +{ + const uint16_t i = 0x01; + return *reinterpret_cast(&i) == 0x01; +} + +inline void write_network_order(unsigned char *buf, const uint32_t value) +{ + if (is_little_endian()) { + ZMQ_CONSTEXPR_VAR uint32_t mask = std::numeric_limits::max(); + *buf++ = (value >> 24) & mask; + *buf++ = (value >> 16) & mask; + *buf++ = (value >> 8) & mask; + *buf++ = value & mask; + } else { + std::memcpy(buf, &value, sizeof(value)); + } +} + +inline uint32_t read_u32_network_order(const unsigned char *buf) +{ + if (is_little_endian()) { + return (static_cast(buf[0]) << 24) + + (static_cast(buf[1]) << 16) + + (static_cast(buf[2]) << 8) + + static_cast(buf[3]); + } else { + uint32_t value; + std::memcpy(&value, buf, sizeof(value)); + return value; + } +} } // namespace detail /* Receive a multipart message. @@ -190,42 +223,37 @@ message_t encode(const Range &parts) // First pass check sizes for (const auto &part : parts) { - size_t part_size = part.size(); + const size_t part_size = part.size(); if (part_size > std::numeric_limits::max()) { // Size value must fit into uint32_t. throw std::range_error("Invalid size, message part too large"); } - size_t count_size = 5; - if (part_size < std::numeric_limits::max()) { - count_size = 1; - } + const size_t count_size = + part_size < std::numeric_limits::max() ? 1 : 5; mmsg_size += part_size + count_size; } message_t encoded(mmsg_size); unsigned char *buf = encoded.data(); for (const auto &part : parts) { - uint32_t part_size = part.size(); + const uint32_t part_size = part.size(); const unsigned char *part_data = static_cast(part.data()); - // small part if (part_size < std::numeric_limits::max()) { + // small part *buf++ = (unsigned char) part_size; - memcpy(buf, part_data, part_size); - buf += part_size; - continue; + } else { + // big part + *buf++ = std::numeric_limits::max(); + detail::write_network_order(buf, part_size); + buf += sizeof(part_size); } - - // big part - *buf++ = std::numeric_limits::max(); - *buf++ = (part_size >> 24) & std::numeric_limits::max(); - *buf++ = (part_size >> 16) & std::numeric_limits::max(); - *buf++ = (part_size >> 8) & std::numeric_limits::max(); - *buf++ = part_size & std::numeric_limits::max(); - memcpy(buf, part_data, part_size); + std::memcpy(buf, part_data, part_size); buf += part_size; } + + assert(static_cast(buf - encoded.data()) == mmsg_size); return encoded; } @@ -252,22 +280,24 @@ template OutputIt decode(const message_t &encoded, OutputIt out) while (source < limit) { size_t part_size = *source++; if (part_size == std::numeric_limits::max()) { - if (source > limit - 4) { + if (static_cast(limit - source) < sizeof(uint32_t)) { throw std::out_of_range( "Malformed encoding, overflow in reading size"); } - part_size = ((uint32_t) source[0] << 24) + ((uint32_t) source[1] << 16) - + ((uint32_t) source[2] << 8) + (uint32_t) source[3]; - source += 4; + part_size = detail::read_u32_network_order(source); + // the part size is allowed to be less than 0xFF + source += sizeof(uint32_t); } - if (source > limit - part_size) { + if (static_cast(limit - source) < part_size) { throw std::out_of_range("Malformed encoding, overflow in reading part"); } *out = message_t(source, part_size); ++out; source += part_size; } + + assert(source == limit); return out; }