diff --git a/webrtc/base/asyncinvoker.cc b/webrtc/base/asyncinvoker.cc index 35ce2fb12..56f1a19d5 100644 --- a/webrtc/base/asyncinvoker.cc +++ b/webrtc/base/asyncinvoker.cc @@ -64,6 +64,18 @@ void AsyncInvoker::DoInvoke(Thread* thread, thread->Post(this, id, new ScopedRefMessageData(closure)); } +void AsyncInvoker::DoInvokeDelayed(Thread* thread, + const scoped_refptr& closure, + uint32 delay_ms, + uint32 id) { + if (destroying_) { + LOG(LS_WARNING) << "Tried to invoke while destroying the invoker."; + return; + } + thread->PostDelayed(delay_ms, this, id, + new ScopedRefMessageData(closure)); +} + NotifyingAsyncClosureBase::NotifyingAsyncClosureBase(AsyncInvoker* invoker, Thread* calling_thread) : invoker_(invoker), calling_thread_(calling_thread) { diff --git a/webrtc/base/asyncinvoker.h b/webrtc/base/asyncinvoker.h index 6e298e60c..fe1506d99 100644 --- a/webrtc/base/asyncinvoker.h +++ b/webrtc/base/asyncinvoker.h @@ -82,6 +82,18 @@ class AsyncInvoker : public MessageHandler { DoInvoke(thread, closure, id); } + // Call |functor| asynchronously on |thread| with |delay_ms|, with no callback + // upon completion. Returns immediately. + template + void AsyncInvokeDelayed(Thread* thread, + const FunctorT& functor, + uint32 delay_ms, + uint32 id = 0) { + scoped_refptr closure( + new RefCountedObject >(functor)); + DoInvokeDelayed(thread, closure, delay_ms, id); + } + // Call |functor| asynchronously on |thread|, calling |callback| when done. template void AsyncInvoke(Thread* thread, @@ -123,7 +135,10 @@ class AsyncInvoker : public MessageHandler { void OnMessage(Message* msg) override; void DoInvoke(Thread* thread, const scoped_refptr& closure, uint32 id); - + void DoInvokeDelayed(Thread* thread, + const scoped_refptr& closure, + uint32 delay_ms, + uint32 id); bool destroying_; DISALLOW_COPY_AND_ASSIGN(AsyncInvoker); diff --git a/webrtc/base/base.gyp b/webrtc/base/base.gyp index cc1ebf43d..66571f804 100644 --- a/webrtc/base/base.gyp +++ b/webrtc/base/base.gyp @@ -375,9 +375,6 @@ '../../boringssl/src/include', ], 'sources!': [ - 'asyncinvoker.cc', - 'asyncinvoker.h', - 'asyncinvoker-inl.h', 'atomicops.h', 'bandwidthsmoother.cc', 'bandwidthsmoother.h', diff --git a/webrtc/p2p/stunprober/main.cc b/webrtc/p2p/stunprober/main.cc index 99f66caa9..076113ce6 100644 --- a/webrtc/p2p/stunprober/main.cc +++ b/webrtc/p2p/stunprober/main.cc @@ -14,29 +14,22 @@ #include #include - #include "webrtc/base/checks.h" #include "webrtc/base/flags.h" #include "webrtc/base/helpers.h" #include "webrtc/base/nethelpers.h" +#include "webrtc/base/network.h" #include "webrtc/base/logging.h" #include "webrtc/base/scoped_ptr.h" #include "webrtc/base/ssladapter.h" #include "webrtc/base/stringutils.h" #include "webrtc/base/thread.h" #include "webrtc/base/timeutils.h" +#include "webrtc/p2p/base/basicpacketsocketfactory.cc" #include "webrtc/p2p/stunprober/stunprober.h" -#include "webrtc/p2p/stunprober/stunprober_dependencies.h" -using stunprober::HostNameResolverInterface; -using stunprober::TaskRunner; -using stunprober::SocketFactory; using stunprober::StunProber; using stunprober::AsyncCallback; -using stunprober::ClientSocketInterface; -using stunprober::ServerSocketInterface; -using stunprober::SocketFactory; -using stunprober::TaskRunner; DEFINE_bool(help, false, "Prints this message"); DEFINE_int(interval, 10, "Interval of consecutive stun pings in milliseconds"); @@ -54,58 +47,6 @@ DEFINE_string( namespace { -class HostNameResolver : public HostNameResolverInterface, - public sigslot::has_slots<> { - public: - HostNameResolver() {} - virtual ~HostNameResolver() {} - - void Resolve(const rtc::SocketAddress& addr, - std::vector* addresses, - AsyncCallback callback) override { - resolver_ = new rtc::AsyncResolver(); - DCHECK(callback_.empty()); - addr_ = addr; - callback_ = callback; - result_ = addresses; - resolver_->SignalDone.connect(this, &HostNameResolver::OnResolveResult); - resolver_->Start(addr); - } - - void OnResolveResult(rtc::AsyncResolverInterface* resolver) { - DCHECK(resolver); - int rv = resolver_->GetError(); - LOG(LS_INFO) << "ResolveResult for " << addr_.ToString() << " : " << rv; - if (rv == 0 && result_) { - for (auto addr : resolver_->addresses()) { - rtc::SocketAddress ip(addr, addr_.port()); - result_->push_back(ip); - LOG(LS_INFO) << "\t" << ip.ToString(); - } - } - if (!callback_.empty()) { - // Need to be the last statement as the object could be deleted by the - // callback_ in the failure case. - AsyncCallback callback = callback_; - callback_ = AsyncCallback(); - - // rtc::AsyncResolver inherits from SignalThread which requires explicit - // Release(). - resolver_->Release(); - resolver_ = nullptr; - callback(rv); - } - } - - private: - AsyncCallback callback_; - rtc::SocketAddress addr_; - std::vector* result_; - - // Not using smart ptr here as this requires specific release pattern. - rtc::AsyncResolver* resolver_; -}; - const char* PrintNatType(stunprober::NatType type) { switch (type) { case stunprober::NATTYPE_NONE: @@ -178,10 +119,17 @@ int main(int argc, char** argv) { rtc::InitializeSSL(); rtc::InitRandom(rtc::Time()); rtc::Thread* thread = rtc::ThreadManager::Instance()->WrapCurrentThread(); - StunProber* prober = new StunProber(new HostNameResolver(), - new SocketFactory(), new TaskRunner()); - auto finish_callback = - [thread, prober](int result) { StopTrial(thread, prober, result); }; + rtc::scoped_ptr socket_factory( + new rtc::BasicPacketSocketFactory()); + rtc::scoped_ptr network_manager( + new rtc::BasicNetworkManager()); + rtc::NetworkManager::NetworkList networks; + network_manager->GetNetworks(&networks); + StunProber* prober = + new StunProber(socket_factory.get(), rtc::Thread::Current(), networks); + auto finish_callback = [thread](StunProber* prober, int result) { + StopTrial(thread, prober, result); + }; prober->Start(server_addresses, FLAG_shared_socket, FLAG_interval, FLAG_pings_per_ip, FLAG_timeout, AsyncCallback(finish_callback)); diff --git a/webrtc/p2p/stunprober/stunprober.cc b/webrtc/p2p/stunprober/stunprober.cc index 8efe069b0..4a33444cc 100644 --- a/webrtc/p2p/stunprober/stunprober.cc +++ b/webrtc/p2p/stunprober/stunprober.cc @@ -13,10 +13,15 @@ #include #include +#include "webrtc/base/asyncpacketsocket.h" +#include "webrtc/base/asyncresolverinterface.h" #include "webrtc/base/bind.h" #include "webrtc/base/checks.h" #include "webrtc/base/helpers.h" +#include "webrtc/base/logging.h" #include "webrtc/base/timeutils.h" +#include "webrtc/base/thread.h" +#include "webrtc/p2p/base/packetsocketfactory.h" #include "webrtc/p2p/base/stun.h" #include "webrtc/p2p/stunprober/stunprober.h" @@ -24,20 +29,18 @@ namespace stunprober { namespace { +const int thread_wake_up_interval_ms = 5; + template void IncrementCounterByAddress(std::map* counter_per_ip, const T& ip) { counter_per_ip->insert(std::make_pair(ip, 0)).first->second++; } -bool behind_nat(NatType nat_type) { - return nat_type > stunprober::NATTYPE_NONE; -} - } // namespace // A requester tracks the requests and responses from a single socket to many // STUN servers -class StunProber::Requester { +class StunProber::Requester : public sigslot::has_slots<> { public: // Each Request maps to a request and response. struct Request { @@ -46,26 +49,20 @@ class StunProber::Requester { // Time the response was received. int64 received_time_ms = 0; - // See whether the observed address returned matches the - // local address as in StunProber.local_addr_. - bool behind_nat = false; - // Server reflexive address from STUN response for this given request. rtc::SocketAddress srflx_addr; rtc::IPAddress server_addr; int64 rtt() { return received_time_ms - sent_time_ms; } - void ProcessResponse(rtc::ByteBuffer* message, - int buf_len, - const rtc::IPAddress& local_addr); + void ProcessResponse(const char* buf, size_t buf_len); }; // StunProber provides |server_ips| for Requester to probe. For shared // socket mode, it'll be all the resolved IP addresses. For non-shared mode, // it'll just be a single address. Requester(StunProber* prober, - ServerSocketInterface* socket, + rtc::AsyncPacketSocket* socket, const std::vector& server_ips); virtual ~Requester(); @@ -74,11 +71,11 @@ class StunProber::Requester { // and move to the next one. void SendStunRequest(); - void ReadStunResponse(); - - // |result| is the positive return value from RecvFrom when data is - // available. - void OnStunResponseReceived(int result); + void OnStunResponseReceived(rtc::AsyncPacketSocket* socket, + const char* buf, + size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketTime& time); const std::vector& requests() { return requests_; } @@ -93,7 +90,7 @@ class StunProber::Requester { StunProber* prober_; // The socket for this session. - rtc::scoped_ptr socket_; + rtc::scoped_ptr socket_; // Temporary SocketAddress and buffer for RecvFrom. rtc::SocketAddress addr_; @@ -111,13 +108,15 @@ class StunProber::Requester { StunProber::Requester::Requester( StunProber* prober, - ServerSocketInterface* socket, + rtc::AsyncPacketSocket* socket, const std::vector& server_ips) : prober_(prober), socket_(socket), response_packet_(new rtc::ByteBuffer(nullptr, kMaxUdpBufferSize)), server_ips_(server_ips), thread_checker_(prober->thread_checker_) { + socket_->SignalReadPacket.connect( + this, &StunProber::Requester::OnStunResponseReceived); } StunProber::Requester::~Requester() { @@ -145,7 +144,7 @@ void StunProber::Requester::SendStunRequest() { rtc::scoped_ptr request_packet( new rtc::ByteBuffer(nullptr, kMaxUdpBufferSize)); if (!message.Write(request_packet.get())) { - prober_->End(WRITE_FAILED, 0); + prober_->End(WRITE_FAILED); return; } @@ -155,48 +154,26 @@ void StunProber::Requester::SendStunRequest() { // The write must succeed immediately. Otherwise, the calculating of the STUN // request timing could become too complicated. Callback is ignored by passing // empty AsyncCallback. - int rv = socket_->SendTo(addr, const_cast(request_packet->Data()), - request_packet->Length(), AsyncCallback()); + rtc::PacketOptions options; + int rv = socket_->SendTo(const_cast(request_packet->Data()), + request_packet->Length(), addr, options); if (rv < 0) { - prober_->End(WRITE_FAILED, rv); + prober_->End(WRITE_FAILED); return; } request.sent_time_ms = rtc::Time(); - // Post a read waiting for response. For share mode, the subsequent read will - // be posted inside OnStunResponseReceived. - if (num_request_sent_ == 0) { - ReadStunResponse(); - } - num_request_sent_++; DCHECK(static_cast(num_request_sent_) <= server_ips_.size()); } -void StunProber::Requester::ReadStunResponse() { - DCHECK(thread_checker_.CalledOnValidThread()); - if (!socket_) { - return; - } - - int rv = socket_->RecvFrom( - response_packet_->ReserveWriteBuffer(kMaxUdpBufferSize), - kMaxUdpBufferSize, &addr_, - [this](int result) { this->OnStunResponseReceived(result); }); - if (rv != SocketInterface::IO_PENDING) { - OnStunResponseReceived(rv); - } -} - -void StunProber::Requester::Request::ProcessResponse( - rtc::ByteBuffer* message, - int buf_len, - const rtc::IPAddress& local_addr) { +void StunProber::Requester::Request::ProcessResponse(const char* buf, + size_t buf_len) { int64 now = rtc::Time(); - + rtc::ByteBuffer message(buf, buf_len); cricket::StunMessage stun_response; - if (!stun_response.Read(message)) { + if (!stun_response.Read(&message)) { // Invalid or incomplete STUN packet. received_time_ms = 0; return; @@ -218,40 +195,25 @@ void StunProber::Requester::Request::ProcessResponse( received_time_ms = now; srflx_addr = addr_attr->GetAddress(); - - // Calculate behind_nat. - behind_nat = (srflx_addr.ipaddr() != local_addr); } -void StunProber::Requester::OnStunResponseReceived(int result) { +void StunProber::Requester::OnStunResponseReceived( + rtc::AsyncPacketSocket* socket, + const char* buf, + size_t size, + const rtc::SocketAddress& addr, + const rtc::PacketTime& time) { DCHECK(thread_checker_.CalledOnValidThread()); DCHECK(socket_); - - if (result < 0) { - // Something is wrong, finish the test. - prober_->End(READ_FAILED, result); - return; - } - - Request* request = GetRequestByAddress(addr_.ipaddr()); + Request* request = GetRequestByAddress(addr.ipaddr()); if (!request) { // Something is wrong, finish the test. - prober_->End(GENERIC_FAILURE, result); + prober_->End(GENERIC_FAILURE); return; } num_response_received_++; - - // Resize will set the end_ to indicate that there are data available in this - // ByteBuffer. - response_packet_->Resize(result); - request->ProcessResponse(response_packet_.get(), result, - prober_->local_addr_); - - if (static_cast(num_response_received_) < server_ips_.size()) { - // Post another read response. - ReadStunResponse(); - } + request->ProcessResponse(buf, size); } StunProber::Requester::Request* StunProber::Requester::GetRequestByAddress( @@ -266,13 +228,13 @@ StunProber::Requester::Request* StunProber::Requester::GetRequestByAddress( return nullptr; } -StunProber::StunProber(HostNameResolverInterface* host_name_resolver, - SocketFactoryInterface* socket_factory, - TaskRunnerInterface* task_runner) +StunProber::StunProber(rtc::PacketSocketFactory* socket_factory, + rtc::Thread* thread, + const rtc::NetworkManager::NetworkList& networks) : interval_ms_(0), socket_factory_(socket_factory), - resolver_(host_name_resolver), - task_runner_(task_runner) { + thread_(thread), + networks_(networks) { } StunProber::~StunProber() { @@ -281,6 +243,11 @@ StunProber::~StunProber() { delete req; } } + for (auto s : sockets_) { + if (s) { + delete s; + } + } } bool StunProber::Start(const std::vector& servers, @@ -301,92 +268,98 @@ bool StunProber::Start(const std::vector& servers, timeout_ms_ = timeout_ms; servers_ = servers; finished_callback_ = callback; - resolver_->Resolve(servers_[0], &resolved_ips_, - [this](int result) { this->OnServerResolved(0, result); }); + return ResolveServerName(servers_.back()); +} + +bool StunProber::ResolveServerName(const rtc::SocketAddress& addr) { + rtc::AsyncResolverInterface* resolver = + socket_factory_->CreateAsyncResolver(); + if (!resolver) { + return false; + } + resolver->SignalDone.connect(this, &StunProber::OnServerResolved); + resolver->Start(addr); return true; } -void StunProber::OnServerResolved(int index, int result) { +void StunProber::OnSocketReady(rtc::AsyncPacketSocket* socket, + const rtc::SocketAddress& addr) { + total_ready_sockets_++; + if (total_ready_sockets_ == total_socket_required()) { + MaybeScheduleStunRequests(); + } +} + +void StunProber::OnServerResolved(rtc::AsyncResolverInterface* resolver) { DCHECK(thread_checker_.CalledOnValidThread()); - if (result == 0) { - all_servers_ips_.insert(all_servers_ips_.end(), resolved_ips_.begin(), - resolved_ips_.end()); - resolved_ips_.clear(); + if (resolver->GetError() == 0) { + rtc::SocketAddress addr(resolver->address().ipaddr(), + resolver->address().port()); + all_servers_addrs_.push_back(addr); } - index++; + // Deletion of AsyncResolverInterface can't be done in OnResolveResult which + // handles SignalDone. + invoker_.AsyncInvoke( + thread_, + rtc::Bind(&rtc::AsyncResolverInterface::Destroy, resolver, false)); + servers_.pop_back(); - if (static_cast(index) < servers_.size()) { - resolver_->Resolve( - servers_[index], &resolved_ips_, - [this, index](int result) { this->OnServerResolved(index, result); }); + if (servers_.size()) { + if (!ResolveServerName(servers_.back())) { + End(RESOLVE_FAILED); + } return; } - if (all_servers_ips_.size() == 0) { - End(RESOLVE_FAILED, result); + if (all_servers_addrs_.size() == 0) { + End(RESOLVE_FAILED); return; } // Dedupe. - std::set addrs(all_servers_ips_.begin(), - all_servers_ips_.end()); - all_servers_ips_.assign(addrs.begin(), addrs.end()); + std::set addrs(all_servers_addrs_.begin(), + all_servers_addrs_.end()); + all_servers_addrs_.assign(addrs.begin(), addrs.end()); - rtc::IPAddress addr; - if (GetLocalAddress(&addr) != 0) { - End(GENERIC_FAILURE, result); - return; - } - - socket_factory_->Prepare(GetTotalClientSockets(), GetTotalServerSockets(), - [this](int result) { - if (result == 0) { - this->MaybeScheduleStunRequests(); - } - }); -} - -int StunProber::GetLocalAddress(rtc::IPAddress* addr) { - DCHECK(thread_checker_.CalledOnValidThread()); - if (local_addr_.family() == AF_UNSPEC) { - rtc::SocketAddress sock_addr; - rtc::scoped_ptr socket( - socket_factory_->CreateClientSocket()); - int rv = socket->Connect(all_servers_ips_[0]); - if (rv != SUCCESS) { - End(GENERIC_FAILURE, rv); - return rv; + // Prepare all the sockets beforehand. All of them will bind to "any" address. + while (sockets_.size() < total_socket_required()) { + rtc::scoped_ptr socket( + socket_factory_->CreateUdpSocket(rtc::SocketAddress(INADDR_ANY, 0), 0, + 0)); + if (!socket) { + End(GENERIC_FAILURE); + return; } - rv = socket->GetLocalAddress(&sock_addr); - if (rv != SUCCESS) { - End(GENERIC_FAILURE, rv); - return rv; + // Chrome and WebRTC behave differently in terms of the state of a socket + // once returned from PacketSocketFactory::CreateUdpSocket. + if (socket->GetState() == rtc::AsyncPacketSocket::STATE_BINDING) { + socket->SignalAddressReady.connect(this, &StunProber::OnSocketReady); + } else { + OnSocketReady(socket.get(), rtc::SocketAddress(INADDR_ANY, 0)); } - local_addr_ = sock_addr.ipaddr(); - socket->Close(); + sockets_.push_back(socket.release()); } - *addr = local_addr_; - return 0; } StunProber::Requester* StunProber::CreateRequester() { DCHECK(thread_checker_.CalledOnValidThread()); - rtc::scoped_ptr socket( - socket_factory_->CreateServerSocket(kMaxUdpBufferSize, - kMaxUdpBufferSize)); - if (!socket) { + if (!sockets_.size()) { return nullptr; } + StunProber::Requester* requester; if (shared_socket_mode_) { - return new Requester(this, socket.release(), all_servers_ips_); + requester = new Requester(this, sockets_.back(), all_servers_addrs_); } else { std::vector server_ip; server_ip.push_back( - all_servers_ips_[(num_request_sent_ % all_servers_ips_.size())]); - return new Requester(this, socket.release(), server_ip); + all_servers_addrs_[(num_request_sent_ % all_servers_addrs_.size())]); + requester = new Requester(this, sockets_.back(), server_ip); } + + sockets_.pop_back(); + return requester; } bool StunProber::SendNextRequest() { @@ -407,19 +380,20 @@ void StunProber::MaybeScheduleStunRequests() { uint32 now = rtc::Time(); if (Done()) { - task_runner_->PostTask(rtc::Bind(&StunProber::End, this, SUCCESS, 0), - timeout_ms_); + invoker_.AsyncInvokeDelayed( + thread_, rtc::Bind(&StunProber::End, this, SUCCESS), timeout_ms_); return; } - if (now >= next_request_time_ms_) { + if ((now + (thread_wake_up_interval_ms / 2)) >= next_request_time_ms_) { if (!SendNextRequest()) { - End(GENERIC_FAILURE, 0); + End(GENERIC_FAILURE); return; } next_request_time_ms_ = now + interval_ms_; } - task_runner_->PostTask( - rtc::Bind(&StunProber::MaybeScheduleStunRequests, this), 1 /* ms */); + invoker_.AsyncInvokeDelayed( + thread_, rtc::Bind(&StunProber::MaybeScheduleStunRequests, this), + thread_wake_up_interval_ms /* ms */); } bool StunProber::GetStats(StunProber::Stats* prob_stats) const { @@ -464,14 +438,7 @@ bool StunProber::GetStats(StunProber::Stats* prob_stats) const { IncrementCounterByAddress(&num_response_per_server, request->server_addr); IncrementCounterByAddress(&num_response_per_srflx_addr, request->srflx_addr); - rtt_sum += request->rtt(); - if (nat_type == NATTYPE_INVALID) { - nat_type = request->behind_nat ? NATTYPE_UNKNOWN : NATTYPE_NONE; - } else if (behind_nat(nat_type) != request->behind_nat) { - // Detect the inconsistency in NAT presence. - return false; - } stats.srflx_addrs.insert(request->srflx_addr.ToString()); srflx_ips.insert(request->srflx_addr.ipaddr()); } @@ -505,18 +472,34 @@ bool StunProber::GetStats(StunProber::Stats* prob_stats) const { return false; } - stats.nat_type = nat_type; - // Shared mode is only true if we use the shared socket and there are more // than 1 responding servers. stats.shared_socket_mode = shared_socket_mode_ && (num_server_ip_with_response > 1); - if (stats.shared_socket_mode && nat_type == NATTYPE_UNKNOWN) { - stats.nat_type = NATTYPE_NON_SYMMETRIC; + if (stats.shared_socket_mode && nat_type == NATTYPE_INVALID) { + nat_type = NATTYPE_NON_SYMMETRIC; } - stats.host_ip = local_addr_.ToString(); + // If we could find a local IP matching srflx, we're not behind a NAT. + rtc::SocketAddress srflx_addr; + if (!srflx_addr.FromString(*(stats.srflx_addrs.begin()))) { + return false; + } + for (const auto& net : networks_) { + if (srflx_addr.ipaddr() == net->GetBestIP()) { + nat_type = stunprober::NATTYPE_NONE; + stats.host_ip = net->GetBestIP().ToString(); + break; + } + } + + // Finally, we know we're behind a NAT but can't determine which type it is. + if (nat_type == NATTYPE_INVALID) { + nat_type = NATTYPE_UNKNOWN; + } + + stats.nat_type = nat_type; stats.num_request_sent = num_sent; stats.num_response_received = num_received; stats.target_request_interval_ns = interval_ms_ * 1000; @@ -538,14 +521,14 @@ bool StunProber::GetStats(StunProber::Stats* prob_stats) const { return true; } -void StunProber::End(StunProber::Status status, int result) { +void StunProber::End(StunProber::Status status) { DCHECK(thread_checker_.CalledOnValidThread()); if (!finished_callback_.empty()) { AsyncCallback callback = finished_callback_; finished_callback_ = AsyncCallback(); // Callback at the last since the prober might be deleted in the callback. - callback(status); + callback(this, status); } } diff --git a/webrtc/p2p/stunprober/stunprober.h b/webrtc/p2p/stunprober/stunprober.h index d773d5760..56d5369d8 100644 --- a/webrtc/p2p/stunprober/stunprober.h +++ b/webrtc/p2p/stunprober/stunprober.h @@ -15,20 +15,32 @@ #include #include +#include "webrtc/base/asyncinvoker.h" #include "webrtc/base/basictypes.h" #include "webrtc/base/bytebuffer.h" #include "webrtc/base/callback.h" #include "webrtc/base/ipaddress.h" +#include "webrtc/base/network.h" #include "webrtc/base/scoped_ptr.h" #include "webrtc/base/socketaddress.h" +#include "webrtc/base/thread.h" #include "webrtc/base/thread_checker.h" #include "webrtc/typedefs.h" +namespace rtc { +class AsyncPacketSocket; +class PacketSocketFactory; +class Thread; +class NetworkManager; +} // namespace rtc + namespace stunprober { +class StunProber; + static const int kMaxUdpBufferSize = 1200; -typedef rtc::Callback1 AsyncCallback; +typedef rtc::Callback2 AsyncCallback; enum NatType { NATTYPE_INVALID, @@ -38,104 +50,7 @@ enum NatType { NATTYPE_NON_SYMMETRIC // Behind a non-symmetric NAT. }; -class HostNameResolverInterface { - public: - HostNameResolverInterface() {} - - // Resolve should allow re-entry as |callback| could trigger another - // Resolve(). - virtual void Resolve(const rtc::SocketAddress& addr, - std::vector* addresses, - AsyncCallback callback) = 0; - - virtual ~HostNameResolverInterface() {} - - private: - DISALLOW_COPY_AND_ASSIGN(HostNameResolverInterface); -}; - -// Chrome has client and server socket. Client socket supports Connect but not -// Bind. Server is opposite. -class SocketInterface { - public: - enum { - IO_PENDING = -1, - FAILED = -2, - }; - SocketInterface() {} - virtual void Close() = 0; - virtual ~SocketInterface() {} - - private: - DISALLOW_COPY_AND_ASSIGN(SocketInterface); -}; - -class ClientSocketInterface : public SocketInterface { - public: - ClientSocketInterface() {} - // Even though we have SendTo and RecvFrom, if Connect is not called first, - // getsockname will only return 0.0.0.0. - virtual int Connect(const rtc::SocketAddress& addr) = 0; - - virtual int GetLocalAddress(rtc::SocketAddress* local_address) = 0; - - private: - DISALLOW_COPY_AND_ASSIGN(ClientSocketInterface); -}; - -class ServerSocketInterface : public SocketInterface { - public: - ServerSocketInterface() {} - - virtual int SendTo(const rtc::SocketAddress& addr, - char* buf, - size_t buf_len, - AsyncCallback callback) = 0; - - // If the returned value is positive, it means that buf has been - // sent. Otherwise, it should return IO_PENDING. Callback will be invoked - // after the data is successfully read into buf. - virtual int RecvFrom(char* buf, - size_t buf_len, - rtc::SocketAddress* addr, - AsyncCallback callback) = 0; - - private: - DISALLOW_COPY_AND_ASSIGN(ServerSocketInterface); -}; - -class SocketFactoryInterface { - public: - SocketFactoryInterface() {} - // To provide a chance to prepare the sockets that we need. This is - // implemented for chrome renderer process as the socket needs to be ready to - // use in browser process. - virtual void Prepare(size_t total_client_socket, - size_t total_server_socket, - AsyncCallback callback) { - callback(0); - } - virtual ClientSocketInterface* CreateClientSocket() = 0; - virtual ServerSocketInterface* CreateServerSocket( - size_t send_buffer_size, - size_t receive_buffer_size) = 0; - virtual ~SocketFactoryInterface() {} - - private: - DISALLOW_COPY_AND_ASSIGN(SocketFactoryInterface); -}; - -class TaskRunnerInterface { - public: - TaskRunnerInterface() {} - virtual void PostTask(rtc::Callback0, uint32_t delay_ms) = 0; - virtual ~TaskRunnerInterface() {} - - private: - DISALLOW_COPY_AND_ASSIGN(TaskRunnerInterface); -}; - -class StunProber { +class StunProber : public sigslot::has_slots<> { public: enum Status { // Used in UMA_HISTOGRAM_ENUMERATION. SUCCESS, // Successfully received bytes from the server. @@ -167,11 +82,9 @@ class StunProber { std::set srflx_addrs; }; - // StunProber is not thread safe. It's task_runner's responsibility to ensure - // all calls happen sequentially. - StunProber(HostNameResolverInterface* host_name_resolver, - SocketFactoryInterface* socket_factory, - TaskRunnerInterface* task_runner); + StunProber(rtc::PacketSocketFactory* socket_factory, + rtc::Thread* thread, + const rtc::NetworkManager::NetworkList& networks); virtual ~StunProber(); // Begin performing the probe test against the |servers|. If @@ -203,16 +116,19 @@ class StunProber { // STUN servers. class Requester; - void OnServerResolved(int index, int result); + bool ResolveServerName(const rtc::SocketAddress& addr); + void OnServerResolved(rtc::AsyncResolverInterface* resolver); + + void OnSocketReady(rtc::AsyncPacketSocket* socket, + const rtc::SocketAddress& addr); bool Done() { - return num_request_sent_ >= requests_per_ip_ * all_servers_ips_.size(); + return num_request_sent_ >= requests_per_ip_ * all_servers_addrs_.size(); } - int GetTotalClientSockets() { return 1; } - int GetTotalServerSockets() { - return static_cast( - (shared_socket_mode_ ? 1 : all_servers_ips_.size()) * requests_per_ip_); + size_t total_socket_required() { + return (shared_socket_mode_ ? 1 : all_servers_addrs_.size()) * + requests_per_ip_; } bool SendNextRequest(); @@ -223,14 +139,7 @@ class StunProber { // End the probe with the given |status|. Invokes |fininsh_callback|, which // may destroy the class. - void End(StunProber::Status status, int result); - - // Create a socket, connect to the first resolved server, and return the - // result of getsockname(). All Requesters will bind to this name. We do this - // because if a socket is not bound nor connected, getsockname will return - // 0.0.0.0. We can't connect to a single STUN server IP either as that will - // fail subsequent requests in shared mode. - int GetLocalAddress(rtc::IPAddress* addr); + void End(StunProber::Status status); Requester* CreateRequester(); @@ -256,19 +165,12 @@ class StunProber { // STUN server name to be resolved. std::vector servers_; - // The local address that each probing socket will be bound to. - rtc::IPAddress local_addr_; + // Weak references. + rtc::PacketSocketFactory* socket_factory_; + rtc::Thread* thread_; - // Owned pointers. - rtc::scoped_ptr socket_factory_; - rtc::scoped_ptr resolver_; - rtc::scoped_ptr task_runner_; - - // Addresses filled out by HostNameResolver for a single server. - std::vector resolved_ips_; - - // Accumulate all resolved IPs. - std::vector all_servers_ips_; + // Accumulate all resolved addresses. + std::vector all_servers_addrs_; // Caller-supplied callback executed when testing is completed, called by // End(). @@ -279,6 +181,15 @@ class StunProber { rtc::ThreadChecker thread_checker_; + // Temporary storage for created sockets. + std::vector sockets_; + // This tracks how many of the sockets are ready. + size_t total_ready_sockets_ = 0; + + rtc::AsyncInvoker invoker_; + + rtc::NetworkManager::NetworkList networks_; + DISALLOW_COPY_AND_ASSIGN(StunProber); }; diff --git a/webrtc/p2p/stunprober/stunprober_dependencies.h b/webrtc/p2p/stunprober/stunprober_dependencies.h deleted file mode 100644 index 4e64df71b..000000000 --- a/webrtc/p2p/stunprober/stunprober_dependencies.h +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Copyright 2015 The WebRTC Project Authors. All rights reserved. - * - * Use of this source code is governed by a BSD-style license - * that can be found in the LICENSE file in the root of the source - * tree. An additional intellectual property rights grant can be found - * in the file PATENTS. All contributing project authors may - * be found in the AUTHORS file in the root of the source tree. - */ - -#ifndef WEBRTC_P2P_STUNPROBER_STUNPROBER_DEPENDENCIES_H_ -#define WEBRTC_P2P_STUNPROBER_STUNPROBER_DEPENDENCIES_H_ - -#include "webrtc/base/checks.h" -#include "webrtc/base/helpers.h" -#include "webrtc/base/logging.h" -#include "webrtc/base/scoped_ptr.h" -#include "webrtc/base/thread.h" -#include "webrtc/base/timeutils.h" -#include "webrtc/p2p/stunprober/stunprober.h" - -// Common classes used by both the command line driver and the unit tests. -namespace stunprober { - -class Socket : public ClientSocketInterface, - public ServerSocketInterface, - public sigslot::has_slots<> { - public: - explicit Socket(rtc::AsyncSocket* socket) : socket_(socket) { - socket_->SignalReadEvent.connect(this, &Socket::OnReadEvent); - socket_->SignalWriteEvent.connect(this, &Socket::OnWriteEvent); - } - - int Connect(const rtc::SocketAddress& addr) override { - return MapResult(socket_->Connect(addr)); - } - - int SendTo(const rtc::SocketAddress& addr, - char* buf, - size_t buf_len, - AsyncCallback callback) override { - write_ = NetworkWrite(addr, buf, buf_len, callback); - return MapResult(socket_->SendTo(buf, buf_len, addr)); - } - - int RecvFrom(char* buf, - size_t buf_len, - rtc::SocketAddress* addr, - AsyncCallback callback) override { - read_ = NetworkRead(buf, buf_len, addr, callback); - return MapResult(socket_->RecvFrom(buf, buf_len, addr)); - } - - int GetLocalAddress(rtc::SocketAddress* local_address) override { - *local_address = socket_->GetLocalAddress(); - return 0; - } - - void Close() override { socket_->Close(); } - - virtual ~Socket() {} - - protected: - int MapResult(int rv) { - if (rv >= 0) { - return rv; - } - int err = socket_->GetError(); - if (err == EWOULDBLOCK || err == EAGAIN) { - return IO_PENDING; - } else { - return FAILED; - } - } - - void OnReadEvent(rtc::AsyncSocket* socket) { - DCHECK(socket_ == socket); - NetworkRead read = read_; - read_ = NetworkRead(); - if (!read.callback.empty()) { - read.callback(socket_->RecvFrom(read.buf, read.buf_len, read.addr)); - } - } - - void OnWriteEvent(rtc::AsyncSocket* socket) { - DCHECK(socket_ == socket); - NetworkWrite write = write_; - write_ = NetworkWrite(); - if (!write.callback.empty()) { - write.callback(socket_->SendTo(write.buf, write.buf_len, write.addr)); - } - } - - struct NetworkWrite { - NetworkWrite() : buf(nullptr), buf_len(0) {} - NetworkWrite(const rtc::SocketAddress& addr, - char* buf, - size_t buf_len, - AsyncCallback callback) - : buf(buf), buf_len(buf_len), addr(addr), callback(callback) {} - char* buf; - size_t buf_len; - rtc::SocketAddress addr; - AsyncCallback callback; - }; - - NetworkWrite write_; - - struct NetworkRead { - NetworkRead() : buf(nullptr), buf_len(0) {} - NetworkRead(char* buf, - size_t buf_len, - rtc::SocketAddress* addr, - AsyncCallback callback) - : buf(buf), buf_len(buf_len), addr(addr), callback(callback) {} - - char* buf; - size_t buf_len; - rtc::SocketAddress* addr; - AsyncCallback callback; - }; - - NetworkRead read_; - - rtc::scoped_ptr socket_; -}; - -class SocketFactory : public SocketFactoryInterface { - public: - ClientSocketInterface* CreateClientSocket() override { - return new Socket( - rtc::Thread::Current()->socketserver()->CreateAsyncSocket(SOCK_DGRAM)); - } - ServerSocketInterface* CreateServerSocket(size_t send_buffer_size, - size_t recv_buffer_size) override { - rtc::scoped_ptr socket( - rtc::Thread::Current()->socketserver()->CreateAsyncSocket(SOCK_DGRAM)); - - if (socket) { - socket->SetOption(rtc::AsyncSocket::OPT_SNDBUF, - static_cast(send_buffer_size)); - socket->SetOption(rtc::AsyncSocket::OPT_RCVBUF, - static_cast(recv_buffer_size)); - return new Socket(socket.release()); - } else { - return nullptr; - } - } -}; - -class TaskRunner : public TaskRunnerInterface, public rtc::MessageHandler { - protected: - class CallbackMessageData : public rtc::MessageData { - public: - explicit CallbackMessageData(rtc::Callback0 callback) - : callback_(callback) {} - rtc::Callback0 callback_; - }; - - public: - void PostTask(rtc::Callback0 callback, uint32_t delay_ms) { - if (delay_ms == 0) { - rtc::Thread::Current()->Post(this, 0, new CallbackMessageData(callback)); - } else { - rtc::Thread::Current()->PostDelayed(delay_ms, this, 0, - new CallbackMessageData(callback)); - } - } - - void OnMessage(rtc::Message* msg) { - rtc::scoped_ptr callback( - reinterpret_cast(msg->pdata)); - callback->callback_(); - } -}; - -} // namespace stunprober -#endif // WEBRTC_P2P_STUNPROBER_STUNPROBER_DEPENDENCIES_H_ diff --git a/webrtc/p2p/stunprober/stunprober_unittest.cc b/webrtc/p2p/stunprober/stunprober_unittest.cc index ffd5040e5..3e030144d 100644 --- a/webrtc/p2p/stunprober/stunprober_unittest.cc +++ b/webrtc/p2p/stunprober/stunprober_unittest.cc @@ -8,6 +8,7 @@ * be found in the AUTHORS file in the root of the source tree. */ +#include "webrtc/base/asyncresolverinterface.h" #include "webrtc/base/basictypes.h" #include "webrtc/base/bind.h" #include "webrtc/base/checks.h" @@ -16,20 +17,12 @@ #include "webrtc/base/scoped_ptr.h" #include "webrtc/base/ssladapter.h" #include "webrtc/base/virtualsocketserver.h" - +#include "webrtc/p2p/base/basicpacketsocketfactory.h" #include "webrtc/p2p/base/teststunserver.h" #include "webrtc/p2p/stunprober/stunprober.h" -#include "webrtc/p2p/stunprober/stunprober_dependencies.h" -using stunprober::HostNameResolverInterface; -using stunprober::TaskRunner; -using stunprober::SocketFactory; using stunprober::StunProber; using stunprober::AsyncCallback; -using stunprober::ClientSocketInterface; -using stunprober::ServerSocketInterface; -using stunprober::SocketFactory; -using stunprober::TaskRunner; namespace stunprober { @@ -41,58 +34,6 @@ const rtc::SocketAddress kStunAddr2("1.1.1.2", 3478); const rtc::SocketAddress kFailedStunAddr("1.1.1.3", 3478); const rtc::SocketAddress kStunMappedAddr("77.77.77.77", 0); -class TestSocketServer : public rtc::VirtualSocketServer { - public: - using rtc::VirtualSocketServer::CreateAsyncSocket; - explicit TestSocketServer(SocketServer* ss) : rtc::VirtualSocketServer(ss) {} - void SetLocalAddress(const rtc::SocketAddress& addr) { addr_ = addr; } - - // CreateAsyncSocket is used by StunProber to create both client and server - // sockets. The first socket is used to retrieve local address which will be - // used later for Bind(). - rtc::AsyncSocket* CreateAsyncSocket(int type) override { - rtc::VirtualSocket* socket = static_cast( - rtc::VirtualSocketServer::CreateAsyncSocket(type)); - if (!local_addr_set_) { - // Only the first socket can SetLocalAddress. For others, Bind will fail - // if local address is set. - socket->SetLocalAddress(addr_); - local_addr_set_ = true; - } else { - sockets_.push_back(socket); - } - return socket; - } - - size_t num_socket() { return sockets_.size(); } - - private: - bool local_addr_set_ = false; - std::vector sockets_; - rtc::SocketAddress addr_; -}; - -class FakeHostNameResolver : public HostNameResolverInterface { - public: - FakeHostNameResolver() {} - void set_result(int ret) { ret_ = ret; } - void set_addresses(const std::vector& addresses) { - server_ips_ = addresses; - } - const std::vector& get_addresses() { return server_ips_; } - void add_address(const rtc::SocketAddress& ip) { server_ips_.push_back(ip); } - void Resolve(const rtc::SocketAddress& addr, - std::vector* addresses, - stunprober::AsyncCallback callback) override { - *addresses = server_ips_; - callback(ret_); - } - - private: - int ret_ = 0; - std::vector server_ips_; -}; - } // namespace class StunProberTest : public testing::Test { @@ -100,7 +41,7 @@ class StunProberTest : public testing::Test { StunProberTest() : main_(rtc::Thread::Current()), pss_(new rtc::PhysicalSocketServer), - ss_(new TestSocketServer(pss_.get())), + ss_(new rtc::VirtualSocketServer(pss_.get())), ss_scope_(ss_.get()), result_(StunProber::SUCCESS), stun_server_1_(cricket::TestStunServer::Create(rtc::Thread::Current(), @@ -114,62 +55,56 @@ class StunProberTest : public testing::Test { void set_expected_result(int result) { result_ = result; } - void StartProbing(HostNameResolverInterface* resolver, - SocketFactoryInterface* socket_factory, - const rtc::SocketAddress& addr, + void StartProbing(rtc::PacketSocketFactory* socket_factory, + const std::vector& addrs, + const rtc::NetworkManager::NetworkList& networks, bool shared_socket, uint16 interval, uint16 pings_per_ip) { - std::vector addrs; - addrs.push_back(addr); - prober.reset(new StunProber(resolver, socket_factory, new TaskRunner())); + prober.reset( + new StunProber(socket_factory, rtc::Thread::Current(), networks)); prober->Start(addrs, shared_socket, interval, pings_per_ip, - 100 /* timeout_ms */, - [this](int result) { this->StopCallback(result); }); + 100 /* timeout_ms */, [this](StunProber* prober, int result) { + this->StopCallback(prober, result); + }); } void RunProber(bool shared_mode) { const int pings_per_ip = 3; - const uint16 port = kStunAddr1.port(); - rtc::SocketAddress addr("stun.l.google.com", port); std::vector addrs; - - // Set up the resolver for 2 stun server addresses. - rtc::scoped_ptr resolver(new FakeHostNameResolver()); - resolver->add_address(kStunAddr1); - resolver->add_address(kStunAddr2); + addrs.push_back(kStunAddr1); + addrs.push_back(kStunAddr2); // Add a non-existing server. This shouldn't pollute the result. - resolver->add_address(kFailedStunAddr); + addrs.push_back(kFailedStunAddr); - rtc::scoped_ptr socket_factory(new SocketFactory()); + rtc::Network ipv4_network1("test_eth0", "Test Network Adapter 1", + rtc::IPAddress(0x12345600U), 24); + ipv4_network1.AddIP(rtc::IPAddress(0x12345678)); + rtc::NetworkManager::NetworkList networks; + networks.push_back(&ipv4_network1); - // Set local address in socketserver so getsockname will return kLocalAddr - // instead of 0.0.0.0 for the first socket. - ss_->SetLocalAddress(kLocalAddr); + rtc::scoped_ptr socket_factory( + new rtc::BasicPacketSocketFactory()); // Set up the expected results for verification. std::set srflx_addresses; srflx_addresses.insert(kStunMappedAddr.ToString()); const uint32 total_pings_tried = - static_cast(pings_per_ip * resolver->get_addresses().size()); + static_cast(pings_per_ip * addrs.size()); // The reported total_pings should not count for pings sent to the // kFailedStunAddr. const uint32 total_pings_reported = total_pings_tried - pings_per_ip; - size_t total_sockets = shared_mode ? pings_per_ip : total_pings_tried; - - StartProbing(resolver.release(), socket_factory.release(), addr, - shared_mode, 3, pings_per_ip); + StartProbing(socket_factory.get(), addrs, networks, shared_mode, 3, + pings_per_ip); WAIT(stopped_, 1000); StunProber::Stats stats; - EXPECT_EQ(ss_->num_socket(), total_sockets); EXPECT_TRUE(prober->GetStats(&stats)); EXPECT_EQ(stats.success_percent, 100); EXPECT_TRUE(stats.nat_type > stunprober::NATTYPE_NONE); - EXPECT_EQ(stats.host_ip, kLocalAddr.ipaddr().ToString()); EXPECT_EQ(stats.srflx_addrs, srflx_addresses); EXPECT_EQ(static_cast(stats.num_request_sent), total_pings_reported); @@ -178,14 +113,14 @@ class StunProberTest : public testing::Test { } private: - void StopCallback(int result) { + void StopCallback(StunProber* prober, int result) { EXPECT_EQ(result, result_); stopped_ = true; } rtc::Thread* main_; rtc::scoped_ptr pss_; - rtc::scoped_ptr ss_; + rtc::scoped_ptr ss_; rtc::SocketServerScope ss_scope_; rtc::scoped_ptr prober; int result_ = 0; @@ -194,19 +129,6 @@ class StunProberTest : public testing::Test { rtc::scoped_ptr stun_server_2_; }; -TEST_F(StunProberTest, DNSFailure) { - rtc::SocketAddress addr("stun.l.google.com", 19302); - rtc::scoped_ptr resolver(new FakeHostNameResolver()); - rtc::scoped_ptr socket_factory(new SocketFactory()); - - set_expected_result(StunProber::RESOLVE_FAILED); - - // Non-0 value is treated as failure. - resolver->set_result(1); - StartProbing(resolver.release(), socket_factory.release(), addr, false, 10, - 30); -} - TEST_F(StunProberTest, NonSharedMode) { RunProber(false); }