/* * libjingle * Copyright 2004, Google Inc. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * 3. The name of the author may not be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #include <string> #include "talk/base/gunit.h" #include "talk/base/logging.h" #include "talk/base/natserver.h" #include "talk/base/natsocketfactory.h" #include "talk/base/nethelpers.h" #include "talk/base/network.h" #include "talk/base/physicalsocketserver.h" #include "talk/base/testclient.h" #include "talk/base/virtualsocketserver.h" using namespace talk_base; bool CheckReceive( TestClient* client, bool should_receive, const char* buf, size_t size) { return (should_receive) ? client->CheckNextPacket(buf, size, 0) : client->CheckNoPacket(); } TestClient* CreateTestClient( SocketFactory* factory, const SocketAddress& local_addr) { AsyncUDPSocket* socket = AsyncUDPSocket::Create(factory, local_addr); return new TestClient(socket); } // Tests that when sending from internal_addr to external_addrs through the // NAT type specified by nat_type, all external addrs receive the sent packet // and, if exp_same is true, all use the same mapped-address on the NAT. void TestSend( SocketServer* internal, const SocketAddress& internal_addr, SocketServer* external, const SocketAddress external_addrs[4], NATType nat_type, bool exp_same) { Thread th_int(internal); Thread th_ext(external); SocketAddress server_addr = internal_addr; server_addr.SetPort(0); // Auto-select a port NATServer* nat = new NATServer( nat_type, internal, server_addr, external, external_addrs[0]); NATSocketFactory* natsf = new NATSocketFactory(internal, nat->internal_address()); TestClient* in = CreateTestClient(natsf, internal_addr); TestClient* out[4]; for (int i = 0; i < 4; i++) out[i] = CreateTestClient(external, external_addrs[i]); th_int.Start(); th_ext.Start(); const char* buf = "filter_test"; size_t len = strlen(buf); in->SendTo(buf, len, out[0]->address()); SocketAddress trans_addr; EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr)); for (int i = 1; i < 4; i++) { in->SendTo(buf, len, out[i]->address()); SocketAddress trans_addr2; EXPECT_TRUE(out[i]->CheckNextPacket(buf, len, &trans_addr2)); bool are_same = (trans_addr == trans_addr2); ASSERT_EQ(are_same, exp_same) << "same translated address"; ASSERT_NE(AF_UNSPEC, trans_addr.family()); ASSERT_NE(AF_UNSPEC, trans_addr2.family()); } th_int.Stop(); th_ext.Stop(); delete nat; delete natsf; delete in; for (int i = 0; i < 4; i++) delete out[i]; } // Tests that when sending from external_addrs to internal_addr, the packet // is delivered according to the specified filter_ip and filter_port rules. void TestRecv( SocketServer* internal, const SocketAddress& internal_addr, SocketServer* external, const SocketAddress external_addrs[4], NATType nat_type, bool filter_ip, bool filter_port) { Thread th_int(internal); Thread th_ext(external); SocketAddress server_addr = internal_addr; server_addr.SetPort(0); // Auto-select a port NATServer* nat = new NATServer( nat_type, internal, server_addr, external, external_addrs[0]); NATSocketFactory* natsf = new NATSocketFactory(internal, nat->internal_address()); TestClient* in = CreateTestClient(natsf, internal_addr); TestClient* out[4]; for (int i = 0; i < 4; i++) out[i] = CreateTestClient(external, external_addrs[i]); th_int.Start(); th_ext.Start(); const char* buf = "filter_test"; size_t len = strlen(buf); in->SendTo(buf, len, out[0]->address()); SocketAddress trans_addr; EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr)); out[1]->SendTo(buf, len, trans_addr); EXPECT_TRUE(CheckReceive(in, !filter_ip, buf, len)); out[2]->SendTo(buf, len, trans_addr); EXPECT_TRUE(CheckReceive(in, !filter_port, buf, len)); out[3]->SendTo(buf, len, trans_addr); EXPECT_TRUE(CheckReceive(in, !filter_ip && !filter_port, buf, len)); th_int.Stop(); th_ext.Stop(); delete nat; delete natsf; delete in; for (int i = 0; i < 4; i++) delete out[i]; } // Tests that NATServer allocates bindings properly. void TestBindings( SocketServer* internal, const SocketAddress& internal_addr, SocketServer* external, const SocketAddress external_addrs[4]) { TestSend(internal, internal_addr, external, external_addrs, NAT_OPEN_CONE, true); TestSend(internal, internal_addr, external, external_addrs, NAT_ADDR_RESTRICTED, true); TestSend(internal, internal_addr, external, external_addrs, NAT_PORT_RESTRICTED, true); TestSend(internal, internal_addr, external, external_addrs, NAT_SYMMETRIC, false); } // Tests that NATServer filters packets properly. void TestFilters( SocketServer* internal, const SocketAddress& internal_addr, SocketServer* external, const SocketAddress external_addrs[4]) { TestRecv(internal, internal_addr, external, external_addrs, NAT_OPEN_CONE, false, false); TestRecv(internal, internal_addr, external, external_addrs, NAT_ADDR_RESTRICTED, true, false); TestRecv(internal, internal_addr, external, external_addrs, NAT_PORT_RESTRICTED, true, true); TestRecv(internal, internal_addr, external, external_addrs, NAT_SYMMETRIC, true, true); } bool TestConnectivity(const SocketAddress& src, const IPAddress& dst) { // The physical NAT tests require connectivity to the selected ip from the // internal address used for the NAT. Things like firewalls can break that, so // check to see if it's worth even trying with this ip. scoped_ptr<PhysicalSocketServer> pss(new PhysicalSocketServer()); scoped_ptr<AsyncSocket> client(pss->CreateAsyncSocket(src.family(), SOCK_DGRAM)); scoped_ptr<AsyncSocket> server(pss->CreateAsyncSocket(src.family(), SOCK_DGRAM)); if (client->Bind(SocketAddress(src.ipaddr(), 0)) != 0 || server->Bind(SocketAddress(dst, 0)) != 0) { return false; } const char* buf = "hello other socket"; size_t len = strlen(buf); int sent = client->SendTo(buf, len, server->GetLocalAddress()); SocketAddress addr; const size_t kRecvBufSize = 64; char recvbuf[kRecvBufSize]; Thread::Current()->SleepMs(100); int received = server->RecvFrom(recvbuf, kRecvBufSize, &addr); return received == sent && ::memcmp(buf, recvbuf, len) == 0; } void TestPhysicalInternal(const SocketAddress& int_addr) { BasicNetworkManager network_manager; network_manager.set_ipv6_enabled(true); network_manager.StartUpdating(); // Process pending messages so the network list is updated. Thread::Current()->ProcessMessages(0); std::vector<Network*> networks; network_manager.GetNetworks(&networks); if (networks.empty()) { LOG(LS_WARNING) << "Not enough network adapters for test."; return; } SocketAddress ext_addr1(int_addr); SocketAddress ext_addr2; // Find an available IP with matching family. The test breaks if int_addr // can't talk to ip, so check for connectivity as well. for (std::vector<Network*>::iterator it = networks.begin(); it != networks.end(); ++it) { const IPAddress& ip = (*it)->ip(); if (ip.family() == int_addr.family() && TestConnectivity(int_addr, ip)) { ext_addr2.SetIP(ip); break; } } if (ext_addr2.IsNil()) { LOG(LS_WARNING) << "No available IP of same family as " << int_addr; return; } LOG(LS_INFO) << "selected ip " << ext_addr2.ipaddr(); SocketAddress ext_addrs[4] = { SocketAddress(ext_addr1), SocketAddress(ext_addr2), SocketAddress(ext_addr1), SocketAddress(ext_addr2) }; scoped_ptr<PhysicalSocketServer> int_pss(new PhysicalSocketServer()); scoped_ptr<PhysicalSocketServer> ext_pss(new PhysicalSocketServer()); TestBindings(int_pss.get(), int_addr, ext_pss.get(), ext_addrs); TestFilters(int_pss.get(), int_addr, ext_pss.get(), ext_addrs); } TEST(NatTest, TestPhysicalIPv4) { TestPhysicalInternal(SocketAddress("127.0.0.1", 0)); } TEST(NatTest, TestPhysicalIPv6) { if (HasIPv6Enabled()) { TestPhysicalInternal(SocketAddress("::1", 0)); } else { LOG(LS_WARNING) << "No IPv6, skipping"; } } class TestVirtualSocketServer : public VirtualSocketServer { public: explicit TestVirtualSocketServer(SocketServer* ss) : VirtualSocketServer(ss), ss_(ss) {} // Expose this publicly IPAddress GetNextIP(int af) { return VirtualSocketServer::GetNextIP(af); } private: scoped_ptr<SocketServer> ss_; }; void TestVirtualInternal(int family) { scoped_ptr<TestVirtualSocketServer> int_vss(new TestVirtualSocketServer( new PhysicalSocketServer())); scoped_ptr<TestVirtualSocketServer> ext_vss(new TestVirtualSocketServer( new PhysicalSocketServer())); SocketAddress int_addr; SocketAddress ext_addrs[4]; int_addr.SetIP(int_vss->GetNextIP(family)); ext_addrs[0].SetIP(ext_vss->GetNextIP(int_addr.family())); ext_addrs[1].SetIP(ext_vss->GetNextIP(int_addr.family())); ext_addrs[2].SetIP(ext_addrs[0].ipaddr()); ext_addrs[3].SetIP(ext_addrs[1].ipaddr()); TestBindings(int_vss.get(), int_addr, ext_vss.get(), ext_addrs); TestFilters(int_vss.get(), int_addr, ext_vss.get(), ext_addrs); } TEST(NatTest, TestVirtualIPv4) { TestVirtualInternal(AF_INET); } TEST(NatTest, TestVirtualIPv6) { if (HasIPv6Enabled()) { TestVirtualInternal(AF_INET6); } else { LOG(LS_WARNING) << "No IPv6, skipping"; } } // TODO: Finish this test class NatTcpTest : public testing::Test, public sigslot::has_slots<> { public: NatTcpTest() : connected_(false) {} virtual void SetUp() { int_vss_ = new TestVirtualSocketServer(new PhysicalSocketServer()); ext_vss_ = new TestVirtualSocketServer(new PhysicalSocketServer()); nat_ = new NATServer(NAT_OPEN_CONE, int_vss_, SocketAddress(), ext_vss_, SocketAddress()); natsf_ = new NATSocketFactory(int_vss_, nat_->internal_address()); } void OnConnectEvent(AsyncSocket* socket) { connected_ = true; } void OnAcceptEvent(AsyncSocket* socket) { accepted_ = server_->Accept(NULL); } void OnCloseEvent(AsyncSocket* socket, int error) { } void ConnectEvents() { server_->SignalReadEvent.connect(this, &NatTcpTest::OnAcceptEvent); client_->SignalConnectEvent.connect(this, &NatTcpTest::OnConnectEvent); } TestVirtualSocketServer* int_vss_; TestVirtualSocketServer* ext_vss_; NATServer* nat_; NATSocketFactory* natsf_; AsyncSocket* client_; AsyncSocket* server_; AsyncSocket* accepted_; bool connected_; }; TEST_F(NatTcpTest, DISABLED_TestConnectOut) { server_ = ext_vss_->CreateAsyncSocket(SOCK_STREAM); server_->Bind(SocketAddress()); server_->Listen(5); client_ = int_vss_->CreateAsyncSocket(SOCK_STREAM); EXPECT_GE(0, client_->Bind(SocketAddress())); EXPECT_GE(0, client_->Connect(server_->GetLocalAddress())); ConnectEvents(); EXPECT_TRUE_WAIT(connected_, 1000); EXPECT_EQ(client_->GetRemoteAddress(), server_->GetLocalAddress()); EXPECT_EQ(client_->GetRemoteAddress(), accepted_->GetLocalAddress()); EXPECT_EQ(client_->GetLocalAddress(), accepted_->GetRemoteAddress()); client_->Close(); } //#endif