From bef8d2d0208768d006e639701658ad4c9e73a69f Mon Sep 17 00:00:00 2001 From: Jiayang Liu Date: Thu, 26 Mar 2015 14:38:46 -0700 Subject: [PATCH] Add a lock to NSSContext to fix data race BUG=crbug/466784 R=juberti@webrtc.org, tommi@webrtc.org Review URL: https://webrtc-codereview.appspot.com/44669005 Cr-Commit-Position: refs/heads/master@{#8871} --- webrtc/base/BUILD.gn | 1 + webrtc/base/base.gyp | 1 + webrtc/base/criticalsection.cc | 33 ++++ webrtc/base/criticalsection.h | 25 +++ webrtc/base/criticalsection_unittest.cc | 205 ++++++++++++++++++++---- webrtc/base/nssstreamadapter.cc | 12 +- webrtc/base/nssstreamadapter.h | 5 +- 7 files changed, 241 insertions(+), 41 deletions(-) create mode 100644 webrtc/base/criticalsection.cc diff --git a/webrtc/base/BUILD.gn b/webrtc/base/BUILD.gn index fba03aa9c..47b05be5e 100644 --- a/webrtc/base/BUILD.gn +++ b/webrtc/base/BUILD.gn @@ -188,6 +188,7 @@ static_library("rtc_base") { "cpumonitor.h", "crc32.cc", "crc32.h", + "criticalsection.cc", "criticalsection.h", "cryptstring.cc", "cryptstring.h", diff --git a/webrtc/base/base.gyp b/webrtc/base/base.gyp index 5da97a6b6..ac2352941 100644 --- a/webrtc/base/base.gyp +++ b/webrtc/base/base.gyp @@ -114,6 +114,7 @@ 'cpumonitor.h', 'crc32.cc', 'crc32.h', + 'criticalsection.cc', 'criticalsection.h', 'cryptstring.cc', 'cryptstring.h', diff --git a/webrtc/base/criticalsection.cc b/webrtc/base/criticalsection.cc new file mode 100644 index 000000000..fcad5c31d --- /dev/null +++ b/webrtc/base/criticalsection.cc @@ -0,0 +1,33 @@ +/* + * Copyright 2015 The WebRTC Project Authors. All rights reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include "webrtc/base/criticalsection.h" + +#include "webrtc/base/checks.h" +#include "webrtc/base/thread.h" + +namespace rtc { + +void GlobalLockPod::Lock() { + while (AtomicOps::CompareAndSwap(&lock_acquired, 0, 1)) { + Thread::SleepMs(0); + } +} + +void GlobalLockPod::Unlock() { + int old_value = AtomicOps::CompareAndSwap(&lock_acquired, 1, 0); + DCHECK_EQ(1, old_value) << "Unlock called without calling Lock first"; +} + +GlobalLock::GlobalLock() { + lock_acquired = 0; +} + +} // namespace rtc diff --git a/webrtc/base/criticalsection.h b/webrtc/base/criticalsection.h index 8d6ddbe88..db197d2e0 100644 --- a/webrtc/base/criticalsection.h +++ b/webrtc/base/criticalsection.h @@ -169,6 +169,11 @@ class AtomicOps { static void Store(volatile int* i, int value) { *i = value; } + static int CompareAndSwap(volatile int* i, int old_value, int new_value) { + return ::InterlockedCompareExchange(reinterpret_cast(i), + new_value, + old_value); + } #else static int Increment(volatile int* i) { return __sync_add_and_fetch(i, 1); @@ -184,9 +189,29 @@ class AtomicOps { __sync_synchronize(); *i = value; } + static int CompareAndSwap(volatile int* i, int old_value, int new_value) { + return __sync_val_compare_and_swap(i, old_value, new_value); + } #endif }; + +// A POD lock used to protect global variables. Do NOT use for other purposes. +// No custom constructor or private data member should be added. +class LOCKABLE GlobalLockPod { + public: + void Lock() EXCLUSIVE_LOCK_FUNCTION(); + + void Unlock() UNLOCK_FUNCTION(); + + volatile int lock_acquired; +}; + +class GlobalLock : public GlobalLockPod { + public: + GlobalLock(); +}; + } // namespace rtc #endif // WEBRTC_BASE_CRITICALSECTION_H__ diff --git a/webrtc/base/criticalsection_unittest.cc b/webrtc/base/criticalsection_unittest.cc index 8c36e3d50..6f3c7e931 100644 --- a/webrtc/base/criticalsection_unittest.cc +++ b/webrtc/base/criticalsection_unittest.cc @@ -26,16 +26,54 @@ const int kLongTime = 10000; // 10 seconds const int kNumThreads = 16; const int kOperationsToRun = 1000; -template -class AtomicOpRunner : public MessageHandler { +class UniqueValueVerifier { public: - explicit AtomicOpRunner(int initial_value) - : value_(initial_value), - threads_active_(0), - start_event_(true, false), - done_event_(true, false) {} + void Verify(const std::vector& values) { + for (size_t i = 0; i < values.size(); ++i) { + std::pair::iterator, bool> result = + all_values_.insert(values[i]); + // Each value should only be taken by one thread, so if this value + // has already been added, something went wrong. + EXPECT_TRUE(result.second) + << " Thread=" << Thread::Current() << " value=" << values[i]; + } + } - int value() const { return value_; } + void Finalize() {} + + private: + std::set all_values_; +}; + +class CompareAndSwapVerifier { + public: + CompareAndSwapVerifier() : zero_count_(0) {} + + void Verify(const std::vector& values) { + for (auto v : values) { + if (v == 0) { + EXPECT_EQ(0, zero_count_) << "Thread=" << Thread::Current(); + ++zero_count_; + } else { + EXPECT_EQ(1, v) << " Thread=" << Thread::Current(); + } + } + } + + void Finalize() { + EXPECT_EQ(1, zero_count_); + } + private: + int zero_count_; +}; + +class RunnerBase : public MessageHandler { + public: + explicit RunnerBase(int value) + : threads_active_(0), + start_event_(true, false), + done_event_(true, false), + shared_value_(value) {} bool Run() { // Signal all threads to start. @@ -49,43 +87,101 @@ class AtomicOpRunner : public MessageHandler { threads_active_ = count; } - virtual void OnMessage(Message* msg) { + int shared_value() const { return shared_value_; } + + protected: + // Derived classes must override OnMessage, and call BeforeStart and AfterEnd + // at the beginning and the end of OnMessage respectively. + void BeforeStart() { + ASSERT_TRUE(start_event_.Wait(kLongTime)); + } + + // Returns true if all threads have finished. + bool AfterEnd() { + if (AtomicOps::Decrement(&threads_active_) == 0) { + done_event_.Set(); + return true; + } + return false; + } + + int threads_active_; + Event start_event_; + Event done_event_; + int shared_value_; +}; + +class LOCKABLE CriticalSectionLock { + public: + void Lock() EXCLUSIVE_LOCK_FUNCTION() { + cs_.Enter(); + } + void Unlock() UNLOCK_FUNCTION() { + cs_.Leave(); + } + + private: + CriticalSection cs_; +}; + +template +class LockRunner : public RunnerBase { + public: + LockRunner() : RunnerBase(0) {} + + void OnMessage(Message* msg) override { + BeforeStart(); + + lock_.Lock(); + + EXPECT_EQ(0, shared_value_); + int old = shared_value_; + + // Use a loop to increase the chance of race. + for (int i = 0; i < kOperationsToRun; ++i) { + ++shared_value_; + } + EXPECT_EQ(old + kOperationsToRun, shared_value_); + shared_value_ = 0; + + lock_.Unlock(); + + AfterEnd(); + } + + private: + Lock lock_; +}; + +template +class AtomicOpRunner : public RunnerBase { + public: + explicit AtomicOpRunner(int initial_value) : RunnerBase(initial_value) {} + + void OnMessage(Message* msg) override { + BeforeStart(); + std::vector values; values.reserve(kOperationsToRun); - // Wait to start. - ASSERT_TRUE(start_event_.Wait(kLongTime)); - - // Generate a bunch of values by updating value_ atomically. + // Generate a bunch of values by updating shared_value_ atomically. for (int i = 0; i < kOperationsToRun; ++i) { - values.push_back(T::AtomicOp(&value_)); + values.push_back(Op::AtomicOp(&shared_value_)); } { // Add them all to the set. CritScope cs(&all_values_crit_); - for (size_t i = 0; i < values.size(); ++i) { - std::pair::iterator, bool> result = - all_values_.insert(values[i]); - // Each value should only be taken by one thread, so if this value - // has already been added, something went wrong. - EXPECT_TRUE(result.second) - << "Thread=" << Thread::Current() << " value=" << values[i]; - } + verifier_.Verify(values); } - // Signal that we're done. - if (AtomicOps::Decrement(&threads_active_) == 0) { - done_event_.Set(); + if (AfterEnd()) { + verifier_.Finalize(); } } private: - int value_; - int threads_active_; CriticalSection all_values_crit_; - std::set all_values_; - Event start_event_; - Event done_event_; + Verifier verifier_; }; struct IncrementOp { @@ -96,6 +192,10 @@ struct DecrementOp { static int AtomicOp(int* i) { return AtomicOps::Decrement(i); } }; +struct CompareAndSwapOp { + static int AtomicOp(int* i) { return AtomicOps::CompareAndSwap(i, 0, 1); } +}; + void StartThreads(ScopedPtrCollection* threads, MessageHandler* handler) { for (int i = 0; i < kNumThreads; ++i) { @@ -122,26 +222,63 @@ TEST(AtomicOpsTest, Simple) { TEST(AtomicOpsTest, Increment) { // Create and start lots of threads. - AtomicOpRunner runner(0); + AtomicOpRunner runner(0); ScopedPtrCollection threads; StartThreads(&threads, &runner); runner.SetExpectedThreadCount(kNumThreads); // Release the hounds! EXPECT_TRUE(runner.Run()); - EXPECT_EQ(kOperationsToRun * kNumThreads, runner.value()); + EXPECT_EQ(kOperationsToRun * kNumThreads, runner.shared_value()); } TEST(AtomicOpsTest, Decrement) { // Create and start lots of threads. - AtomicOpRunner runner(kOperationsToRun * kNumThreads); + AtomicOpRunner runner( + kOperationsToRun * kNumThreads); ScopedPtrCollection threads; StartThreads(&threads, &runner); runner.SetExpectedThreadCount(kNumThreads); // Release the hounds! EXPECT_TRUE(runner.Run()); - EXPECT_EQ(0, runner.value()); + EXPECT_EQ(0, runner.shared_value()); +} + +TEST(AtomicOpsTest, CompareAndSwap) { + // Create and start lots of threads. + AtomicOpRunner runner(0); + ScopedPtrCollection threads; + StartThreads(&threads, &runner); + runner.SetExpectedThreadCount(kNumThreads); + + // Release the hounds! + EXPECT_TRUE(runner.Run()); + EXPECT_EQ(1, runner.shared_value()); +} + +TEST(GlobalLockTest, Basic) { + // Create and start lots of threads. + LockRunner runner; + ScopedPtrCollection threads; + StartThreads(&threads, &runner); + runner.SetExpectedThreadCount(kNumThreads); + + // Release the hounds! + EXPECT_TRUE(runner.Run()); + EXPECT_EQ(0, runner.shared_value()); +} + +TEST(CriticalSectionTest, Basic) { + // Create and start lots of threads. + LockRunner runner; + ScopedPtrCollection threads; + StartThreads(&threads, &runner); + runner.SetExpectedThreadCount(kNumThreads); + + // Release the hounds! + EXPECT_TRUE(runner.Run()); + EXPECT_EQ(0, runner.shared_value()); } } // namespace rtc diff --git a/webrtc/base/nssstreamadapter.cc b/webrtc/base/nssstreamadapter.cc index 044d00ba3..fe1692cc3 100644 --- a/webrtc/base/nssstreamadapter.cc +++ b/webrtc/base/nssstreamadapter.cc @@ -978,25 +978,27 @@ bool NSSStreamAdapter::GetDtlsSrtpCipher(std::string* cipher) { } -bool NSSContext::initialized; +GlobalLockPod NSSContext::lock; NSSContext *NSSContext::global_nss_context; // Static initialization and shutdown NSSContext *NSSContext::Instance() { + lock.Lock(); if (!global_nss_context) { - scoped_ptr new_ctx(new NSSContext()); - new_ctx->slot_ = PK11_GetInternalSlot(); + scoped_ptr new_ctx(new NSSContext(PK11_GetInternalSlot())); if (new_ctx->slot_) global_nss_context = new_ctx.release(); } + lock.Unlock(); + return global_nss_context; } - - bool NSSContext::InitializeSSL(VerificationCallback callback) { ASSERT(!callback); + static bool initialized = false; + if (!initialized) { SECStatus rv; diff --git a/webrtc/base/nssstreamadapter.h b/webrtc/base/nssstreamadapter.h index 8b588852a..fcacb9539 100644 --- a/webrtc/base/nssstreamadapter.h +++ b/webrtc/base/nssstreamadapter.h @@ -19,6 +19,7 @@ #include "secmodt.h" #include "webrtc/base/buffer.h" +#include "webrtc/base/criticalsection.h" #include "webrtc/base/nssidentity.h" #include "webrtc/base/ssladapter.h" #include "webrtc/base/sslstreamadapter.h" @@ -29,7 +30,7 @@ namespace rtc { // Singleton class NSSContext { public: - NSSContext() {} + explicit NSSContext(PK11SlotInfo* slot) : slot_(slot) {} ~NSSContext() { } @@ -44,7 +45,7 @@ class NSSContext { private: PK11SlotInfo *slot_; // The PKCS-11 slot - static bool initialized; // Was this initialized? + static GlobalLockPod lock; // To protect the global context static NSSContext *global_nss_context; // The global context };