diff --git a/webrtc/base/openssladapter.cc b/webrtc/base/openssladapter.cc index feb01d36d..603a55bb0 100644 --- a/webrtc/base/openssladapter.cc +++ b/webrtc/base/openssladapter.cc @@ -37,6 +37,7 @@ #include "webrtc/base/safe_conversions.h" #include "webrtc/base/sslroots.h" #include "webrtc/base/stringutils.h" +#include "webrtc/base/thread.h" // TODO: Use a nicer abstraction for mutex. @@ -266,6 +267,7 @@ OpenSSLAdapter::OpenSSLAdapter(AsyncSocket* socket) ssl_write_needs_read_(false), restartable_(false), ssl_(NULL), ssl_ctx_(NULL), + ssl_mode_(SSL_MODE_TLS), custom_verification_succeeded_(false) { } @@ -273,6 +275,12 @@ OpenSSLAdapter::~OpenSSLAdapter() { Cleanup(); } +void +OpenSSLAdapter::SetMode(SSLMode mode) { + ASSERT(state_ == SSL_NONE); + ssl_mode_ = mode; +} + int OpenSSLAdapter::StartSSL(const char* hostname, bool restartable) { if (state_ != SSL_NONE) @@ -352,6 +360,9 @@ int OpenSSLAdapter::ContinueSSL() { ASSERT(state_ == SSL_CONNECTING); + // Clear the DTLS timer + Thread::Current()->Clear(this, MSG_TIMEOUT); + int code = SSL_connect(ssl_); switch (SSL_get_error(ssl_, code)) { case SSL_ERROR_NONE: @@ -376,6 +387,15 @@ OpenSSLAdapter::ContinueSSL() { break; case SSL_ERROR_WANT_READ: + LOG(LS_VERBOSE) << " -- error want read"; + struct timeval timeout; + if (DTLSv1_get_timeout(ssl_, &timeout)) { + int delay = timeout.tv_sec * 1000 + timeout.tv_usec/1000; + + Thread::Current()->PostDelayed(delay, this, MSG_TIMEOUT, 0); + } + break; + case SSL_ERROR_WANT_WRITE: break; @@ -416,6 +436,9 @@ OpenSSLAdapter::Cleanup() { SSL_CTX_free(ssl_ctx_); ssl_ctx_ = NULL; } + + // Clear the DTLS timer + Thread::Current()->Clear(this, MSG_TIMEOUT); } // @@ -477,6 +500,18 @@ OpenSSLAdapter::Send(const void* pv, size_t cb) { return SOCKET_ERROR; } +int +OpenSSLAdapter::SendTo(const void* pv, size_t cb, const SocketAddress& addr) { + if (socket_->GetState() == Socket::CS_CONNECTED && + addr == socket_->GetRemoteAddress()) { + return Send(pv, cb); + } + + SetError(ENOTCONN); + + return SOCKET_ERROR; +} + int OpenSSLAdapter::Recv(void* pv, size_t cb) { //LOG(LS_INFO) << "OpenSSLAdapter::Recv(" << cb << ")"; @@ -532,6 +567,21 @@ OpenSSLAdapter::Recv(void* pv, size_t cb) { return SOCKET_ERROR; } +int +OpenSSLAdapter::RecvFrom(void* pv, size_t cb, SocketAddress* paddr) { + if (socket_->GetState() == Socket::CS_CONNECTED) { + int ret = Recv(pv, cb); + + *paddr = GetRemoteAddress(); + + return ret; + } + + SetError(ENOTCONN); + + return SOCKET_ERROR; +} + int OpenSSLAdapter::Close() { Cleanup(); @@ -550,6 +600,15 @@ OpenSSLAdapter::GetState() const { return state; } +void +OpenSSLAdapter::OnMessage(Message* msg) { + if (MSG_TIMEOUT == msg->message_id) { + LOG(LS_INFO) << "DTLS timeout expired"; + DTLSv1_handle_timeout(ssl_); + ContinueSSL(); + } +} + void OpenSSLAdapter::OnConnectEvent(AsyncSocket* socket) { LOG(LS_INFO) << "OpenSSLAdapter::OnConnectEvent"; @@ -861,7 +920,8 @@ bool OpenSSLAdapter::ConfigureTrustedRootCertificates(SSL_CTX* ctx) { SSL_CTX* OpenSSLAdapter::SetupSSLContext() { - SSL_CTX* ctx = SSL_CTX_new(TLSv1_client_method()); + SSL_CTX* ctx = SSL_CTX_new(ssl_mode_ == SSL_MODE_DTLS ? + DTLSv1_client_method() : TLSv1_client_method()); if (ctx == NULL) { unsigned long error = ERR_get_error(); // NOLINT: type used by OpenSSL. LOG(LS_WARNING) << "SSL_CTX creation failed: " @@ -882,6 +942,10 @@ OpenSSLAdapter::SetupSSLContext() { SSL_CTX_set_verify_depth(ctx, 4); SSL_CTX_set_cipher_list(ctx, "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + if (ssl_mode_ == SSL_MODE_DTLS) { + SSL_CTX_set_read_ahead(ctx, 1); + } + return ctx; } diff --git a/webrtc/base/openssladapter.h b/webrtc/base/openssladapter.h index d244a7f5c..079f0016d 100644 --- a/webrtc/base/openssladapter.h +++ b/webrtc/base/openssladapter.h @@ -12,6 +12,8 @@ #define WEBRTC_BASE_OPENSSLADAPTER_H__ #include +#include "webrtc/base/messagehandler.h" +#include "webrtc/base/messagequeue.h" #include "webrtc/base/ssladapter.h" typedef struct ssl_st SSL; @@ -22,7 +24,7 @@ namespace rtc { /////////////////////////////////////////////////////////////////////////////// -class OpenSSLAdapter : public SSLAdapter { +class OpenSSLAdapter : public SSLAdapter, public MessageHandler { public: static bool InitializeSSL(VerificationCallback callback); static bool InitializeSSLThread(); @@ -31,9 +33,12 @@ public: OpenSSLAdapter(AsyncSocket* socket); virtual ~OpenSSLAdapter(); + virtual void SetMode(SSLMode mode); virtual int StartSSL(const char* hostname, bool restartable); virtual int Send(const void* pv, size_t cb); + virtual int SendTo(const void* pv, size_t cb, const SocketAddress& addr); virtual int Recv(void* pv, size_t cb); + virtual int RecvFrom(void* pv, size_t cb, SocketAddress* paddr); virtual int Close(); // Note that the socket returns ST_CONNECTING while SSL is being negotiated. @@ -50,11 +55,15 @@ private: SSL_NONE, SSL_WAIT, SSL_CONNECTING, SSL_CONNECTED, SSL_ERROR }; + enum { MSG_TIMEOUT }; + int BeginSSL(); int ContinueSSL(); void Error(const char* context, int err, bool signal = true); void Cleanup(); + virtual void OnMessage(Message* msg); + static bool VerifyServerName(SSL* ssl, const char* host, bool ignore_bad_cert); bool SSLPostConnectionCheck(SSL* ssl, const char* host); @@ -66,7 +75,7 @@ private: friend class OpenSSLStreamAdapter; // for custom_verify_callback_; static bool ConfigureTrustedRootCertificates(SSL_CTX* ctx); - static SSL_CTX* SetupSSLContext(); + SSL_CTX* SetupSSLContext(); SSLState state_; bool ssl_read_needs_write_; @@ -77,6 +86,8 @@ private: SSL* ssl_; SSL_CTX* ssl_ctx_; std::string ssl_host_name_; + // Do DTLS or not + SSLMode ssl_mode_; bool custom_verification_succeeded_; }; diff --git a/webrtc/base/ssladapter.h b/webrtc/base/ssladapter.h index 87b993ffb..57e8ee9cc 100644 --- a/webrtc/base/ssladapter.h +++ b/webrtc/base/ssladapter.h @@ -12,6 +12,7 @@ #define WEBRTC_BASE_SSLADAPTER_H_ #include "webrtc/base/asyncsocket.h" +#include "webrtc/base/sslstreamadapter.h" namespace rtc { @@ -25,6 +26,9 @@ class SSLAdapter : public AsyncSocketAdapter { bool ignore_bad_cert() const { return ignore_bad_cert_; } void set_ignore_bad_cert(bool ignore) { ignore_bad_cert_ = ignore; } + // Do DTLS or TLS (default is TLS, if unspecified) + virtual void SetMode(SSLMode mode) = 0; + // StartSSL returns 0 if successful. // If StartSSL is called while the socket is closed or connecting, the SSL // negotiation will begin as soon as the socket connects. diff --git a/webrtc/base/ssladapter_unittest.cc b/webrtc/base/ssladapter_unittest.cc index 4590db919..9a0548602 100644 --- a/webrtc/base/ssladapter_unittest.cc +++ b/webrtc/base/ssladapter_unittest.cc @@ -44,6 +44,8 @@ class SSLAdapterTestDummyClient : public sigslot::has_slots<> { ssl_adapter_.reset(rtc::SSLAdapter::Create(socket)); + ssl_adapter_->SetMode(ssl_mode_); + // Ignore any certificate errors for the purpose of testing. // Note: We do this only because we don't have a real certificate. // NEVER USE THIS IN PRODUCTION CODE! @@ -55,6 +57,10 @@ class SSLAdapterTestDummyClient : public sigslot::has_slots<> { &SSLAdapterTestDummyClient::OnSSLAdapterCloseEvent); } + rtc::SocketAddress GetAddress() const { + return ssl_adapter_->GetLocalAddress(); + } + rtc::AsyncSocket::ConnState GetState() const { return ssl_adapter_->GetState(); } @@ -64,16 +70,20 @@ class SSLAdapterTestDummyClient : public sigslot::has_slots<> { } int Connect(const std::string& hostname, const rtc::SocketAddress& address) { - LOG(LS_INFO) << "Starting " << GetSSLProtocolName(ssl_mode_) - << " handshake with " << hostname; - - if (ssl_adapter_->StartSSL(hostname.c_str(), false) != 0) { - return -1; - } - LOG(LS_INFO) << "Initiating connection with " << address; - return ssl_adapter_->Connect(address); + int rv = ssl_adapter_->Connect(address); + + if (rv == 0) { + LOG(LS_INFO) << "Starting " << GetSSLProtocolName(ssl_mode_) + << " handshake with " << hostname; + + if (ssl_adapter_->StartSSL(hostname.c_str(), false) != 0) { + return -1; + } + } + + return rv; } int Close() { @@ -126,10 +136,12 @@ class SSLAdapterTestDummyServer : public sigslot::has_slots<> { server_socket_.reset(CreateSocket(ssl_mode_)); - server_socket_->SignalReadEvent.connect(this, - &SSLAdapterTestDummyServer::OnServerSocketReadEvent); + if (ssl_mode_ == rtc::SSL_MODE_TLS) { + server_socket_->SignalReadEvent.connect(this, + &SSLAdapterTestDummyServer::OnServerSocketReadEvent); - server_socket_->Listen(1); + server_socket_->Listen(1); + } LOG(LS_INFO) << ((ssl_mode_ == rtc::SSL_MODE_DTLS) ? "UDP" : "TCP") << " server listening on " << server_socket_->GetLocalAddress(); @@ -170,38 +182,26 @@ class SSLAdapterTestDummyServer : public sigslot::has_slots<> { } } + void AcceptConnection(const rtc::SocketAddress& address) { + // Only a single connection is supported. + ASSERT_TRUE(ssl_stream_adapter_ == NULL); + + // This is only for DTLS. + ASSERT_EQ(rtc::SSL_MODE_DTLS, ssl_mode_); + + // Transfer ownership of the socket to the SSLStreamAdapter object. + rtc::AsyncSocket* socket = server_socket_.release(); + + socket->Connect(address); + + DoHandshake(socket); + } + void OnServerSocketReadEvent(rtc::AsyncSocket* socket) { - if (ssl_stream_adapter_ != NULL) { - // Only a single connection is supported. - return; - } + // Only a single connection is supported. + ASSERT_TRUE(ssl_stream_adapter_ == NULL); - rtc::SocketAddress address; - rtc::AsyncSocket* new_socket = socket->Accept(&address); - rtc::SocketStream* stream = new rtc::SocketStream(new_socket); - - ssl_stream_adapter_.reset(rtc::SSLStreamAdapter::Create(stream)); - ssl_stream_adapter_->SetServerRole(); - - // SSLStreamAdapter is normally used for peer-to-peer communication, but - // here we're testing communication between a client and a server - // (e.g. a WebRTC-based application and an RFC 5766 TURN server), where - // clients are not required to provide a certificate during handshake. - // Accordingly, we must disable client authentication here. - ssl_stream_adapter_->set_client_auth_enabled(false); - - ssl_stream_adapter_->SetIdentity(ssl_identity_->GetReference()); - - // Set a bogus peer certificate digest. - unsigned char digest[20]; - size_t digest_len = sizeof(digest); - ssl_stream_adapter_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, digest, - digest_len); - - ssl_stream_adapter_->StartSSLWithPeer(); - - ssl_stream_adapter_->SignalEvent.connect(this, - &SSLAdapterTestDummyServer::OnSSLStreamAdapterEvent); + DoHandshake(server_socket_->Accept(NULL)); } void OnSSLStreamAdapterEvent(rtc::StreamInterface* stream, int sig, int err) { @@ -226,6 +226,35 @@ class SSLAdapterTestDummyServer : public sigslot::has_slots<> { } private: + void DoHandshake(rtc::AsyncSocket* socket) { + rtc::SocketStream* stream = new rtc::SocketStream(socket); + + ssl_stream_adapter_.reset(rtc::SSLStreamAdapter::Create(stream)); + + ssl_stream_adapter_->SetMode(ssl_mode_); + ssl_stream_adapter_->SetServerRole(); + + // SSLStreamAdapter is normally used for peer-to-peer communication, but + // here we're testing communication between a client and a server + // (e.g. a WebRTC-based application and an RFC 5766 TURN server), where + // clients are not required to provide a certificate during handshake. + // Accordingly, we must disable client authentication here. + ssl_stream_adapter_->set_client_auth_enabled(false); + + ssl_stream_adapter_->SetIdentity(ssl_identity_->GetReference()); + + // Set a bogus peer certificate digest. + unsigned char digest[20]; + size_t digest_len = sizeof(digest); + ssl_stream_adapter_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, digest, + digest_len); + + ssl_stream_adapter_->StartSSLWithPeer(); + + ssl_stream_adapter_->SignalEvent.connect(this, + &SSLAdapterTestDummyServer::OnSSLStreamAdapterEvent); + } + const rtc::SSLMode ssl_mode_; rtc::scoped_ptr server_socket_; @@ -263,6 +292,11 @@ class SSLAdapterTestBase : public testing::Test, // Now the state should be CS_CONNECTING ASSERT_EQ(rtc::AsyncSocket::CS_CONNECTING, client_->GetState()); + if (ssl_mode_ == rtc::SSL_MODE_DTLS) { + // For DTLS, call AcceptConnection() with the client's address. + server_->AcceptConnection(client_->GetAddress()); + } + if (expect_success) { // If expecting success, the client should end up in the CS_CONNECTED // state after handshake. @@ -314,6 +348,10 @@ class SSLAdapterTestTLS : public SSLAdapterTestBase { SSLAdapterTestTLS() : SSLAdapterTestBase(rtc::SSL_MODE_TLS) {} }; +class SSLAdapterTestDTLS : public SSLAdapterTestBase { + public: + SSLAdapterTestDTLS() : SSLAdapterTestBase(rtc::SSL_MODE_DTLS) {} +}; #if SSL_USE_OPENSSL @@ -330,5 +368,18 @@ TEST_F(SSLAdapterTestTLS, TestTLSTransfer) { TestTransfer("Hello, world!"); } +// Basic tests: DTLS + +// Test that handshake works +TEST_F(SSLAdapterTestDTLS, TestDTLSConnect) { + TestHandshake(true); +} + +// Test transfer between client and server +TEST_F(SSLAdapterTestDTLS, TestDTLSTransfer) { + TestHandshake(true); + TestTransfer("Hello, world!"); +} + #endif // SSL_USE_OPENSSL