Merge pull request #3702 from somdoron/ws_mechanism

problem: WS transport doesn't support mechanism
This commit is contained in:
Luca Boccassi 2019-10-04 15:38:49 +01:00 committed by GitHub
commit 8d9acb72c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 121 additions and 8 deletions

View File

@ -49,6 +49,7 @@ zmq::ws_decoder_t::ws_decoder_t (size_t bufsize_,
_must_mask (must_mask_),
_size (0)
{
memset (_tmpbuf, 0, sizeof (_tmpbuf));
int rc = _in_progress.init ();
errno_assert (rc == 0);

View File

@ -60,6 +60,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
#include "random.hpp"
#include "ws_decoder.hpp"
#include "ws_encoder.hpp"
#include "null_mechanism.hpp"
#include "plain_server.hpp"
#include "plain_client.hpp"
#ifdef ZMQ_HAVE_CURVE
#include "curve_client.hpp"
#include "curve_server.hpp"
#endif
#ifdef ZMQ_HAVE_WINDOWS
#define strcasecmp _stricmp
@ -99,10 +107,8 @@ zmq::ws_engine_t::ws_engine_t (fd_t fd_,
memset (_websocket_accept, 0, MAX_HEADER_VALUE_LENGTH + 1);
memset (_websocket_protocol, 0, MAX_HEADER_VALUE_LENGTH + 1);
_next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
&ws_engine_t::routing_id_msg);
_process_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
&ws_engine_t::process_routing_id_msg);
_next_msg = &ws_engine_t::next_handshake_command;
_process_msg = &ws_engine_t::process_handshake_command;
}
zmq::ws_engine_t::~ws_engine_t ()
@ -112,6 +118,18 @@ zmq::ws_engine_t::~ws_engine_t ()
void zmq::ws_engine_t::start_ws_handshake ()
{
if (_client) {
char protocol[21];
if (_options.mechanism == ZMQ_NULL)
strcpy (protocol, "ZWS2.0/NULL,ZWS2.0");
else if (_options.mechanism == ZMQ_PLAIN)
strcpy (protocol, "ZWS2.0/PLAIN");
#ifdef ZMQ_HAVE_CURVE
else if (_options.mechanism == ZMQ_CURVE)
strcpy (protocol, "ZWS2.0/CURVE");
#endif
else
assert (false);
unsigned char nonce[16];
int *p = (int *) nonce;
@ -131,9 +149,10 @@ void zmq::ws_engine_t::start_ws_handshake ()
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Key: %s\r\n"
"Sec-WebSocket-Protocol: ZWS2.0\r\n"
"Sec-WebSocket-Protocol: %s\r\n"
"Sec-WebSocket-Version: 13\r\n\r\n",
_address.path (), _address.host (), _websocket_key);
_address.path (), _address.host (), _websocket_key,
protocol);
assert (size > 0 && size < WS_BUFFER_SIZE);
_outpos = _write_buffer;
_outsize = size;
@ -177,6 +196,48 @@ int zmq::ws_engine_t::process_routing_id_msg (msg_t *msg_)
return 0;
}
bool zmq::ws_engine_t::select_protocol (char *protocol)
{
if (_options.mechanism == ZMQ_NULL && (strcmp ("ZWS2.0", protocol) == 0)) {
_next_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
&ws_engine_t::routing_id_msg);
_process_msg = static_cast<int (stream_engine_base_t::*) (msg_t *)> (
&ws_engine_t::process_routing_id_msg);
return true;
} else if (_options.mechanism == ZMQ_NULL
&& strcmp ("ZWS2.0/NULL", protocol) == 0) {
_mechanism = new (std::nothrow)
null_mechanism_t (session (), _peer_address, _options);
alloc_assert (_mechanism);
return true;
} else if (_options.mechanism == ZMQ_PLAIN
&& strcmp ("ZWS2.0/PLAIN", protocol) == 0) {
if (_options.as_server)
_mechanism = new (std::nothrow)
plain_server_t (session (), _peer_address, _options);
else
_mechanism =
new (std::nothrow) plain_client_t (session (), _options);
alloc_assert (_mechanism);
return true;
}
#ifdef ZMQ_HAVE_CURVE
else if (_options.mechanism == ZMQ_CURVE
&& strcmp ("ZWS2.0/CURVE", protocol) == 0) {
if (_options.as_server)
_mechanism = new (std::nothrow)
curve_server_t (session (), _peer_address, _options);
else
_mechanism =
new (std::nothrow) curve_client_t (session (), _options);
alloc_assert (_mechanism);
return true;
}
#endif
return false;
}
bool zmq::ws_engine_t::handshake ()
{
bool complete;
@ -390,7 +451,7 @@ bool zmq::ws_engine_t::server_handshake ()
if (*p == ' ')
p++;
if (strcmp ("ZWS2.0", p) == 0) {
if (select_protocol (p)) {
strcpy (_websocket_protocol, p);
break;
}
@ -760,7 +821,7 @@ bool zmq::ws_engine_t::client_handshake ()
strcpy (_websocket_accept, _header_value);
else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name)
== 0) {
if (strcmp ("ZWS2.0", _header_value) == 0)
if (select_protocol (_header_value))
strcpy (_websocket_protocol, _header_value);
}
_client_handshake_state = client_header_field_cr;

View File

@ -143,6 +143,8 @@ class ws_engine_t : public stream_engine_base_t
int routing_id_msg (msg_t *msg_);
int process_routing_id_msg (msg_t *msg_);
bool select_protocol (char *protocol);
bool client_handshake ();
bool server_handshake ();

View File

@ -1480,6 +1480,14 @@ int zmq_has (const char *capability_)
#if defined(ZMQ_BUILD_DRAFT_API)
if (strcmp (capability_, "draft") == 0)
return true;
#endif
#if defined(ZMQ_HAVE_WS)
if (strcmp (capability_, "WS") == 0)
return true;
#endif
#if defined(ZMQ_HAVE_WSS)
if (strcmp (capability_, "WSS") == 0)
return true;
#endif
// Whatever the application asked for, we don't have
return false;

View File

@ -106,6 +106,43 @@ void test_large_message ()
test_context_socket_close (sb);
}
void test_curve ()
{
char client_public[41];
char client_secret[41];
char server_public[41];
char server_secret[41];
TEST_ASSERT_SUCCESS_ERRNO (
zmq_curve_keypair (server_public, server_secret));
TEST_ASSERT_SUCCESS_ERRNO (
zmq_curve_keypair (client_public, client_secret));
void *server = test_context_socket (ZMQ_REP);
int as_server = 1;
TEST_ASSERT_SUCCESS_ERRNO (
zmq_setsockopt (server, ZMQ_CURVE_SERVER, &as_server, sizeof (int)));
TEST_ASSERT_SUCCESS_ERRNO (
zmq_setsockopt (server, ZMQ_CURVE_SECRETKEY, server_secret, 41));
TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (server, "ws://*:5556/roundtrip"));
void *client = test_context_socket (ZMQ_REQ);
TEST_ASSERT_SUCCESS_ERRNO (
zmq_setsockopt (client, ZMQ_CURVE_SERVERKEY, server_public, 41));
TEST_ASSERT_SUCCESS_ERRNO (
zmq_setsockopt (client, ZMQ_CURVE_PUBLICKEY, client_public, 41));
TEST_ASSERT_SUCCESS_ERRNO (
zmq_setsockopt (client, ZMQ_CURVE_SECRETKEY, client_secret, 41));
TEST_ASSERT_SUCCESS_ERRNO (
zmq_connect (client, "ws://127.0.0.1:5556/roundtrip"));
bounce (server, client);
test_context_socket_close (client);
test_context_socket_close (server);
}
int main ()
{
setup_test_environment ();
@ -114,5 +151,9 @@ int main ()
RUN_TEST (test_roundtrip);
RUN_TEST (test_short_message);
RUN_TEST (test_large_message);
if (zmq_has ("curve"))
RUN_TEST (test_curve);
return UNITY_END ();
}