Merge pull request #12571 from ZehMatt/network/update-2

Refactor more network code
This commit is contained in:
Duncan 2020-08-05 15:55:24 +01:00 committed by GitHub
commit 3533d1734f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 228 additions and 189 deletions

View File

@ -11,6 +11,9 @@
#include "../common.h"
#include <cstring>
#include <type_traits>
template<size_t size> struct ByteSwapT
{
};
@ -58,6 +61,22 @@ template<> struct ByteSwapT<8>
template<typename T> static T ByteSwapBE(const T& value)
{
using ByteSwap = ByteSwapT<sizeof(T)>;
typename ByteSwap::UIntType result = ByteSwap::SwapBE(reinterpret_cast<const typename ByteSwap::UIntType&>(value));
return *reinterpret_cast<T*>(&result);
using UIntType = typename ByteSwap::UIntType;
if constexpr (std::is_enum_v<T> || std::is_integral_v<T>)
{
auto result = ByteSwap::SwapBE(static_cast<const UIntType>(value));
return static_cast<T>(result);
}
else
{
// Complex type, reinterpret_cast is not safe for this case.
// Create a temporary of size(T) as unsigned type via copy instead.
UIntType temp;
std::memcpy(&temp, &value, sizeof(T));
auto result = ByteSwap::SwapBE(temp);
T res;
std::memcpy(&res, &result, sizeof(T));
return res;
}
}

View File

@ -721,7 +721,7 @@ const char* NetworkBase::FormatChat(NetworkPlayer* fromplayer, const char* text)
return formatted;
}
void NetworkBase::SendPacketToClients(NetworkPacket& packet, bool front, bool gameCmd)
void NetworkBase::SendPacketToClients(const NetworkPacket& packet, bool front, bool gameCmd)
{
for (auto& client_connection : client_connection_list)
{
@ -742,7 +742,8 @@ void NetworkBase::SendPacketToClients(NetworkPacket& packet, bool front, bool ga
continue;
}
}
client_connection->QueuePacket(NetworkPacket::Duplicate(packet), front);
auto packetCopy = packet;
client_connection->QueuePacket(std::move(packetCopy), front);
}
}
@ -1207,16 +1208,16 @@ void NetworkBase::Client_Send_RequestGameState(uint32_t tick)
}
log_verbose("Requesting gamestate from server for tick %u", tick);
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::RequestGameState) << tick;
NetworkPacket packet(NetworkCommand::RequestGameState);
packet << tick;
_serverConnection->QueuePacket(std::move(packet));
}
void NetworkBase::Client_Send_TOKEN()
{
log_verbose("requesting token");
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Token);
NetworkPacket packet(NetworkCommand::Token);
_serverConnection->AuthStatus = NETWORK_AUTH_REQUESTED;
_serverConnection->QueuePacket(std::move(packet));
}
@ -1224,15 +1225,14 @@ void NetworkBase::Client_Send_TOKEN()
void NetworkBase::Client_Send_AUTH(
const std::string& name, const std::string& password, const std::string& pubkey, const std::vector<uint8_t>& signature)
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Auth);
packet->WriteString(network_get_version().c_str());
packet->WriteString(name.c_str());
packet->WriteString(password.c_str());
packet->WriteString(pubkey.c_str());
NetworkPacket packet(NetworkCommand::Auth);
packet.WriteString(network_get_version().c_str());
packet.WriteString(name.c_str());
packet.WriteString(password.c_str());
packet.WriteString(pubkey.c_str());
assert(signature.size() <= static_cast<size_t>(UINT32_MAX));
*packet << static_cast<uint32_t>(signature.size());
packet->Write(signature.data(), signature.size());
packet << static_cast<uint32_t>(signature.size());
packet.Write(signature.data(), signature.size());
_serverConnection->AuthStatus = NETWORK_AUTH_REQUESTED;
_serverConnection->QueuePacket(std::move(packet));
}
@ -1240,21 +1240,21 @@ void NetworkBase::Client_Send_AUTH(
void NetworkBase::Client_Send_MAPREQUEST(const std::vector<std::string>& objects)
{
log_verbose("client requests %u objects", uint32_t(objects.size()));
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::MapRequest) << static_cast<uint32_t>(objects.size());
NetworkPacket packet(NetworkCommand::MapRequest);
packet << static_cast<uint32_t>(objects.size());
for (const auto& object : objects)
{
log_verbose("client requests object %s", object.c_str());
packet->Write(reinterpret_cast<const uint8_t*>(object.c_str()), 8);
packet.Write(reinterpret_cast<const uint8_t*>(object.c_str()), 8);
}
_serverConnection->QueuePacket(std::move(packet));
}
void NetworkBase::Server_Send_TOKEN(NetworkConnection& connection)
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Token) << static_cast<uint32_t>(connection.Challenge.size());
packet->Write(connection.Challenge.data(), connection.Challenge.size());
NetworkPacket packet(NetworkCommand::Token);
packet << static_cast<uint32_t>(connection.Challenge.size());
packet.Write(connection.Challenge.data(), connection.Challenge.size());
connection.QueuePacket(std::move(packet));
}
@ -1265,9 +1265,8 @@ void NetworkBase::Server_Send_OBJECTS_LIST(
if (objects.empty())
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::ObjectsList) << static_cast<uint32_t>(0)
<< static_cast<uint32_t>(objects.size());
NetworkPacket packet(NetworkCommand::ObjectsList);
packet << static_cast<uint32_t>(0) << static_cast<uint32_t>(objects.size());
connection.QueuePacket(std::move(packet));
}
@ -1277,13 +1276,12 @@ void NetworkBase::Server_Send_OBJECTS_LIST(
{
const auto* object = objects[i];
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::ObjectsList) << static_cast<uint32_t>(i)
<< static_cast<uint32_t>(objects.size());
NetworkPacket packet(NetworkCommand::ObjectsList);
packet << static_cast<uint32_t>(i) << static_cast<uint32_t>(objects.size());
log_verbose("Object %.8s (checksum %x)", object->ObjectEntry.name, object->ObjectEntry.checksum);
packet->Write(reinterpret_cast<const uint8_t*>(object->ObjectEntry.name), 8);
*packet << object->ObjectEntry.checksum << object->ObjectEntry.flags;
packet.Write(reinterpret_cast<const uint8_t*>(object->ObjectEntry.name), 8);
packet << object->ObjectEntry.checksum << object->ObjectEntry.flags;
connection.QueuePacket(std::move(packet));
}
@ -1292,8 +1290,8 @@ void NetworkBase::Server_Send_OBJECTS_LIST(
void NetworkBase::Server_Send_SCRIPTS(NetworkConnection& connection) const
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Scripts);
NetworkPacket packet(NetworkCommand::Scripts);
# ifdef ENABLE_SCRIPTING
using namespace OpenRCT2::Scripting;
@ -1310,18 +1308,18 @@ void NetworkBase::Server_Send_SCRIPTS(NetworkConnection& connection) const
}
log_verbose("Server sends %u scripts", pluginsToSend.size());
*packet << static_cast<uint32_t>(pluginsToSend.size());
packet << static_cast<uint32_t>(pluginsToSend.size());
for (const auto& plugin : pluginsToSend)
{
const auto& metadata = plugin->GetMetadata();
log_verbose("Script %s", metadata.Name.c_str());
const auto& code = plugin->GetCode();
*packet << static_cast<uint32_t>(code.size());
packet->Write(reinterpret_cast<const uint8_t*>(code.c_str()), code.size());
packet << static_cast<uint32_t>(code.size());
packet.Write(reinterpret_cast<const uint8_t*>(code.c_str()), code.size());
}
# else
*packet << static_cast<uint32_t>(0);
packet << static_cast<uint32_t>(0);
# endif
connection.QueuePacket(std::move(packet));
}
@ -1330,9 +1328,7 @@ void NetworkBase::Client_Send_HEARTBEAT(NetworkConnection& connection) const
{
log_verbose("Sending heartbeat");
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Heartbeat);
NetworkPacket packet(NetworkCommand::Heartbeat);
connection.QueuePacket(std::move(packet));
}
@ -1364,11 +1360,11 @@ void NetworkBase::Server_Send_AUTH(NetworkConnection& connection)
{
new_playerid = connection.Player->Id;
}
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Auth) << static_cast<uint32_t>(connection.AuthStatus) << new_playerid;
NetworkPacket packet(NetworkCommand::Auth);
packet << static_cast<uint32_t>(connection.AuthStatus) << new_playerid;
if (connection.AuthStatus == NETWORK_AUTH_BADVERSION)
{
packet->WriteString(network_get_version().c_str());
packet.WriteString(network_get_version().c_str());
}
connection.QueuePacket(std::move(packet));
if (connection.AuthStatus != NETWORK_AUTH_OK && connection.AuthStatus != NETWORK_AUTH_REQUIREPASSWORD)
@ -1408,16 +1404,16 @@ void NetworkBase::Server_Send_MAP(NetworkConnection* connection)
for (size_t i = 0; i < out_size; i += chunksize)
{
size_t datasize = std::min(chunksize, out_size - i);
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Map) << static_cast<uint32_t>(out_size) << static_cast<uint32_t>(i);
packet->Write(&header[i], datasize);
NetworkPacket packet(NetworkCommand::Map);
packet << static_cast<uint32_t>(out_size) << static_cast<uint32_t>(i);
packet.Write(&header[i], datasize);
if (connection)
{
connection->QueuePacket(std::move(packet));
}
else
{
SendPacketToClients(*packet);
SendPacketToClients(packet);
}
}
free(header);
@ -1478,22 +1474,20 @@ uint8_t* NetworkBase::save_for_network(size_t& out_size, const std::vector<const
void NetworkBase::Client_Send_CHAT(const char* text)
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Chat);
packet->WriteString(text);
NetworkPacket packet(NetworkCommand::Chat);
packet.WriteString(text);
_serverConnection->QueuePacket(std::move(packet));
}
void NetworkBase::Server_Send_CHAT(const char* text, const std::vector<uint8_t>& playerIds)
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Chat);
packet->WriteString(text);
NetworkPacket packet(NetworkCommand::Chat);
packet.WriteString(text);
if (playerIds.empty())
{
// Empty players / default value means send to all players
SendPacketToClients(*packet);
SendPacketToClients(packet);
}
else
{
@ -1502,7 +1496,7 @@ void NetworkBase::Server_Send_CHAT(const char* text, const std::vector<uint8_t>&
auto conn = GetPlayerConnection(playerId);
if (conn != nullptr && !conn->IsDisconnected)
{
conn->QueuePacket(NetworkPacket::Duplicate(*packet));
conn->QueuePacket(packet);
}
}
}
@ -1510,7 +1504,7 @@ void NetworkBase::Server_Send_CHAT(const char* text, const std::vector<uint8_t>&
void NetworkBase::Client_Send_GAME_ACTION(const GameAction* action)
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
NetworkPacket packet(NetworkCommand::GameAction);
uint32_t networkId = 0;
networkId = ++_actionId;
@ -1525,26 +1519,26 @@ void NetworkBase::Client_Send_GAME_ACTION(const GameAction* action)
DataSerialiser stream(true);
action->Serialise(stream);
*packet << static_cast<uint32_t>(NetworkCommand::GameAction) << gCurrentTicks << action->GetType() << stream;
packet << gCurrentTicks << action->GetType() << stream;
_serverConnection->QueuePacket(std::move(packet));
}
void NetworkBase::Server_Send_GAME_ACTION(const GameAction* action)
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
NetworkPacket packet(NetworkCommand::GameAction);
DataSerialiser stream(true);
action->Serialise(stream);
*packet << static_cast<uint32_t>(NetworkCommand::GameAction) << gCurrentTicks << action->GetType() << stream;
packet << gCurrentTicks << action->GetType() << stream;
SendPacketToClients(*packet);
SendPacketToClients(packet);
}
void NetworkBase::Server_Send_TICK()
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Tick) << gCurrentTicks << scenario_rand_state().s0;
NetworkPacket packet(NetworkCommand::Tick);
packet << gCurrentTicks << scenario_rand_state().s0;
uint32_t flags = 0;
// Simple counter which limits how often a sprite checksum gets sent.
// This can get somewhat expensive, so we don't want to push it every tick in release,
@ -1558,75 +1552,72 @@ void NetworkBase::Server_Send_TICK()
}
// Send flags always, so we can understand packet structure on the other end,
// and allow for some expansion.
*packet << flags;
packet << flags;
if (flags & NETWORK_TICK_FLAG_CHECKSUMS)
{
rct_sprite_checksum checksum = sprite_checksum();
packet->WriteString(checksum.ToString().c_str());
packet.WriteString(checksum.ToString().c_str());
}
SendPacketToClients(*packet);
SendPacketToClients(packet);
}
void NetworkBase::Server_Send_PLAYERINFO(int32_t playerId)
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::PlayerInfo) << gCurrentTicks;
NetworkPacket packet(NetworkCommand::PlayerInfo);
packet << gCurrentTicks;
auto* player = GetPlayerByID(playerId);
if (player == nullptr)
return;
player->Write(*packet);
SendPacketToClients(*packet);
player->Write(packet);
SendPacketToClients(packet);
}
void NetworkBase::Server_Send_PLAYERLIST()
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::PlayerList) << gCurrentTicks << static_cast<uint8_t>(player_list.size());
NetworkPacket packet(NetworkCommand::PlayerList);
packet << gCurrentTicks << static_cast<uint8_t>(player_list.size());
for (auto& player : player_list)
{
player->Write(*packet);
player->Write(packet);
}
SendPacketToClients(*packet);
SendPacketToClients(packet);
}
void NetworkBase::Client_Send_PING()
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Ping);
NetworkPacket packet(NetworkCommand::Ping);
_serverConnection->QueuePacket(std::move(packet));
}
void NetworkBase::Server_Send_PING()
{
last_ping_sent_time = platform_get_ticks();
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Ping);
NetworkPacket packet(NetworkCommand::Ping);
for (auto& client_connection : client_connection_list)
{
client_connection->PingTime = platform_get_ticks();
}
SendPacketToClients(*packet, true);
SendPacketToClients(packet, true);
}
void NetworkBase::Server_Send_PINGLIST()
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::PingList) << static_cast<uint8_t>(player_list.size());
NetworkPacket packet(NetworkCommand::PingList);
packet << static_cast<uint8_t>(player_list.size());
for (auto& player : player_list)
{
*packet << player->Id << player->Ping;
packet << player->Id << player->Ping;
}
SendPacketToClients(*packet);
SendPacketToClients(packet);
}
void NetworkBase::Server_Send_SETDISCONNECTMSG(NetworkConnection& connection, const char* msg)
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::DisconnectMessage);
packet->WriteString(msg);
NetworkPacket packet(NetworkCommand::DisconnectMessage);
packet.WriteString(msg);
connection.QueuePacket(std::move(packet));
}
@ -1646,8 +1637,7 @@ json_t* NetworkBase::GetServerInfoAsJson() const
void NetworkBase::Server_Send_GAMEINFO(NetworkConnection& connection)
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::GameInfo);
NetworkPacket packet(NetworkCommand::GameInfo);
# ifndef DISABLE_HTTP
json_t* obj = GetServerInfoAsJson();
@ -1658,8 +1648,8 @@ void NetworkBase::Server_Send_GAMEINFO(NetworkConnection& connection)
json_object_set_new(jsonProvider, "website", json_string(gConfigNetwork.provider_website.c_str()));
json_object_set_new(obj, "provider", jsonProvider);
packet->WriteString(json_dumps(obj, 0));
*packet << _serverState.gamestateSnapshotsEnabled;
packet.WriteString(json_dumps(obj, 0));
packet << _serverState.gamestateSnapshotsEnabled;
json_decref(obj);
# endif
@ -1668,39 +1658,37 @@ void NetworkBase::Server_Send_GAMEINFO(NetworkConnection& connection)
void NetworkBase::Server_Send_SHOWERROR(NetworkConnection& connection, rct_string_id title, rct_string_id message)
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::ShowError) << title << message;
NetworkPacket packet(NetworkCommand::ShowError);
packet << title << message;
connection.QueuePacket(std::move(packet));
}
void NetworkBase::Server_Send_GROUPLIST(NetworkConnection& connection)
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::GroupList) << static_cast<uint8_t>(group_list.size()) << default_group;
NetworkPacket packet(NetworkCommand::GroupList);
packet << static_cast<uint8_t>(group_list.size()) << default_group;
for (auto& group : group_list)
{
group->Write(*packet);
group->Write(packet);
}
connection.QueuePacket(std::move(packet));
}
void NetworkBase::Server_Send_EVENT_PLAYER_JOINED(const char* playerName)
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Event);
*packet << static_cast<uint16_t>(SERVER_EVENT_PLAYER_JOINED);
packet->WriteString(playerName);
SendPacketToClients(*packet);
NetworkPacket packet(NetworkCommand::Event);
packet << static_cast<uint16_t>(SERVER_EVENT_PLAYER_JOINED);
packet.WriteString(playerName);
SendPacketToClients(packet);
}
void NetworkBase::Server_Send_EVENT_PLAYER_DISCONNECTED(const char* playerName, const char* reason)
{
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::Event);
*packet << static_cast<uint16_t>(SERVER_EVENT_PLAYER_DISCONNECTED);
packet->WriteString(playerName);
packet->WriteString(reason);
SendPacketToClients(*packet);
NetworkPacket packet(NetworkCommand::Event);
packet << static_cast<uint16_t>(SERVER_EVENT_PLAYER_DISCONNECTED);
packet.WriteString(playerName);
packet.WriteString(reason);
SendPacketToClients(packet);
}
bool NetworkBase::ProcessConnection(NetworkConnection& connection)
@ -1748,11 +1736,8 @@ bool NetworkBase::ProcessConnection(NetworkConnection& connection)
void NetworkBase::ProcessPacket(NetworkConnection& connection, NetworkPacket& packet)
{
std::underlying_type<NetworkCommand>::type command;
packet >> command;
const auto& handlerList = GetMode() == NETWORK_MODE_SERVER ? server_command_handlers : client_command_handlers;
auto it = handlerList.find(static_cast<NetworkCommand>(command));
auto it = handlerList.find(packet.GetCommand());
if (it != handlerList.end())
{
auto commandHandler = it->second;
@ -2240,11 +2225,11 @@ void NetworkBase::Server_Handle_REQUEST_GAMESTATE(NetworkConnection& connection,
dataSize = snapshotMemory.GetLength() - bytesSent;
}
std::unique_ptr<NetworkPacket> gameStateChunk(NetworkPacket::Allocate());
*gameStateChunk << static_cast<uint32_t>(NetworkCommand::GameState) << tick << length << bytesSent << dataSize;
gameStateChunk->Write(static_cast<const uint8_t*>(snapshotMemory.GetData()) + bytesSent, dataSize);
NetworkPacket packetGameStateChunk(NetworkCommand::GameState);
packetGameStateChunk << tick << length << bytesSent << dataSize;
packetGameStateChunk.Write(static_cast<const uint8_t*>(snapshotMemory.GetData()) + bytesSent, dataSize);
connection.QueuePacket(std::move(gameStateChunk));
connection.QueuePacket(std::move(packetGameStateChunk));
bytesSent += dataSize;
}
@ -2662,7 +2647,7 @@ void NetworkBase::Client_Handle_MAP([[maybe_unused]] NetworkConnection& connecti
{
uint32_t size, offset;
packet >> size >> offset;
int32_t chunksize = static_cast<int32_t>(packet.Size - packet.BytesRead);
int32_t chunksize = static_cast<int32_t>(packet.Header.Size - packet.BytesRead);
if (chunksize <= 0)
{
return;
@ -2937,7 +2922,7 @@ void NetworkBase::Client_Handle_GAME_ACTION([[maybe_unused]] NetworkConnection&
packet >> tick >> actionType;
MemoryStream stream;
size_t size = packet.Size - packet.BytesRead;
const size_t size = packet.Header.Size - packet.BytesRead;
stream.WriteArray(packet.Read(size), size);
stream.SetPosition(0);
@ -3027,7 +3012,7 @@ void NetworkBase::Server_Handle_GAME_ACTION(NetworkConnection& connection, Netwo
}
DataSerialiser stream(false);
size_t size = packet.Size - packet.BytesRead;
const size_t size = packet.Header.Size - packet.BytesRead;
stream.GetStream().WriteArray(packet.Read(size), size);
stream.GetStream().SetPosition(0);
@ -3208,8 +3193,7 @@ void NetworkBase::Client_Handle_EVENT([[maybe_unused]] NetworkConnection& connec
void NetworkBase::Client_Send_GAMEINFO()
{
log_verbose("requesting gameinfo");
std::unique_ptr<NetworkPacket> packet(NetworkPacket::Allocate());
*packet << static_cast<uint32_t>(NetworkCommand::GameInfo);
NetworkPacket packet(NetworkCommand::GameInfo);
_serverConnection->QueuePacket(std::move(packet));
}

View File

@ -113,7 +113,7 @@ public: // Client
void ProcessPlayerInfo();
void ProcessDisconnectedClients();
static const char* FormatChat(NetworkPlayer* fromplayer, const char* text);
void SendPacketToClients(NetworkPacket& packet, bool front = false, bool gameCmd = false);
void SendPacketToClients(const NetworkPacket& packet, bool front = false, bool gameCmd = false);
bool CheckSRAND(uint32_t tick, uint32_t srand0);
bool CheckDesynchronizaton();
void RequestStateSnapshot();

View File

@ -18,6 +18,7 @@
# include "network.h"
constexpr size_t NETWORK_DISCONNECT_REASON_BUFFER_SIZE = 256;
constexpr size_t NetworkBufferSize = 1024;
NetworkConnection::NetworkConnection()
{
@ -31,47 +32,61 @@ NetworkConnection::~NetworkConnection()
int32_t NetworkConnection::ReadPacket()
{
if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Size))
size_t bytesRead = 0;
// Read packet header.
auto& header = InboundPacket.Header;
if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header))
{
// read packet size
void* buffer = &(reinterpret_cast<char*>(&InboundPacket.Size))[InboundPacket.BytesTransferred];
size_t bufferLength = sizeof(InboundPacket.Size) - InboundPacket.BytesTransferred;
size_t readBytes;
NETWORK_READPACKET status = Socket->ReceiveData(buffer, bufferLength, &readBytes);
const size_t missingLength = sizeof(header) - InboundPacket.BytesTransferred;
uint8_t* buffer = reinterpret_cast<uint8_t*>(&InboundPacket.Header);
NETWORK_READPACKET status = Socket->ReceiveData(buffer, missingLength, &bytesRead);
if (status != NETWORK_READPACKET_SUCCESS)
{
return status;
}
InboundPacket.BytesTransferred += readBytes;
if (InboundPacket.BytesTransferred == sizeof(InboundPacket.Size))
InboundPacket.BytesTransferred += bytesRead;
if (InboundPacket.BytesTransferred < sizeof(InboundPacket.Header))
{
InboundPacket.Size = Convert::NetworkToHost(InboundPacket.Size);
if (InboundPacket.Size == 0) // Can't have a size 0 packet
{
return NETWORK_READPACKET_DISCONNECTED;
}
InboundPacket.Data->resize(InboundPacket.Size);
// If still not enough data for header, keep waiting.
return NETWORK_READPACKET_MORE_DATA;
}
// Normalise values.
header.Size = Convert::NetworkToHost(header.Size);
header.Id = ByteSwapBE(header.Id);
// NOTE: For compatibility reasons for the master server we need to remove sizeof(Header.Id) from the size.
// Previously the Id field was not part of the header rather part of the body.
header.Size -= sizeof(header.Id);
// Fall-through: Read rest of packet.
}
else
// Read packet body.
{
// read packet data
if (InboundPacket.Data->capacity() > 0)
const size_t missingLength = header.Size - (InboundPacket.BytesTransferred - sizeof(header));
uint8_t buffer[NetworkBufferSize];
if (missingLength > 0)
{
void* buffer = &InboundPacket.GetData()[InboundPacket.BytesTransferred - sizeof(InboundPacket.Size)];
size_t bufferLength = sizeof(InboundPacket.Size) + InboundPacket.Size - InboundPacket.BytesTransferred;
size_t readBytes;
NETWORK_READPACKET status = Socket->ReceiveData(buffer, bufferLength, &readBytes);
NETWORK_READPACKET status = Socket->ReceiveData(buffer, std::min(missingLength, NetworkBufferSize), &bytesRead);
if (status != NETWORK_READPACKET_SUCCESS)
{
return status;
}
InboundPacket.BytesTransferred += readBytes;
InboundPacket.BytesTransferred += bytesRead;
InboundPacket.Write(buffer, bytesRead);
}
if (InboundPacket.BytesTransferred == sizeof(InboundPacket.Size) + InboundPacket.Size)
if (InboundPacket.Data.size() == header.Size)
{
// Received complete packet.
_lastPacketTime = platform_get_ticks();
RecordPacketStats(InboundPacket, false);
@ -79,26 +94,34 @@ int32_t NetworkConnection::ReadPacket()
return NETWORK_READPACKET_SUCCESS;
}
}
return NETWORK_READPACKET_MORE_DATA;
}
bool NetworkConnection::SendPacket(NetworkPacket& packet)
{
uint16_t sizen = Convert::HostToNetwork(packet.Size);
std::vector<uint8_t> tosend;
tosend.reserve(sizeof(sizen) + packet.Size);
tosend.insert(tosend.end(), reinterpret_cast<uint8_t*>(&sizen), reinterpret_cast<uint8_t*>(&sizen) + sizeof(sizen));
tosend.insert(tosend.end(), packet.Data->begin(), packet.Data->end());
auto header = packet.Header;
const void* buffer = &tosend[packet.BytesTransferred];
size_t bufferSize = tosend.size() - packet.BytesTransferred;
size_t sent = Socket->SendData(buffer, bufferSize);
std::vector<uint8_t> buffer;
buffer.reserve(sizeof(header) + header.Size);
// NOTE: For compatibility reasons for the master server we need to add sizeof(Header.Id) to the size.
// Previously the Id field was not part of the header rather part of the body.
header.Size += sizeof(header.Id);
header.Size = Convert::HostToNetwork(header.Size);
header.Id = ByteSwapBE(header.Id);
buffer.insert(buffer.end(), reinterpret_cast<uint8_t*>(&header), reinterpret_cast<uint8_t*>(&header) + sizeof(header));
buffer.insert(buffer.end(), packet.Data.begin(), packet.Data.end());
size_t bufferSize = buffer.size() - packet.BytesTransferred;
size_t sent = Socket->SendData(buffer.data() + packet.BytesTransferred, bufferSize);
if (sent > 0)
{
packet.BytesTransferred += sent;
}
bool sendComplete = packet.BytesTransferred == tosend.size();
bool sendComplete = packet.BytesTransferred == buffer.size();
if (sendComplete)
{
RecordPacketStats(packet, true);
@ -106,15 +129,15 @@ bool NetworkConnection::SendPacket(NetworkPacket& packet)
return sendComplete;
}
void NetworkConnection::QueuePacket(std::unique_ptr<NetworkPacket> packet, bool front)
void NetworkConnection::QueuePacket(NetworkPacket&& packet, bool front)
{
if (AuthStatus == NETWORK_AUTH_OK || !packet->CommandRequiresAuth())
if (AuthStatus == NETWORK_AUTH_OK || !packet.CommandRequiresAuth())
{
packet->Size = static_cast<uint16_t>(packet->Data->size());
packet.Header.Size = static_cast<uint16_t>(packet.Data.size());
if (front)
{
// If the first packet was already partially sent add new packet to second position
if (!_outboundPackets.empty() && _outboundPackets.front()->BytesTransferred > 0)
if (!_outboundPackets.empty() && _outboundPackets.front().BytesTransferred > 0)
{
auto it = _outboundPackets.begin();
it++; // Second position
@ -134,9 +157,9 @@ void NetworkConnection::QueuePacket(std::unique_ptr<NetworkPacket> packet, bool
void NetworkConnection::SendQueuedPackets()
{
while (!_outboundPackets.empty() && SendPacket(*_outboundPackets.front()))
while (!_outboundPackets.empty() && SendPacket(_outboundPackets.front()))
{
_outboundPackets.remove(_outboundPackets.front());
_outboundPackets.pop_front();
}
}

View File

@ -16,7 +16,7 @@
# include "NetworkTypes.h"
# include "Socket.h"
# include <list>
# include <deque>
# include <memory>
# include <vector>
@ -41,7 +41,13 @@ public:
~NetworkConnection();
int32_t ReadPacket();
void QueuePacket(std::unique_ptr<NetworkPacket> packet, bool front = false);
void QueuePacket(NetworkPacket&& packet, bool front = false);
void QueuePacket(const NetworkPacket& packet, bool front = false)
{
auto copy = packet;
return QueuePacket(std::move(copy), front);
}
void SendQueuedPackets();
void ResetLastPacketTime();
bool ReceivedPacketRecently();
@ -51,7 +57,7 @@ public:
void SetLastDisconnectReason(const rct_string_id string_id, void* args = nullptr);
private:
std::list<std::unique_ptr<NetworkPacket>> _outboundPackets;
std::deque<NetworkPacket> _outboundPackets;
uint32_t _lastPacketTime = 0;
utf8* _lastDisconnectReason = nullptr;

View File

@ -15,35 +15,31 @@
# include <memory>
std::unique_ptr<NetworkPacket> NetworkPacket::Allocate()
NetworkPacket::NetworkPacket(NetworkCommand id)
: Header{ 0, id }
{
return std::make_unique<NetworkPacket>();
}
std::unique_ptr<NetworkPacket> NetworkPacket::Duplicate(NetworkPacket& packet)
{
return std::make_unique<NetworkPacket>(packet);
}
uint8_t* NetworkPacket::GetData()
{
return &(*Data)[0];
return Data.data();
}
const uint8_t* NetworkPacket::GetData() const
{
return Data.data();
}
NetworkCommand NetworkPacket::GetCommand() const
{
if (Data->size() < sizeof(uint32_t))
return NetworkCommand::Invalid;
const uint32_t commandId = ByteSwapBE(*reinterpret_cast<uint32_t*>(&(*Data)[0]));
return static_cast<NetworkCommand>(commandId);
return Header.Id;
}
void NetworkPacket::Clear()
{
BytesTransferred = 0;
BytesRead = 0;
Data->clear();
Data.clear();
}
bool NetworkPacket::CommandRequiresAuth()
@ -63,9 +59,10 @@ bool NetworkPacket::CommandRequiresAuth()
}
}
void NetworkPacket::Write(const uint8_t* bytes, size_t size)
void NetworkPacket::Write(const void* bytes, size_t size)
{
Data->insert(Data->end(), bytes, bytes + size);
const uint8_t* src = reinterpret_cast<const uint8_t*>(bytes);
Data.insert(Data.end(), src, src + size);
}
void NetworkPacket::WriteString(const utf8* string)
@ -75,7 +72,7 @@ void NetworkPacket::WriteString(const utf8* string)
const uint8_t* NetworkPacket::Read(size_t size)
{
if (BytesRead + size > NetworkPacket::Size)
if (BytesRead + size > Header.Size)
{
return nullptr;
}
@ -91,7 +88,7 @@ const utf8* NetworkPacket::ReadString()
{
char* str = reinterpret_cast<char*>(&GetData()[BytesRead]);
char* strend = str;
while (BytesRead < Size && *strend != 0)
while (BytesRead < Header.Size && *strend != 0)
{
BytesRead++;
strend++;

View File

@ -16,18 +16,23 @@
#include <memory>
#include <vector>
class NetworkPacket final
#pragma pack(push, 1)
struct PacketHeader
{
public:
uint16_t Size = 0;
std::shared_ptr<std::vector<uint8_t>> Data = std::make_shared<std::vector<uint8_t>>();
size_t BytesTransferred = 0;
size_t BytesRead = 0;
NetworkCommand Id = NetworkCommand::Invalid;
};
static_assert(sizeof(PacketHeader) == 6);
#pragma pack(pop)
static std::unique_ptr<NetworkPacket> Allocate();
static std::unique_ptr<NetworkPacket> Duplicate(NetworkPacket& packet);
struct NetworkPacket final
{
NetworkPacket() = default;
NetworkPacket(NetworkCommand id);
uint8_t* GetData();
const uint8_t* GetData() const;
NetworkCommand GetCommand() const;
void Clear();
@ -36,12 +41,12 @@ public:
const uint8_t* Read(size_t size);
const utf8* ReadString();
void Write(const uint8_t* bytes, size_t size);
void Write(const void* bytes, size_t size);
void WriteString(const utf8* string);
template<typename T> NetworkPacket& operator>>(T& value)
{
if (BytesRead + sizeof(value) > Size)
if (BytesRead + sizeof(value) > Header.Size)
{
value = T{};
}
@ -58,8 +63,7 @@ public:
template<typename T> NetworkPacket& operator<<(T value)
{
T swapped = ByteSwapBE(value);
uint8_t* bytes = reinterpret_cast<uint8_t*>(&swapped);
Data->insert(Data->end(), bytes, bytes + sizeof(value));
Write(&swapped, sizeof(T));
return *this;
}
@ -68,4 +72,10 @@ public:
Write(static_cast<const uint8_t*>(data.GetStream().GetData()), data.GetStream().GetLength());
return *this;
}
public:
PacketHeader Header{};
std::vector<uint8_t> Data;
size_t BytesTransferred = 0;
size_t BytesRead = 0;
};

View File

@ -17,7 +17,7 @@
#include <string>
#include <unordered_map>
class NetworkPacket;
struct NetworkPacket;
class NetworkPlayer final
{