From 4af089b9be8f56418cc96447b3ca26f037f7888b Mon Sep 17 00:00:00 2001 From: Rubidium Date: Sun, 17 Mar 2024 19:11:55 +0100 Subject: [PATCH] Feature: console command to change authorized keys --- src/console_cmds.cpp | 98 ++++++++++++++++++++++++++++++++++ src/network/network_func.h | 1 + src/network/network_server.cpp | 12 +++++ src/network/network_server.h | 1 + 4 files changed, 112 insertions(+) diff --git a/src/console_cmds.cpp b/src/console_cmds.cpp index 3b797005cc..10e6810f14 100644 --- a/src/console_cmds.cpp +++ b/src/console_cmds.cpp @@ -1956,6 +1956,101 @@ DEF_CONSOLE_CMD(ConCompanyPassword) return true; } +/** All the known authorized keys with their name. */ +static std::vector *>> _console_cmd_authorized_keys{ + { "rcon", &_settings_client.network.rcon_authorized_keys }, + { "server", &_settings_client.network.server_authorized_keys }, +}; + +/** + * Simple helper to find the location of the given authorized key in the authorized keys. + * @param authorized_keys The keys to look through. + * @param authorized_key The key to look for. + * @return The iterator to the location of the authorized key, or \c authorized_keys.end(). + */ +static auto FindKey(std::vector *authorized_keys, std::string_view authorized_key) +{ + return std::find_if(authorized_keys->begin(), authorized_keys->end(), [authorized_key](auto &value) { return StrEqualsIgnoreCase(value, authorized_key); }); +} + +DEF_CONSOLE_CMD(ConNetworkAuthorizedKey) +{ + if (argc <= 2) { + IConsolePrint(CC_HELP, "List and update authorized keys. Usage: 'authorized_key list [type]|add [type] [key]|remove [type] [key]'."); + IConsolePrint(CC_HELP, " list: list all the authorized keys of the given type."); + IConsolePrint(CC_HELP, " add: add the given key to the authorized keys of the given type."); + IConsolePrint(CC_HELP, " remove: remove the given key from the authorized keys of the given type; use 'all' to remove all authorized keys."); + IConsolePrint(CC_HELP, "Instead of a key, use 'client:' to add/remove the key of that given client."); + + std::string buffer; + for (auto [name, _] : _console_cmd_authorized_keys) fmt::format_to(std::back_inserter(buffer), ", {}", name); + IConsolePrint(CC_HELP, "The supported types are: all{}.", buffer); + return true; + } + + bool valid_type = false; ///< Whether a valid type was given. + + for (auto [name, authorized_keys] : _console_cmd_authorized_keys) { + if (!StrEqualsIgnoreCase(argv[2], name) && !StrEqualsIgnoreCase(argv[2], "all")) continue; + + valid_type = true; + + if (StrEqualsIgnoreCase(argv[1], "list")) { + IConsolePrint(CC_WHITE, "The authorized keys for {} are:", name); + for (auto &authorized_key : *authorized_keys) IConsolePrint(CC_INFO, " {}", authorized_key); + continue; + } + + if (argc <= 3) { + IConsolePrint(CC_ERROR, "You must enter the key."); + return false; + } + + std::string authorized_key = argv[3]; + if (StrStartsWithIgnoreCase(authorized_key, "client:")) { + std::string id_string(authorized_key.substr(7)); + authorized_key = NetworkGetPublicKeyOfClient(static_cast(std::stoi(id_string))); + if (authorized_key.empty()) { + IConsolePrint(CC_ERROR, "You must enter a valid client id; see 'clients'."); + return false; + } + } + + auto iter = FindKey(authorized_keys, authorized_key); + + if (StrEqualsIgnoreCase(argv[1], "add")) { + if (iter == authorized_keys->end()) { + authorized_keys->push_back(authorized_key); + IConsolePrint(CC_INFO, "Added {} to {}.", authorized_key, name); + } else { + IConsolePrint(CC_WARNING, "Not added {} to {} as it already exists.", authorized_key, name); + } + continue; + } + + if (StrEqualsIgnoreCase(argv[1], "remove")) { + if (iter != authorized_keys->end()) { + authorized_keys->erase(iter); + IConsolePrint(CC_INFO, "Removed {} from {}.", authorized_key, name); + } else { + IConsolePrint(CC_WARNING, "Not removed {} from {} as it does not exist.", authorized_key, name); + } + continue; + } + + IConsolePrint(CC_WARNING, "No valid action was given."); + return false; + } + + if (!valid_type) { + IConsolePrint(CC_WARNING, "No valid type was given."); + return false; + } + + return true; +} + + /* Content downloading only is available with ZLIB */ #if defined(WITH_ZLIB) #include "network/network_content.h" @@ -2723,6 +2818,9 @@ void IConsoleStdLibRegister() IConsole::CmdRegister("pause", ConPauseGame, ConHookServerOrNoNetwork); IConsole::CmdRegister("unpause", ConUnpauseGame, ConHookServerOrNoNetwork); + IConsole::CmdRegister("authorized_key", ConNetworkAuthorizedKey, ConHookServerOnly); + IConsole::AliasRegister("ak", "authorized_key %+"); + IConsole::CmdRegister("company_pw", ConCompanyPassword, ConHookNeedNetwork); IConsole::AliasRegister("company_password", "company_pw %+"); diff --git a/src/network/network_func.h b/src/network/network_func.h index 37a4a81fd6..66b4660169 100644 --- a/src/network/network_func.h +++ b/src/network/network_func.h @@ -61,6 +61,7 @@ bool NetworkCompanyIsPassworded(CompanyID company_id); uint NetworkMaxCompaniesAllowed(); bool NetworkMaxCompaniesReached(); void NetworkPrintClients(); +std::string_view NetworkGetPublicKeyOfClient(ClientID client_id); void NetworkHandlePauseChange(PauseMode prev_mode, PauseMode changed_mode); /*** Commands ran by the server ***/ diff --git a/src/network/network_server.cpp b/src/network/network_server.cpp index 6166376b11..39bd2aa540 100644 --- a/src/network/network_server.cpp +++ b/src/network/network_server.cpp @@ -2243,6 +2243,18 @@ void NetworkPrintClients() } } +/** + * Get the public key of the client with the given id. + * @param client_id The id of the client. + * @return View of the public key, which is empty when the client does not exist. + */ +std::string_view NetworkGetPublicKeyOfClient(ClientID client_id) +{ + auto socket = NetworkClientSocket::GetByClientID(client_id); + return socket == nullptr ? "" : socket->GetPeerPublicKey(); +} + + /** * Perform all the server specific administration of a new company. * @param c The newly created company; can't be nullptr. diff --git a/src/network/network_server.h b/src/network/network_server.h index d41ef4756c..6b75a123f9 100644 --- a/src/network/network_server.h +++ b/src/network/network_server.h @@ -121,6 +121,7 @@ public: } const std::string &GetClientIP(); + std::string_view GetPeerPublicKey() const { return this->peer_public_key; } static ServerNetworkGameSocketHandler *GetByClientID(ClientID client_id); };