OpenRCT2/src/openrct2/network/TcpSocket.cpp

549 lines
15 KiB
C++
Raw Normal View History

/*****************************************************************************
* Copyright (c) 2014-2018 OpenRCT2 developers
*
* For a complete list of all authors, please refer to contributors.md
* Interested in contributing? Visit https://github.com/OpenRCT2/OpenRCT2
*
* OpenRCT2 is licensed under the GNU General Public License version 3.
*****************************************************************************/
#ifndef DISABLE_NETWORK
2018-07-21 16:17:06 +02:00
# include <chrono>
# include <cmath>
# include <cstring>
# include <future>
# include <string>
# include <thread>
// clang-format off
2017-12-03 23:45:43 +01:00
// MSVC: include <math.h> here otherwise PI gets defined twice
#include <cmath>
#ifdef _WIN32
// winsock2 must be included before windows.h
#include <winsock2.h>
#include <ws2tcpip.h>
#define LAST_SOCKET_ERROR() WSAGetLastError()
#undef EWOULDBLOCK
#define EWOULDBLOCK WSAEWOULDBLOCK
#ifndef SHUT_RD
#define SHUT_RD SD_RECEIVE
#endif
#ifndef SHUT_RDWR
#define SHUT_RDWR SD_BOTH
#endif
#define FLAG_NO_PIPE 0
#else
2018-01-04 00:36:33 +01:00
#include <cerrno>
#include <arpa/inet.h>
#include <netdb.h>
#include <netinet/tcp.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <fcntl.h>
2017-01-05 23:43:19 +01:00
#include "../common.h"
using SOCKET = int32_t;
#define SOCKET_ERROR -1
#define INVALID_SOCKET -1
#define LAST_SOCKET_ERROR() errno
#define closesocket close
#define ioctlsocket ioctl
2017-06-12 19:00:15 +02:00
#if defined(__linux__)
#define FLAG_NO_PIPE MSG_NOSIGNAL
#else
#define FLAG_NO_PIPE 0
#endif // defined(__linux__)
#endif // _WIN32
// clang-format on
2018-07-21 16:17:06 +02:00
# include "TcpSocket.h"
constexpr auto CONNECT_TIMEOUT = std::chrono::milliseconds(3000);
2018-07-21 16:17:06 +02:00
# ifdef _WIN32
2018-06-22 23:02:47 +02:00
static bool _wsaInitialised = false;
2018-07-21 16:17:06 +02:00
# endif
2017-02-08 13:53:00 +01:00
class TcpSocket;
2018-01-02 20:23:22 +01:00
class SocketException : public std::runtime_error
{
public:
2018-06-22 23:02:47 +02:00
explicit SocketException(const std::string& message)
: std::runtime_error(message)
{
}
};
class TcpSocket final : public ITcpSocket
{
private:
2018-06-22 23:02:47 +02:00
SOCKET_STATUS _status = SOCKET_STATUS_CLOSED;
uint16_t _listeningPort = 0;
SOCKET _socket = INVALID_SOCKET;
2018-06-22 23:02:47 +02:00
std::string _hostName;
std::future<void> _connectFuture;
std::string _error;
public:
TcpSocket() = default;
~TcpSocket() override
{
if (_connectFuture.valid())
{
_connectFuture.wait();
}
CloseSocket();
}
SOCKET_STATUS GetStatus() override
{
return _status;
}
2018-06-22 23:02:47 +02:00
const char* GetError() override
{
return _error.empty() ? nullptr : _error.c_str();
}
void Listen(uint16_t port) override
{
Listen(nullptr, port);
}
2018-06-22 23:02:47 +02:00
void Listen(const char* address, uint16_t port) override
{
if (_status != SOCKET_STATUS_CLOSED)
{
2018-01-02 20:23:22 +01:00
throw std::runtime_error("Socket not closed.");
}
2018-06-22 23:02:47 +02:00
sockaddr_storage ss{};
int32_t ss_len;
if (!ResolveAddress(address, port, &ss, &ss_len))
{
throw SocketException("Unable to resolve address.");
}
// Create the listening socket
_socket = socket(ss.ss_family, SOCK_STREAM, IPPROTO_TCP);
if (_socket == INVALID_SOCKET)
{
throw SocketException("Unable to create socket.");
}
// Turn off IPV6_V6ONLY so we can accept both v4 and v6 connections
int32_t value = 0;
if (setsockopt(_socket, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&value, sizeof(value)) != 0)
{
log_error("IPV6_V6ONLY failed. %d", LAST_SOCKET_ERROR());
}
value = 1;
if (setsockopt(_socket, SOL_SOCKET, SO_REUSEADDR, (const char*)&value, sizeof(value)) != 0)
{
log_error("SO_REUSEADDR failed. %d", LAST_SOCKET_ERROR());
}
try
{
// Bind to address:port and listen
2018-06-22 23:02:47 +02:00
if (bind(_socket, (sockaddr*)&ss, ss_len) != 0)
{
throw SocketException("Unable to bind to socket.");
}
if (listen(_socket, SOMAXCONN) != 0)
{
throw SocketException("Unable to listen on socket.");
}
if (!SetNonBlocking(_socket, true))
{
throw SocketException("Failed to set non-blocking mode.");
}
}
2018-06-22 23:02:47 +02:00
catch (const std::exception&)
{
CloseSocket();
throw;
}
_listeningPort = port;
_status = SOCKET_STATUS_LISTENING;
}
2018-12-17 12:58:12 +01:00
std::unique_ptr<ITcpSocket> Accept() override
{
if (_status != SOCKET_STATUS_LISTENING)
{
2018-01-02 20:23:22 +01:00
throw std::runtime_error("Socket not listening.");
}
2018-06-22 23:02:47 +02:00
struct sockaddr_storage client_addr
{
};
2016-10-12 15:44:20 +02:00
socklen_t client_len = sizeof(struct sockaddr_storage);
2018-12-17 12:58:12 +01:00
std::unique_ptr<ITcpSocket> tcpSocket;
2018-06-22 23:02:47 +02:00
SOCKET socket = accept(_socket, (struct sockaddr*)&client_addr, &client_len);
if (socket == INVALID_SOCKET)
{
if (LAST_SOCKET_ERROR() != EWOULDBLOCK)
{
log_error("Failed to accept client.");
}
}
else
{
if (!SetNonBlocking(socket, true))
{
closesocket(socket);
log_error("Failed to set non-blocking mode.");
}
else
{
char hostName[NI_MAXHOST];
int32_t rc = getnameinfo(
(struct sockaddr*)&client_addr, client_len, hostName, sizeof(hostName), nullptr, 0,
2016-11-13 20:17:49 +01:00
NI_NUMERICHOST | NI_NUMERICSERV);
SetTCPNoDelay(socket, true);
2016-10-12 23:54:13 +02:00
if (rc == 0)
{
2018-12-17 12:58:12 +01:00
tcpSocket = std::unique_ptr<ITcpSocket>(new TcpSocket(socket, hostName));
}
else
{
2018-12-17 12:58:12 +01:00
tcpSocket = std::unique_ptr<ITcpSocket>(new TcpSocket(socket, ""));
}
}
}
return tcpSocket;
}
2018-06-22 23:02:47 +02:00
void Connect(const char* address, uint16_t port) override
{
if (_status != SOCKET_STATUS_CLOSED)
{
2018-01-02 20:23:22 +01:00
throw std::runtime_error("Socket not closed.");
}
2016-11-13 20:17:49 +01:00
try
{
// Resolve address
_status = SOCKET_STATUS_RESOLVING;
2018-06-22 23:02:47 +02:00
sockaddr_storage ss{};
int32_t ss_len;
if (!ResolveAddress(address, port, &ss, &ss_len))
{
throw SocketException("Unable to resolve address.");
}
_status = SOCKET_STATUS_CONNECTING;
_socket = socket(ss.ss_family, SOCK_STREAM, IPPROTO_TCP);
if (_socket == INVALID_SOCKET)
{
throw SocketException("Unable to create socket.");
}
SetTCPNoDelay(_socket, true);
if (!SetNonBlocking(_socket, true))
{
throw SocketException("Failed to set non-blocking mode.");
}
// Connect
2018-06-22 23:02:47 +02:00
int32_t connectResult = connect(_socket, (sockaddr*)&ss, ss_len);
if (connectResult != SOCKET_ERROR || (LAST_SOCKET_ERROR() != EINPROGRESS && LAST_SOCKET_ERROR() != EWOULDBLOCK))
{
throw SocketException("Failed to connect.");
}
auto connectStartTime = std::chrono::system_clock::now();
int32_t error = 0;
socklen_t len = sizeof(error);
2018-06-22 23:02:47 +02:00
if (getsockopt(_socket, SOL_SOCKET, SO_ERROR, (char*)&error, &len) != 0)
{
throw SocketException("getsockopt failed with error: " + std::to_string(LAST_SOCKET_ERROR()));
}
if (error != 0)
{
throw SocketException("Connection failed: " + std::to_string(error));
}
do
{
// Sleep for a bit
std::this_thread::sleep_for(std::chrono::milliseconds(100));
fd_set writeFD;
FD_ZERO(&writeFD);
2018-07-21 16:17:06 +02:00
# pragma warning(push)
# pragma warning(disable : 4548) // expression before comma has no effect; expected expression with side-effect
FD_SET(_socket, &writeFD);
2018-07-21 16:17:06 +02:00
# pragma warning(pop)
2018-06-22 23:02:47 +02:00
timeval timeout{};
timeout.tv_sec = 0;
timeout.tv_usec = 0;
if (select((int32_t)(_socket + 1), nullptr, &writeFD, nullptr, &timeout) > 0)
{
error = 0;
2017-01-12 18:36:05 +01:00
len = sizeof(error);
if (getsockopt(_socket, SOL_SOCKET, SO_ERROR, (char*)&error, &len) != 0)
{
throw SocketException("getsockopt failed with error: " + std::to_string(LAST_SOCKET_ERROR()));
}
if (error == 0)
{
_status = SOCKET_STATUS_CONNECTED;
return;
}
}
} while ((std::chrono::system_clock::now() - connectStartTime) < CONNECT_TIMEOUT);
// Connection request timed out
throw SocketException("Connection timed out.");
}
2018-06-22 23:02:47 +02:00
catch (const std::exception&)
{
CloseSocket();
throw;
}
}
2018-06-22 23:02:47 +02:00
void ConnectAsync(const char* address, uint16_t port) override
{
if (_status != SOCKET_STATUS_CLOSED)
{
2018-01-02 20:23:22 +01:00
throw std::runtime_error("Socket not closed.");
}
auto saddress = std::string(address);
std::promise<void> barrier;
_connectFuture = barrier.get_future();
2018-06-22 23:02:47 +02:00
auto thread = std::thread(
[this, saddress, port](std::promise<void> barrier2) -> void {
try
{
Connect(saddress.c_str(), port);
}
catch (const std::exception& ex)
{
_error = std::string(ex.what());
}
barrier2.set_value();
},
std::move(barrier));
thread.detach();
}
void Disconnect() override
{
if (_status == SOCKET_STATUS_CONNECTED)
{
shutdown(_socket, SHUT_RDWR);
}
}
2018-06-22 23:02:47 +02:00
size_t SendData(const void* buffer, size_t size) override
{
if (_status != SOCKET_STATUS_CONNECTED)
{
2018-01-02 20:23:22 +01:00
throw std::runtime_error("Socket not connected.");
}
size_t totalSent = 0;
do
{
2018-06-22 23:02:47 +02:00
const char* bufferStart = (const char*)buffer + totalSent;
size_t remainingSize = size - totalSent;
int32_t sentBytes = send(_socket, bufferStart, (int32_t)remainingSize, FLAG_NO_PIPE);
if (sentBytes == SOCKET_ERROR)
{
2016-06-10 00:04:02 +02:00
return totalSent;
}
totalSent += sentBytes;
} while (totalSent < size);
2016-06-10 00:04:02 +02:00
return totalSent;
}
2018-06-22 23:02:47 +02:00
NETWORK_READPACKET ReceiveData(void* buffer, size_t size, size_t* sizeReceived) override
{
if (_status != SOCKET_STATUS_CONNECTED)
{
2018-01-02 20:23:22 +01:00
throw std::runtime_error("Socket not connected.");
}
2018-06-22 23:02:47 +02:00
int32_t readBytes = recv(_socket, (char*)buffer, (int32_t)size, 0);
2016-07-24 22:01:14 +02:00
if (readBytes == 0)
{
*sizeReceived = 0;
return NETWORK_READPACKET_DISCONNECTED;
}
else if (readBytes == SOCKET_ERROR)
{
*sizeReceived = 0;
2018-07-21 16:17:06 +02:00
# ifndef _WIN32
// Removing the check for EAGAIN and instead relying on the values being the same allows turning on of
// -Wlogical-op warning.
// This is not true on Windows, see:
// * https://msdn.microsoft.com/en-us/library/windows/desktop/ms737828(v=vs.85).aspx
// * https://msdn.microsoft.com/en-us/library/windows/desktop/ms741580(v=vs.85).aspx
// * https://msdn.microsoft.com/en-us/library/windows/desktop/ms740668(v=vs.85).aspx
2018-06-22 23:02:47 +02:00
static_assert(
EWOULDBLOCK == EAGAIN,
"Portability note: your system has different values for EWOULDBLOCK "
"and EAGAIN, please extend the condition below");
2018-07-21 16:17:06 +02:00
# endif // _WIN32
if (LAST_SOCKET_ERROR() != EWOULDBLOCK)
{
return NETWORK_READPACKET_DISCONNECTED;
}
else
{
return NETWORK_READPACKET_NO_DATA;
}
}
else
{
*sizeReceived = readBytes;
return NETWORK_READPACKET_SUCCESS;
}
}
void Close() override
{
if (_connectFuture.valid())
{
_connectFuture.wait();
}
CloseSocket();
}
2018-06-22 23:02:47 +02:00
const char* GetHostName() const override
2016-10-12 15:44:20 +02:00
{
return _hostName.empty() ? nullptr : _hostName.c_str();
2016-10-12 15:44:20 +02:00
}
private:
explicit TcpSocket(SOCKET socket, const std::string& hostName)
{
_socket = socket;
_hostName = hostName;
2017-02-08 13:53:00 +01:00
_status = SOCKET_STATUS_CONNECTED;
}
void CloseSocket()
{
if (_socket != INVALID_SOCKET)
{
closesocket(_socket);
_socket = INVALID_SOCKET;
}
_status = SOCKET_STATUS_CLOSED;
}
2018-06-22 23:02:47 +02:00
bool ResolveAddress(const char* address, uint16_t port, sockaddr_storage* ss, int32_t* ss_len)
{
std::string serviceName = std::to_string(port);
addrinfo hints = {};
hints.ai_family = AF_UNSPEC;
if (address == nullptr)
{
hints.ai_flags = AI_PASSIVE;
}
2018-06-22 23:02:47 +02:00
addrinfo* result = nullptr;
2017-09-10 05:48:29 +02:00
int errorcode = getaddrinfo(address, serviceName.c_str(), &hints, &result);
if (errorcode != 0)
{
log_error("Resolving address failed: Code %d.", errorcode);
log_error("Resolution error message: %s.", gai_strerror(errorcode));
return false;
}
if (result == nullptr)
{
return false;
}
else
{
std::memcpy(ss, result->ai_addr, result->ai_addrlen);
*ss_len = (int32_t)result->ai_addrlen;
2016-06-04 00:29:25 +02:00
freeaddrinfo(result);
return true;
}
}
static bool SetNonBlocking(SOCKET socket, bool on)
{
2018-07-21 16:17:06 +02:00
# ifdef _WIN32
u_long nonBlocking = on;
return ioctlsocket(socket, FIONBIO, &nonBlocking) == 0;
2018-07-21 16:17:06 +02:00
# else
int32_t flags = fcntl(socket, F_GETFL, 0);
return fcntl(socket, F_SETFL, on ? (flags | O_NONBLOCK) : (flags & ~O_NONBLOCK)) == 0;
2018-07-21 16:17:06 +02:00
# endif
}
static bool SetTCPNoDelay(SOCKET socket, bool enabled)
{
return setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (const char*)&enabled, sizeof(enabled)) == 0;
}
};
2018-12-17 12:58:12 +01:00
std::unique_ptr<ITcpSocket> CreateTcpSocket()
{
2018-12-17 12:58:12 +01:00
return std::make_unique<TcpSocket>();
}
2017-02-08 13:53:00 +01:00
bool InitialiseWSA()
{
2018-07-21 16:17:06 +02:00
# ifdef _WIN32
2017-02-08 13:53:00 +01:00
if (!_wsaInitialised)
{
log_verbose("Initialising WSA");
WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) != 0)
{
log_error("Unable to initialise winsock.");
return false;
}
_wsaInitialised = true;
}
return _wsaInitialised;
2018-07-21 16:17:06 +02:00
# else
return true;
2018-07-21 16:17:06 +02:00
# endif
2017-02-08 13:53:00 +01:00
}
void DisposeWSA()
{
2018-07-21 16:17:06 +02:00
# ifdef _WIN32
2017-02-08 13:53:00 +01:00
if (_wsaInitialised)
{
WSACleanup();
_wsaInitialised = false;
}
2018-07-21 16:17:06 +02:00
# endif
2017-02-08 13:53:00 +01:00
}
namespace Convert
{
uint16_t HostToNetwork(uint16_t value)
2017-02-08 13:53:00 +01:00
{
return htons(value);
}
uint16_t NetworkToHost(uint16_t value)
2017-02-08 13:53:00 +01:00
{
return ntohs(value);
}
2018-05-04 22:40:09 +02:00
} // namespace Convert
2017-02-08 13:53:00 +01:00
#endif