|  | // Copyright 2014 The Chromium Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style license that can be | 
|  | // found in the LICENSE file. | 
|  |  | 
|  | #include "base/macros.h" | 
|  | #include "base/memory/scoped_ptr.h" | 
|  | #include "base/run_loop.h" | 
|  | #include "mojo/public/cpp/application/application_connection.h" | 
|  | #include "mojo/public/cpp/application/application_impl.h" | 
|  | #include "mojo/public/cpp/application/application_test_base.h" | 
|  | #include "mojo/public/cpp/bindings/callback.h" | 
|  | #include "mojo/services/network/public/cpp/udp_socket_wrapper.h" | 
|  | #include "mojo/services/network/public/interfaces/network_service.mojom.h" | 
|  | #include "mojo/services/network/public/interfaces/udp_socket.mojom.h" | 
|  | #include "net/base/net_errors.h" | 
|  | #include "testing/gtest/include/gtest/gtest.h" | 
|  |  | 
|  | namespace mojo { | 
|  | namespace service { | 
|  | namespace { | 
|  |  | 
|  | NetAddressPtr GetLocalHostWithAnyPort() { | 
|  | NetAddressPtr addr(NetAddress::New()); | 
|  | addr->family = NET_ADDRESS_FAMILY_IPV4; | 
|  | addr->ipv4 = NetAddressIPv4::New(); | 
|  | addr->ipv4->port = 0; | 
|  | addr->ipv4->addr.resize(4); | 
|  | addr->ipv4->addr[0] = 127; | 
|  | addr->ipv4->addr[1] = 0; | 
|  | addr->ipv4->addr[2] = 0; | 
|  | addr->ipv4->addr[3] = 1; | 
|  |  | 
|  | return addr.Pass(); | 
|  | } | 
|  |  | 
|  | Array<uint8_t> CreateTestMessage(uint8_t initial, size_t size) { | 
|  | Array<uint8_t> array(size); | 
|  | for (size_t i = 0; i < size; ++i) | 
|  | array[i] = static_cast<uint8_t>((i + initial) % 256); | 
|  | return array.Pass(); | 
|  | } | 
|  |  | 
|  | template <typename CallbackType> | 
|  | class TestCallbackBase { | 
|  | public: | 
|  | TestCallbackBase() : state_(nullptr), run_loop_(nullptr), ran_(false) {} | 
|  |  | 
|  | ~TestCallbackBase() { | 
|  | state_->set_test_callback(nullptr); | 
|  | } | 
|  |  | 
|  | CallbackType callback() const { return callback_; } | 
|  |  | 
|  | void WaitForResult() { | 
|  | if (ran_) | 
|  | return; | 
|  |  | 
|  | base::RunLoop run_loop; | 
|  | run_loop_ = &run_loop; | 
|  | run_loop.Run(); | 
|  | run_loop_ = nullptr; | 
|  | } | 
|  |  | 
|  | protected: | 
|  | struct StateBase : public CallbackType::Runnable { | 
|  | StateBase() : test_callback_(nullptr) {} | 
|  | ~StateBase() override {} | 
|  |  | 
|  | void set_test_callback(TestCallbackBase* test_callback) { | 
|  | test_callback_ = test_callback; | 
|  | } | 
|  |  | 
|  | protected: | 
|  | void NotifyRun() const { | 
|  | if (test_callback_) { | 
|  | test_callback_->ran_ = true; | 
|  | if (test_callback_->run_loop_) | 
|  | test_callback_->run_loop_->Quit(); | 
|  | } | 
|  | } | 
|  |  | 
|  | TestCallbackBase* test_callback_; | 
|  |  | 
|  | private: | 
|  | DISALLOW_COPY_AND_ASSIGN(StateBase); | 
|  | }; | 
|  |  | 
|  | // Takes ownership of |state|, and guarantees that it lives at least as long | 
|  | // as this object. | 
|  | void Initialize(StateBase* state) { | 
|  | state_ = state; | 
|  | state_->set_test_callback(this); | 
|  | callback_ = CallbackType( | 
|  | static_cast<typename CallbackType::Runnable*>(state_)); | 
|  | } | 
|  |  | 
|  | private: | 
|  | // The lifespan is managed by |callback_| (and its copies). | 
|  | StateBase* state_; | 
|  | CallbackType callback_; | 
|  | base::RunLoop* run_loop_; | 
|  | bool ran_; | 
|  |  | 
|  | DISALLOW_COPY_AND_ASSIGN(TestCallbackBase); | 
|  | }; | 
|  |  | 
|  | class TestCallback : public TestCallbackBase<Callback<void(NetworkErrorPtr)>> { | 
|  | public: | 
|  | TestCallback() { | 
|  | Initialize(new State()); | 
|  | } | 
|  | ~TestCallback() {} | 
|  |  | 
|  | const NetworkErrorPtr& result() const { return result_; } | 
|  |  | 
|  | private: | 
|  | struct State: public StateBase { | 
|  | ~State() override {} | 
|  |  | 
|  | void Run(NetworkErrorPtr result) const override { | 
|  | if (test_callback_) { | 
|  | TestCallback* callback = static_cast<TestCallback*>(test_callback_); | 
|  | callback->result_ = result.Pass(); | 
|  | } | 
|  | NotifyRun(); | 
|  | } | 
|  | }; | 
|  |  | 
|  | NetworkErrorPtr result_; | 
|  | }; | 
|  |  | 
|  | class TestCallbackWithAddress | 
|  | : public TestCallbackBase<Callback<void(NetworkErrorPtr, NetAddressPtr)>> { | 
|  | public: | 
|  | TestCallbackWithAddress() { | 
|  | Initialize(new State()); | 
|  | } | 
|  | ~TestCallbackWithAddress() {} | 
|  |  | 
|  | const NetworkErrorPtr& result() const { return result_; } | 
|  | const NetAddressPtr& net_address() const { return net_address_; } | 
|  |  | 
|  | private: | 
|  | struct State : public StateBase { | 
|  | ~State() override {} | 
|  |  | 
|  | void Run(NetworkErrorPtr result, NetAddressPtr net_address) const override { | 
|  | if (test_callback_) { | 
|  | TestCallbackWithAddress* callback = | 
|  | static_cast<TestCallbackWithAddress*>(test_callback_); | 
|  | callback->result_ = result.Pass(); | 
|  | callback->net_address_ = net_address.Pass(); | 
|  | } | 
|  | NotifyRun(); | 
|  | } | 
|  | }; | 
|  |  | 
|  | NetworkErrorPtr result_; | 
|  | NetAddressPtr net_address_; | 
|  | }; | 
|  |  | 
|  | class TestCallbackWithUint32 | 
|  | : public TestCallbackBase<Callback<void(uint32_t)>> { | 
|  | public: | 
|  | TestCallbackWithUint32() : result_(0) { | 
|  | Initialize(new State()); | 
|  | } | 
|  | ~TestCallbackWithUint32() {} | 
|  |  | 
|  | uint32_t result() const { return result_; } | 
|  |  | 
|  | private: | 
|  | struct State : public StateBase { | 
|  | ~State() override {} | 
|  |  | 
|  | void Run(uint32_t result) const override { | 
|  | if (test_callback_) { | 
|  | TestCallbackWithUint32* callback = | 
|  | static_cast<TestCallbackWithUint32*>(test_callback_); | 
|  | callback->result_ = result; | 
|  | } | 
|  | NotifyRun(); | 
|  | } | 
|  | }; | 
|  |  | 
|  | uint32_t result_; | 
|  | }; | 
|  |  | 
|  | class TestReceiveCallback | 
|  | : public TestCallbackBase< | 
|  | Callback<void(NetworkErrorPtr, NetAddressPtr, Array<uint8_t>)>> { | 
|  | public: | 
|  | TestReceiveCallback() { | 
|  | Initialize(new State()); | 
|  | } | 
|  | ~TestReceiveCallback() {} | 
|  |  | 
|  | const NetworkErrorPtr& result() const { return result_; } | 
|  | const NetAddressPtr& src_addr() const { return src_addr_; } | 
|  | const Array<uint8_t>& data() const { return data_; } | 
|  |  | 
|  | private: | 
|  | struct State : public StateBase { | 
|  | ~State() override {} | 
|  |  | 
|  | void Run(NetworkErrorPtr result, | 
|  | NetAddressPtr src_addr, | 
|  | Array<uint8_t> data) const override { | 
|  | if (test_callback_) { | 
|  | TestReceiveCallback* callback = | 
|  | static_cast<TestReceiveCallback*>(test_callback_); | 
|  | callback->result_ = result.Pass(); | 
|  | callback->src_addr_ = src_addr.Pass(); | 
|  | callback->data_ = data.Pass(); | 
|  | } | 
|  | NotifyRun(); | 
|  | } | 
|  | }; | 
|  |  | 
|  | NetworkErrorPtr result_; | 
|  | NetAddressPtr src_addr_; | 
|  | Array<uint8_t> data_; | 
|  | }; | 
|  |  | 
|  | class UDPSocketAppTest : public test::ApplicationTestBase { | 
|  | public: | 
|  | UDPSocketAppTest() {} | 
|  | ~UDPSocketAppTest() override {} | 
|  |  | 
|  | void SetUp() override { | 
|  | ApplicationTestBase::SetUp(); | 
|  |  | 
|  | ApplicationConnection* connection = | 
|  | application_impl()->ConnectToApplication("mojo:network_service"); | 
|  | connection->ConnectToService(&network_service_); | 
|  |  | 
|  | network_service_->CreateUDPSocket(GetProxy(&udp_socket_)); | 
|  | udp_socket_.set_client(&udp_socket_client_); | 
|  | } | 
|  |  | 
|  | protected: | 
|  | struct ReceiveResult { | 
|  | NetworkErrorPtr result; | 
|  | NetAddressPtr addr; | 
|  | Array<uint8_t> data; | 
|  | }; | 
|  |  | 
|  | class UDPSocketClientImpl : public UDPSocketClient { | 
|  | public: | 
|  |  | 
|  | UDPSocketClientImpl() : run_loop_(nullptr), expected_receive_count_(0) {} | 
|  |  | 
|  | ~UDPSocketClientImpl() override { | 
|  | while (!results_.empty()) { | 
|  | delete results_.front(); | 
|  | results_.pop(); | 
|  | } | 
|  | } | 
|  |  | 
|  | void OnReceived(NetworkErrorPtr result, | 
|  | NetAddressPtr src_addr, | 
|  | Array<uint8_t> data) override { | 
|  | ReceiveResult* entry = new ReceiveResult(); | 
|  | entry->result = result.Pass(); | 
|  | entry->addr = src_addr.Pass(); | 
|  | entry->data = data.Pass(); | 
|  |  | 
|  | results_.push(entry); | 
|  |  | 
|  | if (results_.size() == expected_receive_count_ && run_loop_) { | 
|  | expected_receive_count_ = 0; | 
|  | run_loop_->Quit(); | 
|  | } | 
|  | } | 
|  |  | 
|  | base::RunLoop* run_loop_; | 
|  | std::queue<ReceiveResult*> results_; | 
|  | size_t expected_receive_count_; | 
|  |  | 
|  | DISALLOW_COPY_AND_ASSIGN(UDPSocketClientImpl); | 
|  | }; | 
|  |  | 
|  | std::queue<ReceiveResult*>* GetReceiveResults() { | 
|  | return &udp_socket_client_.results_; | 
|  | } | 
|  |  | 
|  | void WaitForReceiveResults(size_t count) { | 
|  | if (GetReceiveResults()->size() == count) | 
|  | return; | 
|  |  | 
|  | udp_socket_client_.expected_receive_count_ = count; | 
|  | base::RunLoop run_loop; | 
|  | udp_socket_client_.run_loop_ = &run_loop; | 
|  | run_loop.Run(); | 
|  | udp_socket_client_.run_loop_ = nullptr; | 
|  | } | 
|  |  | 
|  | NetworkServicePtr network_service_; | 
|  | UDPSocketPtr udp_socket_; | 
|  | UDPSocketClientImpl udp_socket_client_; | 
|  |  | 
|  | DISALLOW_COPY_AND_ASSIGN(UDPSocketAppTest); | 
|  | }; | 
|  |  | 
|  | }  // namespace | 
|  |  | 
|  | TEST_F(UDPSocketAppTest, Settings) { | 
|  | TestCallback callback1; | 
|  | udp_socket_->AllowAddressReuse(callback1.callback()); | 
|  | callback1.WaitForResult(); | 
|  | EXPECT_EQ(net::OK, callback1.result()->code); | 
|  |  | 
|  | // Should fail because the socket hasn't been bound. | 
|  | TestCallback callback2; | 
|  | udp_socket_->SetSendBufferSize(1024, callback2.callback()); | 
|  | callback2.WaitForResult(); | 
|  | EXPECT_NE(net::OK, callback2.result()->code); | 
|  |  | 
|  | // Should fail because the socket hasn't been bound. | 
|  | TestCallback callback3; | 
|  | udp_socket_->SetReceiveBufferSize(2048, callback3.callback()); | 
|  | callback3.WaitForResult(); | 
|  | EXPECT_NE(net::OK, callback3.result()->code); | 
|  |  | 
|  | TestCallbackWithAddress callback4; | 
|  | udp_socket_->Bind(GetLocalHostWithAnyPort(), callback4.callback()); | 
|  | callback4.WaitForResult(); | 
|  | EXPECT_EQ(net::OK, callback4.result()->code); | 
|  | EXPECT_NE(0u, callback4.net_address()->ipv4->port); | 
|  |  | 
|  | // Should fail because the socket has been bound. | 
|  | TestCallback callback5; | 
|  | udp_socket_->AllowAddressReuse(callback5.callback()); | 
|  | callback5.WaitForResult(); | 
|  | EXPECT_NE(net::OK, callback5.result()->code); | 
|  |  | 
|  | TestCallback callback6; | 
|  | udp_socket_->SetSendBufferSize(1024, callback6.callback()); | 
|  | callback6.WaitForResult(); | 
|  | EXPECT_EQ(net::OK, callback6.result()->code); | 
|  |  | 
|  | TestCallback callback7; | 
|  | udp_socket_->SetReceiveBufferSize(2048, callback7.callback()); | 
|  | callback7.WaitForResult(); | 
|  | EXPECT_EQ(net::OK, callback7.result()->code); | 
|  |  | 
|  | TestCallbackWithUint32 callback8; | 
|  | udp_socket_->NegotiateMaxPendingSendRequests(0, callback8.callback()); | 
|  | callback8.WaitForResult(); | 
|  | EXPECT_GT(callback8.result(), 0u); | 
|  |  | 
|  | TestCallbackWithUint32 callback9; | 
|  | udp_socket_->NegotiateMaxPendingSendRequests(16, callback9.callback()); | 
|  | callback9.WaitForResult(); | 
|  | EXPECT_GT(callback9.result(), 0u); | 
|  | } | 
|  |  | 
|  | TEST_F(UDPSocketAppTest, TestReadWrite) { | 
|  | TestCallbackWithAddress callback1; | 
|  | udp_socket_->Bind(GetLocalHostWithAnyPort(), callback1.callback()); | 
|  | callback1.WaitForResult(); | 
|  | ASSERT_EQ(net::OK, callback1.result()->code); | 
|  | ASSERT_NE(0u, callback1.net_address()->ipv4->port); | 
|  |  | 
|  | NetAddressPtr server_addr = callback1.net_address().Clone(); | 
|  |  | 
|  | UDPSocketPtr client_socket; | 
|  | network_service_->CreateUDPSocket(GetProxy(&client_socket)); | 
|  |  | 
|  | TestCallbackWithAddress callback2; | 
|  | client_socket->Bind(GetLocalHostWithAnyPort(), callback2.callback()); | 
|  | callback2.WaitForResult(); | 
|  | ASSERT_EQ(net::OK, callback2.result()->code); | 
|  | ASSERT_NE(0u, callback2.net_address()->ipv4->port); | 
|  |  | 
|  | NetAddressPtr client_addr = callback2.net_address().Clone(); | 
|  |  | 
|  | const size_t kDatagramCount = 6; | 
|  | const size_t kDatagramSize = 255; | 
|  | udp_socket_->ReceiveMore(kDatagramCount); | 
|  |  | 
|  | for (size_t i = 0; i < kDatagramCount; ++i) { | 
|  | TestCallback callback; | 
|  | client_socket->SendTo( | 
|  | server_addr.Clone(), | 
|  | CreateTestMessage(static_cast<uint8_t>(i), kDatagramSize), | 
|  | callback.callback()); | 
|  | callback.WaitForResult(); | 
|  | EXPECT_EQ(255, callback.result()->code); | 
|  | } | 
|  |  | 
|  | WaitForReceiveResults(kDatagramCount); | 
|  | for (size_t i = 0; i < kDatagramCount; ++i) { | 
|  | scoped_ptr<ReceiveResult> result(GetReceiveResults()->front()); | 
|  | GetReceiveResults()->pop(); | 
|  |  | 
|  | EXPECT_EQ(static_cast<int>(kDatagramSize), result->result->code); | 
|  | EXPECT_TRUE(result->addr.Equals(client_addr)); | 
|  | EXPECT_TRUE(result->data.Equals( | 
|  | CreateTestMessage(static_cast<uint8_t>(i), kDatagramSize))); | 
|  | } | 
|  | } | 
|  |  | 
|  | TEST_F(UDPSocketAppTest, TestUDPSocketWrapper) { | 
|  | UDPSocketWrapper udp_socket(udp_socket_.Pass(), 4, 4); | 
|  |  | 
|  | TestCallbackWithAddress callback1; | 
|  | udp_socket.Bind(GetLocalHostWithAnyPort(), callback1.callback()); | 
|  | callback1.WaitForResult(); | 
|  | ASSERT_EQ(net::OK, callback1.result()->code); | 
|  | ASSERT_NE(0u, callback1.net_address()->ipv4->port); | 
|  |  | 
|  | NetAddressPtr server_addr = callback1.net_address().Clone(); | 
|  |  | 
|  | UDPSocketPtr raw_client_socket; | 
|  | network_service_->CreateUDPSocket(GetProxy(&raw_client_socket)); | 
|  | UDPSocketWrapper client_socket(raw_client_socket.Pass(), 4, 4); | 
|  |  | 
|  | TestCallbackWithAddress callback2; | 
|  | client_socket.Bind(GetLocalHostWithAnyPort(), callback2.callback()); | 
|  | callback2.WaitForResult(); | 
|  | ASSERT_EQ(net::OK, callback2.result()->code); | 
|  | ASSERT_NE(0u, callback2.net_address()->ipv4->port); | 
|  |  | 
|  | NetAddressPtr client_addr = callback2.net_address().Clone(); | 
|  |  | 
|  | const size_t kDatagramCount = 16; | 
|  | const size_t kDatagramSize = 255; | 
|  |  | 
|  | for (size_t i = 1; i < kDatagramCount; ++i) { | 
|  | scoped_ptr<TestCallback[]> send_callbacks(new TestCallback[i]); | 
|  | scoped_ptr<TestReceiveCallback[]> receive_callbacks( | 
|  | new TestReceiveCallback[i]); | 
|  |  | 
|  | for (size_t j = 0; j < i; ++j) { | 
|  | client_socket.SendTo( | 
|  | server_addr.Clone(), | 
|  | CreateTestMessage(static_cast<uint8_t>(j), kDatagramSize), | 
|  | send_callbacks[j].callback()); | 
|  |  | 
|  | udp_socket.ReceiveFrom(receive_callbacks[j].callback()); | 
|  | } | 
|  |  | 
|  | receive_callbacks[i - 1].WaitForResult(); | 
|  |  | 
|  | for (size_t j = 0; j < i; ++j) { | 
|  | EXPECT_EQ(static_cast<int>(kDatagramSize), | 
|  | receive_callbacks[j].result()->code); | 
|  | EXPECT_TRUE(receive_callbacks[j].src_addr().Equals(client_addr)); | 
|  | EXPECT_TRUE(receive_callbacks[j].data().Equals( | 
|  | CreateTestMessage(static_cast<uint8_t>(j), kDatagramSize))); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | }  // namespace service | 
|  | }  // namespace mojo |