diff --git a/talk/app/webrtc/peerconnection_unittest.cc b/talk/app/webrtc/peerconnection_unittest.cc index fe712e8d0..99ee9be86 100644 --- a/talk/app/webrtc/peerconnection_unittest.cc +++ b/talk/app/webrtc/peerconnection_unittest.cc @@ -82,6 +82,7 @@ using webrtc::MockDataChannelObserver; using webrtc::MockSetSessionDescriptionObserver; using webrtc::MockStatsObserver; using webrtc::PeerConnectionInterface; +using webrtc::PeerConnectionFactory; using webrtc::SessionDescriptionInterface; using webrtc::StreamCollectionInterface; @@ -501,7 +502,8 @@ class PeerConnectionTestClientBase video_decoder_factory_enabled_(false), signaling_message_receiver_(NULL) { } - bool Init(const MediaConstraintsInterface* constraints) { + bool Init(const MediaConstraintsInterface* constraints, + const PeerConnectionFactory::Options* options) { EXPECT_TRUE(!peer_connection_); EXPECT_TRUE(!peer_connection_factory_); allocator_factory_ = webrtc::FakePortAllocatorFactory::Create(); @@ -523,6 +525,9 @@ class PeerConnectionTestClientBase if (!peer_connection_factory_) { return false; } + if (options) { + peer_connection_factory_->SetOptions(*options); + } peer_connection_ = CreatePeerConnection(allocator_factory_.get(), constraints); return peer_connection_.get() != NULL; @@ -619,9 +624,10 @@ class JsepTestClient public: static JsepTestClient* CreateClient( const std::string& id, - const MediaConstraintsInterface* constraints) { + const MediaConstraintsInterface* constraints, + const PeerConnectionFactory::Options* options) { JsepTestClient* client(new JsepTestClient(id)); - if (!client->Init(constraints)) { + if (!client->Init(constraints, options)) { delete client; return NULL; } @@ -967,10 +973,19 @@ class P2PTestConductor : public testing::Test { bool CreateTestClients(MediaConstraintsInterface* init_constraints, MediaConstraintsInterface* recv_constraints) { + return CreateTestClients(init_constraints, NULL, recv_constraints, NULL); + } + + bool CreateTestClients(MediaConstraintsInterface* init_constraints, + PeerConnectionFactory::Options* init_options, + MediaConstraintsInterface* recv_constraints, + PeerConnectionFactory::Options* recv_options) { initiating_client_.reset(SignalingClass::CreateClient("Caller: ", - init_constraints)); + init_constraints, + init_options)); receiving_client_.reset(SignalingClass::CreateClient("Callee: ", - recv_constraints)); + recv_constraints, + recv_options)); if (!initiating_client_ || !receiving_client_) { return false; } @@ -1307,13 +1322,77 @@ TEST_F(JsepPeerConnectionP2PTestClient, GetBytesSentStats) { kMaxWaitForStatsMs); } -// Test that we can get negotiated ciphers. -TEST_F(JsepPeerConnectionP2PTestClient, GetNegotiatedCiphersStats) { - ASSERT_TRUE(CreateTestClients()); +// Test that DTLS 1.0 is used if both sides only support DTLS 1.0. +TEST_F(JsepPeerConnectionP2PTestClient, GetDtls12None) { + PeerConnectionFactory::Options init_options; + init_options.ssl_max_version = rtc::SSL_PROTOCOL_DTLS_10; + PeerConnectionFactory::Options recv_options; + recv_options.ssl_max_version = rtc::SSL_PROTOCOL_DTLS_10; + ASSERT_TRUE(CreateTestClients(NULL, &init_options, NULL, &recv_options)); + LocalP2PTest(); + + EXPECT_EQ_WAIT( + rtc::SSLStreamAdapter::GetDefaultSslCipher(rtc::SSL_PROTOCOL_DTLS_10), + initializing_client()->GetDtlsCipherStats(), + kMaxWaitForStatsMs); + + EXPECT_EQ_WAIT( + kDefaultSrtpCipher, + initializing_client()->GetSrtpCipherStats(), + kMaxWaitForStatsMs); +} + +// Test that DTLS 1.2 is used if both ends support it. +TEST_F(JsepPeerConnectionP2PTestClient, GetDtls12Both) { + PeerConnectionFactory::Options init_options; + init_options.ssl_max_version = rtc::SSL_PROTOCOL_DTLS_12; + PeerConnectionFactory::Options recv_options; + recv_options.ssl_max_version = rtc::SSL_PROTOCOL_DTLS_12; + ASSERT_TRUE(CreateTestClients(NULL, &init_options, NULL, &recv_options)); + LocalP2PTest(); + + EXPECT_EQ_WAIT( + rtc::SSLStreamAdapter::GetDefaultSslCipher(rtc::SSL_PROTOCOL_DTLS_12), + initializing_client()->GetDtlsCipherStats(), + kMaxWaitForStatsMs); + + EXPECT_EQ_WAIT( + kDefaultSrtpCipher, + initializing_client()->GetSrtpCipherStats(), + kMaxWaitForStatsMs); +} + +// Test that DTLS 1.0 is used if the initator supports DTLS 1.2 and the +// received supports 1.0. +TEST_F(JsepPeerConnectionP2PTestClient, GetDtls12Init) { + PeerConnectionFactory::Options init_options; + init_options.ssl_max_version = rtc::SSL_PROTOCOL_DTLS_12; + PeerConnectionFactory::Options recv_options; + recv_options.ssl_max_version = rtc::SSL_PROTOCOL_DTLS_10; + ASSERT_TRUE(CreateTestClients(NULL, &init_options, NULL, &recv_options)); + LocalP2PTest(); + + EXPECT_EQ_WAIT( + rtc::SSLStreamAdapter::GetDefaultSslCipher(rtc::SSL_PROTOCOL_DTLS_10), + initializing_client()->GetDtlsCipherStats(), + kMaxWaitForStatsMs); + + EXPECT_EQ_WAIT( + kDefaultSrtpCipher, + initializing_client()->GetSrtpCipherStats(), + kMaxWaitForStatsMs); +} + +// Test that DTLS 1.0 is used if the initator supports DTLS 1.0 and the +// received supports 1.2. +TEST_F(JsepPeerConnectionP2PTestClient, GetDtls12Recv) { + PeerConnectionFactory::Options init_options; + init_options.ssl_max_version = rtc::SSL_PROTOCOL_DTLS_10; + PeerConnectionFactory::Options recv_options; + recv_options.ssl_max_version = rtc::SSL_PROTOCOL_DTLS_12; + ASSERT_TRUE(CreateTestClients(NULL, &init_options, NULL, &recv_options)); LocalP2PTest(); - // TODO(jbauch): this should check for DTLS 1.0 / 1.2 when we have a way to - // enable DTLS 1.2 on peer connections. EXPECT_EQ_WAIT( rtc::SSLStreamAdapter::GetDefaultSslCipher(rtc::SSL_PROTOCOL_DTLS_10), initializing_client()->GetDtlsCipherStats(), diff --git a/talk/app/webrtc/peerconnectioninterface.h b/talk/app/webrtc/peerconnectioninterface.h index 329137f97..59cc7008e 100644 --- a/talk/app/webrtc/peerconnectioninterface.h +++ b/talk/app/webrtc/peerconnectioninterface.h @@ -79,6 +79,7 @@ #include "talk/app/webrtc/umametrics.h" #include "webrtc/base/fileutils.h" #include "webrtc/base/network.h" +#include "webrtc/base/sslstreamadapter.h" #include "webrtc/base/socketaddress.h" namespace rtc { @@ -518,7 +519,8 @@ class PeerConnectionFactoryInterface : public rtc::RefCountInterface { Options() : disable_encryption(false), disable_sctp_data_channels(false), - network_ignore_mask(rtc::kDefaultNetworkIgnoreMask) { + network_ignore_mask(rtc::kDefaultNetworkIgnoreMask), + ssl_max_version(rtc::SSL_PROTOCOL_DTLS_10) { } bool disable_encryption; bool disable_sctp_data_channels; @@ -527,6 +529,11 @@ class PeerConnectionFactoryInterface : public rtc::RefCountInterface { // ADAPTER_TYPE_ETHERNET | ADAPTER_TYPE_LOOPBACK will ignore Ethernet and // loopback interfaces. int network_ignore_mask; + + // Sets the maximum supported protocol version. The highest version + // supported by both ends will be used for the connection, i.e. if one + // party supports DTLS 1.0 and the other DTLS 1.2, DTLS 1.0 will be used. + rtc::SSLProtocolVersion ssl_max_version; }; virtual void SetOptions(const Options& options) = 0; diff --git a/talk/app/webrtc/webrtcsession.cc b/talk/app/webrtc/webrtcsession.cc index e4613650d..a058129ca 100644 --- a/talk/app/webrtc/webrtcsession.cc +++ b/talk/app/webrtc/webrtcsession.cc @@ -524,6 +524,7 @@ bool WebRtcSession::Initialize( const PeerConnectionInterface::RTCConfiguration& rtc_configuration) { bundle_policy_ = rtc_configuration.bundle_policy; rtcp_mux_policy_ = rtc_configuration.rtcp_mux_policy; + SetSslMaxProtocolVersion(options.ssl_max_version); // TODO(perkj): Take |constraints| into consideration. Return false if not all // mandatory constraints can be fulfilled. Note that |constraints| diff --git a/webrtc/p2p/base/dtlstransport.h b/webrtc/p2p/base/dtlstransport.h index 8e17ea660..27cece49d 100644 --- a/webrtc/p2p/base/dtlstransport.h +++ b/webrtc/p2p/base/dtlstransport.h @@ -33,7 +33,8 @@ class DtlsTransport : public Base { rtc::SSLIdentity* identity) : Base(signaling_thread, worker_thread, content_name, allocator), identity_(identity), - secure_role_(rtc::SSL_CLIENT) { + secure_role_(rtc::SSL_CLIENT), + ssl_max_version_(rtc::SSL_PROTOCOL_DTLS_10) { } ~DtlsTransport() { @@ -50,6 +51,11 @@ class DtlsTransport : public Base { return true; } + virtual bool SetSslMaxProtocolVersion_w(rtc::SSLProtocolVersion version) { + ssl_max_version_ = version; + return true; + } + virtual bool ApplyLocalTransportDescription_w(TransportChannelImpl* channel, std::string* error_desc) { rtc::SSLFingerprint* local_fp = @@ -189,8 +195,10 @@ class DtlsTransport : public Base { } virtual DtlsTransportChannelWrapper* CreateTransportChannel(int component) { - return new DtlsTransportChannelWrapper( + DtlsTransportChannelWrapper* channel = new DtlsTransportChannelWrapper( this, Base::CreateTransportChannel(component)); + channel->SetSslMaxProtocolVersion(ssl_max_version_); + return channel; } virtual void DestroyTransportChannel(TransportChannelImpl* channel) { @@ -231,6 +239,7 @@ class DtlsTransport : public Base { rtc::SSLIdentity* identity_; rtc::SSLRole secure_role_; + rtc::SSLProtocolVersion ssl_max_version_; rtc::scoped_ptr remote_fingerprint_; }; diff --git a/webrtc/p2p/base/dtlstransportchannel.cc b/webrtc/p2p/base/dtlstransportchannel.cc index 6a206d420..e73cf688b 100644 --- a/webrtc/p2p/base/dtlstransportchannel.cc +++ b/webrtc/p2p/base/dtlstransportchannel.cc @@ -163,15 +163,16 @@ bool DtlsTransportChannelWrapper::GetLocalIdentity( return true; } -void DtlsTransportChannelWrapper::SetMaxProtocolVersion( +bool DtlsTransportChannelWrapper::SetSslMaxProtocolVersion( rtc::SSLProtocolVersion version) { if (dtls_state_ != STATE_NONE) { LOG(LS_ERROR) << "Not changing max. protocol version " << "while DTLS is negotiating"; - return; + return false; } ssl_max_version_ = version; + return true; } bool DtlsTransportChannelWrapper::SetSslRole(rtc::SSLRole role) { diff --git a/webrtc/p2p/base/dtlstransportchannel.h b/webrtc/p2p/base/dtlstransportchannel.h index a48d09fec..1af5efdd9 100644 --- a/webrtc/p2p/base/dtlstransportchannel.h +++ b/webrtc/p2p/base/dtlstransportchannel.h @@ -129,7 +129,7 @@ class DtlsTransportChannelWrapper : public TransportChannelImpl { return channel_->SessionId(); } - virtual void SetMaxProtocolVersion(rtc::SSLProtocolVersion version); + virtual bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version); // Set up the ciphers to use for DTLS-SRTP. If this method is not called // before DTLS starts, or |ciphers| is empty, SRTP keys won't be negotiated. diff --git a/webrtc/p2p/base/dtlstransportchannel_unittest.cc b/webrtc/p2p/base/dtlstransportchannel_unittest.cc index 670f8f92c..8c1c21cd1 100644 --- a/webrtc/p2p/base/dtlstransportchannel_unittest.cc +++ b/webrtc/p2p/base/dtlstransportchannel_unittest.cc @@ -90,7 +90,7 @@ class DtlsTestClient : public sigslot::has_slots<> { static_cast( transport_->CreateChannel(i)); ASSERT_TRUE(channel != NULL); - channel->SetMaxProtocolVersion(ssl_max_version_); + channel->SetSslMaxProtocolVersion(ssl_max_version_); channel->SignalWritableState.connect(this, &DtlsTestClient::OnTransportChannelWritableState); channel->SignalReadPacket.connect(this, diff --git a/webrtc/p2p/base/session.cc b/webrtc/p2p/base/session.cc index 8cbe128a7..d1f6e54a0 100644 --- a/webrtc/p2p/base/session.cc +++ b/webrtc/p2p/base/session.cc @@ -342,6 +342,7 @@ BaseSession::BaseSession(rtc::Thread* signaling_thread, transport_type_(NS_GINGLE_P2P), initiator_(initiator), identity_(NULL), + ssl_max_version_(rtc::SSL_PROTOCOL_DTLS_10), ice_tiebreaker_(rtc::CreateRandomId64()), role_switch_(false) { ASSERT(signaling_thread->IsCurrent()); @@ -404,6 +405,15 @@ bool BaseSession::SetIdentity(rtc::SSLIdentity* identity) { return true; } +bool BaseSession::SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) { + if (state_ != STATE_INIT) { + return false; + } + + ssl_max_version_ = version; + return true; +} + bool BaseSession::PushdownTransportDescription(ContentSource source, ContentAction action, std::string* error_desc) { @@ -501,6 +511,7 @@ TransportProxy* BaseSession::GetOrCreateTransportProxy( Transport* transport = CreateTransport(content_name); transport->SetIceRole(initiator_ ? ICEROLE_CONTROLLING : ICEROLE_CONTROLLED); transport->SetIceTiebreaker(ice_tiebreaker_); + transport->SetSslMaxProtocolVersion(ssl_max_version_); // TODO: Connect all the Transport signals to TransportProxy // then to the BaseSession. transport->SignalConnecting.connect( diff --git a/webrtc/p2p/base/session.h b/webrtc/p2p/base/session.h index ba4206aee..318a7888d 100644 --- a/webrtc/p2p/base/session.h +++ b/webrtc/p2p/base/session.h @@ -323,6 +323,8 @@ class BaseSession : public sigslot::has_slots<>, // Specifies the identity to use in this session. bool SetIdentity(rtc::SSLIdentity* identity); + bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version); + bool PushdownTransportDescription(ContentSource source, ContentAction action, std::string* error_desc); @@ -445,6 +447,7 @@ class BaseSession : public sigslot::has_slots<>, const std::string transport_type_; bool initiator_; rtc::SSLIdentity* identity_; + rtc::SSLProtocolVersion ssl_max_version_; rtc::scoped_ptr local_description_; rtc::scoped_ptr remote_description_; uint64 ice_tiebreaker_; diff --git a/webrtc/p2p/base/transport.cc b/webrtc/p2p/base/transport.cc index df452fdf3..0184af3da 100644 --- a/webrtc/p2p/base/transport.cc +++ b/webrtc/p2p/base/transport.cc @@ -458,6 +458,11 @@ bool Transport::GetSslRole(rtc::SSLRole* ssl_role) const { &Transport::GetSslRole_w, this, ssl_role)); } +bool Transport::SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) { + return worker_thread_->Invoke(Bind( + &Transport::SetSslMaxProtocolVersion_w, this, version)); +} + void Transport::OnRemoteCandidates(const std::vector& candidates) { for (std::vector::const_iterator iter = candidates.begin(); iter != candidates.end(); diff --git a/webrtc/p2p/base/transport.h b/webrtc/p2p/base/transport.h index c4ab03da8..d7ebc9cbe 100644 --- a/webrtc/p2p/base/transport.h +++ b/webrtc/p2p/base/transport.h @@ -264,6 +264,9 @@ class Transport : public rtc::MessageHandler, virtual bool GetSslRole(rtc::SSLRole* ssl_role) const; + // Must be called before channel is starting to connect. + virtual bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version); + protected: // These are called by Create/DestroyChannel above in order to create or // destroy the appropriate type of channel. @@ -320,6 +323,10 @@ class Transport : public rtc::MessageHandler, return false; } + virtual bool SetSslMaxProtocolVersion_w(rtc::SSLProtocolVersion version) { + return false; + } + private: struct ChannelMapEntry { ChannelMapEntry() : impl_(NULL), candidates_allocated_(false), ref_(0) {}