diff --git a/include/msgpack/v1/adaptor/cpp17/variant.hpp b/include/msgpack/v1/adaptor/cpp17/variant.hpp index 6d0ca1f7..e9efa757 100644 --- a/include/msgpack/v1/adaptor/cpp17/variant.hpp +++ b/include/msgpack/v1/adaptor/cpp17/variant.hpp @@ -11,8 +11,6 @@ #ifndef MSGPACK_V1_TYPE_VARIANT_HPP #define MSGPACK_V1_TYPE_VARIANT_HPP -#define MSGPACK_USE_STD_VARIANT_ADAPTOR - #if defined(MSGPACK_USE_STD_VARIANT_ADAPTOR) #include "msgpack/cpp_version.hpp" @@ -43,20 +41,31 @@ MSGPACK_API_VERSION_NAMESPACE(v1) { } } + struct object_variant_overload { + object_variant_overload(msgpack::object &obj, msgpack::zone &zone) + : obj{obj} + , zone{zone} {} + + template + void operator()(T const &value) { + obj = msgpack::object(value, zone); + } + + msgpack::object &obj; + msgpack::zone &zone; + }; } // namespace detail template struct as, typename std::enable_if<(msgpack::has_as::value && ...)>::type> { std::variant operator()(msgpack::object const &o) const { - if (o.type != msgpack::type::ARRAY) { - throw msgpack::type_error{}; - } - if (o.via.array.size != 2) { - throw msgpack::type_error{}; - } - if (o.via.array.ptr[0].type != msgpack::type::POSITIVE_INTEGER) { + if ( o.type != msgpack::type::ARRAY + || o.via.array.size != 2 + || o.via.array.ptr[0].type != msgpack::type::POSITIVE_INTEGER + || o.via.array.ptr[0].via.u64 >= sizeof...(Ts)) { throw msgpack::type_error{}; } + return detail::construct_variant, Ts...>( o.via.array.ptr[0].as(), o.via.array.ptr[1], @@ -68,15 +77,13 @@ MSGPACK_API_VERSION_NAMESPACE(v1) { template struct convert> { msgpack::object const &operator()(msgpack::object const &o, std::variant &v) const { - if (o.type != msgpack::type::ARRAY) { - throw msgpack::type_error{}; - } - if (o.via.array.size != 2) { - throw msgpack::type_error{}; - } - if (o.via.array.ptr[0].type != msgpack::type::POSITIVE_INTEGER) { + if ( o.type != msgpack::type::ARRAY + || o.via.array.size != 2 + || o.via.array.ptr[0].type != msgpack::type::POSITIVE_INTEGER + || o.via.array.ptr[0].via.u64 >= sizeof...(Ts)) { throw msgpack::type_error{}; } + v = detail::construct_variant, Ts...>( o.via.array.ptr[0].as(), o.via.array.ptr[1], @@ -92,38 +99,24 @@ MSGPACK_API_VERSION_NAMESPACE(v1) { msgpack::packer& operator()(msgpack::packer &o, std::variant const &v) const { o.pack_array(2); o.pack_uint64(v.index()); - std::visit([&o](auto const &real_value){o.pack(real_value);}, v); + std::visit([&o](auto const &value){o.pack(value);}, v); return o; } }; - // template - // struct object> { - // void operator()(msgpack::object &o, std::variant const &v) const { - // o.type = msgpack::type::ARRAY; - // o.via.array.size = 2; - // msgpack::adaptor::object(o.via.array.ptr[0], v.index()); - // std::visit([&o](auto const &value) { - // msgpack::adaptor::object(o.via.array.ptr[1], value); - // }, v); - // } - // }; - // - // template - // struct object_with_zone> { - // void operator()(msgpack::object::with_zone &o, std::variant const &v) const { - // o.type = msgpack::type::ARRAY; - // - // msgpack::object *p = static_cast(o.zone.allocate_align(sizeof(msgpack::object) * 2, MSGPACK_ZONE_ALIGNOF(msgpack::object))); - // - // o.via.array.size = 2; - // o.via.array.ptr = p; - // msgpack::adaptor::object_with_zone()(o.via.array.ptr[0], v.index(), o.zone); - // std::visit([&o](auto const &real_value){ - // o.via.array.ptr[1] = msgpack::adaptor::object()(real_value, o.zone); - // }, v); - // } - // }; + + template + struct object_with_zone> { + void operator()(msgpack::object::with_zone &o, std::variant const &v) const { + msgpack::object *p = static_cast(o.zone.allocate_align(sizeof(msgpack::object) * 2, MSGPACK_ZONE_ALIGNOF(msgpack::object))); + + o.type = msgpack::type::ARRAY; + o.via.array.size = 2; + o.via.array.ptr = p; + o.via.array.ptr[0]= msgpack::object(v.index(), o.zone); + std::visit(detail::object_variant_overload(o.via.array.ptr[1], o.zone), v); + } + }; } // namespace adaptor } } // namespace msgpack diff --git a/test/msgpack_cpp17.cpp b/test/msgpack_cpp17.cpp index 742a58b6..30a76770 100644 --- a/test/msgpack_cpp17.cpp +++ b/test/msgpack_cpp17.cpp @@ -461,7 +461,9 @@ BOOST_AUTO_TEST_CASE(carray_byte_object_with_zone) } } -BOOST_AUTO_TEST_CASE(variant_as) { +#if defined(MSGPACK_USE_STD_VARIANT_ADAPTOR) + +BOOST_AUTO_TEST_CASE(variant_pack_unpack_as) { std::stringstream ss; std::variant val1{1.0}; msgpack::pack(ss, val1); @@ -470,6 +472,18 @@ BOOST_AUTO_TEST_CASE(variant_as) { msgpack::unpack(str.data(), str.size()); std::variant val2 = oh.get().as >(); BOOST_CHECK(val1 == val2); + BOOST_CHECK_THROW((oh.get().as>()), msgpack::type_error); } +BOOST_AUTO_TEST_CASE(variant_with_zone) { + msgpack::zone z; + std::variant val1{1.0}; + msgpack::object obj(val1, z); + std::variant val2 = obj.as>(); + BOOST_CHECK(val1 == val2); + BOOST_CHECK_THROW((obj.as>()), msgpack::type_error); +} + +#endif // defined(MSGPACK_USE_STD_VARIANT_ADAPTOR) + #endif // MSGPACK_CPP_VERSION >= 201703