James Robinson | 646469d | 2014-10-03 15:33:28 -0700 | [diff] [blame] | 1 | // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style license that can be |
| 3 | // found in the LICENSE file. |
| 4 | |
| 5 | #include "tools/android/forwarder2/socket.h" |
| 6 | |
| 7 | #include <arpa/inet.h> |
| 8 | #include <fcntl.h> |
| 9 | #include <netdb.h> |
| 10 | #include <netinet/in.h> |
| 11 | #include <stdio.h> |
| 12 | #include <string.h> |
| 13 | #include <sys/socket.h> |
| 14 | #include <sys/types.h> |
| 15 | #include <unistd.h> |
| 16 | |
| 17 | #include "base/logging.h" |
| 18 | #include "base/posix/eintr_wrapper.h" |
| 19 | #include "base/safe_strerror_posix.h" |
| 20 | #include "tools/android/common/net.h" |
| 21 | #include "tools/android/forwarder2/common.h" |
| 22 | |
| 23 | namespace { |
| 24 | const int kNoTimeout = -1; |
| 25 | const int kConnectTimeOut = 10; // Seconds. |
| 26 | |
| 27 | bool FamilyIsTCP(int family) { |
| 28 | return family == AF_INET || family == AF_INET6; |
| 29 | } |
| 30 | } // namespace |
| 31 | |
| 32 | namespace forwarder2 { |
| 33 | |
| 34 | bool Socket::BindUnix(const std::string& path) { |
| 35 | errno = 0; |
| 36 | if (!InitUnixSocket(path) || !BindAndListen()) { |
| 37 | Close(); |
| 38 | return false; |
| 39 | } |
| 40 | return true; |
| 41 | } |
| 42 | |
| 43 | bool Socket::BindTcp(const std::string& host, int port) { |
| 44 | errno = 0; |
| 45 | if (!InitTcpSocket(host, port) || !BindAndListen()) { |
| 46 | Close(); |
| 47 | return false; |
| 48 | } |
| 49 | return true; |
| 50 | } |
| 51 | |
| 52 | bool Socket::ConnectUnix(const std::string& path) { |
| 53 | errno = 0; |
| 54 | if (!InitUnixSocket(path) || !Connect()) { |
| 55 | Close(); |
| 56 | return false; |
| 57 | } |
| 58 | return true; |
| 59 | } |
| 60 | |
| 61 | bool Socket::ConnectTcp(const std::string& host, int port) { |
| 62 | errno = 0; |
| 63 | if (!InitTcpSocket(host, port) || !Connect()) { |
| 64 | Close(); |
| 65 | return false; |
| 66 | } |
| 67 | return true; |
| 68 | } |
| 69 | |
| 70 | Socket::Socket() |
| 71 | : socket_(-1), |
| 72 | port_(0), |
| 73 | socket_error_(false), |
| 74 | family_(AF_INET), |
| 75 | addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)), |
| 76 | addr_len_(sizeof(sockaddr)) { |
| 77 | memset(&addr_, 0, sizeof(addr_)); |
| 78 | } |
| 79 | |
| 80 | Socket::~Socket() { |
| 81 | Close(); |
| 82 | } |
| 83 | |
| 84 | void Socket::Shutdown() { |
| 85 | if (!IsClosed()) { |
| 86 | PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR)); |
| 87 | } |
| 88 | } |
| 89 | |
| 90 | void Socket::Close() { |
| 91 | if (!IsClosed()) { |
| 92 | CloseFD(socket_); |
| 93 | socket_ = -1; |
| 94 | } |
| 95 | } |
| 96 | |
| 97 | bool Socket::InitSocketInternal() { |
| 98 | socket_ = socket(family_, SOCK_STREAM, 0); |
| 99 | if (socket_ < 0) { |
| 100 | PLOG(ERROR) << "socket"; |
| 101 | return false; |
| 102 | } |
| 103 | tools::DisableNagle(socket_); |
| 104 | int reuse_addr = 1; |
| 105 | setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &reuse_addr, |
| 106 | sizeof(reuse_addr)); |
| 107 | if (!SetNonBlocking()) |
| 108 | return false; |
| 109 | return true; |
| 110 | } |
| 111 | |
| 112 | bool Socket::SetNonBlocking() { |
| 113 | const int flags = fcntl(socket_, F_GETFL); |
| 114 | if (flags < 0) { |
| 115 | PLOG(ERROR) << "fcntl"; |
| 116 | return false; |
| 117 | } |
| 118 | if (flags & O_NONBLOCK) |
| 119 | return true; |
| 120 | if (fcntl(socket_, F_SETFL, flags | O_NONBLOCK) < 0) { |
| 121 | PLOG(ERROR) << "fcntl"; |
| 122 | return false; |
| 123 | } |
| 124 | return true; |
| 125 | } |
| 126 | |
| 127 | bool Socket::InitUnixSocket(const std::string& path) { |
| 128 | static const size_t kPathMax = sizeof(addr_.addr_un.sun_path); |
| 129 | // For abstract sockets we need one extra byte for the leading zero. |
| 130 | if (path.size() + 2 /* '\0' */ > kPathMax) { |
| 131 | LOG(ERROR) << "The provided path is too big to create a unix " |
| 132 | << "domain socket: " << path; |
| 133 | return false; |
| 134 | } |
| 135 | family_ = PF_UNIX; |
| 136 | addr_.addr_un.sun_family = family_; |
| 137 | // Copied from net/socket/unix_domain_socket_posix.cc |
| 138 | // Convert the path given into abstract socket name. It must start with |
| 139 | // the '\0' character, so we are adding it. |addr_len| must specify the |
| 140 | // length of the structure exactly, as potentially the socket name may |
| 141 | // have '\0' characters embedded (although we don't support this). |
| 142 | // Note that addr_.addr_un.sun_path is already zero initialized. |
| 143 | memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size()); |
| 144 | addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1; |
| 145 | addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un); |
| 146 | return InitSocketInternal(); |
| 147 | } |
| 148 | |
| 149 | bool Socket::InitTcpSocket(const std::string& host, int port) { |
| 150 | port_ = port; |
| 151 | if (host.empty()) { |
| 152 | // Use localhost: INADDR_LOOPBACK |
| 153 | family_ = AF_INET; |
| 154 | addr_.addr4.sin_family = family_; |
| 155 | addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK); |
| 156 | } else if (!Resolve(host)) { |
| 157 | return false; |
| 158 | } |
| 159 | CHECK(FamilyIsTCP(family_)) << "Invalid socket family."; |
| 160 | if (family_ == AF_INET) { |
| 161 | addr_.addr4.sin_port = htons(port_); |
| 162 | addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4); |
| 163 | addr_len_ = sizeof(addr_.addr4); |
| 164 | } else if (family_ == AF_INET6) { |
| 165 | addr_.addr6.sin6_port = htons(port_); |
| 166 | addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6); |
| 167 | addr_len_ = sizeof(addr_.addr6); |
| 168 | } |
| 169 | return InitSocketInternal(); |
| 170 | } |
| 171 | |
| 172 | bool Socket::BindAndListen() { |
| 173 | errno = 0; |
| 174 | if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 || |
| 175 | HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) { |
| 176 | PLOG(ERROR) << "bind/listen"; |
| 177 | SetSocketError(); |
| 178 | return false; |
| 179 | } |
| 180 | if (port_ == 0 && FamilyIsTCP(family_)) { |
| 181 | SockAddr addr; |
| 182 | memset(&addr, 0, sizeof(addr)); |
| 183 | socklen_t addrlen = 0; |
| 184 | sockaddr* addr_ptr = NULL; |
| 185 | uint16* port_ptr = NULL; |
| 186 | if (family_ == AF_INET) { |
| 187 | addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4); |
| 188 | port_ptr = &addr.addr4.sin_port; |
| 189 | addrlen = sizeof(addr.addr4); |
| 190 | } else if (family_ == AF_INET6) { |
| 191 | addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6); |
| 192 | port_ptr = &addr.addr6.sin6_port; |
| 193 | addrlen = sizeof(addr.addr6); |
| 194 | } |
| 195 | errno = 0; |
| 196 | if (getsockname(socket_, addr_ptr, &addrlen) != 0) { |
| 197 | PLOG(ERROR) << "getsockname"; |
| 198 | SetSocketError(); |
| 199 | return false; |
| 200 | } |
| 201 | port_ = ntohs(*port_ptr); |
| 202 | } |
| 203 | return true; |
| 204 | } |
| 205 | |
| 206 | bool Socket::Accept(Socket* new_socket) { |
| 207 | DCHECK(new_socket != NULL); |
| 208 | if (!WaitForEvent(READ, kNoTimeout)) { |
| 209 | SetSocketError(); |
| 210 | return false; |
| 211 | } |
| 212 | errno = 0; |
| 213 | int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL)); |
| 214 | if (new_socket_fd < 0) { |
| 215 | SetSocketError(); |
| 216 | return false; |
| 217 | } |
| 218 | tools::DisableNagle(new_socket_fd); |
| 219 | new_socket->socket_ = new_socket_fd; |
| 220 | if (!new_socket->SetNonBlocking()) |
| 221 | return false; |
| 222 | return true; |
| 223 | } |
| 224 | |
| 225 | bool Socket::Connect() { |
| 226 | DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK); |
| 227 | errno = 0; |
| 228 | if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 && |
| 229 | errno != EINPROGRESS) { |
| 230 | SetSocketError(); |
| 231 | return false; |
| 232 | } |
| 233 | // Wait for connection to complete, or receive a notification. |
| 234 | if (!WaitForEvent(WRITE, kConnectTimeOut)) { |
| 235 | SetSocketError(); |
| 236 | return false; |
| 237 | } |
| 238 | int socket_errno; |
| 239 | socklen_t opt_len = sizeof(socket_errno); |
| 240 | if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) { |
| 241 | PLOG(ERROR) << "getsockopt()"; |
| 242 | SetSocketError(); |
| 243 | return false; |
| 244 | } |
| 245 | if (socket_errno != 0) { |
| 246 | LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno); |
| 247 | SetSocketError(); |
| 248 | return false; |
| 249 | } |
| 250 | return true; |
| 251 | } |
| 252 | |
| 253 | bool Socket::Resolve(const std::string& host) { |
| 254 | struct addrinfo hints; |
| 255 | struct addrinfo* res; |
| 256 | memset(&hints, 0, sizeof(hints)); |
| 257 | hints.ai_family = AF_UNSPEC; |
| 258 | hints.ai_socktype = SOCK_STREAM; |
| 259 | hints.ai_flags |= AI_CANONNAME; |
| 260 | |
| 261 | int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res); |
| 262 | if (errcode != 0) { |
| 263 | errno = 0; |
| 264 | SetSocketError(); |
| 265 | freeaddrinfo(res); |
| 266 | return false; |
| 267 | } |
| 268 | family_ = res->ai_family; |
| 269 | switch (res->ai_family) { |
| 270 | case AF_INET: |
| 271 | memcpy(&addr_.addr4, |
| 272 | reinterpret_cast<sockaddr_in*>(res->ai_addr), |
| 273 | sizeof(sockaddr_in)); |
| 274 | break; |
| 275 | case AF_INET6: |
| 276 | memcpy(&addr_.addr6, |
| 277 | reinterpret_cast<sockaddr_in6*>(res->ai_addr), |
| 278 | sizeof(sockaddr_in6)); |
| 279 | break; |
| 280 | } |
| 281 | freeaddrinfo(res); |
| 282 | return true; |
| 283 | } |
| 284 | |
| 285 | int Socket::GetPort() { |
| 286 | if (!FamilyIsTCP(family_)) { |
| 287 | LOG(ERROR) << "Can't call GetPort() on an unix domain socket."; |
| 288 | return 0; |
| 289 | } |
| 290 | return port_; |
| 291 | } |
| 292 | |
| 293 | int Socket::ReadNumBytes(void* buffer, size_t num_bytes) { |
| 294 | int bytes_read = 0; |
| 295 | int ret = 1; |
| 296 | while (bytes_read < num_bytes && ret > 0) { |
| 297 | ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read); |
| 298 | if (ret >= 0) |
| 299 | bytes_read += ret; |
| 300 | } |
| 301 | return bytes_read; |
| 302 | } |
| 303 | |
| 304 | void Socket::SetSocketError() { |
| 305 | socket_error_ = true; |
| 306 | DCHECK_NE(EAGAIN, errno); |
| 307 | DCHECK_NE(EWOULDBLOCK, errno); |
| 308 | Close(); |
| 309 | } |
| 310 | |
| 311 | int Socket::Read(void* buffer, size_t buffer_size) { |
| 312 | if (!WaitForEvent(READ, kNoTimeout)) { |
| 313 | SetSocketError(); |
| 314 | return 0; |
| 315 | } |
| 316 | int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size)); |
| 317 | if (ret < 0) { |
| 318 | PLOG(ERROR) << "read"; |
| 319 | SetSocketError(); |
| 320 | } |
| 321 | return ret; |
| 322 | } |
| 323 | |
| 324 | int Socket::NonBlockingRead(void* buffer, size_t buffer_size) { |
| 325 | DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK); |
| 326 | int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size)); |
| 327 | if (ret < 0) { |
| 328 | PLOG(ERROR) << "read"; |
| 329 | SetSocketError(); |
| 330 | } |
| 331 | return ret; |
| 332 | } |
| 333 | |
| 334 | int Socket::Write(const void* buffer, size_t count) { |
| 335 | if (!WaitForEvent(WRITE, kNoTimeout)) { |
| 336 | SetSocketError(); |
| 337 | return 0; |
| 338 | } |
| 339 | int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL)); |
| 340 | if (ret < 0) { |
| 341 | PLOG(ERROR) << "send"; |
| 342 | SetSocketError(); |
| 343 | } |
| 344 | return ret; |
| 345 | } |
| 346 | |
| 347 | int Socket::NonBlockingWrite(const void* buffer, size_t count) { |
| 348 | DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK); |
| 349 | int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL)); |
| 350 | if (ret < 0) { |
| 351 | PLOG(ERROR) << "send"; |
| 352 | SetSocketError(); |
| 353 | } |
| 354 | return ret; |
| 355 | } |
| 356 | |
| 357 | int Socket::WriteString(const std::string& buffer) { |
| 358 | return WriteNumBytes(buffer.c_str(), buffer.size()); |
| 359 | } |
| 360 | |
| 361 | void Socket::AddEventFd(int event_fd) { |
| 362 | Event event; |
| 363 | event.fd = event_fd; |
| 364 | event.was_fired = false; |
| 365 | events_.push_back(event); |
| 366 | } |
| 367 | |
| 368 | bool Socket::DidReceiveEventOnFd(int fd) const { |
| 369 | for (size_t i = 0; i < events_.size(); ++i) |
| 370 | if (events_[i].fd == fd) |
| 371 | return events_[i].was_fired; |
| 372 | return false; |
| 373 | } |
| 374 | |
| 375 | bool Socket::DidReceiveEvent() const { |
| 376 | for (size_t i = 0; i < events_.size(); ++i) |
| 377 | if (events_[i].was_fired) |
| 378 | return true; |
| 379 | return false; |
| 380 | } |
| 381 | |
| 382 | int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) { |
| 383 | int bytes_written = 0; |
| 384 | int ret = 1; |
| 385 | while (bytes_written < num_bytes && ret > 0) { |
| 386 | ret = Write(static_cast<const char*>(buffer) + bytes_written, |
| 387 | num_bytes - bytes_written); |
| 388 | if (ret >= 0) |
| 389 | bytes_written += ret; |
| 390 | } |
| 391 | return bytes_written; |
| 392 | } |
| 393 | |
| 394 | bool Socket::WaitForEvent(EventType type, int timeout_secs) { |
| 395 | if (socket_ == -1) |
| 396 | return true; |
| 397 | DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK); |
| 398 | fd_set read_fds; |
| 399 | fd_set write_fds; |
| 400 | FD_ZERO(&read_fds); |
| 401 | FD_ZERO(&write_fds); |
| 402 | if (type == READ) |
| 403 | FD_SET(socket_, &read_fds); |
| 404 | else |
| 405 | FD_SET(socket_, &write_fds); |
| 406 | for (size_t i = 0; i < events_.size(); ++i) |
| 407 | FD_SET(events_[i].fd, &read_fds); |
| 408 | timeval tv = {}; |
| 409 | timeval* tv_ptr = NULL; |
| 410 | if (timeout_secs > 0) { |
| 411 | tv.tv_sec = timeout_secs; |
| 412 | tv.tv_usec = 0; |
| 413 | tv_ptr = &tv; |
| 414 | } |
| 415 | int max_fd = socket_; |
| 416 | for (size_t i = 0; i < events_.size(); ++i) |
| 417 | if (events_[i].fd > max_fd) |
| 418 | max_fd = events_[i].fd; |
| 419 | if (HANDLE_EINTR( |
| 420 | select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) { |
| 421 | PLOG(ERROR) << "select"; |
| 422 | return false; |
| 423 | } |
| 424 | bool event_was_fired = false; |
| 425 | for (size_t i = 0; i < events_.size(); ++i) { |
| 426 | if (FD_ISSET(events_[i].fd, &read_fds)) { |
| 427 | events_[i].was_fired = true; |
| 428 | event_was_fired = true; |
| 429 | } |
| 430 | } |
| 431 | return !event_was_fired; |
| 432 | } |
| 433 | |
| 434 | // static |
| 435 | pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) { |
| 436 | Socket socket; |
| 437 | if (!socket.ConnectUnix(path)) |
| 438 | return -1; |
| 439 | ucred ucred; |
| 440 | socklen_t len = sizeof(ucred); |
| 441 | if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) { |
| 442 | CHECK_NE(ENOPROTOOPT, errno); |
| 443 | return -1; |
| 444 | } |
| 445 | return ucred.pid; |
| 446 | } |
| 447 | |
| 448 | } // namespace forwarder2 |