diff --git a/python/msgpack/_msgpack.pyx b/python/msgpack/_msgpack.pyx index 66869c80..fb7f0c1e 100644 --- a/python/msgpack/_msgpack.pyx +++ b/python/msgpack/_msgpack.pyx @@ -20,6 +20,9 @@ cdef extern from "Python.h": cdef bint PyFloat_Check(object o) cdef bint PyBytes_Check(object o) cdef bint PyUnicode_Check(object o) + cdef bint PyCallable_Check(object o) + cdef void Py_INCREF(object o) + cdef void Py_DECREF(object o) cdef extern from "stdlib.h": void* malloc(size_t) @@ -60,6 +63,7 @@ cdef class Packer(object): astream.write(packer.pack(b)) """ cdef msgpack_packer pk + cdef object default def __cinit__(self): cdef int buf_size = 1024*1024 @@ -67,6 +71,12 @@ cdef class Packer(object): self.pk.buf_size = buf_size self.pk.length = 0 + def __init__(self, default=None): + if default is not None: + if not PyCallable_Check(default): + raise TypeError("default must be a callable.") + self.default = default + def __dealloc__(self): free(self.pk.buf); @@ -126,9 +136,18 @@ cdef class Packer(object): for v in o: ret = self._pack(v) if ret != 0: break + elif self.default is not None: + o = self.default(o) + d = o + ret = msgpack_pack_map(&self.pk, len(d)) + if ret == 0: + for k,v in d.items(): + ret = self._pack(k) + if ret != 0: break + ret = self._pack(v) + if ret != 0: break else: - # TODO: Serialize with defalt() like simplejson. - raise TypeError, "can't serialize %r" % (o,) + raise TypeError("can't serialize %r" % (o,)) return ret def pack(self, object obj): @@ -141,14 +160,14 @@ cdef class Packer(object): return buf -def pack(object o, object stream): +def pack(object o, object stream, default=None): """pack an object `o` and write it to stream).""" - packer = Packer() + packer = Packer(default) stream.write(packer.pack(o)) -def packb(object o): +def packb(object o, default=None): """pack o and return packed bytes.""" - packer = Packer() + packer = Packer(default=default) return packer.pack(o) packs = packb @@ -156,6 +175,7 @@ packs = packb cdef extern from "unpack.h": ctypedef struct msgpack_user: int use_list + PyObject* object_hook ctypedef struct template_context: msgpack_user user @@ -170,7 +190,7 @@ cdef extern from "unpack.h": object template_data(template_context* ctx) -def unpackb(bytes packed_bytes): +def unpackb(bytes packed_bytes, object object_hook=None): """Unpack packed_bytes to object. Returns an unpacked object.""" cdef const_char_ptr p = packed_bytes cdef template_context ctx @@ -178,7 +198,16 @@ def unpackb(bytes packed_bytes): cdef int ret template_init(&ctx) ctx.user.use_list = 0 + ctx.user.object_hook = NULL + if object_hook is not None: + if not PyCallable_Check(object_hook): + raise TypeError("object_hook must be a callable.") + Py_INCREF(object_hook) + ctx.user.object_hook = object_hook ret = template_execute(&ctx, p, len(packed_bytes), &off) + if object_hook is not None: + pass + #Py_DECREF(object_hook) if ret == 1: return template_data(&ctx) else: @@ -186,10 +215,10 @@ def unpackb(bytes packed_bytes): unpacks = unpackb -def unpack(object stream): +def unpack(object stream, object object_hook=None): """unpack an object from stream.""" packed = stream.read() - return unpackb(packed) + return unpackb(packed, object_hook=object_hook) cdef class UnpackIterator(object): cdef object unpacker @@ -234,6 +263,7 @@ cdef class Unpacker(object): cdef int read_size cdef object waiting_bytes cdef bint use_list + cdef object object_hook def __cinit__(self): self.buf = NULL @@ -242,7 +272,8 @@ cdef class Unpacker(object): if self.buf: free(self.buf); - def __init__(self, file_like=None, int read_size=0, bint use_list=0): + def __init__(self, file_like=None, int read_size=0, bint use_list=0, + object object_hook=None): if read_size == 0: read_size = 1024*1024 self.use_list = use_list @@ -255,6 +286,11 @@ cdef class Unpacker(object): self.buf_tail = 0 template_init(&self.ctx) self.ctx.user.use_list = use_list + self.ctx.user.object_hook = NULL + if object_hook is not None: + if not PyCallable_Check(object_hook): + raise TypeError("object_hook must be a callable.") + self.ctx.user.object_hook = object_hook def feed(self, bytes next_bytes): self.waiting_bytes.append(next_bytes) diff --git a/python/msgpack/unpack.h b/python/msgpack/unpack.h index 9eb8ce77..e4c03bdd 100644 --- a/python/msgpack/unpack.h +++ b/python/msgpack/unpack.h @@ -21,6 +21,7 @@ typedef struct unpack_user { int use_list; + PyObject *object_hook; } unpack_user; @@ -172,6 +173,19 @@ static inline int template_callback_map_item(unpack_user* u, msgpack_unpack_obje return -1; } +//static inline int template_callback_map_end(unpack_user* u, msgpack_unpack_object* c) +int template_callback_map_end(unpack_user* u, msgpack_unpack_object* c) +{ + if (u->object_hook) { + PyObject *arglist = Py_BuildValue("(O)", *c); + Py_INCREF(*c); + *c = PyEval_CallObject(u->object_hook, arglist); + Py_DECREF(arglist); + return 0; + } + return -1; +} + static inline int template_callback_raw(unpack_user* u, const char* b, const char* p, unsigned int l, msgpack_unpack_object* o) { PyObject *py; diff --git a/python/msgpack/unpack_template.h b/python/msgpack/unpack_template.h index ca6e1f32..1fdedd70 100644 --- a/python/msgpack/unpack_template.h +++ b/python/msgpack/unpack_template.h @@ -317,6 +317,7 @@ _push: case CT_MAP_VALUE: if(msgpack_unpack_callback(_map_item)(user, &c->obj, c->map_key, obj) < 0) { goto _failed; } if(--c->count == 0) { + msgpack_unpack_callback(_map_end)(user, &c->obj); obj = c->obj; --top; /*printf("stack pop %d\n", top);*/ diff --git a/python/test/test_obj.py b/python/test/test_obj.py new file mode 100644 index 00000000..64a63903 --- /dev/null +++ b/python/test/test_obj.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# coding: utf-8 + +from nose import main +from nose.tools import * + +from msgpack import packs, unpacks + +def _decode_complex(obj): + if '__complex__' in obj: + return complex(obj['real'], obj['imag']) + return obj + +def _encode_complex(obj): + if isinstance(obj, complex): + return {'__complex__': True, 'real': 1, 'imag': 2} + return obj + +def test_encode_hook(): + packed = packs([3, 1+2j], default=_encode_complex) + unpacked = unpacks(packed) + eq_(unpacked[1], {'__complex__': True, 'real': 1, 'imag': 2}) + +def test_decode_hook(): + packed = packs([3, {'__complex__': True, 'real': 1, 'imag': 2}]) + unpacked = unpacks(packed, object_hook=_decode_complex) + eq_(unpacked[1], 1+2j) + +if __name__ == '__main__': + #main() + test_decode_hook()