diff --git a/src/command.cpp b/src/command.cpp index 001f2ee7d1..788c71989c 100644 --- a/src/command.cpp +++ b/src/command.cpp @@ -55,6 +55,8 @@ #include "viewport_cmd.h" #include "water_cmd.h" #include "waypoint_cmd.h" +#include "misc/endian_buffer.hpp" +#include "string_func.h" #include @@ -399,12 +401,38 @@ bool DoCommandP(Commands cmd, StringID err_message, CommandCallback *callback, T return DoCommandP(cmd, err_message, callback, true, false, tile, p1, p2, text); } +/** + * Toplevel network safe docommand function for the current company. Must not be called recursively. + * The callback is called when the command succeeded or failed. The parameters + * \a tile, \a p1, and \a p2 are from the #CommandProc function. The parameter \a cmd is the command to execute. + * + * @param cmd The command to execute (a CMD_* value) + * @param err_message Message prefix to show on error + * @param callback A callback function to call after the command is finished + * @param my_cmd indicator if the command is from a company or server (to display error messages for a user) + * @param tile The tile to perform a command on (see #CommandProc) + * @param p1 Additional data for the command (see #CommandProc) + * @param p2 Additional data for the command (see #CommandProc) + * @param text The text to pass + * @return \c true if the command succeeded, else \c false. + */ +bool InjectNetworkCommand(Commands cmd, StringID err_message, CommandCallback *callback, bool my_cmd, TileIndex tile, uint32 p1, uint32 p2, const std::string &text) +{ + return DoCommandP(cmd, err_message, callback, my_cmd, true, tile, p1, p2, text); +} + /** * Helper to deduplicate the code for returning. * @param cmd the command cost to return. */ #define return_dcpi(cmd) { _docommand_recursive = 0; return cmd; } +/** Helper to format command parameters into a hex string. */ +static std::string CommandParametersToHexString(TileIndex tile, uint32 p1, uint32 p2, const std::string &text) +{ + return FormatArrayAsHex(EndianBufferWriter<>::FromValue(std::make_tuple(tile, p1, p2, text))); +} + /*! * Helper function for the toplevel network safe docommand function for the current company. * @@ -482,7 +510,7 @@ CommandCost DoCommandPInternal(Commands cmd, StringID err_message, CommandCallba if (!_networking || _generating_world || network_command) { /* Log the failed command as well. Just to be able to be find * causes of desyncs due to bad command test implementations. */ - Debug(desync, 1, "cmdf: {:08x}; {:02x}; {:02x}; {:06x}; {:08x}; {:08x}; {:08x}; {:08x}; \"{}\" ({})", _date, _date_fract, (int)_current_company, tile, p1, p2, cmd, err_message, text, GetCommandName(cmd)); + Debug(desync, 1, "cmdf: {:08x}; {:02x}; {:02x}; {:08x}; {:08x}; {:06x}; {} ({})", _date, _date_fract, (int)_current_company, cmd, err_message, tile, CommandParametersToHexString(tile, p1, p2, text), GetCommandName(cmd)); } cur_company.Restore(); return_dcpi(res); @@ -502,7 +530,7 @@ CommandCost DoCommandPInternal(Commands cmd, StringID err_message, CommandCallba * reset the storages as we've not executed the command. */ return_dcpi(CommandCost()); } - Debug(desync, 1, "cmd: {:08x}; {:02x}; {:02x}; {:06x}; {:08x}; {:08x}; {:08x}; {:08x}; \"{}\" ({})", _date, _date_fract, (int)_current_company, tile, p1, p2, cmd, err_message, text, GetCommandName(cmd)); + Debug(desync, 1, "cmd: {:08x}; {:02x}; {:02x}; {:08x}; {:08x}; {:06x}; {} ({})", _date, _date_fract, (int)_current_company, cmd, err_message, tile, CommandParametersToHexString(tile, p1, p2, text), GetCommandName(cmd)); /* Actually try and execute the command. If no cost-type is given * use the construction one */ diff --git a/src/command_func.h b/src/command_func.h index 03bfc73f1b..0d000755c8 100644 --- a/src/command_func.h +++ b/src/command_func.h @@ -12,6 +12,7 @@ #include "command_type.h" #include "company_type.h" +#include /** * Define a default return value for a failed command. @@ -32,6 +33,9 @@ static const CommandCost CMD_ERROR = CommandCost(INVALID_STRING_ID); */ #define return_cmd_error(errcode) return CommandCost(errcode); +/** Storage buffer for serialized command data. */ +typedef std::vector CommandDataBuffer; + CommandCost DoCommand(DoCommandFlag flags, Commands cmd, TileIndex tile, uint32 p1, uint32 p2, const std::string &text = {}); CommandCost DoCommand(const CommandContainer *container, DoCommandFlag flags); @@ -41,9 +45,12 @@ bool DoCommandP(Commands cmd, CommandCallback *callback, TileIndex tile, uint32 bool DoCommandP(Commands cmd, TileIndex tile, uint32 p1, uint32 p2, const std::string &text = {}); bool DoCommandP(const CommandContainer *container, bool my_cmd = true, bool network_command = false); +bool InjectNetworkCommand(Commands cmd, StringID err_message, CommandCallback *callback, bool my_cmd, TileIndex tile, uint32 p1, uint32 p2, const std::string &text); + CommandCost DoCommandPInternal(Commands cmd, StringID err_message, CommandCallback *callback, bool my_cmd, bool estimate_only, bool network_command, TileIndex tile, uint32 p1, uint32 p2, const std::string &text); void NetworkSendCommand(Commands cmd, StringID err_message, CommandCallback *callback, CompanyID company, TileIndex tile, uint32 p1, uint32 p2, const std::string &text); +void NetworkSendCommand(Commands cmd, StringID err_message, CommandCallback *callback, CompanyID company, TileIndex location, const CommandDataBuffer &cmd_data); extern Money _additional_cash_required; diff --git a/src/command_type.h b/src/command_type.h index fa381ce133..3a15087b9a 100644 --- a/src/command_type.h +++ b/src/command_type.h @@ -424,11 +424,19 @@ enum CommandPauseLevel { typedef CommandCost CommandProc(DoCommandFlag flags, TileIndex tile, uint32 p1, uint32 p2, const std::string &text); +template struct CommandFunctionTraitHelper; +template +struct CommandFunctionTraitHelper { + using Args = std::tuple...>; +}; + /** Defines the traits of a command. */ template struct CommandTraits; #define DEF_CMD_TRAIT(cmd_, proc_, flags_, type_) \ template<> struct CommandTraits { \ + using Args = typename CommandFunctionTraitHelper::Args; \ + static constexpr Commands cmd = cmd_; \ static constexpr auto &proc = proc_; \ static constexpr CommandFlags flags = (CommandFlags)(flags_); \ static constexpr CommandType type = type_; \ diff --git a/src/core/span_type.hpp b/src/core/span_type.hpp index 614be84567..0df528816e 100644 --- a/src/core/span_type.hpp +++ b/src/core/span_type.hpp @@ -92,6 +92,8 @@ public: constexpr const_iterator cbegin() const noexcept { return const_iterator(first); } constexpr const_iterator cend() const noexcept { return const_iterator(last); } + constexpr reference operator[](size_type idx) const { return first[idx]; } + private: pointer first; pointer last; diff --git a/src/misc/CMakeLists.txt b/src/misc/CMakeLists.txt index ee2ca6a41c..24cde73e41 100644 --- a/src/misc/CMakeLists.txt +++ b/src/misc/CMakeLists.txt @@ -5,6 +5,7 @@ add_files( countedptr.hpp dbg_helpers.cpp dbg_helpers.h + endian_buffer.hpp fixedsizearray.hpp getoptdata.cpp getoptdata.h diff --git a/src/misc/endian_buffer.hpp b/src/misc/endian_buffer.hpp new file mode 100644 index 0000000000..c20d9a8b99 --- /dev/null +++ b/src/misc/endian_buffer.hpp @@ -0,0 +1,206 @@ +/* + * This file is part of OpenTTD. + * OpenTTD is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 2. + * OpenTTD is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OpenTTD. If not, see . + */ + +/** @file endian_buffer.hpp Endian-aware buffer. */ + +#ifndef ENDIAN_BUFFER_HPP +#define ENDIAN_BUFFER_HPP + +#include +#include +#include "../core/span_type.hpp" +#include "../core/bitmath_func.hpp" + +struct StrongTypedefBase; + +/** + * Endian-aware buffer adapter that always writes values in little endian order. + * @note This class uses operator overloading (<<, just like streams) for writing + * as this allows providing custom operator overloads for more complex types + * like e.g. structs without needing to modify this class. + */ +template , typename Titer = typename std::back_insert_iterator> +class EndianBufferWriter { + /** Output iterator for the destination buffer. */ + Titer buffer; + +public: + EndianBufferWriter(Titer buffer) : buffer(buffer) {} + EndianBufferWriter(typename Titer::container_type &container) : buffer(std::back_inserter(container)) {} + + EndianBufferWriter &operator <<(const std::string &data) { return *this << std::string_view{ data }; } + EndianBufferWriter &operator <<(const char *data) { return *this << std::string_view{ data }; } + EndianBufferWriter &operator <<(std::string_view data) { this->Write(data); return *this; } + EndianBufferWriter &operator <<(bool data) { return *this << static_cast(data ? 1 : 0); } + + template + EndianBufferWriter &operator <<(const std::tuple &data) + { + this->WriteTuple(data, std::index_sequence_for{}); + return *this; + } + + template >, std::is_base_of>, int> = 0> + EndianBufferWriter &operator <<(const T data) + { + if constexpr (std::is_enum_v) { + this->Write(static_cast>(data)); + } else if constexpr (std::is_base_of_v) { + this->Write(data.value); + } else { + this->Write(data); + } + return *this; + } + + template > + static Tbuf FromValue(const Tvalue &data) + { + Tbuf buffer; + EndianBufferWriter writer{ buffer }; + writer << data; + return buffer; + } + +private: + /** Helper function to write a tuple to the buffer. */ + template + void WriteTuple(const Ttuple &values, std::index_sequence) { + ((*this << std::get(values)), ...); + } + + /** Write overload for string values. */ + void Write(std::string_view value) + { + for (auto c : value) { + this->buffer++ = c; + } + this->buffer++ = '\0'; + } + + /** Fundamental write function. */ + template + void Write(T value) + { + static_assert(sizeof(T) <= 8, "Value can't be larger than 8 bytes"); + + if constexpr (sizeof(T) > 1) { + this->buffer++ = GB(value, 0, 8); + this->buffer++ = GB(value, 8, 8); + if constexpr (sizeof(T) > 2) { + this->buffer++ = GB(value, 16, 8); + this->buffer++ = GB(value, 24, 8); + } + if constexpr (sizeof(T) > 4) { + this->buffer++ = GB(value, 32, 8); + this->buffer++ = GB(value, 40, 8); + this->buffer++ = GB(value, 48, 8); + this->buffer++ = GB(value, 56, 8); + } + } else { + this->buffer++ = value; + } + } +}; + +/** + * Endian-aware buffer adapter that always reads values in little endian order. + * @note This class uses operator overloading (>>, just like streams) for reading + * as this allows providing custom operator overloads for more complex types + * like e.g. structs without needing to modify this class. + */ +class EndianBufferReader { + /** Reference to storage buffer. */ + span buffer; + /** Current read position. */ + size_t read_pos = 0; + +public: + EndianBufferReader(span buffer) : buffer(buffer) {} + + void rewind() { this->read_pos = 0; } + + EndianBufferReader &operator >>(std::string &data) { data = this->ReadStr(); return *this; } + EndianBufferReader &operator >>(bool &data) { data = this->Read() != 0; return *this; } + + template + EndianBufferReader &operator >>(std::tuple &data) + { + this->ReadTuple(data, std::index_sequence_for{}); + return *this; + } + + template >, std::is_base_of>, int> = 0> + EndianBufferReader &operator >>(T &data) + { + if constexpr (std::is_enum_v) { + data = static_cast(this->Read>()); + } else if constexpr (std::is_base_of_v) { + data.value = this->Read(); + } else { + data = this->Read(); + } + return *this; + } + + template + static Tvalue ToValue(span buffer) + { + Tvalue result{}; + EndianBufferReader reader{ buffer }; + reader >> result; + return result; + } + +private: + /** Helper function to read a tuple from the buffer. */ + template + void ReadTuple(Ttuple &values, std::index_sequence) { + ((*this >> std::get(values)), ...); + } + + /** Read overload for string data. */ + std::string ReadStr() + { + std::string str; + while (this->read_pos < this->buffer.size()) { + char ch = this->Read(); + if (ch == '\0') break; + str.push_back(ch); + } + return str; + } + + /** Fundamental read function. */ + template + T Read() + { + static_assert(!std::is_const_v, "Can't read into const variables"); + static_assert(sizeof(T) <= 8, "Value can't be larger than 8 bytes"); + + if (read_pos + sizeof(T) > this->buffer.size()) return {}; + + T value = static_cast(this->buffer[this->read_pos++]); + if constexpr (sizeof(T) > 1) { + value += static_cast(this->buffer[this->read_pos++]) << 8; + } + if constexpr (sizeof(T) > 2) { + value += static_cast(this->buffer[this->read_pos++]) << 16; + value += static_cast(this->buffer[this->read_pos++]) << 24; + } + if constexpr (sizeof(T) > 4) { + value += static_cast(this->buffer[this->read_pos++]) << 32; + value += static_cast(this->buffer[this->read_pos++]) << 40; + value += static_cast(this->buffer[this->read_pos++]) << 48; + value += static_cast(this->buffer[this->read_pos++]) << 56; + } + + return value; + } +}; + +#endif /* ENDIAN_BUFFER_HPP */ diff --git a/src/network/core/packet.cpp b/src/network/core/packet.cpp index e106d5787f..ec0919757f 100644 --- a/src/network/core/packet.cpp +++ b/src/network/core/packet.cpp @@ -185,6 +185,17 @@ void Packet::Send_string(const std::string_view data) this->buffer.emplace_back('\0'); } +/** + * Copy a sized byte buffer into the packet. + * @param data The data to send. + */ +void Packet::Send_buffer(const std::vector &data) +{ + assert(this->CanWriteToPacket(sizeof(uint16) + data.size())); + this->Send_uint16((uint16)data.size()); + this->buffer.insert(this->buffer.end(), data.begin(), data.end()); +} + /** * Send as many of the bytes as possible in the packet. This can mean * that it is possible that not all bytes are sent. To cope with this @@ -366,6 +377,23 @@ uint64 Packet::Recv_uint64() return n; } +/** + * Extract a sized byte buffer from the packet. + * @return The extracted buffer. + */ +std::vector Packet::Recv_buffer() +{ + uint16 size = this->Recv_uint16(); + if (size == 0 || !this->CanReadFromPacket(size, true)) return {}; + + std::vector data; + while (size-- > 0) { + data.push_back(this->buffer[this->pos++]); + } + + return data; +} + /** * Reads characters (bytes) from the packet until it finds a '\0', or reaches a * maximum of \c length characters. diff --git a/src/network/core/packet.h b/src/network/core/packet.h index 277ff8bba1..04a232e1ca 100644 --- a/src/network/core/packet.h +++ b/src/network/core/packet.h @@ -72,6 +72,7 @@ public: void Send_uint32(uint32 data); void Send_uint64(uint64 data); void Send_string(const std::string_view data); + void Send_buffer(const std::vector &data); size_t Send_bytes (const byte *begin, const byte *end); /* Reading/receiving of packets */ @@ -87,6 +88,7 @@ public: uint16 Recv_uint16(); uint32 Recv_uint32(); uint64 Recv_uint64(); + std::vector Recv_buffer(); std::string Recv_string(size_t length, StringValidationSettings settings = SVS_REPLACE_WITH_QUESTION_MARK); size_t RemainingBytesToTransfer() const; diff --git a/src/network/network.cpp b/src/network/network.cpp index 13f2fe52a4..8194f34d07 100644 --- a/src/network/network.cpp +++ b/src/network/network.cpp @@ -35,6 +35,7 @@ #include "../core/pool_func.hpp" #include "../gfx_func.h" #include "../error.h" +#include "../misc_cmd.h" #include #include #include @@ -1064,8 +1065,8 @@ void NetworkGameLoop() while (f != nullptr && !feof(f)) { if (_date == next_date && _date_fract == next_date_fract) { if (cp != nullptr) { - NetworkSendCommand(cp->cmd, cp->err_msg, nullptr, cp->company, cp->tile, cp->p1, cp->p2, cp->text); - Debug(desync, 0, "Injecting: {:08x}; {:02x}; {:02x}; {:06x}; {:08x}; {:08x}; {:08x}; \"{}\" ({})", _date, _date_fract, (int)_current_company, cp->tile, cp->p1, cp->p2, cp->cmd, cp->text, GetCommandName(cp->cmd)); + NetworkSendCommand(cp->cmd, cp->err_msg, nullptr, cp->company, cp->data); + Debug(desync, 0, "Injecting: {:08x}; {:02x}; {:02x}; {:08x}; {:06x}; {} ({})", _date, _date_fract, (int)_current_company, cp->cmd, cp->tile, FormatArrayAsHex(cp->data), GetCommandName(cp->cmd)); delete cp; cp = nullptr; } @@ -1104,15 +1105,21 @@ void NetworkGameLoop() cp = new CommandPacket(); int company; uint cmd; - char buffer[128]; - int ret = sscanf(p, "%x; %x; %x; %x; %x; %x; %x; %x; \"%127[^\"]\"", &next_date, &next_date_fract, &company, &cp->tile, &cp->p1, &cp->p2, &cmd, &cp->err_msg, buffer); - cp->text = buffer; - /* There are 8 pieces of data to read, however the last is a - * string that might or might not exist. Ignore it if that - * string misses because in 99% of the time it's not used. */ - assert(ret == 9 || ret == 8); + char buffer[256]; + int ret = sscanf(p, "%x; %x; %x; %x; %x; %x; %255s", &next_date, &next_date_fract, &company, &cmd, &cp->err_msg, &cp->tile, buffer); + assert(ret == 6); cp->company = (CompanyID)company; cp->cmd = (Commands)cmd; + + /* Parse command data. */ + std::vector args; + size_t arg_len = strlen(buffer); + for (size_t i = 0; i + 1 < arg_len; i += 2) { + byte e = 0; + std::from_chars(buffer + i, buffer + i + 1, e, 16); + args.emplace_back(e); + } + cp->data = args; } else if (strncmp(p, "join: ", 6) == 0) { /* Manually insert a pause when joining; this way the client can join at the exact right time. */ int ret = sscanf(p + 6, "%x; %x", &next_date, &next_date_fract); @@ -1121,8 +1128,7 @@ void NetworkGameLoop() cp = new CommandPacket(); cp->company = COMPANY_SPECTATOR; cp->cmd = CMD_PAUSE; - cp->p1 = PM_PAUSED_NORMAL; - cp->p2 = 1; + cp->data = EndianBufferWriter<>::FromValue(CommandTraits::Args{ 0, PM_PAUSED_NORMAL, 1, "" }); _ddc_fastforward = false; } else if (strncmp(p, "sync: ", 6) == 0) { int ret = sscanf(p + 6, "%x; %x; %x; %x", &next_date, &next_date_fract, &sync_state[0], &sync_state[1]); diff --git a/src/network/network_admin.cpp b/src/network/network_admin.cpp index 99f803e24e..4711cdf046 100644 --- a/src/network/network_admin.cpp +++ b/src/network/network_admin.cpp @@ -630,10 +630,7 @@ NetworkRecvStatus ServerNetworkAdminSocketHandler::SendCmdLogging(ClientID clien p->Send_uint32(client_id); p->Send_uint8 (cp->company); p->Send_uint16(cp->cmd); - p->Send_uint32(cp->p1); - p->Send_uint32(cp->p2); - p->Send_uint32(cp->tile); - p->Send_string(cp->text); + p->Send_buffer(cp->data); p->Send_uint32(cp->frame); this->SendPacket(p); diff --git a/src/network/network_command.cpp b/src/network/network_command.cpp index 0fae6bcbf0..472d5e60e2 100644 --- a/src/network/network_command.cpp +++ b/src/network/network_command.cpp @@ -15,18 +15,41 @@ #include "../company_func.h" #include "../settings_type.h" #include "../airport_cmd.h" +#include "../aircraft_cmd.h" +#include "../autoreplace_cmd.h" +#include "../company_cmd.h" #include "../depot_cmd.h" #include "../dock_cmd.h" +#include "../economy_cmd.h" +#include "../engine_cmd.h" +#include "../goal_cmd.h" #include "../group_cmd.h" #include "../industry_cmd.h" +#include "../landscape_cmd.h" +#include "../misc_cmd.h" +#include "../news_cmd.h" +#include "../object_cmd.h" +#include "../order_cmd.h" #include "../rail_cmd.h" #include "../road_cmd.h" +#include "../roadveh_cmd.h" +#include "../settings_cmd.h" +#include "../signs_cmd.h" +#include "../station_cmd.h" +#include "../story_cmd.h" +#include "../subsidy_cmd.h" #include "../terraform_cmd.h" +#include "../timetable_cmd.h" #include "../town_cmd.h" #include "../train_cmd.h" +#include "../tree_cmd.h" #include "../tunnelbridge_cmd.h" #include "../vehicle_cmd.h" +#include "../viewport_cmd.h" +#include "../water_cmd.h" +#include "../waypoint_cmd.h" #include "../script/script_cmd.h" +#include #include "../safeguards.h" @@ -62,6 +85,23 @@ static CommandCallback * const _callback_table[] = { /* 0x1B */ CcAddVehicleNewGroup, }; +/* Helpers to generate the command dispatch table from the command traits. */ + +template static CommandDataBuffer SanitizeCmdStrings(const CommandDataBuffer &data); +template static void UnpackNetworkCommand(const CommandPacket *cp); +struct CommandDispatch { + CommandDataBuffer(*Sanitize)(const CommandDataBuffer &); + void (*Unpack)(const CommandPacket *); +}; + +template +inline constexpr auto MakeDispatchTable(std::integer_sequence) noexcept +{ + return std::array{{ { &SanitizeCmdStrings(i)>, &UnpackNetworkCommand(i)> }... }}; +} +static constexpr auto _cmd_dispatch = MakeDispatchTable(std::make_integer_sequence, CMD_END>{}); + + /** * Append a CommandPacket at the end of the queue. * @param p The packet to append to the queue. @@ -148,16 +188,29 @@ static CommandQueue _local_execution_queue; * @param text The text to pass */ void NetworkSendCommand(Commands cmd, StringID err_message, CommandCallback *callback, CompanyID company, TileIndex tile, uint32 p1, uint32 p2, const std::string &text) +{ + auto data = EndianBufferWriter::FromValue(std::make_tuple(tile, p1, p2, text)); + NetworkSendCommand(cmd, err_message, callback, company, tile, data); +} + +/** + * Prepare a DoCommand to be send over the network + * @param cmd The command to execute (a CMD_* value) + * @param err_message Message prefix to show on error + * @param callback A callback function to call after the command is finished + * @param company The company that wants to send the command + * @param location Location of the command (e.g. for error message position) + * @param cmd_data The command proc arguments. + */ +void NetworkSendCommand(Commands cmd, StringID err_message, CommandCallback *callback, CompanyID company, TileIndex location, const CommandDataBuffer &cmd_data) { CommandPacket c; c.company = company; - c.tile = tile; - c.p1 = p1; - c.p2 = p2; c.cmd = cmd; c.err_msg = err_message; c.callback = callback; - c.text = text; + c.tile = location; + c.data = cmd_data; if (_network_server) { /* If we are the server, we queue the command in our 'special' queue. @@ -220,7 +273,7 @@ void NetworkExecuteLocalCommandQueue() /* We can execute this command */ _current_company = cp->company; - DoCommandP(cp, cp->my_cmd, true); + _cmd_dispatch[cp->cmd].Unpack(cp); queue.Pop(); delete cp; @@ -311,11 +364,8 @@ const char *NetworkGameSocketHandler::ReceiveCommand(Packet *p, CommandPacket *c if (!IsValidCommand(cp->cmd)) return "invalid command"; if (GetCommandFlags(cp->cmd) & CMD_OFFLINE) return "single-player only command"; cp->err_msg = p->Recv_uint16(); - - cp->p1 = p->Recv_uint32(); - cp->p2 = p->Recv_uint32(); cp->tile = p->Recv_uint32(); - cp->text = p->Recv_string(NETWORK_COMPANY_NAME_LENGTH, (!_network_server && GetCommandFlags(cp->cmd) & CMD_STR_CTRL) != 0 ? SVS_ALLOW_CONTROL_CODE | SVS_REPLACE_WITH_QUESTION_MARK : SVS_REPLACE_WITH_QUESTION_MARK); + cp->data = _cmd_dispatch[cp->cmd].Sanitize(p->Recv_buffer()); byte callback = p->Recv_uint8(); if (callback >= lengthof(_callback_table)) return "invalid callback"; @@ -331,13 +381,11 @@ const char *NetworkGameSocketHandler::ReceiveCommand(Packet *p, CommandPacket *c */ void NetworkGameSocketHandler::SendCommand(Packet *p, const CommandPacket *cp) { - p->Send_uint8 (cp->company); + p->Send_uint8(cp->company); p->Send_uint16(cp->cmd); p->Send_uint16(cp->err_msg); - p->Send_uint32(cp->p1); - p->Send_uint32(cp->p2); p->Send_uint32(cp->tile); - p->Send_string(cp->text); + p->Send_buffer(cp->data); byte callback = 0; while (callback < lengthof(_callback_table) && _callback_table[callback] != cp->callback) { @@ -350,3 +398,58 @@ void NetworkGameSocketHandler::SendCommand(Packet *p, const CommandPacket *cp) } p->Send_uint8 (callback); } + +/** + * Insert a client ID into the command data in a command packet. + * @param cp Command packet to modify. + * @param client_id Client id to insert. + */ +void NetworkReplaceCommandClientId(CommandPacket &cp, ClientID client_id) +{ + /* Unpack command parameters. */ + auto params = EndianBufferReader::ToValue>(cp.data); + + /* Insert client id. */ + std::get<2>(params) = client_id; + + /* Repack command parameters. */ + cp.data = EndianBufferWriter::FromValue(params); +} + + +/** Validate a single string argument coming from network. */ +template +static inline void SanitizeSingleStringHelper([[maybe_unused]] CommandFlags cmd_flags, T &data) +{ + if constexpr (std::is_same_v) { + data = StrMakeValid(data.substr(0, NETWORK_COMPANY_NAME_LENGTH), (!_network_server && cmd_flags & CMD_STR_CTRL) != 0 ? SVS_ALLOW_CONTROL_CODE | SVS_REPLACE_WITH_QUESTION_MARK : SVS_REPLACE_WITH_QUESTION_MARK); + } +} + +/** Helper function to perform validation on command data strings. */ +template +static inline void SanitizeStringsHelper(CommandFlags cmd_flags, Ttuple &values, std::index_sequence) +{ + ((SanitizeSingleStringHelper(cmd_flags, std::get(values))), ...); +} + +/** + * Validate and sanitize strings in command data. + * @tparam Tcmd Command this data belongs to. + * @param data Command data. + * @return Sanitized command data. + */ +template +CommandDataBuffer SanitizeCmdStrings(const CommandDataBuffer &data) +{ + auto args = EndianBufferReader::ToValue::Args>(data); + SanitizeStringsHelper(CommandTraits::flags, args, std::make_index_sequence::Args>>{}); + return EndianBufferWriter::FromValue(args); +} + +template +void UnpackNetworkCommand(const CommandPacket *cp) +{ + auto args = EndianBufferReader::ToValue::Args>(cp->data); + std::apply(&InjectNetworkCommand, std::tuple_cat(std::make_tuple(Tcmd, cp->err_msg, cp->callback, cp->my_cmd), args)); +} diff --git a/src/network/network_internal.h b/src/network/network_internal.h index 25240da5d7..58c99867c2 100644 --- a/src/network/network_internal.h +++ b/src/network/network_internal.h @@ -15,6 +15,8 @@ #include "core/tcp_game.h" #include "../command_type.h" +#include "../command_func.h" +#include "../misc/endian_buffer.hpp" #ifdef RANDOM_DEBUG /** @@ -104,19 +106,26 @@ void UpdateNetworkGameWindow(); /** * Everything we need to know about a command to be able to execute it. */ -struct CommandPacket : CommandContainer { +struct CommandPacket { /** Make sure the pointer is nullptr. */ - CommandPacket() : next(nullptr), company(INVALID_COMPANY), frame(0), my_cmd(false) {} + CommandPacket() : next(nullptr), company(INVALID_COMPANY), frame(0), my_cmd(false), tile(0) {} CommandPacket *next; ///< the next command packet (if in queue) CompanyID company; ///< company that is executing the command uint32 frame; ///< the frame in which this packet is executed bool my_cmd; ///< did the command originate from "me" + + Commands cmd; ///< command being executed. + StringID err_msg; ///< string ID of error message to use. + CommandCallback *callback; ///< any callback function executed upon successful completion of the command. + TileIndex tile; ///< location of the command (for e.g. error message or effect display). + CommandDataBuffer data; ///< command parameters. }; void NetworkDistributeCommands(); void NetworkExecuteLocalCommandQueue(); void NetworkFreeLocalCommandQueue(); void NetworkSyncCommandQueue(NetworkClientSocket *cs); +void NetworkReplaceCommandClientId(CommandPacket &cp, ClientID client_id); void ShowNetworkError(StringID error_string); void NetworkTextMessage(NetworkAction action, TextColour colour, bool self_send, const std::string &name, const std::string &str = "", int64 data = 0, const std::string &data_str = ""); diff --git a/src/network/network_server.cpp b/src/network/network_server.cpp index 50b46dfe10..967ad40a89 100644 --- a/src/network/network_server.cpp +++ b/src/network/network_server.cpp @@ -24,6 +24,7 @@ #include "../genworld.h" #include "../company_func.h" #include "../company_gui.h" +#include "../company_cmd.h" #include "../roadveh.h" #include "../order_backup.h" #include "../core/pool_func.hpp" @@ -1048,14 +1049,15 @@ NetworkRecvStatus ServerNetworkGameSocketHandler::Receive_CLIENT_COMMAND(Packet * to match the company in the packet. If it doesn't, the client has done * something pretty naughty (or a bug), and will be kicked */ - if (!(cp.cmd == CMD_COMPANY_CTRL && cp.p1 == 0 && ci->client_playas == COMPANY_NEW_COMPANY) && ci->client_playas != cp.company) { + uint32 company_p1 = cp.cmd == CMD_COMPANY_CTRL ? std::get<1>(EndianBufferReader::ToValue::Args>(cp.data)) : 0; + if (!(cp.cmd == CMD_COMPANY_CTRL && company_p1 == 0 && ci->client_playas == COMPANY_NEW_COMPANY) && ci->client_playas != cp.company) { IConsolePrint(CC_WARNING, "Kicking client #{} (IP: {}) due to calling a command as another company {}.", ci->client_playas + 1, this->GetClientIP(), cp.company + 1); return this->SendError(NETWORK_ERROR_COMPANY_MISMATCH); } if (cp.cmd == CMD_COMPANY_CTRL) { - if (cp.p1 != 0 || cp.company != COMPANY_SPECTATOR) { + if (company_p1 != 0 || cp.company != COMPANY_SPECTATOR) { return this->SendError(NETWORK_ERROR_CHEATER); } @@ -1066,7 +1068,7 @@ NetworkRecvStatus ServerNetworkGameSocketHandler::Receive_CLIENT_COMMAND(Packet } } - if (GetCommandFlags(cp.cmd) & CMD_CLIENT_ID) cp.p2 = this->client_id; + if (GetCommandFlags(cp.cmd) & CMD_CLIENT_ID) NetworkReplaceCommandClientId(cp, this->client_id); this->incoming_queue.Append(&cp); return NETWORK_RECV_STATUS_OKAY; diff --git a/src/string.cpp b/src/string.cpp index d027cb7bff..aeac4fe84f 100644 --- a/src/string.cpp +++ b/src/string.cpp @@ -19,6 +19,7 @@ #include #include /* required for tolower() */ #include +#include #ifdef _MSC_VER #include // required by vsnprintf implementation for MSVC @@ -160,6 +161,23 @@ char *CDECL str_fmt(const char *str, ...) return p; } +/** + * Format a byte array into a continuous hex string. + * @param data Array to format + * @return Converted string. + */ +std::string FormatArrayAsHex(span data) +{ + std::ostringstream ss; + ss << std::uppercase << std::setfill('0') << std::setw(2) << std::hex; + + for (auto b : data) { + ss << b; + } + + return ss.str(); +} + /** * Scan the string for old values of SCC_ENCODED and fix it to * it's new, static value. diff --git a/src/string_func.h b/src/string_func.h index 0cbf26d6b2..a5d3499c76 100644 --- a/src/string_func.h +++ b/src/string_func.h @@ -28,6 +28,7 @@ #include #include "core/bitmath_func.hpp" +#include "core/span_type.hpp" #include "string_type.h" char *strecat(char *dst, const char *src, const char *last) NOACCESS(3); @@ -39,6 +40,8 @@ int CDECL vseprintf(char *str, const char *last, const char *format, va_list ap) char *CDECL str_fmt(const char *str, ...) WARN_FORMAT(1, 2); +std::string FormatArrayAsHex(span data); + void StrMakeValidInPlace(char *str, const char *last, StringValidationSettings settings = SVS_REPLACE_WITH_QUESTION_MARK) NOACCESS(2); [[nodiscard]] std::string StrMakeValid(const std::string &str, StringValidationSettings settings = SVS_REPLACE_WITH_QUESTION_MARK); void StrMakeValidInPlace(char *str, StringValidationSettings settings = SVS_REPLACE_WITH_QUESTION_MARK);