diff --git a/src/ws_decoder.cpp b/src/ws_decoder.cpp
index a87bdd48..cd44250d 100644
--- a/src/ws_decoder.cpp
+++ b/src/ws_decoder.cpp
@@ -49,6 +49,7 @@ zmq::ws_decoder_t::ws_decoder_t (size_t bufsize_,
_must_mask (must_mask_),
_size (0)
{
+ memset (_tmpbuf, 0, sizeof (_tmpbuf));
int rc = _in_progress.init ();
errno_assert (rc == 0);
diff --git a/src/ws_engine.cpp b/src/ws_engine.cpp
index a20e12c2..20da9fe4 100644
--- a/src/ws_engine.cpp
+++ b/src/ws_engine.cpp
@@ -60,6 +60,14 @@ along with this program. If not, see .
#include "random.hpp"
#include "ws_decoder.hpp"
#include "ws_encoder.hpp"
+#include "null_mechanism.hpp"
+#include "plain_server.hpp"
+#include "plain_client.hpp"
+
+#ifdef ZMQ_HAVE_CURVE
+#include "curve_client.hpp"
+#include "curve_server.hpp"
+#endif
#ifdef ZMQ_HAVE_WINDOWS
#define strcasecmp _stricmp
@@ -99,10 +107,8 @@ zmq::ws_engine_t::ws_engine_t (fd_t fd_,
memset (_websocket_accept, 0, MAX_HEADER_VALUE_LENGTH + 1);
memset (_websocket_protocol, 0, MAX_HEADER_VALUE_LENGTH + 1);
- _next_msg = static_cast (
- &ws_engine_t::routing_id_msg);
- _process_msg = static_cast (
- &ws_engine_t::process_routing_id_msg);
+ _next_msg = &ws_engine_t::next_handshake_command;
+ _process_msg = &ws_engine_t::process_handshake_command;
}
zmq::ws_engine_t::~ws_engine_t ()
@@ -112,6 +118,18 @@ zmq::ws_engine_t::~ws_engine_t ()
void zmq::ws_engine_t::start_ws_handshake ()
{
if (_client) {
+ char protocol[21];
+ if (_options.mechanism == ZMQ_NULL)
+ strcpy (protocol, "ZWS2.0/NULL,ZWS2.0");
+ else if (_options.mechanism == ZMQ_PLAIN)
+ strcpy (protocol, "ZWS2.0/PLAIN");
+#ifdef ZMQ_HAVE_CURVE
+ else if (_options.mechanism == ZMQ_CURVE)
+ strcpy (protocol, "ZWS2.0/CURVE");
+#endif
+ else
+ assert (false);
+
unsigned char nonce[16];
int *p = (int *) nonce;
@@ -131,9 +149,10 @@ void zmq::ws_engine_t::start_ws_handshake ()
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
"Sec-WebSocket-Key: %s\r\n"
- "Sec-WebSocket-Protocol: ZWS2.0\r\n"
+ "Sec-WebSocket-Protocol: %s\r\n"
"Sec-WebSocket-Version: 13\r\n\r\n",
- _address.path (), _address.host (), _websocket_key);
+ _address.path (), _address.host (), _websocket_key,
+ protocol);
assert (size > 0 && size < WS_BUFFER_SIZE);
_outpos = _write_buffer;
_outsize = size;
@@ -177,6 +196,48 @@ int zmq::ws_engine_t::process_routing_id_msg (msg_t *msg_)
return 0;
}
+bool zmq::ws_engine_t::select_protocol (char *protocol)
+{
+ if (_options.mechanism == ZMQ_NULL && (strcmp ("ZWS2.0", protocol) == 0)) {
+ _next_msg = static_cast (
+ &ws_engine_t::routing_id_msg);
+ _process_msg = static_cast (
+ &ws_engine_t::process_routing_id_msg);
+ return true;
+ } else if (_options.mechanism == ZMQ_NULL
+ && strcmp ("ZWS2.0/NULL", protocol) == 0) {
+ _mechanism = new (std::nothrow)
+ null_mechanism_t (session (), _peer_address, _options);
+ alloc_assert (_mechanism);
+ return true;
+ } else if (_options.mechanism == ZMQ_PLAIN
+ && strcmp ("ZWS2.0/PLAIN", protocol) == 0) {
+ if (_options.as_server)
+ _mechanism = new (std::nothrow)
+ plain_server_t (session (), _peer_address, _options);
+ else
+ _mechanism =
+ new (std::nothrow) plain_client_t (session (), _options);
+ alloc_assert (_mechanism);
+ return true;
+ }
+#ifdef ZMQ_HAVE_CURVE
+ else if (_options.mechanism == ZMQ_CURVE
+ && strcmp ("ZWS2.0/CURVE", protocol) == 0) {
+ if (_options.as_server)
+ _mechanism = new (std::nothrow)
+ curve_server_t (session (), _peer_address, _options);
+ else
+ _mechanism =
+ new (std::nothrow) curve_client_t (session (), _options);
+ alloc_assert (_mechanism);
+ return true;
+ }
+#endif
+
+ return false;
+}
+
bool zmq::ws_engine_t::handshake ()
{
bool complete;
@@ -390,7 +451,7 @@ bool zmq::ws_engine_t::server_handshake ()
if (*p == ' ')
p++;
- if (strcmp ("ZWS2.0", p) == 0) {
+ if (select_protocol (p)) {
strcpy (_websocket_protocol, p);
break;
}
@@ -760,7 +821,7 @@ bool zmq::ws_engine_t::client_handshake ()
strcpy (_websocket_accept, _header_value);
else if (strcasecmp ("Sec-WebSocket-Protocol", _header_name)
== 0) {
- if (strcmp ("ZWS2.0", _header_value) == 0)
+ if (select_protocol (_header_value))
strcpy (_websocket_protocol, _header_value);
}
_client_handshake_state = client_header_field_cr;
diff --git a/src/ws_engine.hpp b/src/ws_engine.hpp
index e3a15048..12d5e612 100644
--- a/src/ws_engine.hpp
+++ b/src/ws_engine.hpp
@@ -143,6 +143,8 @@ class ws_engine_t : public stream_engine_base_t
int routing_id_msg (msg_t *msg_);
int process_routing_id_msg (msg_t *msg_);
+ bool select_protocol (char *protocol);
+
bool client_handshake ();
bool server_handshake ();
diff --git a/src/zmq.cpp b/src/zmq.cpp
index 0931e61f..787e80bb 100644
--- a/src/zmq.cpp
+++ b/src/zmq.cpp
@@ -1480,6 +1480,14 @@ int zmq_has (const char *capability_)
#if defined(ZMQ_BUILD_DRAFT_API)
if (strcmp (capability_, "draft") == 0)
return true;
+#endif
+#if defined(ZMQ_HAVE_WS)
+ if (strcmp (capability_, "WS") == 0)
+ return true;
+#endif
+#if defined(ZMQ_HAVE_WSS)
+ if (strcmp (capability_, "WSS") == 0)
+ return true;
#endif
// Whatever the application asked for, we don't have
return false;
diff --git a/tests/test_ws_transport.cpp b/tests/test_ws_transport.cpp
index cf241afc..0a9272af 100644
--- a/tests/test_ws_transport.cpp
+++ b/tests/test_ws_transport.cpp
@@ -106,6 +106,43 @@ void test_large_message ()
test_context_socket_close (sb);
}
+void test_curve ()
+{
+ char client_public[41];
+ char client_secret[41];
+ char server_public[41];
+ char server_secret[41];
+
+ TEST_ASSERT_SUCCESS_ERRNO (
+ zmq_curve_keypair (server_public, server_secret));
+ TEST_ASSERT_SUCCESS_ERRNO (
+ zmq_curve_keypair (client_public, client_secret));
+
+ void *server = test_context_socket (ZMQ_REP);
+ int as_server = 1;
+ TEST_ASSERT_SUCCESS_ERRNO (
+ zmq_setsockopt (server, ZMQ_CURVE_SERVER, &as_server, sizeof (int)));
+ TEST_ASSERT_SUCCESS_ERRNO (
+ zmq_setsockopt (server, ZMQ_CURVE_SECRETKEY, server_secret, 41));
+ TEST_ASSERT_SUCCESS_ERRNO (zmq_bind (server, "ws://*:5556/roundtrip"));
+
+
+ void *client = test_context_socket (ZMQ_REQ);
+ TEST_ASSERT_SUCCESS_ERRNO (
+ zmq_setsockopt (client, ZMQ_CURVE_SERVERKEY, server_public, 41));
+ TEST_ASSERT_SUCCESS_ERRNO (
+ zmq_setsockopt (client, ZMQ_CURVE_PUBLICKEY, client_public, 41));
+ TEST_ASSERT_SUCCESS_ERRNO (
+ zmq_setsockopt (client, ZMQ_CURVE_SECRETKEY, client_secret, 41));
+ TEST_ASSERT_SUCCESS_ERRNO (
+ zmq_connect (client, "ws://127.0.0.1:5556/roundtrip"));
+
+ bounce (server, client);
+
+ test_context_socket_close (client);
+ test_context_socket_close (server);
+}
+
int main ()
{
setup_test_environment ();
@@ -114,5 +151,9 @@ int main ()
RUN_TEST (test_roundtrip);
RUN_TEST (test_short_message);
RUN_TEST (test_large_message);
+
+ if (zmq_has ("curve"))
+ RUN_TEST (test_curve);
+
return UNITY_END ();
}