problem: ws_engine duplicate code from stream_engine

Solution: New class called stream_engine_base which is inherited by ws_engine, zmtp_engine and raw_engine.
This commit is contained in:
somdoron
2019-07-17 17:57:44 +03:00
parent 39941a0c82
commit 157b2a2ee0
21 changed files with 1858 additions and 1779 deletions

View File

@@ -70,79 +70,31 @@ zmq::ws_engine_t::ws_engine_t (fd_t fd_,
const options_t &options_,
const endpoint_uri_pair_t &endpoint_uri_pair_,
bool client_) :
stream_engine_base_t (fd_, options_, endpoint_uri_pair_),
_client (client_),
_plugged (false),
_socket (NULL),
_fd (fd_),
_session (NULL),
_handle (static_cast<handle_t> (NULL)),
_options (options_),
_endpoint_uri_pair (endpoint_uri_pair_),
_handshaking (true),
_client_handshake_state (client_handshake_initial),
_server_handshake_state (handshake_initial),
_header_name_position (0),
_header_value_position (0),
_header_upgrade_websocket (false),
_header_connection_upgrade (false),
_websocket_protocol (false),
_input_stopped (false),
_decoder (NULL),
_inpos (NULL),
_insize (0),
_output_stopped (false),
_outpos (NULL),
_outsize (0),
_encoder (NULL),
_sent_routing_id (false),
_received_routing_id (false)
_websocket_protocol (false)
{
// Put the socket into non-blocking mode.
unblock_socket (_fd);
memset (_websocket_key, 0, MAX_HEADER_VALUE_LENGTH + 1);
memset (_websocket_accept, 0, MAX_HEADER_VALUE_LENGTH + 1);
int rc = _tx_msg.init ();
errno_assert (rc == 0);
_next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
&ws_engine_t::routing_id_msg);
_process_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
&ws_engine_t::process_routing_id_msg);
}
zmq::ws_engine_t::~ws_engine_t ()
{
zmq_assert (!_plugged);
if (_fd != retired_fd) {
#ifdef ZMQ_HAVE_WINDOWS
int rc = closesocket (_fd);
wsa_assert (rc != SOCKET_ERROR);
#else
int rc = close (_fd);
errno_assert (rc == 0);
#endif
_fd = retired_fd;
}
int rc = _tx_msg.close ();
errno_assert (rc == 0);
LIBZMQ_DELETE (_encoder);
LIBZMQ_DELETE (_decoder);
}
void zmq::ws_engine_t::plug (io_thread_t *io_thread_, session_base_t *session_)
void zmq::ws_engine_t::plug_internal ()
{
zmq_assert (!_plugged);
_plugged = true;
zmq_assert (!_session);
zmq_assert (session_);
_session = session_;
_socket = _session->get_socket ();
// Connect to I/O threads poller object.
io_object_t::plug (io_thread_);
_handle = add_fd (_fd);
if (_client) {
unsigned char nonce[16];
int *p = (int *) nonce;
@@ -170,215 +122,78 @@ void zmq::ws_engine_t::plug (io_thread_t *io_thread_, session_base_t *session_)
assert (size > 0 && size < WS_BUFFER_SIZE);
_outpos = _write_buffer;
_outsize = size;
_output_stopped = false;
set_pollout (_handle);
} else
_output_stopped = true;
set_pollout ();
}
_input_stopped = false;
set_pollin (_handle);
set_pollin ();
in_event ();
}
void zmq::ws_engine_t::unplug ()
int zmq::ws_engine_t::routing_id_msg (msg_t *msg_)
{
zmq_assert (_plugged);
_plugged = false;
rm_fd (_handle);
// Disconnect from I/O threads poller object.
io_object_t::unplug ();
int rc = msg_->init_size (_options.routing_id_size);
errno_assert (rc == 0);
if (_options.routing_id_size > 0)
memcpy (msg_->data (), _options.routing_id, _options.routing_id_size);
_next_msg = &ws_engine_t::pull_msg_from_session;
return 0;
}
void zmq::ws_engine_t::terminate ()
int zmq::ws_engine_t::process_routing_id_msg (msg_t *msg_)
{
unplug ();
delete this;
if (_options.recv_routing_id) {
msg_->set_flags (msg_t::routing_id);
int rc = session ()->push_msg (msg_);
errno_assert (rc == 0);
} else {
int rc = msg_->close ();
errno_assert (rc == 0);
rc = msg_->init ();
errno_assert (rc == 0);
}
_process_msg = &ws_engine_t::push_msg_to_session;
return 0;
}
void zmq::ws_engine_t::in_event ()
bool zmq::ws_engine_t::handshake ()
{
if (_handshaking) {
if (_client) {
if (!client_handshake ())
return;
} else if (!server_handshake ())
return;
bool complete;
if (_client)
complete = client_handshake ();
else
complete = server_handshake ();
if (complete) {
_encoder =
new (std::nothrow) ws_encoder_t (_options.out_batch_size, _client);
alloc_assert (_encoder);
_decoder = new (std::nothrow)
ws_decoder_t (_options.in_batch_size, _options.maxmsgsize,
_options.zero_copy, !_client);
alloc_assert (_decoder);
socket ()->event_handshake_succeeded (_endpoint_uri_pair, 0);
set_pollout ();
}
zmq_assert (_decoder);
// If there's no data to process in the buffer...
if (_insize == 0) {
// Retrieve the buffer and read as much data as possible.
// Note that buffer can be arbitrarily large. However, we assume
// the underlying TCP layer has fixed buffer size and thus the
// number of bytes read will be always limited.
size_t bufsize = 0;
_decoder->get_buffer (&_inpos, &bufsize);
const int rc = tcp_read (_fd, _inpos, bufsize);
if (rc == 0) {
// connection closed by peer
errno = EPIPE;
error (zmq::stream_engine_t::connection_error);
return;
}
if (rc == -1) {
if (errno != EAGAIN) {
error (zmq::stream_engine_t::connection_error);
return;
}
return;
}
// Adjust input size
_insize = static_cast<size_t> (rc);
// Adjust buffer size to received bytes
_decoder->resize_buffer (_insize);
}
int rc = 0;
size_t processed = 0;
while (_insize > 0) {
rc = _decoder->decode (_inpos, _insize, processed);
zmq_assert (processed <= _insize);
_inpos += processed;
_insize -= processed;
if (rc == 0 || rc == -1)
break;
if (!_received_routing_id) {
_received_routing_id = true;
if (_options.recv_routing_id)
_decoder->msg ()->set_flags (msg_t::routing_id);
else {
_decoder->msg ()->close ();
_decoder->msg ()->init ();
continue;
}
}
rc = _session->push_msg (_decoder->msg ());
if (rc == -1)
break;
}
// Tear down the connection if we have failed to decode input data
// or the session has rejected the message.
if (rc == -1) {
if (errno != EAGAIN) {
error (zmq::stream_engine_t::protocol_error);
return;
}
_input_stopped = true;
reset_pollin (_handle);
}
_session->flush ();
return;
}
void zmq::ws_engine_t::out_event ()
{
// If write buffer is empty, try to read new data from the encoder.
if (!_outsize) {
// Even when we stop polling as soon as there is no
// data to send, the poller may invoke out_event one
// more time due to 'speculative write' optimisation.
if (unlikely (_encoder == NULL)) {
zmq_assert (_handshaking);
return;
}
_outpos = NULL;
_outsize = _encoder->encode (&_outpos, 0);
while (_outsize < static_cast<size_t> (_options.out_batch_size)) {
if (!_sent_routing_id) {
_tx_msg.close ();
int rc = _tx_msg.init_size (_options.routing_id_size);
errno_assert (rc == 0);
if (_options.routing_id_size > 0)
memcpy (_tx_msg.data (), _options.routing_id,
_options.routing_id_size);
_sent_routing_id = true;
} else if (_session->pull_msg (&_tx_msg) == -1)
break;
_encoder->load_msg (&_tx_msg);
unsigned char *bufptr = _outpos + _outsize;
size_t n =
_encoder->encode (&bufptr, _options.out_batch_size - _outsize);
zmq_assert (n > 0);
if (_outpos == NULL)
_outpos = bufptr;
_outsize += n;
}
// If there is no data to send, stop polling for output.
if (_outsize == 0) {
_output_stopped = true;
reset_pollout (_handle);
return;
}
}
// If there are any data to write in write buffer, write as much as
// possible to the socket. Note that amount of data to write can be
// arbitrarily large. However, we assume that underlying TCP layer has
// limited transmission buffer and thus the actual number of bytes
// written should be reasonably modest.
const int nbytes = tcp_write (_fd, _outpos, _outsize);
// IO error has occurred. We stop waiting for output events.
// The engine is not terminated until we detect input error;
// this is necessary to prevent losing incoming messages.
if (nbytes == -1) {
_output_stopped = true;
reset_pollout (_handle);
return;
}
_outpos += nbytes;
_outsize -= nbytes;
// If we are still handshaking and there are no data
// to send, stop polling for output.
if (unlikely (_handshaking))
if (_outsize == 0) {
_output_stopped = true;
reset_pollout (_handle);
}
}
const zmq::endpoint_uri_pair_t &zmq::ws_engine_t::get_endpoint () const
{
return _endpoint_uri_pair;
}
void zmq::ws_engine_t::restart_output ()
{
if (likely (_output_stopped)) {
set_pollout (_handle);
_output_stopped = false;
}
return complete;
}
bool zmq::ws_engine_t::server_handshake ()
{
int nbytes = tcp_read (_fd, _read_buffer, WS_BUFFER_SIZE);
int nbytes = tcp_read (_read_buffer, WS_BUFFER_SIZE);
if (nbytes == 0) {
errno = EPIPE;
error (zmq::stream_engine_t::connection_error);
error (zmq::i_engine::connection_error);
return false;
} else if (nbytes == -1) {
if (errno != EAGAIN)
error (zmq::stream_engine_t::connection_error);
error (zmq::i_engine::connection_error);
return false;
}
@@ -573,20 +388,6 @@ bool zmq::ws_engine_t::server_handshake ()
if (_header_connection_upgrade && _header_upgrade_websocket
&& _websocket_protocol && _websocket_key[0] != '\0') {
_server_handshake_state = handshake_complete;
_handshaking = false;
// TODO: check which decoder/encoder to use according to selected protocol
_encoder = new (std::nothrow)
ws_encoder_t (_options.out_batch_size, false);
alloc_assert (_encoder);
_decoder = new (std::nothrow) ws_decoder_t (
_options.in_batch_size, _options.maxmsgsize,
_options.zero_copy, true);
alloc_assert (_decoder);
_socket->event_handshake_succeeded (_endpoint_uri_pair,
0);
const char *magic_string =
"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
@@ -623,16 +424,15 @@ bool zmq::ws_engine_t::server_handshake ()
_outpos = _write_buffer;
_outsize = written;
if (_output_stopped)
restart_output ();
_inpos++;
_insize--;
return true;
} else
_server_handshake_state = handshake_error;
} else
_server_handshake_state = handshake_error;
break;
case handshake_complete:
// no more bytes are allowed after complete
_server_handshake_state = handshake_error;
default:
assert (false);
}
@@ -643,26 +443,26 @@ bool zmq::ws_engine_t::server_handshake ()
if (_server_handshake_state == handshake_error) {
// TODO: send bad request
_socket->event_handshake_failed_protocol (
socket ()->event_handshake_failed_protocol (
_endpoint_uri_pair, ZMQ_PROTOCOL_ERROR_WS_UNSPECIFIED);
error (zmq::stream_engine_t::protocol_error);
error (zmq::i_engine::protocol_error);
return false;
}
}
return _server_handshake_state == handshake_complete;
return false;
}
bool zmq::ws_engine_t::client_handshake ()
{
int nbytes = tcp_read (_fd, _read_buffer, WS_BUFFER_SIZE);
int nbytes = tcp_read (_read_buffer, WS_BUFFER_SIZE);
if (nbytes == 0) {
errno = EPIPE;
error (zmq::stream_engine_t::connection_error);
error (zmq::i_engine::connection_error);
return false;
} else if (nbytes == -1) {
if (errno != EAGAIN)
error (zmq::stream_engine_t::connection_error);
error (zmq::i_engine::connection_error);
return false;
}
@@ -967,25 +767,9 @@ bool zmq::ws_engine_t::client_handshake ()
&& _websocket_protocol
&& _websocket_accept[0] != '\0') {
_client_handshake_state = client_handshake_complete;
_handshaking = false;
_encoder = new (std::nothrow)
ws_encoder_t (_options.out_batch_size, true);
alloc_assert (_encoder);
_decoder = new (std::nothrow) ws_decoder_t (
_options.in_batch_size, _options.maxmsgsize,
_options.zero_copy, false);
alloc_assert (_decoder);
_socket->event_handshake_succeeded (_endpoint_uri_pair,
0);
// TODO: validate accept key
if (_output_stopped)
restart_output ();
_inpos++;
_insize--;
@@ -1003,10 +787,10 @@ bool zmq::ws_engine_t::client_handshake ()
_insize--;
if (_client_handshake_state == client_handshake_error) {
_socket->event_handshake_failed_protocol (
socket ()->event_handshake_failed_protocol (
_endpoint_uri_pair, ZMQ_PROTOCOL_ERROR_WS_UNSPECIFIED);
error (zmq::stream_engine_t::protocol_error);
error (zmq::i_engine::protocol_error);
return false;
}
}
@@ -1014,33 +798,6 @@ bool zmq::ws_engine_t::client_handshake ()
return false;
}
void zmq::ws_engine_t::error (zmq::stream_engine_t::error_reason_t reason_)
{
zmq_assert (_session);
if (reason_ != zmq::stream_engine_t::protocol_error && _handshaking) {
int err = errno;
_socket->event_handshake_failed_no_detail (_endpoint_uri_pair, err);
}
_socket->event_disconnected (_endpoint_uri_pair, _fd);
_session->flush ();
_session->engine_error (reason_);
unplug ();
delete this;
}
bool zmq::ws_engine_t::restart_input ()
{
zmq_assert (_input_stopped);
_input_stopped = false;
set_pollin (_handle);
in_event ();
return true;
}
static int
encode_base64 (const unsigned char *in, int in_len, char *out, int out_len)
{