diff --git a/talk/app/webrtc/dtlsidentitystore.cc b/talk/app/webrtc/dtlsidentitystore.cc index 7beb68d13..82285da86 100644 --- a/talk/app/webrtc/dtlsidentitystore.cc +++ b/talk/app/webrtc/dtlsidentitystore.cc @@ -44,8 +44,50 @@ enum { }; typedef rtc::ScopedMessageData IdentityResultMessageData; + } // namespace +// This class runs on the worker thread to generate the identity. It's necessary +// to separate this class from DtlsIdentityStore so that it can live on the +// worker thread after DtlsIdentityStore is destroyed. +class DtlsIdentityStore::WorkerTask : public sigslot::has_slots<>, + public rtc::MessageHandler { + public: + explicit WorkerTask(DtlsIdentityStore* store) : store_(store) { + store_->SignalDestroyed.connect(this, &WorkerTask::OnStoreDestroyed); + }; + + virtual ~WorkerTask() {} + + void GenerateIdentity() { + rtc::scoped_ptr identity( + rtc::SSLIdentity::Generate(DtlsIdentityStore::kIdentityName)); + + { + rtc::CritScope cs(&cs_); + if (store_) { + store_->PostGenerateIdentityResult_w(identity.Pass()); + } + } + } + + void OnMessage(rtc::Message* msg) override { + DCHECK(msg->message_id == MSG_GENERATE_IDENTITY); + GenerateIdentity(); + // Deleting msg->pdata will destroy the WorkerTask. + delete msg->pdata; + } + + private: + void OnStoreDestroyed() { + rtc::CritScope cs(&cs_); + store_ = NULL; + } + + rtc::CriticalSection cs_; + DtlsIdentityStore* store_; +}; + // Arbitrary constant used as common name for the identity. // Chosen to make the certificates more readable. const char DtlsIdentityStore::kIdentityName[] = "WebRTC"; @@ -56,10 +98,16 @@ DtlsIdentityStore::DtlsIdentityStore(rtc::Thread* signaling_thread, worker_thread_(worker_thread), pending_jobs_(0) {} -DtlsIdentityStore::~DtlsIdentityStore() {} +DtlsIdentityStore::~DtlsIdentityStore() { + SignalDestroyed(); +} void DtlsIdentityStore::Initialize() { - GenerateIdentity(); + // Do not aggressively generate the free identity if the worker thread and the + // signaling thread are the same. + if (worker_thread_ != signaling_thread_) { + GenerateIdentity(); + } } void DtlsIdentityStore::RequestIdentity(DTLSIdentityRequestObserver* observer) { @@ -79,9 +127,6 @@ void DtlsIdentityStore::RequestIdentity(DTLSIdentityRequestObserver* observer) { void DtlsIdentityStore::OnMessage(rtc::Message* msg) { switch (msg->message_id) { - case MSG_GENERATE_IDENTITY: - GenerateIdentity_w(); - break; case MSG_GENERATE_IDENTITY_RESULT: { rtc::scoped_ptr pdata( static_cast(msg->pdata)); @@ -105,7 +150,12 @@ void DtlsIdentityStore::GenerateIdentity() { pending_jobs_++; LOG(LS_VERBOSE) << "New DTLS identity generation is posted, " << "pending_identities=" << pending_jobs_; - worker_thread_->Post(this, MSG_GENERATE_IDENTITY, NULL); + + WorkerTask* task = new WorkerTask(this); + // The WorkerTask is owned by the message data to make sure it will not be + // leaked even if the task does not get run. + IdentityTaskMessageData* msg = new IdentityTaskMessageData(task); + worker_thread_->Post(task, MSG_GENERATE_IDENTITY, msg); } void DtlsIdentityStore::OnIdentityGenerated( @@ -143,19 +193,22 @@ void DtlsIdentityStore::ReturnIdentity( LOG(LS_WARNING) << "Failed to generate SSL identity"; } - if (pending_observers_.empty() && pending_jobs_ == 0) { - // Generate a free identity in the background. - GenerateIdentity(); + // Do not aggressively generate the free identity if the worker thread and the + // signaling thread are the same. + if (worker_thread_ != signaling_thread_ && + pending_observers_.empty() && + pending_jobs_ == 0) { + // Generate a free identity in the background. + GenerateIdentity(); } } -void DtlsIdentityStore::GenerateIdentity_w() { +void DtlsIdentityStore::PostGenerateIdentityResult_w( + rtc::scoped_ptr identity) { DCHECK(rtc::Thread::Current() == worker_thread_); - rtc::SSLIdentity* identity = rtc::SSLIdentity::Generate(kIdentityName); - - IdentityResultMessageData* msg = new IdentityResultMessageData(identity); + IdentityResultMessageData* msg = + new IdentityResultMessageData(identity.release()); signaling_thread_->Post(this, MSG_GENERATE_IDENTITY_RESULT, msg); } - } // namespace webrtc diff --git a/talk/app/webrtc/dtlsidentitystore.h b/talk/app/webrtc/dtlsidentitystore.h index 2d52249bd..ffe8067aa 100644 --- a/talk/app/webrtc/dtlsidentitystore.h +++ b/talk/app/webrtc/dtlsidentitystore.h @@ -33,6 +33,7 @@ #include "talk/app/webrtc/peerconnectioninterface.h" #include "webrtc/base/messagehandler.h" +#include "webrtc/base/messagequeue.h" #include "webrtc/base/scoped_ptr.h" #include "webrtc/base/scoped_ref_ptr.h" @@ -51,7 +52,7 @@ class DtlsIdentityStore : public rtc::MessageHandler { DtlsIdentityStore(rtc::Thread* signaling_thread, rtc::Thread* worker_thread); - ~DtlsIdentityStore(); + virtual ~DtlsIdentityStore(); // Initialize will start generating the free identity in the background. void Initialize(); @@ -67,11 +68,16 @@ class DtlsIdentityStore : public rtc::MessageHandler { bool HasFreeIdentityForTesting() const; private: + sigslot::signal0 SignalDestroyed; + class WorkerTask; + typedef rtc::ScopedMessageData + IdentityTaskMessageData; + void GenerateIdentity(); void OnIdentityGenerated(rtc::scoped_ptr identity); void ReturnIdentity(rtc::scoped_ptr identity); - void GenerateIdentity_w(); + void PostGenerateIdentityResult_w(rtc::scoped_ptr identity); rtc::Thread* signaling_thread_; rtc::Thread* worker_thread_; diff --git a/talk/app/webrtc/dtlsidentitystore_unittest.cc b/talk/app/webrtc/dtlsidentitystore_unittest.cc index 67b7cb22a..c0b204a85 100644 --- a/talk/app/webrtc/dtlsidentitystore_unittest.cc +++ b/talk/app/webrtc/dtlsidentitystore_unittest.cc @@ -80,10 +80,12 @@ class MockDtlsIdentityRequestObserver : class DtlsIdentityStoreTest : public testing::Test { protected: DtlsIdentityStoreTest() - : store_(new DtlsIdentityStore(rtc::Thread::Current(), - rtc::Thread::Current())), + : worker_thread_(new rtc::Thread()), + store_(new DtlsIdentityStore(rtc::Thread::Current(), + worker_thread_.get())), observer_( new rtc::RefCountedObject()) { + CHECK(worker_thread_->Start()); store_->Initialize(); } ~DtlsIdentityStoreTest() {} @@ -95,6 +97,7 @@ class DtlsIdentityStoreTest : public testing::Test { rtc::CleanupSSL(); } + rtc::scoped_ptr worker_thread_; rtc::scoped_ptr store_; rtc::scoped_refptr observer_; }; @@ -106,4 +109,22 @@ TEST_F(DtlsIdentityStoreTest, RequestIdentitySuccess) { EXPECT_TRUE_WAIT(observer_->LastRequestSucceeded(), kTimeoutMs); EXPECT_TRUE_WAIT(store_->HasFreeIdentityForTesting(), kTimeoutMs); + + observer_->Reset(); + + // Verifies that the callback is async when a free identity is ready. + store_->RequestIdentity(observer_.get()); + EXPECT_FALSE(observer_->call_back_called()); + EXPECT_TRUE_WAIT(observer_->LastRequestSucceeded(), kTimeoutMs); } + +TEST_F(DtlsIdentityStoreTest, DeleteStoreEarlyNoCrash) { + EXPECT_FALSE(store_->HasFreeIdentityForTesting()); + + store_->RequestIdentity(observer_.get()); + store_.reset(); + + worker_thread_->Stop(); + EXPECT_FALSE(observer_->call_back_called()); +} + diff --git a/talk/app/webrtc/peerconnectionfactory.cc b/talk/app/webrtc/peerconnectionfactory.cc index 9f1e8588e..1c933764a 100644 --- a/talk/app/webrtc/peerconnectionfactory.cc +++ b/talk/app/webrtc/peerconnectionfactory.cc @@ -132,6 +132,11 @@ PeerConnectionFactory::~PeerConnectionFactory() { DCHECK(signaling_thread_->IsCurrent()); channel_manager_.reset(NULL); default_allocator_factory_ = NULL; + + // Make sure |worker_thread_| and |signaling_thread_| outlive + // |dtls_identity_store_|. + dtls_identity_store_.reset(NULL); + if (owns_ptrs_) { if (wraps_current_thread_) rtc::ThreadManager::Instance()->UnwrapCurrentThread(); diff --git a/talk/app/webrtc/peerconnectionfactory_unittest.cc b/talk/app/webrtc/peerconnectionfactory_unittest.cc index 2860ca70f..67a20331d 100644 --- a/talk/app/webrtc/peerconnectionfactory_unittest.cc +++ b/talk/app/webrtc/peerconnectionfactory_unittest.cc @@ -30,6 +30,7 @@ #include "talk/app/webrtc/fakeportallocatorfactory.h" #include "talk/app/webrtc/mediastreaminterface.h" #include "talk/app/webrtc/peerconnectionfactory.h" +#include "talk/app/webrtc/test/fakedtlsidentityservice.h" #include "talk/app/webrtc/test/fakevideotrackrenderer.h" #include "talk/app/webrtc/videosourceinterface.h" #include "talk/media/base/fakevideocapturer.h" @@ -157,7 +158,8 @@ TEST(PeerConnectionFactoryTestInternal, CreatePCUsingInternalModules) { webrtc::PeerConnectionInterface::IceServers servers; rtc::scoped_refptr pc( - factory->CreatePeerConnection(servers, NULL, NULL, NULL, &observer)); + factory->CreatePeerConnection( + servers, NULL, NULL, new FakeIdentityService(), &observer)); EXPECT_TRUE(pc.get() != NULL); } @@ -178,7 +180,7 @@ TEST_F(PeerConnectionFactoryTest, CreatePCUsingIceServers) { rtc::scoped_refptr pc( factory_->CreatePeerConnection(config, NULL, allocator_factory_.get(), - NULL, + new FakeIdentityService(), &observer_)); EXPECT_TRUE(pc.get() != NULL); StunConfigurations stun_configs; @@ -214,7 +216,7 @@ TEST_F(PeerConnectionFactoryTest, CreatePCUsingIceServersOldSignature) { rtc::scoped_refptr pc( factory_->CreatePeerConnection(ice_servers, NULL, allocator_factory_.get(), - NULL, + new FakeIdentityService(), &observer_)); EXPECT_TRUE(pc.get() != NULL); StunConfigurations stun_configs; @@ -244,7 +246,7 @@ TEST_F(PeerConnectionFactoryTest, CreatePCUsingNoUsernameInUri) { rtc::scoped_refptr pc( factory_->CreatePeerConnection(config, NULL, allocator_factory_.get(), - NULL, + new FakeIdentityService(), &observer_)); EXPECT_TRUE(pc.get() != NULL); TurnConfigurations turn_configs; @@ -265,7 +267,7 @@ TEST_F(PeerConnectionFactoryTest, CreatePCUsingTurnUrlWithTransportParam) { rtc::scoped_refptr pc( factory_->CreatePeerConnection(config, NULL, allocator_factory_.get(), - NULL, + new FakeIdentityService(), &observer_)); EXPECT_TRUE(pc.get() != NULL); TurnConfigurations turn_configs; @@ -290,7 +292,7 @@ TEST_F(PeerConnectionFactoryTest, CreatePCUsingSecureTurnUrl) { rtc::scoped_refptr pc( factory_->CreatePeerConnection(config, NULL, allocator_factory_.get(), - NULL, + new FakeIdentityService(), &observer_)); EXPECT_TRUE(pc.get() != NULL); TurnConfigurations turn_configs; @@ -327,7 +329,7 @@ TEST_F(PeerConnectionFactoryTest, CreatePCUsingIPLiteralAddress) { rtc::scoped_refptr pc( factory_->CreatePeerConnection(config, NULL, allocator_factory_.get(), - NULL, + new FakeIdentityService(), &observer_)); EXPECT_TRUE(pc.get() != NULL); StunConfigurations stun_configs; @@ -377,3 +379,4 @@ TEST_F(PeerConnectionFactoryTest, LocalRendering) { EXPECT_TRUE(capturer->CaptureFrame()); EXPECT_EQ(2, local_renderer.num_rendered_frames()); } +