mirror of
https://github.com/zeromq/cppzmq.git
synced 2025-04-20 16:03:37 +02:00
Problem: No endian check in encoding
Solution: Always write message part sizes using network order.
This commit is contained in:
parent
a3e5b54c3c
commit
4784b74c37
@ -69,7 +69,7 @@ TEST_CASE("multipart codec decode bad data overflow", "[codec_multipart]")
|
|||||||
|
|
||||||
CHECK_THROWS_AS(
|
CHECK_THROWS_AS(
|
||||||
multipart_t::decode(wrong_size),
|
multipart_t::decode(wrong_size),
|
||||||
std::out_of_range);
|
const std::out_of_range&);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("multipart codec decode bad data extra data", "[codec_multipart]")
|
TEST_CASE("multipart codec decode bad data extra data", "[codec_multipart]")
|
||||||
@ -83,7 +83,7 @@ TEST_CASE("multipart codec decode bad data extra data", "[codec_multipart]")
|
|||||||
|
|
||||||
CHECK_THROWS_AS(
|
CHECK_THROWS_AS(
|
||||||
multipart_t::decode(wrong_size),
|
multipart_t::decode(wrong_size),
|
||||||
std::out_of_range);
|
const std::out_of_range&);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,6 +67,39 @@ recv_multipart_n(socket_ref s, OutputIt out, size_t n, recv_flags flags)
|
|||||||
}
|
}
|
||||||
return msg_count;
|
return msg_count;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool is_little_endian()
|
||||||
|
{
|
||||||
|
const uint16_t i = 0x01;
|
||||||
|
return *reinterpret_cast<const uint8_t *>(&i) == 0x01;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void write_network_order(unsigned char *buf, const uint32_t value)
|
||||||
|
{
|
||||||
|
if (is_little_endian()) {
|
||||||
|
ZMQ_CONSTEXPR_VAR uint32_t mask = std::numeric_limits<std::uint8_t>::max();
|
||||||
|
*buf++ = (value >> 24) & mask;
|
||||||
|
*buf++ = (value >> 16) & mask;
|
||||||
|
*buf++ = (value >> 8) & mask;
|
||||||
|
*buf++ = value & mask;
|
||||||
|
} else {
|
||||||
|
std::memcpy(buf, &value, sizeof(value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint32_t read_u32_network_order(const unsigned char *buf)
|
||||||
|
{
|
||||||
|
if (is_little_endian()) {
|
||||||
|
return (static_cast<uint32_t>(buf[0]) << 24)
|
||||||
|
+ (static_cast<uint32_t>(buf[1]) << 16)
|
||||||
|
+ (static_cast<uint32_t>(buf[2]) << 8)
|
||||||
|
+ static_cast<uint32_t>(buf[3]);
|
||||||
|
} else {
|
||||||
|
uint32_t value;
|
||||||
|
std::memcpy(&value, buf, sizeof(value));
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
/* Receive a multipart message.
|
/* Receive a multipart message.
|
||||||
@ -189,42 +222,37 @@ message_t encode(const Range &parts)
|
|||||||
|
|
||||||
// First pass check sizes
|
// First pass check sizes
|
||||||
for (const auto &part : parts) {
|
for (const auto &part : parts) {
|
||||||
size_t part_size = part.size();
|
const size_t part_size = part.size();
|
||||||
if (part_size > std::numeric_limits<std::uint32_t>::max()) {
|
if (part_size > std::numeric_limits<std::uint32_t>::max()) {
|
||||||
// Size value must fit into uint32_t.
|
// Size value must fit into uint32_t.
|
||||||
throw std::range_error("Invalid size, message part too large");
|
throw std::range_error("Invalid size, message part too large");
|
||||||
}
|
}
|
||||||
size_t count_size = 5;
|
const size_t count_size =
|
||||||
if (part_size < std::numeric_limits<std::uint8_t>::max()) {
|
part_size < std::numeric_limits<std::uint8_t>::max() ? 1 : 5;
|
||||||
count_size = 1;
|
|
||||||
}
|
|
||||||
mmsg_size += part_size + count_size;
|
mmsg_size += part_size + count_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
message_t encoded(mmsg_size);
|
message_t encoded(mmsg_size);
|
||||||
unsigned char *buf = encoded.data<unsigned char>();
|
unsigned char *buf = encoded.data<unsigned char>();
|
||||||
for (const auto &part : parts) {
|
for (const auto &part : parts) {
|
||||||
uint32_t part_size = part.size();
|
const uint32_t part_size = part.size();
|
||||||
const unsigned char *part_data =
|
const unsigned char *part_data =
|
||||||
static_cast<const unsigned char *>(part.data());
|
static_cast<const unsigned char *>(part.data());
|
||||||
|
|
||||||
// small part
|
|
||||||
if (part_size < std::numeric_limits<std::uint8_t>::max()) {
|
if (part_size < std::numeric_limits<std::uint8_t>::max()) {
|
||||||
|
// small part
|
||||||
*buf++ = (unsigned char) part_size;
|
*buf++ = (unsigned char) part_size;
|
||||||
memcpy(buf, part_data, part_size);
|
} else {
|
||||||
buf += part_size;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// big part
|
// big part
|
||||||
*buf++ = std::numeric_limits<uint8_t>::max();
|
*buf++ = std::numeric_limits<uint8_t>::max();
|
||||||
*buf++ = (part_size >> 24) & std::numeric_limits<std::uint8_t>::max();
|
detail::write_network_order(buf, part_size);
|
||||||
*buf++ = (part_size >> 16) & std::numeric_limits<std::uint8_t>::max();
|
buf += sizeof(part_size);
|
||||||
*buf++ = (part_size >> 8) & std::numeric_limits<std::uint8_t>::max();
|
}
|
||||||
*buf++ = part_size & std::numeric_limits<std::uint8_t>::max();
|
std::memcpy(buf, part_data, part_size);
|
||||||
memcpy(buf, part_data, part_size);
|
|
||||||
buf += part_size;
|
buf += part_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert(static_cast<size_t>(buf - encoded.data<unsigned char>()) == mmsg_size);
|
||||||
return encoded;
|
return encoded;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,22 +279,24 @@ template<class OutputIt> OutputIt decode(const message_t &encoded, OutputIt out)
|
|||||||
while (source < limit) {
|
while (source < limit) {
|
||||||
size_t part_size = *source++;
|
size_t part_size = *source++;
|
||||||
if (part_size == std::numeric_limits<std::uint8_t>::max()) {
|
if (part_size == std::numeric_limits<std::uint8_t>::max()) {
|
||||||
if (source > limit - 4) {
|
if (static_cast<size_t>(limit - source) < sizeof(uint32_t)) {
|
||||||
throw std::out_of_range(
|
throw std::out_of_range(
|
||||||
"Malformed encoding, overflow in reading size");
|
"Malformed encoding, overflow in reading size");
|
||||||
}
|
}
|
||||||
part_size = ((uint32_t) source[0] << 24) + ((uint32_t) source[1] << 16)
|
part_size = detail::read_u32_network_order(source);
|
||||||
+ ((uint32_t) source[2] << 8) + (uint32_t) source[3];
|
// the part size is allowed to be less than 0xFF
|
||||||
source += 4;
|
source += sizeof(uint32_t);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (source > limit - part_size) {
|
if (static_cast<size_t>(limit - source) < part_size) {
|
||||||
throw std::out_of_range("Malformed encoding, overflow in reading part");
|
throw std::out_of_range("Malformed encoding, overflow in reading part");
|
||||||
}
|
}
|
||||||
*out = message_t(source, part_size);
|
*out = message_t(source, part_size);
|
||||||
++out;
|
++out;
|
||||||
source += part_size;
|
source += part_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert(source == limit);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user