Thread Safe Sockets.

1. Reorganise C API socket functions to eliminate bad practice
of public functions calling other public functions. This should
be done for msg's too but hasn't been in this patch.

2. Reorganise code in C API socket functions so that the
socket is cast on one line, the C++ function called on
the next with the result retained, then the result is returned.

This makes the code much simpler to read and also allows
pre- and post- call hooks to be inserted easily.

3. Insert pre- and post- call hooks which set and release
a mutex iff the thread_safe flag is on.

4. Add the thread_safe_flag to base_socket_t initialised to
false to preserve existing semantics. Add an accessor for
the flag, add a mutex, and add lock and unlock functions.

Note: as yet no code to actually set the flag.
This commit is contained in:
skaller 2012-02-04 01:41:09 +11:00
parent 4dd6ce0639
commit 988efbc73a
3 changed files with 96 additions and 35 deletions

View File

@ -121,7 +121,8 @@ zmq::socket_base_t::socket_base_t (ctx_t *parent_, uint32_t tid_) :
destroyed (false),
last_tsc (0),
ticks (0),
rcvmore (false)
rcvmore (false),
thread_safe_flag (false)
{
}
@ -873,3 +874,13 @@ void zmq::socket_base_t::extract_flags (msg_t *msg_)
rcvmore = msg_->flags () & msg_t::more ? true : false;
}
void zmq::socket_base_t::lock()
{
sync.lock();
}
void zmq::socket_base_t::unlock()
{
sync.unlock();
}

View File

@ -95,7 +95,9 @@ namespace zmq
void write_activated (pipe_t *pipe_);
void hiccuped (pipe_t *pipe_);
void terminated (pipe_t *pipe_);
bool thread_safe() const { return thread_safe_flag; }
void lock();
void unlock();
protected:
socket_base_t (zmq::ctx_t *parent_, uint32_t tid_);
@ -195,6 +197,8 @@ namespace zmq
socket_base_t (const socket_base_t&);
const socket_base_t &operator = (const socket_base_t&);
bool thread_safe_flag;
mutex_t sync;
};
}

View File

@ -194,8 +194,11 @@ int zmq_setsockopt (void *s_, int option_, const void *optval_,
errno = ENOTSOCK;
return -1;
}
return (((zmq::socket_base_t*) s_)->setsockopt (option_, optval_,
optvallen_));
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = s->setsockopt (option_, optval_, optvallen_);
if(s->thread_safe()) s->unlock();
return result;
}
int zmq_getsockopt (void *s_, int option_, void *optval_, size_t *optvallen_)
@ -204,8 +207,11 @@ int zmq_getsockopt (void *s_, int option_, void *optval_, size_t *optvallen_)
errno = ENOTSOCK;
return -1;
}
return (((zmq::socket_base_t*) s_)->getsockopt (option_, optval_,
optvallen_));
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = s->getsockopt (option_, optval_, optvallen_);
if(s->thread_safe()) s->unlock();
return result;
}
int zmq_bind (void *s_, const char *addr_)
@ -214,7 +220,11 @@ int zmq_bind (void *s_, const char *addr_)
errno = ENOTSOCK;
return -1;
}
return (((zmq::socket_base_t*) s_)->bind (addr_));
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = s->bind (addr_);
if(s->thread_safe()) s->unlock();
return result;
}
int zmq_connect (void *s_, const char *addr_)
@ -223,7 +233,34 @@ int zmq_connect (void *s_, const char *addr_)
errno = ENOTSOCK;
return -1;
}
return (((zmq::socket_base_t*) s_)->connect (addr_));
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = s->connect (addr_);
if(s->thread_safe()) s->unlock();
return result;
}
// sending functions
static int inner_sendmsg (zmq::socket_base_t *s_, zmq_msg_t *msg_, int flags_)
{
int sz = (int) zmq_msg_size (msg_);
int rc = s_->send ((zmq::msg_t*) msg_, flags_);
if (unlikely (rc < 0))
return -1;
return sz;
}
int zmq_sendmsg (void *s_, zmq_msg_t *msg_, int flags_)
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = inner_sendmsg (s, msg_, flags_);
if(s->thread_safe()) s->unlock();
return result;
}
int zmq_send (void *s_, const void *buf_, size_t len_, int flags_)
@ -234,7 +271,10 @@ int zmq_send (void *s_, const void *buf_, size_t len_, int flags_)
return -1;
memcpy (zmq_msg_data (&msg), buf_, len_);
rc = zmq_sendmsg (s_, &msg, flags_);
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
rc = inner_sendmsg (s, &msg, flags_);
if(s->thread_safe()) s->unlock();
if (unlikely (rc < 0)) {
int err = errno;
int rc2 = zmq_msg_close (&msg);
@ -248,13 +288,43 @@ int zmq_send (void *s_, const void *buf_, size_t len_, int flags_)
return rc;
}
// receiving functions
static int inner_recvmsg (zmq::socket_base_t *s_, zmq_msg_t *msg_, int flags_)
{
int rc = s_->recv ((zmq::msg_t*) msg_, flags_);
if (unlikely (rc < 0))
return -1;
return (int) zmq_msg_size (msg_);
}
int zmq_recvmsg (void *s_, zmq_msg_t *msg_, int flags_)
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int result = inner_recvmsg(s, msg_, flags_);
if(s->thread_safe()) s->unlock();
return result;
}
int zmq_recv (void *s_, void *buf_, size_t len_, int flags_)
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
zmq_msg_t msg;
int rc = zmq_msg_init (&msg);
errno_assert (rc == 0);
int nbytes = zmq_recvmsg (s_, &msg, flags_);
zmq::socket_base_t *s = (zmq::socket_base_t *) s_;
if(s->thread_safe()) s->lock();
int nbytes = inner_recvmsg (s, &msg, flags_);
if(s->thread_safe()) s->unlock();
if (unlikely (nbytes < 0)) {
int err = errno;
rc = zmq_msg_close (&msg);
@ -274,31 +344,7 @@ int zmq_recv (void *s_, void *buf_, size_t len_, int flags_)
return nbytes;
}
int zmq_sendmsg (void *s_, zmq_msg_t *msg_, int flags_)
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
int sz = (int) zmq_msg_size (msg_);
int rc = (((zmq::socket_base_t*) s_)->send ((zmq::msg_t*) msg_, flags_));
if (unlikely (rc < 0))
return -1;
return sz;
}
int zmq_recvmsg (void *s_, zmq_msg_t *msg_, int flags_)
{
if (!s_ || !((zmq::socket_base_t*) s_)->check_tag ()) {
errno = ENOTSOCK;
return -1;
}
int rc = (((zmq::socket_base_t*) s_)->recv ((zmq::msg_t*) msg_, flags_));
if (unlikely (rc < 0))
return -1;
return (int) zmq_msg_size (msg_);
}
// message manipulators
int zmq_msg_init (zmq_msg_t *msg_)
{
return ((zmq::msg_t*) msg_)->init ();