refactor network, create ITcpSocket

Abstracts all socket code into a new class TcpSocket which is only exposed by a light interface, ITcpSocket. This now means that platform specific headers like winsock2.h and sys/socket.h do not have to be included in OpenRCT2 header files reducing include load and other issues.
This commit is contained in:
Ted John 2016-06-01 22:58:21 +01:00
parent 14de1cd5eb
commit 8dfbabbd07
11 changed files with 674 additions and 486 deletions

View File

@ -84,13 +84,13 @@
<ClCompile Include="src\network\http.cpp" />
<ClCompile Include="src\network\network.cpp" />
<ClCompile Include="src\network\NetworkAction.cpp" />
<ClCompile Include="src\network\NetworkAddress.cpp" />
<ClCompile Include="src\network\NetworkConnection.cpp" />
<ClCompile Include="src\network\NetworkGroup.cpp" />
<ClCompile Include="src\network\NetworkKey.cpp" />
<ClCompile Include="src\network\NetworkPacket.cpp" />
<ClCompile Include="src\network\NetworkPlayer.cpp" />
<ClCompile Include="src\network\NetworkUser.cpp" />
<ClCompile Include="src\network\TcpSocket.cpp" />
<ClCompile Include="src\network\twitch.cpp" />
<ClCompile Include="src\object.c" />
<ClCompile Include="src\object_list.c" />
@ -365,13 +365,13 @@
<ClInclude Include="src\management\research.h" />
<ClInclude Include="src\network\http.h" />
<ClInclude Include="src\network\NetworkAction.h" />
<ClInclude Include="src\network\NetworkAddress.h" />
<ClInclude Include="src\network\NetworkConnection.h" />
<ClInclude Include="src\network\NetworkGroup.h" />
<ClInclude Include="src\network\NetworkPacket.h" />
<ClInclude Include="src\network\NetworkPlayer.h" />
<ClInclude Include="src\network\NetworkTypes.h" />
<ClInclude Include="src\network\NetworkUser.h" />
<ClInclude Include="src\network\TcpSocket.h" />
<ClInclude Include="src\network\twitch.h" />
<ClInclude Include="src\network\network.h" />
<ClInclude Include="src\network\NetworkKey.h" />

View File

@ -19,23 +19,26 @@
#include "../common.h"
#include <exception>
#include <string>
class Exception : public std::exception
{
public:
Exception() : Exception(nullptr) { }
Exception(const char * message) : std::exception()
Exception(const char * message) : Exception(std::string(message)) { }
Exception(const std::string &message) : std::exception()
{
_message = message;
}
virtual ~Exception() { }
const char * what() const throw() override { return _message; }
const char * GetMessage() const { return _message; }
const char * GetMsg() const { return _message; }
const char * what() const throw() override { return _message.c_str(); }
const char * GetMessage() const { return _message.c_str(); }
const char * GetMsg() const { return _message.c_str(); }
private:
const char * _message;
std::string _message;
};

View File

@ -1,127 +0,0 @@
#pragma region Copyright (c) 2014-2016 OpenRCT2 Developers
/*****************************************************************************
* OpenRCT2, an open source clone of Roller Coaster Tycoon 2.
*
* OpenRCT2 is the work of many authors, a full list can be found in contributors.md
* For more information, visit https://github.com/OpenRCT2/OpenRCT2
*
* OpenRCT2 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* A full copy of the GNU General Public License can be found in licence.txt
*****************************************************************************/
#pragma endregion
#ifndef DISABLE_NETWORK
#include <string>
#include <SDL_thread.h>
#include "NetworkAddress.h"
NetworkAddress::NetworkAddress()
{
_result = std::make_shared<ResolveResult>();
_result->status = RESOLVE_NONE;
_resolveMutex = SDL_CreateMutex();
}
NetworkAddress::~NetworkAddress()
{
SDL_DestroyMutex(_resolveMutex);
}
void NetworkAddress::Resolve(const char * host, uint16 port)
{
SDL_LockMutex(_resolveMutex);
{
// Create a new result store
_result = std::make_shared<ResolveResult>();
_result->status = RESOLVE_INPROGRESS;
// Create a new request
auto req = new ResolveRequest();
req->Host = std::string(host == nullptr ? "" : host);
req->Port = port;;
req->Result = _result;
// Resolve synchronously
ResolveWorker(req);
}
SDL_UnlockMutex(_resolveMutex);
}
void NetworkAddress::ResolveAsync(const char * host, uint16 port)
{
SDL_LockMutex(_resolveMutex);
{
// Create a new result store
_result = std::make_shared<ResolveResult>();
_result->status = RESOLVE_INPROGRESS;
// Create a new request
auto req = new ResolveRequest();
req->Host = std::string(host);
req->Port = port;
req->Result = _result;
// Spin off a worker thread for resolving the address
SDL_CreateThread([](void * pointer) -> int
{
ResolveWorker((ResolveRequest *)pointer);
return 0;
}, 0, req);
}
SDL_UnlockMutex(_resolveMutex);
}
NetworkAddress::RESOLVE_STATUS NetworkAddress::GetResult(sockaddr_storage * ss, int * ss_len)
{
SDL_LockMutex(_resolveMutex);
{
const ResolveResult * result = _result.get();
if (result->status == RESOLVE_OK)
{
*ss = result->ss;
*ss_len = result->ss_len;
}
return result->status;
}
SDL_UnlockMutex(_resolveMutex);
}
void NetworkAddress::ResolveWorker(ResolveRequest * req)
{
// Resolve the address
const char * nodeName = req->Host.c_str();
std::string serviceName = std::to_string(req->Port);
addrinfo hints = { 0 };
hints.ai_family = AF_UNSPEC;
if (req->Host.empty())
{
hints.ai_flags = AI_PASSIVE;
nodeName = nullptr;
}
addrinfo * result;
getaddrinfo(nodeName, serviceName.c_str(), &hints, &result);
// Store the result
ResolveResult * resolveResult = req->Result.get();
if (result != nullptr)
{
resolveResult->status = RESOLVE_OK;
memcpy(&resolveResult->ss, result->ai_addr, result->ai_addrlen);
resolveResult->ss_len = result->ai_addrlen;
freeaddrinfo(result);
}
else
{
resolveResult->status = RESOLVE_FAILED;
}
delete req;
}
#endif

View File

@ -1,71 +0,0 @@
#pragma region Copyright (c) 2014-2016 OpenRCT2 Developers
/*****************************************************************************
* OpenRCT2, an open source clone of Roller Coaster Tycoon 2.
*
* OpenRCT2 is the work of many authors, a full list can be found in contributors.md
* For more information, visit https://github.com/OpenRCT2/OpenRCT2
*
* OpenRCT2 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* A full copy of the GNU General Public License can be found in licence.txt
*****************************************************************************/
#pragma endregion
#pragma once
#include <memory>
#include <SDL_mutex.h>
#include "NetworkTypes.h"
#include "../common.h"
class NetworkAddress final
{
public:
enum RESOLVE_STATUS
{
RESOLVE_NONE,
RESOLVE_INPROGRESS,
RESOLVE_OK,
RESOLVE_FAILED
};
NetworkAddress();
~NetworkAddress();
void Resolve(const char * host, uint16 port);
void ResolveAsync(const char * host, uint16 port);
RESOLVE_STATUS GetResult(sockaddr_storage * ss, int * ss_len);
private:
struct ResolveResult
{
RESOLVE_STATUS status;
sockaddr_storage ss;
int ss_len;
};
struct ResolveRequest
{
std::string Host;
uint16 Port;
std::shared_ptr<ResolveResult> Result;
};
/**
* Store for the async result. A new store is created for every request.
* Old requests simply write to an old store that will then be
* automatically deleted by std::shared_ptr.
*/
std::shared_ptr<ResolveResult> _result;
/**
* Mutex so synchronoise the requests.
*/
SDL_mutex * _resolveMutex;
static void ResolveWorker(ResolveRequest * req);
};

View File

@ -16,6 +16,7 @@
#ifndef DISABLE_NETWORK
#include "Network.h"
#include "NetworkConnection.h"
#include "../core/String.hpp"
#include <SDL.h>
@ -34,10 +35,7 @@ NetworkConnection::NetworkConnection()
NetworkConnection::~NetworkConnection()
{
if (Socket != INVALID_SOCKET)
{
closesocket(Socket);
}
delete Socket;
if (_lastDisconnectReason)
{
delete[] _lastDisconnectReason;
@ -49,22 +47,19 @@ int NetworkConnection::ReadPacket()
if (InboundPacket.transferred < sizeof(InboundPacket.size))
{
// read packet size
int readBytes = recv(Socket, &((char*)&InboundPacket.size)[InboundPacket.transferred], sizeof(InboundPacket.size) - InboundPacket.transferred, 0);
if (readBytes == SOCKET_ERROR || readBytes == 0)
void * buffer = &((char*)&InboundPacket.size)[InboundPacket.transferred];
size_t bufferLength = sizeof(InboundPacket.size) - InboundPacket.transferred;
size_t readBytes;
NETWORK_READPACKET status = Socket->ReceiveData(buffer, bufferLength, &readBytes);
if (status != NETWORK_READPACKET_SUCCESS)
{
if (LAST_SOCKET_ERROR() != EWOULDBLOCK && LAST_SOCKET_ERROR() != EAGAIN)
{
return NETWORK_READPACKET_DISCONNECTED;
}
else
{
return NETWORK_READPACKET_NO_DATA;
}
return status;
}
InboundPacket.transferred += readBytes;
if (InboundPacket.transferred == sizeof(InboundPacket.size))
{
InboundPacket.size = ntohs(InboundPacket.size);
InboundPacket.size = Convert::NetworkToHost(InboundPacket.size);
if (InboundPacket.size == 0) // Can't have a size 0 packet
{
return NETWORK_READPACKET_DISCONNECTED;
@ -77,21 +72,15 @@ int NetworkConnection::ReadPacket()
// read packet data
if (InboundPacket.data->capacity() > 0)
{
int readBytes = recv(Socket,
(char*)&InboundPacket.GetData()[InboundPacket.transferred - sizeof(InboundPacket.size)],
sizeof(InboundPacket.size) + InboundPacket.size - InboundPacket.transferred,
0);
if (readBytes == SOCKET_ERROR || readBytes == 0)
void * buffer = &InboundPacket.GetData()[InboundPacket.transferred - sizeof(InboundPacket.size)];
size_t bufferLength = sizeof(InboundPacket.size) + InboundPacket.size - InboundPacket.transferred;
size_t readBytes;
NETWORK_READPACKET status = Socket->ReceiveData(buffer, bufferLength, &readBytes);
if (status != NETWORK_READPACKET_SUCCESS)
{
if (LAST_SOCKET_ERROR() != EWOULDBLOCK && LAST_SOCKET_ERROR() != EAGAIN)
{
return NETWORK_READPACKET_DISCONNECTED;
}
else
{
return NETWORK_READPACKET_NO_DATA;
}
return status;
}
InboundPacket.transferred += readBytes;
}
if (InboundPacket.transferred == sizeof(InboundPacket.size) + InboundPacket.size)
@ -105,25 +94,23 @@ int NetworkConnection::ReadPacket()
bool NetworkConnection::SendPacket(NetworkPacket& packet)
{
uint16 sizen = htons(packet.size);
uint16 sizen = Convert::HostToNetwork(packet.size);
std::vector<uint8> tosend;
tosend.reserve(sizeof(sizen) + packet.size);
tosend.insert(tosend.end(), (uint8*)&sizen, (uint8*)&sizen + sizeof(sizen));
tosend.insert(tosend.end(), packet.data->begin(), packet.data->end());
while (true)
const void * buffer = &tosend[packet.transferred];
size_t bufferSize = tosend.size() - packet.transferred;
if (Socket->SendData(buffer, bufferSize))
{
int sentBytes = send(Socket, (const char*)&tosend[packet.transferred], tosend.size() - packet.transferred, FLAG_NO_PIPE);
if (sentBytes == SOCKET_ERROR)
{
return false;
}
packet.transferred += sentBytes;
if (packet.transferred == tosend.size())
{
return true;
}
packet.transferred += bufferSize;
return true;
}
else
{
return false;
}
return false;
}
void NetworkConnection::QueuePacket(std::unique_ptr<NetworkPacket> packet, bool front)
@ -150,27 +137,6 @@ void NetworkConnection::SendQueuedPackets()
}
}
bool NetworkConnection::SetTCPNoDelay(bool on)
{
return setsockopt(Socket, IPPROTO_TCP, TCP_NODELAY, (const char*)&on, sizeof(on)) == 0;
}
bool NetworkConnection::SetNonBlocking(bool on)
{
return SetNonBlocking(Socket, on);
}
bool NetworkConnection::SetNonBlocking(SOCKET socket, bool on)
{
#ifdef __WINDOWS__
u_long nonblocking = on;
return ioctlsocket(socket, FIONBIO, &nonblocking) == 0;
#else
int flags = fcntl(socket, F_GETFL, 0);
return fcntl(socket, F_SETFL, on ? (flags | O_NONBLOCK) : (flags & ~O_NONBLOCK)) == 0;
#endif
}
void NetworkConnection::ResetLastPacketTime()
{
_lastPacketTime = SDL_GetTicks();

View File

@ -19,17 +19,20 @@
#include <list>
#include <memory>
#include <vector>
#include "../common.h"
#include "NetworkTypes.h"
#include "NetworkKey.h"
#include "NetworkPacket.h"
#include "../common.h"
#include "TcpSocket.h"
class NetworkPlayer;
class NetworkConnection
{
public:
SOCKET Socket = INVALID_SOCKET;
ITcpSocket * Socket = nullptr;
NetworkPacket InboundPacket;
NETWORK_AUTH AuthStatus = NETWORK_AUTH_NONE;
NetworkPlayer * Player = nullptr;
@ -43,8 +46,6 @@ public:
int ReadPacket();
void QueuePacket(std::unique_ptr<NetworkPacket> packet, bool front = false);
void SendQueuedPackets();
bool SetTCPNoDelay(bool on);
bool SetNonBlocking(bool on);
void ResetLastPacketTime();
bool ReceivedPacketRecently();
@ -52,8 +53,6 @@ public:
void SetLastDisconnectReason(const utf8 * src);
void SetLastDisconnectReason(const rct_string_id string_id, void * args = nullptr);
static bool SetNonBlocking(SOCKET socket, bool on);
private:
std::list<std::unique_ptr<NetworkPacket>> _outboundPackets;
uint32 _lastPacketTime;

View File

@ -20,52 +20,11 @@
#include <SDL_platform.h>
#ifndef DISABLE_NETWORK
#ifdef __WINDOWS__
// winsock2 must be included before windows.h
#include <winsock2.h>
#include <ws2tcpip.h>
#define LAST_SOCKET_ERROR() WSAGetLastError()
#undef EWOULDBLOCK
#define EWOULDBLOCK WSAEWOULDBLOCK
#ifndef SHUT_RD
#define SHUT_RD SD_RECEIVE
#endif
#ifndef SHUT_RDWR
#define SHUT_RDWR SD_BOTH
#endif
#define FLAG_NO_PIPE 0
#else
#include <errno.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <fcntl.h>
typedef int SOCKET;
#define SOCKET_ERROR -1
#define INVALID_SOCKET -1
#define LAST_SOCKET_ERROR() errno
#define closesocket close
#define ioctlsocket ioctl
#if defined(__LINUX__)
#define FLAG_NO_PIPE MSG_NOSIGNAL
#else
#define FLAG_NO_PIPE 0
#endif // defined(__LINUX__)
#endif // __WINDOWS__
#include "../common.h"
#endif
enum NETWORK_READPACKET
{
NETWORK_READPACKET_SUCCESS,
NETWORK_READPACKET_NO_DATA,
NETWORK_READPACKET_MORE_DATA,
NETWORK_READPACKET_DISCONNECTED
};
enum NETWORK_AUTH
{
NETWORK_AUTH_NONE,

462
src/network/TcpSocket.cpp Normal file
View File

@ -0,0 +1,462 @@
#pragma region Copyright (c) 2014-2016 OpenRCT2 Developers
/*****************************************************************************
* OpenRCT2, an open source clone of Roller Coaster Tycoon 2.
*
* OpenRCT2 is the work of many authors, a full list can be found in contributors.md
* For more information, visit https://github.com/OpenRCT2/OpenRCT2
*
* OpenRCT2 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* A full copy of the GNU General Public License can be found in licence.txt
*****************************************************************************/
#pragma endregion
#ifndef DISABLE_NETWORK
// MSVC: include <math.h> here otherwise PI gets defined twice
#include <math.h>
#include <SDL_platform.h>
#include <SDL_thread.h>
#include <SDL_timer.h>
#ifdef __WINDOWS__
// winsock2 must be included before windows.h
#include <winsock2.h>
#include <ws2tcpip.h>
#define LAST_SOCKET_ERROR() WSAGetLastError()
#undef EWOULDBLOCK
#define EWOULDBLOCK WSAEWOULDBLOCK
#ifndef SHUT_RD
#define SHUT_RD SD_RECEIVE
#endif
#ifndef SHUT_RDWR
#define SHUT_RDWR SD_BOTH
#endif
#define FLAG_NO_PIPE 0
#else
#include <errno.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <fcntl.h>
typedef int SOCKET;
#define SOCKET_ERROR -1
#define INVALID_SOCKET -1
#define LAST_SOCKET_ERROR() errno
#define closesocket close
#define ioctlsocket ioctl
#if defined(__LINUX__)
#define FLAG_NO_PIPE MSG_NOSIGNAL
#else
#define FLAG_NO_PIPE 0
#endif // defined(__LINUX__)
#endif // __WINDOWS__
#include "../core/Exception.hpp"
#include "TcpSocket.h"
constexpr uint32 CONNECT_TIMEOUT_MS = 3000;
class TcpSocket;
class SocketException : public Exception
{
public:
SocketException(const char * message) : Exception(message) { }
SocketException(const std::string &message) : Exception(message) { }
};
struct ConnectRequest
{
TcpSocket * TcpSocket;
std::string Address;
uint16 Port;
};
class TcpSocket : public ITcpSocket
{
private:
SOCKET_STATUS _status = SOCKET_STATUS_CLOSED;
uint16 _listeningPort = 0;
SOCKET _socket = INVALID_SOCKET;
SDL_mutex * _connectMutex = nullptr;
std::string _error;
public:
TcpSocket()
{
_connectMutex = SDL_CreateMutex();
}
~TcpSocket() override
{
SDL_LockMutex(_connectMutex);
{
CloseSocket();
}
SDL_UnlockMutex(_connectMutex);
SDL_DestroyMutex(_connectMutex);
}
SOCKET_STATUS GetStatus() override
{
return _status;
}
const char * GetError() override
{
return _error.empty() ? nullptr : _error.c_str();
}
void Listen(uint16 port) override
{
Listen(nullptr, port);
}
void Listen(const char * address, uint16 port) override
{
if (_status != SOCKET_STATUS_CLOSED)
{
throw Exception("Socket not closed.");
}
sockaddr_storage ss;
int ss_len;
if (!ResolveAddress(address, port, &ss, &ss_len))
{
throw SocketException("Unable to resolve address.");
}
// Create the listening socket
_socket = socket(ss.ss_family, SOCK_STREAM, IPPROTO_TCP);
if (_socket == INVALID_SOCKET)
{
throw SocketException("Unable to create socket.");
}
// Turn off IPV6_V6ONLY so we can accept both v4 and v6 connections
int value = 0;
if (setsockopt(_socket, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&value, sizeof(value)) != 0)
{
log_error("IPV6_V6ONLY failed. %d", LAST_SOCKET_ERROR());
}
try
{
// Bind to address:port and listen
if (bind(_socket, (sockaddr *)&ss, ss_len) != 0)
{
throw SocketException("Unable to bind to socket.");
}
if (listen(_socket, SOMAXCONN) != 0)
{
throw SocketException("Unable to listen on socket.");
}
if (!SetNonBlocking(_socket, true))
{
throw SocketException("Failed to set non-blocking mode.");
}
}
catch (Exception ex)
{
CloseSocket();
throw;
}
_listeningPort = port;
_status = SOCKET_STATUS_LISTENING;
}
ITcpSocket * Accept() override
{
if (_status != SOCKET_STATUS_LISTENING)
{
throw Exception("Socket not listening.");
}
ITcpSocket * tcpSocket = nullptr;
SOCKET socket = accept(_socket, nullptr, nullptr);
if (socket == INVALID_SOCKET)
{
if (LAST_SOCKET_ERROR() != EWOULDBLOCK)
{
log_error("Failed to accept client.");
}
}
else
{
if (!SetNonBlocking(socket, true))
{
closesocket(socket);
log_error("Failed to set non-blocking mode.");
}
else
{
SetTCPNoDelay(socket, true);
tcpSocket = new TcpSocket(socket);
}
}
return tcpSocket;
}
void Connect(const char * address, uint16 port) override
{
if (_status != SOCKET_STATUS_CLOSED)
{
throw Exception("Socket not closed.");
}
try
{
// Resolve address
_status = SOCKET_STATUS_RESOLVING;
sockaddr_storage ss;
int ss_len;
if (!ResolveAddress(address, port, &ss, &ss_len))
{
throw SocketException("Unable to resolve address.");
}
_status = SOCKET_STATUS_CONNECTING;
_socket = socket(ss.ss_family, SOCK_STREAM, IPPROTO_TCP);
if (_socket == INVALID_SOCKET)
{
throw SocketException("Unable to create socket.");
}
SetTCPNoDelay(_socket, true);
if (!SetNonBlocking(_socket, true))
{
throw SocketException("Failed to set non-blocking mode.");
}
// Connect
uint32 connectStartTick;
int connectResult = connect(_socket, (sockaddr *)&ss, ss_len);
if (connectResult != SOCKET_ERROR || (LAST_SOCKET_ERROR() != EINPROGRESS &&
LAST_SOCKET_ERROR() != EWOULDBLOCK))
{
throw SocketException("Failed to connect.");
}
connectStartTick = SDL_GetTicks();
int error = 0;
socklen_t len = sizeof(error);
if (getsockopt(_socket, SOL_SOCKET, SO_ERROR, (char *)&error, &len) != 0)
{
throw SocketException("getsockopt failed with error: " + std::to_string(LAST_SOCKET_ERROR()));
}
if (error != 0)
{
throw SocketException("Connection failed: " + std::to_string(error));
}
do
{
// Sleep for a bit
SDL_Delay(100);
fd_set writeFD;
FD_ZERO(&writeFD);
FD_SET(_socket, &writeFD);
timeval timeout;
timeout.tv_sec = 0;
timeout.tv_usec = 0;
if (select(_socket + 1, nullptr, &writeFD, nullptr, &timeout) > 0)
{
error = 0;
socklen_t len = sizeof(error);
if (getsockopt(_socket, SOL_SOCKET, SO_ERROR, (char*)&error, &len) != 0)
{
throw SocketException("getsockopt failed with error: " + std::to_string(LAST_SOCKET_ERROR()));
}
if (error == 0)
{
_status = SOCKET_STATUS_CONNECTED;
return;
}
}
} while (!SDL_TICKS_PASSED(SDL_GetTicks(), connectStartTick + CONNECT_TIMEOUT_MS));
// Connection request timed out
throw SocketException("Connection timed out.");
}
catch (Exception ex)
{
CloseSocket();
throw;
}
}
void ConnectAsync(const char * address, uint16 port) override
{
if (_status != SOCKET_STATUS_CLOSED)
{
throw Exception("Socket not closed.");
}
if (SDL_TryLockMutex(_connectMutex) == 0)
{
// Spin off a worker thread for resolving the address
auto req = new ConnectRequest();
req->TcpSocket = this;
req->Address = std::string(address);
req->Port = port;
SDL_CreateThread([](void * pointer) -> int
{
auto req = (ConnectRequest *)pointer;
try
{
req->TcpSocket->Connect(req->Address.c_str(), req->Port);
}
catch (Exception ex)
{
req->TcpSocket->_error = std::string(ex.GetMsg());
}
delete req;
SDL_UnlockMutex(req->TcpSocket->_connectMutex);
return 0;
}, 0, req);
}
}
void Disconnect() override
{
if (_status == SOCKET_STATUS_CONNECTED)
{
shutdown(_socket, SHUT_RDWR);
}
}
bool SendData(const void * buffer, size_t size) override
{
if (_status != SOCKET_STATUS_CONNECTED)
{
throw Exception("Socket not connected.");
}
size_t totalSent = 0;
do
{
int sentBytes = send(_socket, (const char *)buffer, (int)size, FLAG_NO_PIPE);
if (sentBytes == SOCKET_ERROR)
{
return false;
}
totalSent += sentBytes;
} while (totalSent < size);
return true;
}
NETWORK_READPACKET ReceiveData(void * buffer, size_t size, size_t * sizeReceived) override
{
if (_status != SOCKET_STATUS_CONNECTED)
{
throw Exception("Socket not connected.");
}
int readBytes = recv(_socket, (char *)buffer, size, 0);
if (readBytes == SOCKET_ERROR || readBytes <= 0)
{
*sizeReceived = 0;
if (LAST_SOCKET_ERROR() != EWOULDBLOCK && LAST_SOCKET_ERROR() != EAGAIN)
{
return NETWORK_READPACKET_DISCONNECTED;
}
else
{
return NETWORK_READPACKET_NO_DATA;
}
}
else
{
*sizeReceived = readBytes;
return NETWORK_READPACKET_SUCCESS;
}
}
void Close()
{
SDL_LockMutex(_connectMutex);
{
CloseSocket();
}
SDL_UnlockMutex(_connectMutex);
}
private:
TcpSocket(SOCKET socket)
{
_socket = socket;
_status = SOCKET_STATUS_CONNECTED;
}
void CloseSocket()
{
if (_socket != INVALID_SOCKET)
{
closesocket(_socket);
_socket = INVALID_SOCKET;
}
_status = SOCKET_STATUS_CLOSED;
}
bool ResolveAddress(const char * address, uint16 port, sockaddr_storage * ss, int * ss_len)
{
std::string serviceName = std::to_string(port);
addrinfo hints = { 0 };
hints.ai_family = AF_UNSPEC;
if (address == nullptr)
{
hints.ai_flags = AI_PASSIVE;
}
addrinfo * result;
getaddrinfo(address, serviceName.c_str(), &hints, &result);
if (result == nullptr)
{
return false;
}
else
{
memcpy(ss, result->ai_addr, result->ai_addrlen);
*ss_len = result->ai_addrlen;
return true;
}
}
static bool SetNonBlocking(SOCKET socket, bool on)
{
#ifdef __WINDOWS__
u_long nonBlocking = on;
return ioctlsocket(socket, FIONBIO, &nonBlocking) == 0;
#else
int flags = fcntl(socket, F_GETFL, 0);
return fcntl(socket, F_SETFL, on ? (flags | O_NONBLOCK) : (flags & ~O_NONBLOCK)) == 0;
#endif
}
static bool SetTCPNoDelay(SOCKET socket, bool enabled)
{
return setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (const char*)&enabled, sizeof(enabled)) == 0;
}
};
ITcpSocket * CreateTcpSocket()
{
return new TcpSocket();
}
#endif

63
src/network/TcpSocket.h Normal file
View File

@ -0,0 +1,63 @@
#pragma region Copyright (c) 2014-2016 OpenRCT2 Developers
/*****************************************************************************
* OpenRCT2, an open source clone of Roller Coaster Tycoon 2.
*
* OpenRCT2 is the work of many authors, a full list can be found in contributors.md
* For more information, visit https://github.com/OpenRCT2/OpenRCT2
*
* OpenRCT2 is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* A full copy of the GNU General Public License can be found in licence.txt
*****************************************************************************/
#pragma endregion
#pragma once
#include "../common.h"
enum SOCKET_STATUS
{
SOCKET_STATUS_CLOSED,
SOCKET_STATUS_RESOLVING,
SOCKET_STATUS_CONNECTING,
SOCKET_STATUS_CONNECTED,
SOCKET_STATUS_LISTENING,
};
enum NETWORK_READPACKET
{
NETWORK_READPACKET_SUCCESS,
NETWORK_READPACKET_NO_DATA,
NETWORK_READPACKET_MORE_DATA,
NETWORK_READPACKET_DISCONNECTED
};
/**
* Represents a TCP socket / connection or listener.
*/
interface ITcpSocket
{
public:
virtual ~ITcpSocket() { }
virtual SOCKET_STATUS GetStatus() abstract;
virtual const char * GetError() abstract;
virtual void Listen(uint16 port) abstract;
virtual void Listen(const char * address, uint16 port) abstract;
virtual ITcpSocket * Accept() abstract;
virtual void Connect(const char * address, uint16 port) abstract;
virtual void ConnectAsync(const char * address, uint16 port) abstract;
virtual bool SendData(const void * buffer, size_t size) abstract;
virtual NETWORK_READPACKET ReceiveData(void * buffer, size_t size, size_t * sizeReceived) abstract;
virtual void Disconnect() abstract;
virtual void Close() abstract;
};
ITcpSocket * CreateTcpSocket();

View File

@ -14,6 +14,15 @@
*****************************************************************************/
#pragma endregion
#include <SDL_platform.h>
#ifdef __WINDOWS__
// winsock2 must be included before windows.h
#include <winsock2.h>
#else
#include <arpa/inet.h>
#endif
extern "C" {
#include "../openrct2.h"
#include "../platform/platform.h"
@ -161,14 +170,16 @@ void Network::Close()
return;
}
if (mode == NETWORK_MODE_CLIENT) {
closesocket(server_connection.Socket);
} else
if (mode == NETWORK_MODE_SERVER) {
closesocket(listening_socket);
delete server_connection.Socket;
server_connection.Socket = nullptr;
} else if (mode == NETWORK_MODE_SERVER) {
delete listening_socket;
listening_socket = nullptr;
}
mode = NETWORK_MODE_NONE;
status = NETWORK_STATUS_NONE;
_lastConnectStatus = SOCKET_STATUS_CLOSED;
server_connection.AuthStatus = NETWORK_AUTH_NONE;
server_connection.InboundPacket.Clear();
server_connection.SetLastDisconnectReason(nullptr);
@ -199,14 +210,11 @@ bool Network::BeginClient(const char* host, unsigned short port)
if (!Init())
return false;
server_address.ResolveAsync(host, port);
status = NETWORK_STATUS_RESOLVING;
char str_resolving[256];
format_string(str_resolving, STR_MULTIPLAYER_RESOLVING, NULL);
window_network_status_open(str_resolving, []() -> void {
gNetwork.Close();
});
assert(server_connection.Socket == nullptr);
server_connection.Socket = CreateTcpSocket();
server_connection.Socket->ConnectAsync(host, port);
status = NETWORK_STATUS_CONNECTING;
_lastConnectStatus = SOCKET_STATUS_CLOSED;
BeginChatLog();
@ -272,43 +280,11 @@ bool Network::BeginServer(unsigned short port, const char* address)
_userManager.Load();
NetworkAddress networkaddress;
networkaddress.Resolve(address, port);
sockaddr_storage ss;
int ss_len;
networkaddress.GetResult(&ss, &ss_len);
log_verbose("Begin listening for clients");
listening_socket = socket(ss.ss_family, SOCK_STREAM, IPPROTO_TCP);
if (listening_socket == INVALID_SOCKET) {
log_error("Unable to create socket.");
return false;
}
// Turn off IPV6_V6ONLY so we can accept both v4 and v6 connections
int value = 0;
if (setsockopt(listening_socket, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&value, sizeof(value)) != 0) {
log_error("IPV6_V6ONLY failed. %d", LAST_SOCKET_ERROR());
}
if (bind(listening_socket, (sockaddr *)&ss, ss_len) != 0) {
closesocket(listening_socket);
log_error("Unable to bind to socket.");
return false;
}
if (listen(listening_socket, SOMAXCONN) != 0) {
closesocket(listening_socket);
log_error("Unable to listen on socket.");
return false;
}
if (!NetworkConnection::SetNonBlocking(listening_socket, true)) {
closesocket(listening_socket);
log_error("Failed to set non-blocking mode.");
return false;
}
assert(listening_socket == nullptr);
listening_socket = CreateTcpSocket();
listening_socket->Listen(address, port);
ServerName = gConfigNetwork.server_name;
ServerDescription = gConfigNetwork.server_description;
@ -421,113 +397,74 @@ void Network::UpdateServer()
break;
}
SOCKET socket = accept(listening_socket, NULL, NULL);
if (socket == INVALID_SOCKET) {
if (LAST_SOCKET_ERROR() != EWOULDBLOCK) {
PrintError();
log_error("Failed to accept client.");
}
} else {
if (!NetworkConnection::SetNonBlocking(socket, true)) {
closesocket(socket);
log_error("Failed to set non-blocking mode.");
} else {
AddClient(socket);
}
ITcpSocket * tcpSocket = listening_socket->Accept();
if (tcpSocket != nullptr) {
AddClient(tcpSocket);
}
}
void Network::UpdateClient()
{
bool connectfailed = false;
switch(status){
case NETWORK_STATUS_RESOLVING:{
sockaddr_storage ss;
int ss_len;
NetworkAddress::RESOLVE_STATUS result = server_address.GetResult(&ss, &ss_len);
if (result == NetworkAddress::RESOLVE_OK) {
server_connection.Socket = socket(ss.ss_family, SOCK_STREAM, IPPROTO_TCP);
if (server_connection.Socket == INVALID_SOCKET) {
log_error("Unable to create socket.");
connectfailed = true;
break;
case NETWORK_STATUS_CONNECTING:
{
switch (server_connection.Socket->GetStatus()) {
case SOCKET_STATUS_RESOLVING:
{
if (_lastConnectStatus != SOCKET_STATUS_RESOLVING)
{
_lastConnectStatus = SOCKET_STATUS_RESOLVING;
char str_resolving[256];
format_string(str_resolving, STR_MULTIPLAYER_RESOLVING, NULL);
window_network_status_open(str_resolving, []() -> void {
gNetwork.Close();
});
}
server_connection.SetTCPNoDelay(true);
if (!server_connection.SetNonBlocking(true)) {
log_error("Failed to set non-blocking mode.");
connectfailed = true;
break;
}
if (connect(server_connection.Socket, (sockaddr *)&ss, ss_len) == SOCKET_ERROR &&
(LAST_SOCKET_ERROR() == EINPROGRESS || LAST_SOCKET_ERROR() == EWOULDBLOCK)
) {
break;
}
case SOCKET_STATUS_CONNECTING:
{
if (_lastConnectStatus != SOCKET_STATUS_CONNECTING)
{
_lastConnectStatus = SOCKET_STATUS_CONNECTING;
char str_connecting[256];
format_string(str_connecting, STR_MULTIPLAYER_CONNECTING, NULL);
window_network_status_open(str_connecting, []() -> void {
gNetwork.Close();
});
server_connect_time = SDL_GetTicks();
status = NETWORK_STATUS_CONNECTING;
} else {
log_error("connect() failed %d", LAST_SOCKET_ERROR());
connectfailed = true;
break;
}
} else if (result == NetworkAddress::RESOLVE_INPROGRESS) {
break;
} else {
log_error("Could not resolve address.");
connectfailed = true;
}
}break;
case NETWORK_STATUS_CONNECTING:{
int error = 0;
socklen_t len = sizeof(error);
int result = getsockopt(server_connection.Socket, SOL_SOCKET, SO_ERROR, (char*)&error, &len);
if (result != 0) {
log_error("getsockopt failed with error %d", LAST_SOCKET_ERROR());
break;
}
if (error != 0) {
log_error("Connection failed %d", error);
connectfailed = true;
case NETWORK_STATUS_CONNECTED:
{
status = NETWORK_STATUS_CONNECTED;
server_connection.ResetLastPacketTime();
Client_Send_TOKEN();
char str_authenticating[256];
format_string(str_authenticating, STR_MULTIPLAYER_AUTHENTICATING, NULL);
window_network_status_open(str_authenticating, []() -> void {
gNetwork.Close();
});
break;
}
if (SDL_TICKS_PASSED(SDL_GetTicks(), server_connect_time + 3000)) {
log_error("Connection timed out.");
connectfailed = true;
break;
}
fd_set writeFD;
FD_ZERO(&writeFD);
FD_SET(server_connection.Socket, &writeFD);
timeval timeout;
timeout.tv_sec = 0;
timeout.tv_usec = 0;
if (select(server_connection.Socket + 1, NULL, &writeFD, NULL, &timeout) > 0) {
error = 0;
socklen_t len = sizeof(error);
result = getsockopt(server_connection.Socket, SOL_SOCKET, SO_ERROR, (char*)&error, &len);
if (result != 0) {
log_error("getsockopt failed with error %d", LAST_SOCKET_ERROR());
break;
}
if (error == 0) {
status = NETWORK_STATUS_CONNECTED;
server_connection.ResetLastPacketTime();
Client_Send_TOKEN();
char str_authenticating[256];
format_string(str_authenticating, STR_MULTIPLAYER_AUTHENTICATING, NULL);
window_network_status_open(str_authenticating, []() -> void {
gNetwork.Close();
});
default:
{
const char * error = server_connection.Socket->GetError();
if (error != nullptr) {
Console::Error::WriteLine(error);
}
Close();
window_network_status_close();
window_error_open(STR_UNABLE_TO_CONNECT_TO_SERVER, STR_NONE);
break;
}
}break;
}
break;
}
case NETWORK_STATUS_CONNECTED:
{
if (!ProcessConnection(server_connection)) {
// Do not show disconnect message window when password window closed/canceled
if (server_connection.AuthStatus == NETWORK_AUTH_REQUIREPASSWORD) {
@ -560,11 +497,6 @@ void Network::UpdateClient()
}
break;
}
if (connectfailed) {
Close();
window_network_status_close();
window_error_open(STR_UNABLE_TO_CONNECT_TO_SERVER, STR_NONE);
}
}
@ -659,7 +591,7 @@ void Network::KickPlayer(int playerId)
char str_disconnect_msg[256];
format_string(str_disconnect_msg, STR_MULTIPLAYER_KICKED_REASON, NULL);
Server_Send_SETDISCONNECTMSG(*(*it), str_disconnect_msg);
shutdown((*it)->Socket, SHUT_RD);
(*it)->Socket->Disconnect();
(*it)->SendQueuedPackets();
break;
}
@ -674,7 +606,7 @@ void Network::SetPassword(const char* password)
void Network::ShutdownClient()
{
if (GetMode() == NETWORK_MODE_CLIENT) {
shutdown(server_connection.Socket, SHUT_RDWR);
server_connection.Socket->Disconnect();
}
}
@ -1035,8 +967,8 @@ void Network::Server_Send_AUTH(NetworkConnection& connection)
}
connection.QueuePacket(std::move(packet));
if (connection.AuthStatus != NETWORK_AUTH_OK && connection.AuthStatus != NETWORK_AUTH_REQUIREPASSWORD) {
shutdown(connection.Socket, SHUT_RD);
connection.SendQueuedPackets();
connection.Socket->Disconnect();
}
}
@ -1324,11 +1256,10 @@ void Network::ProcessGameCommandQueue()
}
}
void Network::AddClient(SOCKET socket)
void Network::AddClient(ITcpSocket * socket)
{
auto connection = std::unique_ptr<NetworkConnection>(new NetworkConnection); // change to make_unique in c++14
connection->Socket = socket;
connection->SetTCPNoDelay(true);
client_connection_list.push_back(std::move(connection));
}
@ -1442,21 +1373,6 @@ std::string Network::MakePlayerNameUnique(const std::string &name)
return new_name;
}
void Network::PrintError()
{
#ifdef __WINDOWS__
wchar_t *s = NULL;
FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL,
LAST_SOCKET_ERROR(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPWSTR)&s, 0, NULL);
fprintf(stderr, "%S\n", s);
LocalFree(s);
#else
char *s = strerror(LAST_SOCKET_ERROR());
fprintf(stderr, "%s\n", s);
#endif
}
void Network::Client_Handle_TOKEN(NetworkConnection& connection, NetworkPacket& packet)
{
utf8 keyPath[MAX_PATH];
@ -1471,7 +1387,7 @@ void Network::Client_Handle_TOKEN(NetworkConnection& connection, NetworkPacket&
if (!ok) {
log_error("Failed to load key %s", keyPath);
connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
return;
}
uint32 challenge_size;
@ -1486,7 +1402,7 @@ void Network::Client_Handle_TOKEN(NetworkConnection& connection, NetworkPacket&
if (!ok) {
log_error("Failed to sign server's challenge.");
connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
return;
}
// Don't keep private key in memory. There's no need and it may get leaked
@ -1505,37 +1421,37 @@ void Network::Client_Handle_AUTH(NetworkConnection& connection, NetworkPacket& p
break;
case NETWORK_AUTH_BADNAME:
connection.SetLastDisconnectReason(STR_MULTIPLAYER_BAD_PLAYER_NAME);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
case NETWORK_AUTH_BADVERSION:
{
const char *version = packet.ReadString();
connection.SetLastDisconnectReason(STR_MULTIPLAYER_INCORRECT_SOFTWARE_VERSION, &version);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
}
case NETWORK_AUTH_BADPASSWORD:
connection.SetLastDisconnectReason(STR_MULTIPLAYER_BAD_PASSWORD);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
case NETWORK_AUTH_VERIFICATIONFAILURE:
connection.SetLastDisconnectReason(STR_MULTIPLAYER_VERIFICATION_FAILURE);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
case NETWORK_AUTH_FULL:
connection.SetLastDisconnectReason(STR_MULTIPLAYER_SERVER_FULL);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
case NETWORK_AUTH_REQUIREPASSWORD:
window_network_status_open_password();
break;
case NETWORK_AUTH_UNKNOWN_KEY_DISALLOWED:
connection.SetLastDisconnectReason(STR_MULTIPLAYER_UNKNOWN_KEY_DISALLOWED);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
default:
connection.SetLastDisconnectReason(STR_MULTIPLAYER_INCORRECT_SOFTWARE_VERSION);
shutdown(connection.Socket, SHUT_RDWR);
connection.Socket->Disconnect();
break;
}
}
@ -1964,6 +1880,19 @@ void Network::Client_Handle_GAMEINFO(NetworkConnection& connection, NetworkPacke
json_decref(root);
}
namespace Convert
{
uint16 HostToNetwork(uint16 value)
{
return htons(value);
}
uint16 NetworkToHost(uint16 value)
{
return ntohs(value);
}
}
int network_init()
{
return gNetwork.Init();

View File

@ -30,7 +30,6 @@ enum {
enum {
NETWORK_STATUS_NONE,
NETWORK_STATUS_READY,
NETWORK_STATUS_RESOLVING,
NETWORK_STATUS_CONNECTING,
NETWORK_STATUS_CONNECTED
};
@ -76,13 +75,13 @@ extern "C" {
#include <SDL.h>
#include "../core/Json.hpp"
#include "../core/Nullable.hpp"
#include "NetworkAddress.h"
#include "NetworkConnection.h"
#include "NetworkGroup.h"
#include "NetworkKey.h"
#include "NetworkPacket.h"
#include "NetworkPlayer.h"
#include "NetworkUser.h"
#include "TcpSocket.h"
class Network
{
@ -162,7 +161,7 @@ private:
bool ProcessConnection(NetworkConnection& connection);
void ProcessPacket(NetworkConnection& connection, NetworkPacket& packet);
void ProcessGameCommandQueue();
void AddClient(SOCKET socket);
void AddClient(ITcpSocket * socket);
void RemoveClient(std::unique_ptr<NetworkConnection>& connection);
NetworkPlayer* AddPlayer(const utf8 *name, const std::string &keyhash);
std::string MakePlayerNameUnique(const std::string &name);
@ -187,11 +186,11 @@ private:
int mode = NETWORK_MODE_NONE;
int status = NETWORK_STATUS_NONE;
NetworkAddress server_address;
bool wsa_initialized = false;
SOCKET listening_socket = INVALID_SOCKET;
ITcpSocket * listening_socket = nullptr;
unsigned short listening_port = 0;
NetworkConnection server_connection;
SOCKET_STATUS _lastConnectStatus;
uint32 last_tick_sent_time = 0;
uint32 last_ping_sent_time = 0;
uint32 server_tick = 0;
@ -242,6 +241,12 @@ private:
void Server_Handle_TOKEN(NetworkConnection& connection, NetworkPacket& packet);
};
namespace Convert
{
uint16 HostToNetwork(uint16 value);
uint16 NetworkToHost(uint16 value);
}
#endif // __cplusplus
#else /* DISABLE_NETWORK */
#define NETWORK_STREAM_ID "Multiplayer disabled"