diff --git a/src/ws_decoder.cpp b/src/ws_decoder.cpp index a87bdd48..cd44250d 100644 --- a/src/ws_decoder.cpp +++ b/src/ws_decoder.cpp @@ -49,6 +49,7 @@ zmq::ws_decoder_t::ws_decoder_t (size_t bufsize_, _must_mask (must_mask_), _size (0) { + memset (_tmpbuf, 0, sizeof (_tmpbuf)); int rc = _in_progress.init (); errno_assert (rc == 0); diff --git a/src/ws_engine.cpp b/src/ws_engine.cpp index a20e12c2..20da9fe4 100644 --- a/src/ws_engine.cpp +++ b/src/ws_engine.cpp @@ -60,6 +60,14 @@ along with this program. If not, see . #include "random.hpp" #include "ws_decoder.hpp" #include "ws_encoder.hpp" +#include "null_mechanism.hpp" +#include "plain_server.hpp" +#include "plain_client.hpp" + +#ifdef ZMQ_HAVE_CURVE +#include "curve_client.hpp" +#include "curve_server.hpp" +#endif #ifdef ZMQ_HAVE_WINDOWS #define strcasecmp _stricmp @@ -99,10 +107,8 @@ zmq::ws_engine_t::ws_engine_t (fd_t fd_, memset (_websocket_accept, 0, MAX_HEADER_VALUE_LENGTH + 1); memset (_websocket_protocol, 0, MAX_HEADER_VALUE_LENGTH + 1); - _next_msg = static_cast ( - &ws_engine_t::routing_id_msg); - _process_msg = static_cast ( - &ws_engine_t::process_routing_id_msg); + _next_msg = &ws_engine_t::next_handshake_command; + _process_msg = &ws_engine_t::process_handshake_command; } zmq::ws_engine_t::~ws_engine_t () @@ -112,6 +118,18 @@ zmq::ws_engine_t::~ws_engine_t () void zmq::ws_engine_t::start_ws_handshake () { if (_client) { + char protocol[21]; + if (_options.mechanism == ZMQ_NULL) + strcpy (protocol, "ZWS2.0/NULL,ZWS2.0"); + else if (_options.mechanism == ZMQ_PLAIN) + strcpy (protocol, "ZWS2.0/PLAIN"); +#ifdef ZMQ_HAVE_CURVE + else if (_options.mechanism == ZMQ_CURVE) + strcpy (protocol, "ZWS2.0/CURVE"); +#endif + else + assert (false); + unsigned char nonce[16]; int *p = (int *) nonce; @@ -131,9 +149,10 @@ void zmq::ws_engine_t::start_ws_handshake () "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Key: %s\r\n" - "Sec-WebSocket-Protocol: ZWS2.0\r\n" + "Sec-WebSocket-Protocol: %s\r\n" "Sec-WebSocket-Version: 13\r\n\r\n", - _address.path (), _address.host (), _websocket_key); + _address.path (), _address.host (), _websocket_key, + protocol); assert (size > 0 && size < WS_BUFFER_SIZE); _outpos = _write_buffer; _outsize = size; @@ -177,6 +196,48 @@ int zmq::ws_engine_t::process_routing_id_msg (msg_t *msg_) return 0; } +bool zmq::ws_engine_t::select_protocol (char *protocol) +{ + if (_options.mechanism == ZMQ_NULL && (strcmp ("ZWS2.0", protocol) == 0)) { + _next_msg = static_cast ( + &ws_engine_t::routing_id_msg); + _process_msg = static_cast ( + &ws_engine_t::process_routing_id_msg); + return true; + } else if (_options.mechanism == ZMQ_NULL + && strcmp ("ZWS2.0/NULL", protocol) == 0) { + _mechanism = new (std::nothrow) + null_mechanism_t (session (), _peer_address, _options); + alloc_assert (_mechanism); + return true; + } else if (_options.mechanism == ZMQ_PLAIN + && strcmp ("ZWS2.0/PLAIN", protocol) == 0) { + if (_options.as_server) + _mechanism = new (std::nothrow) + plain_server_t (session (), _peer_address, _options); + else + _mechanism = + new (std::nothrow) plain_client_t (session (), _options); + alloc_assert (_mechanism); + return true; + } +#ifdef ZMQ_HAVE_CURVE + else if (_options.mechanism == ZMQ_CURVE + && strcmp ("ZWS2.0/CURVE", protocol) == 0) { + if (_options.as_server) + _mechanism = new (std::nothrow) + curve_server_t (session (), _peer_address, _options); + else + _mechanism = + new (std::nothrow) curve_client_t (session (), _options); + alloc_assert (_mechanism); + return true; + } +#endif + + return false; +} + bool zmq::ws_engine_t::handshake () { bool complete; @@ -390,7 +451,7 @@ bool zmq::ws_engine_t::server_handshake () if (*p == ' ') p++; - if (strcmp ("ZWS2.0", p) == 0) { + if (select_protocol (p)) { strcpy (_websocket_protocol, p); break; } @@ -760,7 +821,7 @@ bool zmq::ws_engine_t::client_handshake () strcpy (_websocket_accept, _header_value); else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name) == 0) { - if (strcmp ("ZWS2.0", _header_value) == 0) + if (select_protocol (_header_value)) strcpy (_websocket_protocol, _header_value); } _client_handshake_state = client_header_field_cr; diff --git a/src/ws_engine.hpp b/src/ws_engine.hpp index e3a15048..12d5e612 100644 --- a/src/ws_engine.hpp +++ b/src/ws_engine.hpp @@ -143,6 +143,8 @@ class ws_engine_t : public stream_engine_base_t int routing_id_msg (msg_t *msg_); int process_routing_id_msg (msg_t *msg_); + bool select_protocol (char *protocol); + bool client_handshake (); bool server_handshake (); diff --git a/src/zmq.cpp b/src/zmq.cpp index 0931e61f..787e80bb 100644 --- a/src/zmq.cpp +++ b/src/zmq.cpp @@ -1480,6 +1480,14 @@ int zmq_has (const char *capability_) #if defined(ZMQ_BUILD_DRAFT_API) if (strcmp (capability_, "draft") == 0) return true; +#endif +#if defined(ZMQ_HAVE_WS) + if (strcmp (capability_, "WS") == 0) + return true; +#endif +#if defined(ZMQ_HAVE_WSS) + if (strcmp (capability_, "WSS") == 0) + return true; #endif // Whatever the application asked for, we don't have return false; diff --git a/tests/test_ws_transport.cpp b/tests/test_ws_transport.cpp index cf241afc..0a9272af 100644 --- a/tests/test_ws_transport.cpp +++ b/tests/test_ws_transport.cpp @@ -106,6 +106,43 @@ void test_large_message () test_context_socket_close (sb); } +void test_curve () +{ + char client_public[41]; + char client_secret[41]; + char server_public[41]; + char server_secret[41]; + + TEST_ASSERT_SUCCESS_ERRNO ( + zmq_curve_keypair (server_public, server_secret)); + TEST_ASSERT_SUCCESS_ERRNO ( + zmq_curve_keypair (client_public, client_secret)); + + void *server = test_context_socket (ZMQ_REP); + int as_server = 1; + TEST_ASSERT_SUCCESS_ERRNO ( + zmq_setsockopt (server, ZMQ_CURVE_SERVER, &as_server, sizeof (int))); + TEST_ASSERT_SUCCESS_ERRNO ( + zmq_setsockopt (server, ZMQ_CURVE_SECRETKEY, server_secret, 41)); + TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (server, "ws://*:5556/roundtrip")); + + + void *client = test_context_socket (ZMQ_REQ); + TEST_ASSERT_SUCCESS_ERRNO ( + zmq_setsockopt (client, ZMQ_CURVE_SERVERKEY, server_public, 41)); + TEST_ASSERT_SUCCESS_ERRNO ( + zmq_setsockopt (client, ZMQ_CURVE_PUBLICKEY, client_public, 41)); + TEST_ASSERT_SUCCESS_ERRNO ( + zmq_setsockopt (client, ZMQ_CURVE_SECRETKEY, client_secret, 41)); + TEST_ASSERT_SUCCESS_ERRNO ( + zmq_connect (client, "ws://127.0.0.1:5556/roundtrip")); + + bounce (server, client); + + test_context_socket_close (client); + test_context_socket_close (server); +} + int main () { setup_test_environment (); @@ -114,5 +151,9 @@ int main () RUN_TEST (test_roundtrip); RUN_TEST (test_short_message); RUN_TEST (test_large_message); + + if (zmq_has ("curve")) + RUN_TEST (test_curve); + return UNITY_END (); }