From e46ec31209bc660c3463d70b84d8c38591ca1a43 Mon Sep 17 00:00:00 2001 From: Martin Hurton Date: Sun, 12 Jan 2014 21:58:36 +0100 Subject: [PATCH] Implement socket_base_t::get_credential member function The get_credential () member function returns credential for the last peer we received message for. The idea is that this function is used to implement user-level API. --- src/curve_server.cpp | 3 +++ src/dealer.cpp | 6 ++++++ src/dealer.hpp | 1 + src/fq.cpp | 16 +++++++++++++++- src/fq.hpp | 10 ++++++++++ src/mechanism.cpp | 10 ++++++++++ src/mechanism.hpp | 6 ++++++ src/msg.cpp | 5 +++++ src/msg.hpp | 2 ++ src/null_mechanism.cpp | 3 +++ src/pair.cpp | 16 ++++++++++++++-- src/pair.hpp | 6 ++++++ src/pipe.cpp | 15 +++++++++++++++ src/pipe.hpp | 5 +++++ src/plain_mechanism.cpp | 3 +++ src/pull.cpp | 5 +++++ src/pull.hpp | 1 + src/router.cpp | 5 +++++ src/router.hpp | 1 + src/socket_base.cpp | 5 +++++ src/socket_base.hpp | 6 ++++++ src/stream_engine.cpp | 25 ++++++++++++++++++++++++- src/stream_engine.hpp | 1 + src/xsub.cpp | 5 +++++ src/xsub.hpp | 1 + 25 files changed, 158 insertions(+), 4 deletions(-) diff --git a/src/curve_server.cpp b/src/curve_server.cpp index cf11a90f..4b151157 100644 --- a/src/curve_server.cpp +++ b/src/curve_server.cpp @@ -613,6 +613,9 @@ int zmq::curve_server_t::receive_and_process_zap_reply () goto error; } + // Save user id + set_user_id (msg [5].data (), msg [5].size ()); + // Process metadata frame rc = parse_metadata (static_cast (msg [6].data ()), msg [6].size ()); diff --git a/src/dealer.cpp b/src/dealer.cpp index 8c9d142f..6739c20a 100644 --- a/src/dealer.cpp +++ b/src/dealer.cpp @@ -98,6 +98,12 @@ bool zmq::dealer_t::xhas_out () return lb.has_out (); } +zmq::blob_t zmq::dealer_t::get_credential () const +{ + return fq.get_credential (); +} + + void zmq::dealer_t::xread_activated (pipe_t *pipe_) { fq.activated (pipe_); diff --git a/src/dealer.hpp b/src/dealer.hpp index f25b39dd..abcfef63 100644 --- a/src/dealer.hpp +++ b/src/dealer.hpp @@ -51,6 +51,7 @@ namespace zmq int xrecv (zmq::msg_t *msg_); bool xhas_in (); bool xhas_out (); + blob_t get_credential () const; void xread_activated (zmq::pipe_t *pipe_); void xwrite_activated (zmq::pipe_t *pipe_); void xpipe_terminated (zmq::pipe_t *pipe_); diff --git a/src/fq.cpp b/src/fq.cpp index f43d7881..932503e5 100644 --- a/src/fq.cpp +++ b/src/fq.cpp @@ -24,6 +24,7 @@ zmq::fq_t::fq_t () : active (0), + last_in (NULL), current (0), more (false) { @@ -54,6 +55,11 @@ void zmq::fq_t::pipe_terminated (pipe_t *pipe_) current = 0; } pipes.erase (pipe_); + + if (last_in == pipe_) { + saved_credential = last_in->get_credential (); + last_in = NULL; + } } void zmq::fq_t::activated (pipe_t *pipe_) @@ -88,8 +94,10 @@ int zmq::fq_t::recvpipe (msg_t *msg_, pipe_t **pipe_) if (pipe_) *pipe_ = pipes [current]; more = msg_->flags () & msg_t::more? true: false; - if (!more) + if (!more) { + last_in = pipes [current]; current = (current + 1) % active; + } return 0; } @@ -136,3 +144,9 @@ bool zmq::fq_t::has_in () return false; } +zmq::blob_t zmq::fq_t::get_credential () const +{ + return last_in? + last_in->get_credential (): saved_credential; +} + diff --git a/src/fq.hpp b/src/fq.hpp index 04fc28b1..49f69059 100644 --- a/src/fq.hpp +++ b/src/fq.hpp @@ -21,6 +21,7 @@ #define __ZMQ_FQ_HPP_INCLUDED__ #include "array.hpp" +#include "blob.hpp" #include "pipe.hpp" #include "msg.hpp" @@ -45,6 +46,7 @@ namespace zmq int recv (msg_t *msg_); int recvpipe (msg_t *msg_, pipe_t **pipe_); bool has_in (); + blob_t get_credential () const; private: @@ -56,6 +58,11 @@ namespace zmq // beginning of the pipes array. pipes_t::size_type active; + // Pointer to the last pipe we received message from. + // NULL when no message has been received or the pipe + // has terminated. + pipe_t *last_in; + // Index of the next bound pipe to read a message from. pipes_t::size_type current; @@ -63,6 +70,9 @@ namespace zmq // there are following parts still waiting in the current pipe. bool more; + // Holds credential after the last_acive_pipe has terminated. + blob_t saved_credential; + fq_t (const fq_t&); const fq_t &operator = (const fq_t&); }; diff --git a/src/mechanism.cpp b/src/mechanism.cpp index 36b20956..d5e84a18 100644 --- a/src/mechanism.cpp +++ b/src/mechanism.cpp @@ -47,6 +47,16 @@ void zmq::mechanism_t::peer_identity (msg_t *msg_) msg_->set_flags (msg_t::identity); } +void zmq::mechanism_t::set_user_id (const void *data_, size_t size_) +{ + user_id = blob_t (static_cast (data_), size_); +} + +zmq::blob_t zmq::mechanism_t::get_user_id () const +{ + return user_id; +} + const char *zmq::mechanism_t::socket_type_string (int socket_type) const { static const char *names [] = {"PAIR", "PUB", "SUB", "REQ", "REP", diff --git a/src/mechanism.hpp b/src/mechanism.hpp index c0638190..f2a45d3a 100644 --- a/src/mechanism.hpp +++ b/src/mechanism.hpp @@ -60,6 +60,10 @@ namespace zmq void peer_identity (msg_t *msg_); + void set_user_id (const void *user_id, size_t size); + + blob_t get_user_id () const; + protected: // Only used to identify the socket for the Socket-Type @@ -91,6 +95,8 @@ namespace zmq blob_t identity; + blob_t user_id; + // Returns true iff socket associated with the mechanism // is compatible with a given socket type 'type_'. bool check_socket_type (const std::string type_) const; diff --git a/src/msg.cpp b/src/msg.cpp index fa38d6d5..0b9aade1 100644 --- a/src/msg.cpp +++ b/src/msg.cpp @@ -267,6 +267,11 @@ bool zmq::msg_t::is_identity () const return (u.base.flags & identity) == identity; } +bool zmq::msg_t::is_credential () const +{ + return (u.base.flags & credential) == credential; +} + bool zmq::msg_t::is_delimiter () const { return u.base.type == type_delimiter; diff --git a/src/msg.hpp b/src/msg.hpp index 7a47fff9..6fac144e 100644 --- a/src/msg.hpp +++ b/src/msg.hpp @@ -49,6 +49,7 @@ namespace zmq { more = 1, // Followed by more parts command = 2, // Command frame (see ZMTP spec) + credential = 32, identity = 64, shared = 128 }; @@ -70,6 +71,7 @@ namespace zmq int64_t fd (); void set_fd (int64_t fd_); bool is_identity () const; + bool is_credential () const; bool is_delimiter () const; bool is_vsm (); bool is_cmsg (); diff --git a/src/null_mechanism.cpp b/src/null_mechanism.cpp index 07ec7b9a..e6a742c0 100644 --- a/src/null_mechanism.cpp +++ b/src/null_mechanism.cpp @@ -268,6 +268,9 @@ int zmq::null_mechanism_t::receive_and_process_zap_reply () goto error; } + // Save user id + set_user_id (msg [5].data (), msg [5].size ()); + // Process metadata frame rc = parse_metadata (static_cast (msg [6].data ()), msg [6].size ()); diff --git a/src/pair.cpp b/src/pair.cpp index 1733d023..1ec431b9 100644 --- a/src/pair.cpp +++ b/src/pair.cpp @@ -24,7 +24,8 @@ zmq::pair_t::pair_t (class ctx_t *parent_, uint32_t tid_, int sid_) : socket_base_t (parent_, tid_, sid_), - pipe (NULL) + pipe (NULL), + last_in (NULL) { options.type = ZMQ_PAIR; } @@ -51,8 +52,13 @@ void zmq::pair_t::xattach_pipe (pipe_t *pipe_, bool subscribe_to_all_) void zmq::pair_t::xpipe_terminated (pipe_t *pipe_) { - if (pipe_ == pipe) + if (pipe_ == pipe) { + if (last_in == pipe) { + saved_credential = last_in->get_credential (); + last_in = NULL; + } pipe = NULL; + } } void zmq::pair_t::xread_activated (pipe_t *) @@ -99,6 +105,7 @@ int zmq::pair_t::xrecv (msg_t *msg_) errno = EAGAIN; return -1; } + last_in = pipe; return 0; } @@ -117,3 +124,8 @@ bool zmq::pair_t::xhas_out () return pipe->check_write (); } + +zmq::blob_t zmq::pair_t::get_credential () const +{ + return last_in? last_in->get_credential (): saved_credential; +} diff --git a/src/pair.hpp b/src/pair.hpp index f7a1ddd4..18d1bc67 100644 --- a/src/pair.hpp +++ b/src/pair.hpp @@ -20,6 +20,7 @@ #ifndef __ZMQ_PAIR_HPP_INCLUDED__ #define __ZMQ_PAIR_HPP_INCLUDED__ +#include "blob.hpp" #include "socket_base.hpp" #include "session_base.hpp" @@ -45,6 +46,7 @@ namespace zmq int xrecv (zmq::msg_t *msg_); bool xhas_in (); bool xhas_out (); + blob_t get_credential () const; void xread_activated (zmq::pipe_t *pipe_); void xwrite_activated (zmq::pipe_t *pipe_); void xpipe_terminated (zmq::pipe_t *pipe_); @@ -53,6 +55,10 @@ namespace zmq zmq::pipe_t *pipe; + zmq::pipe_t *last_in; + + blob_t saved_credential; + pair_t (const pair_t&); const pair_t &operator = (const pair_t&); }; diff --git a/src/pipe.cpp b/src/pipe.cpp index 4c485a0b..e0d27872 100644 --- a/src/pipe.cpp +++ b/src/pipe.cpp @@ -110,6 +110,11 @@ zmq::blob_t zmq::pipe_t::get_identity () return identity; } +zmq::blob_t zmq::pipe_t::get_credential () const +{ + return credential; +} + bool zmq::pipe_t::check_read () { if (unlikely (!in_active)) @@ -143,11 +148,21 @@ bool zmq::pipe_t::read (msg_t *msg_) if (unlikely (state != active && state != waiting_for_delimiter)) return false; +read_message: if (!inpipe->read (msg_)) { in_active = false; return false; } + // If this is a credential, save a copy and receive next message. + if (unlikely (msg_->is_credential ())) { + const unsigned char *data = static_cast (msg_->data ()); + credential = blob_t (data, msg_->size ()); + const int rc = msg_->close (); + zmq_assert (rc == 0); + goto read_message; + } + // If delimiter was read, start termination process of the pipe. if (msg_->is_delimiter ()) { process_delimiter (); diff --git a/src/pipe.hpp b/src/pipe.hpp index c9a2cade..d5a69a37 100644 --- a/src/pipe.hpp +++ b/src/pipe.hpp @@ -78,6 +78,8 @@ namespace zmq void set_identity (const blob_t &identity_); blob_t get_identity (); + blob_t get_credential () const; + // Returns true if there is at least one message to read in the pipe. bool check_read (); @@ -198,6 +200,9 @@ namespace zmq // Identity of the writer. Used uniquely by the reader side. blob_t identity; + // Pipe's credential. + blob_t credential; + // Returns true if the message is delimiter; false otherwise. static bool is_delimiter (const msg_t &msg_); diff --git a/src/plain_mechanism.cpp b/src/plain_mechanism.cpp index 92b1db87..2859377d 100644 --- a/src/plain_mechanism.cpp +++ b/src/plain_mechanism.cpp @@ -468,6 +468,9 @@ int zmq::plain_mechanism_t::receive_and_process_zap_reply () goto error; } + // Save user id + set_user_id (msg [5].data (), msg [5].size ()); + // Process metadata frame rc = parse_metadata (static_cast (msg [6].data ()), msg [6].size ()); diff --git a/src/pull.cpp b/src/pull.cpp index 07552e10..85854539 100644 --- a/src/pull.cpp +++ b/src/pull.cpp @@ -60,3 +60,8 @@ bool zmq::pull_t::xhas_in () { return fq.has_in (); } + +zmq::blob_t zmq::pull_t::get_credential () const +{ + return fq.get_credential (); +} diff --git a/src/pull.hpp b/src/pull.hpp index adb385c1..f1580b4b 100644 --- a/src/pull.hpp +++ b/src/pull.hpp @@ -46,6 +46,7 @@ namespace zmq void xattach_pipe (zmq::pipe_t *pipe_, bool subscribe_to_all_); int xrecv (zmq::msg_t *msg_); bool xhas_in (); + blob_t get_credential () const; void xread_activated (zmq::pipe_t *pipe_); void xpipe_terminated (zmq::pipe_t *pipe_); diff --git a/src/router.cpp b/src/router.cpp index 1bdfa157..1fa80e28 100644 --- a/src/router.cpp +++ b/src/router.cpp @@ -371,6 +371,11 @@ bool zmq::router_t::xhas_out () return true; } +zmq::blob_t zmq::router_t::get_credential () const +{ + return fq.get_credential (); +} + bool zmq::router_t::identify_peer (pipe_t *pipe_) { msg_t msg; diff --git a/src/router.hpp b/src/router.hpp index 3ef12b81..a2bdb2ec 100644 --- a/src/router.hpp +++ b/src/router.hpp @@ -59,6 +59,7 @@ namespace zmq // Rollback any message parts that were sent but not yet flushed. int rollback (); + blob_t get_credential () const; private: diff --git a/src/socket_base.cpp b/src/socket_base.cpp index 43937115..c47b71f8 100644 --- a/src/socket_base.cpp +++ b/src/socket_base.cpp @@ -1043,6 +1043,11 @@ int zmq::socket_base_t::xrecv (msg_t *) return -1; } +zmq::blob_t zmq::socket_base_t::get_credential () const +{ + return blob_t (); +} + void zmq::socket_base_t::xread_activated (pipe_t *) { zmq_assert (false); diff --git a/src/socket_base.hpp b/src/socket_base.hpp index d4cd358d..4554b406 100644 --- a/src/socket_base.hpp +++ b/src/socket_base.hpp @@ -26,6 +26,7 @@ #include "own.hpp" #include "array.hpp" +#include "blob.hpp" #include "stdint.hpp" #include "poller.hpp" #include "atomic_counter.hpp" @@ -144,6 +145,11 @@ namespace zmq virtual bool xhas_in (); virtual int xrecv (zmq::msg_t *msg_); + // Returns the credential for the peer from which we have received + // the last message. If no message has been received yet, + // the function returns empty credential. + virtual blob_t get_credential () const; + // i_pipe_events will be forwarded to these functions. virtual void xread_activated (pipe_t *pipe_); virtual void xwrite_activated (pipe_t *pipe_); diff --git a/src/stream_engine.cpp b/src/stream_engine.cpp index 20d8f395..c4880564 100644 --- a/src/stream_engine.cpp +++ b/src/stream_engine.cpp @@ -692,7 +692,7 @@ void zmq::stream_engine_t::mechanism_ready () } read_msg = &stream_engine_t::pull_and_encode; - write_msg = &stream_engine_t::decode_and_push; + write_msg = &stream_engine_t::write_credential; } int zmq::stream_engine_t::pull_msg_from_session (msg_t *msg_) @@ -705,6 +705,29 @@ int zmq::stream_engine_t::push_msg_to_session (msg_t *msg_) return session->push_msg (msg_); } +int zmq::stream_engine_t::write_credential (msg_t *msg_) +{ + zmq_assert (mechanism != NULL); + zmq_assert (session != NULL); + + const blob_t credential = mechanism->get_user_id (); + if (credential.size () > 0) { + msg_t msg; + int rc = msg.init_size (credential.size ()); + zmq_assert (rc == 0); + memcpy (msg.data (), credential.data (), credential.size ()); + msg.set_flags (msg_t::credential); + rc = session->push_msg (&msg); + if (rc == -1) { + rc = msg.close (); + errno_assert (rc == 0); + return -1; + } + } + write_msg = &stream_engine_t::decode_and_push; + return decode_and_push (msg_); +} + int zmq::stream_engine_t::pull_and_encode (msg_t *msg_) { zmq_assert (mechanism != NULL); diff --git a/src/stream_engine.hpp b/src/stream_engine.hpp index 9a472659..4e9e7915 100644 --- a/src/stream_engine.hpp +++ b/src/stream_engine.hpp @@ -101,6 +101,7 @@ namespace zmq int pull_msg_from_session (msg_t *msg_); int push_msg_to_session (msg_t *msg); + int write_credential (msg_t *msg_); int pull_and_encode (msg_t *msg_); int decode_and_push (msg_t *msg_); int push_one_then_decode_and_push (msg_t *msg_); diff --git a/src/xsub.cpp b/src/xsub.cpp index e72088e8..aaf51310 100644 --- a/src/xsub.cpp +++ b/src/xsub.cpp @@ -199,6 +199,11 @@ bool zmq::xsub_t::xhas_in () } } +zmq::blob_t zmq::xsub_t::get_credential () const +{ + return fq.get_credential (); +} + bool zmq::xsub_t::match (msg_t *msg_) { return subscriptions.check ((unsigned char*) msg_->data (), msg_->size ()); diff --git a/src/xsub.hpp b/src/xsub.hpp index f5a179b6..7b3a3884 100644 --- a/src/xsub.hpp +++ b/src/xsub.hpp @@ -49,6 +49,7 @@ namespace zmq bool xhas_out (); int xrecv (zmq::msg_t *msg_); bool xhas_in (); + blob_t get_credential () const; void xread_activated (zmq::pipe_t *pipe_); void xwrite_activated (zmq::pipe_t *pipe_); void xhiccuped (pipe_t *pipe_);