diff --git a/src/network/core/core.h b/src/network/core/core.h index 712687022a..a9d12b1f11 100644 --- a/src/network/core/core.h +++ b/src/network/core/core.h @@ -13,6 +13,7 @@ #define NETWORK_CORE_CORE_H #include "../../newgrf_config.h" +#include "../network_crypto.h" #include "config.h" bool NetworkCoreInitialize(); @@ -43,6 +44,11 @@ class NetworkSocketHandler { private: bool has_quit; ///< Whether the current client has quit/send a bad packet +protected: + friend struct Packet; + std::unique_ptr receive_encryption_handler; ///< The handler for decrypting received packets. + std::unique_ptr send_encryption_handler; ///< The handler for encrypting sent packets. + public: /** Create a new unbound socket */ NetworkSocketHandler() { this->has_quit = false; } diff --git a/src/network/core/packet.cpp b/src/network/core/packet.cpp index d2fe0b4cde..974630c63d 100644 --- a/src/network/core/packet.cpp +++ b/src/network/core/packet.cpp @@ -48,7 +48,14 @@ Packet::Packet(NetworkSocketHandler *cs, size_t limit, size_t initial_read_size) Packet::Packet(NetworkSocketHandler *cs, PacketType type, size_t limit) : pos(0), limit(limit), cs(cs) { /* Allocate space for the the size so we can write that in just before sending the packet. */ - this->Send_uint16(0); + size_t size = EncodedLengthOfPacketSize(); + if (cs != nullptr && cs->send_encryption_handler != nullptr) { + /* Allocate some space for the message authentication code of the encryption. */ + size += cs->send_encryption_handler->MACSize(); + } + assert(this->CanWriteToPacket(size)); + this->buffer.resize(size, 0); + this->Send_uint8(type); } @@ -64,6 +71,13 @@ void Packet::PrepareToSend() this->buffer[0] = GB(this->Size(), 0, 8); this->buffer[1] = GB(this->Size(), 8, 8); + if (cs != nullptr && cs->send_encryption_handler != nullptr) { + size_t offset = EncodedLengthOfPacketSize(); + size_t mac_size = cs->send_encryption_handler->MACSize(); + size_t message_offset = offset + mac_size; + cs->send_encryption_handler->Encrypt(std::span(&this->buffer[offset], mac_size), std::span(&this->buffer[message_offset], this->buffer.size() - message_offset)); + } + this->pos = 0; // We start reading from here this->buffer.shrink_to_fit(); } @@ -259,11 +273,21 @@ bool Packet::ParsePacketSize() /** * Prepares the packet so it can be read + * @return True when the packet was valid, otherwise false. */ -void Packet::PrepareToRead() +bool Packet::PrepareToRead() { /* Put the position on the right place */ this->pos = static_cast(EncodedLengthOfPacketSize()); + + if (cs == nullptr || cs->receive_encryption_handler == nullptr) return true; + + size_t mac_size = cs->receive_encryption_handler->MACSize(); + if (this->buffer.size() <= pos + mac_size) return false; + + bool valid = cs->receive_encryption_handler->Decrypt(std::span(&this->buffer[pos], mac_size), std::span(&this->buffer[pos + mac_size], this->buffer.size() - pos - mac_size)); + this->pos += static_cast(mac_size); + return valid; } /** @@ -273,7 +297,9 @@ void Packet::PrepareToRead() PacketType Packet::GetPacketType() const { assert(this->Size() >= EncodedLengthOfPacketSize() + EncodedLengthOfPacketType()); - return static_cast(buffer[EncodedLengthOfPacketSize()]); + size_t offset = EncodedLengthOfPacketSize(); + if (cs != nullptr && cs->send_encryption_handler != nullptr) offset += cs->send_encryption_handler->MACSize(); + return static_cast(buffer[offset]); } /** diff --git a/src/network/core/packet.h b/src/network/core/packet.h index 2631173cce..839f8d4740 100644 --- a/src/network/core/packet.h +++ b/src/network/core/packet.h @@ -74,7 +74,7 @@ public: bool HasPacketSizeData() const; bool ParsePacketSize(); size_t Size() const; - void PrepareToRead(); + [[nodiscard]] bool PrepareToRead(); PacketType GetPacketType() const; bool CanReadFromPacket(size_t bytes_to_read, bool close_connection = false); diff --git a/src/network/core/tcp.cpp b/src/network/core/tcp.cpp index b01b8cd075..8bd7b44f2f 100644 --- a/src/network/core/tcp.cpp +++ b/src/network/core/tcp.cpp @@ -188,7 +188,11 @@ std::unique_ptr NetworkTCPSocketHandler::ReceivePacket() } } - p.PrepareToRead(); + if (!p.PrepareToRead()) { + Debug(net, 0, "Invalid packet received (too small / decryption error)"); + this->CloseConnection(); + return nullptr; + } return std::move(this->packet_recv); } diff --git a/src/network/core/tcp_game.cpp b/src/network/core/tcp_game.cpp index 9e179362bc..5b7ef1dfca 100644 --- a/src/network/core/tcp_game.cpp +++ b/src/network/core/tcp_game.cpp @@ -85,7 +85,7 @@ NetworkRecvStatus NetworkGameSocketHandler::HandlePacket(Packet &p) case PACKET_SERVER_AUTH_REQUEST: return this->Receive_SERVER_AUTH_REQUEST(p); case PACKET_SERVER_NEED_COMPANY_PASSWORD: return this->Receive_SERVER_NEED_COMPANY_PASSWORD(p); case PACKET_CLIENT_AUTH_RESPONSE: return this->Receive_CLIENT_AUTH_RESPONSE(p); - case PACKET_SERVER_AUTH_COMPLETED: return this->Receive_SERVER_AUTH_COMPLETED(p); + case PACKET_SERVER_ENABLE_ENCRYPTION: return this->Receive_SERVER_ENABLE_ENCRYPTION(p); case PACKET_CLIENT_COMPANY_PASSWORD: return this->Receive_CLIENT_COMPANY_PASSWORD(p); case PACKET_SERVER_WELCOME: return this->Receive_SERVER_WELCOME(p); case PACKET_CLIENT_GETMAP: return this->Receive_CLIENT_GETMAP(p); @@ -168,7 +168,7 @@ NetworkRecvStatus NetworkGameSocketHandler::Receive_CLIENT_IDENTIFY(Packet &) { NetworkRecvStatus NetworkGameSocketHandler::Receive_SERVER_AUTH_REQUEST(Packet &) { return this->ReceiveInvalidPacket(PACKET_SERVER_AUTH_REQUEST); } NetworkRecvStatus NetworkGameSocketHandler::Receive_SERVER_NEED_COMPANY_PASSWORD(Packet &) { return this->ReceiveInvalidPacket(PACKET_SERVER_NEED_COMPANY_PASSWORD); } NetworkRecvStatus NetworkGameSocketHandler::Receive_CLIENT_AUTH_RESPONSE(Packet &) { return this->ReceiveInvalidPacket(PACKET_CLIENT_AUTH_RESPONSE); } -NetworkRecvStatus NetworkGameSocketHandler::Receive_SERVER_AUTH_COMPLETED(Packet &) { return this->ReceiveInvalidPacket(PACKET_SERVER_AUTH_COMPLETED); } +NetworkRecvStatus NetworkGameSocketHandler::Receive_SERVER_ENABLE_ENCRYPTION(Packet &) { return this->ReceiveInvalidPacket(PACKET_SERVER_ENABLE_ENCRYPTION); } NetworkRecvStatus NetworkGameSocketHandler::Receive_CLIENT_COMPANY_PASSWORD(Packet &) { return this->ReceiveInvalidPacket(PACKET_CLIENT_COMPANY_PASSWORD); } NetworkRecvStatus NetworkGameSocketHandler::Receive_SERVER_WELCOME(Packet &) { return this->ReceiveInvalidPacket(PACKET_SERVER_WELCOME); } NetworkRecvStatus NetworkGameSocketHandler::Receive_CLIENT_GETMAP(Packet &) { return this->ReceiveInvalidPacket(PACKET_CLIENT_GETMAP); } diff --git a/src/network/core/tcp_game.h b/src/network/core/tcp_game.h index 2e5f33dfdb..cf1b8b92b2 100644 --- a/src/network/core/tcp_game.h +++ b/src/network/core/tcp_game.h @@ -15,7 +15,6 @@ #include "os_abstraction.h" #include "tcp.h" #include "../network_type.h" -#include "../network_crypto.h" #include "../../core/pool_type.hpp" #include @@ -60,7 +59,7 @@ enum PacketGameType : uint8_t { /* After the join step, the first perform game authentication and enabling encryption. */ PACKET_SERVER_AUTH_REQUEST, ///< The server requests the client to authenticate using a number of methods. PACKET_CLIENT_AUTH_RESPONSE, ///< The client responds to the authentication request. - PACKET_SERVER_AUTH_COMPLETED, ///< The server indicates the authentication is completed. + PACKET_SERVER_ENABLE_ENCRYPTION, ///< The server tells that authentication has completed and requests to enable encryption with the keys of the last \c PACKET_CLIENT_AUTH_RESPONSE. /* After the authentication is done, the next step is identification. */ PACKET_CLIENT_IDENTIFY, ///< Client telling the server the client's name and requested company. @@ -244,10 +243,10 @@ protected: virtual NetworkRecvStatus Receive_CLIENT_AUTH_RESPONSE(Packet &p); /** - * Indication to the client that authentication has completed. + * Indication to the client that authentication is complete and encryption has to be used from here on forward. * @param p The packet that was just received. */ - virtual NetworkRecvStatus Receive_SERVER_AUTH_COMPLETED(Packet &p); + virtual NetworkRecvStatus Receive_SERVER_ENABLE_ENCRYPTION(Packet &p); /** * Send a password to the server to authorize diff --git a/src/network/core/udp.cpp b/src/network/core/udp.cpp index 04324e0098..c7acbd21a1 100644 --- a/src/network/core/udp.cpp +++ b/src/network/core/udp.cpp @@ -137,7 +137,10 @@ void NetworkUDPSocketHandler::ReceivePackets() Debug(net, 1, "Received a packet with mismatching size from {}", address.GetAddressAsString()); continue; } - p.PrepareToRead(); + if (!p.PrepareToRead()) { + Debug(net, 1, "Invalid packet received (too small / decryption error)"); + continue; + } /* Handle the packet */ this->HandleUDPPacket(p, address); diff --git a/src/network/network_client.cpp b/src/network/network_client.cpp index f3fd5fe6f0..698b2975dd 100644 --- a/src/network/network_client.cpp +++ b/src/network/network_client.cpp @@ -707,7 +707,7 @@ NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_ERROR(Packet &p NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_CHECK_NEWGRFS(Packet &p) { - if (this->status != STATUS_AUTHENTICATED) return NETWORK_RECV_STATUS_MALFORMED_PACKET; + if (this->status != STATUS_ENCRYPTED) return NETWORK_RECV_STATUS_MALFORMED_PACKET; uint grf_count = p.Recv_uint8(); NetworkRecvStatus ret = NETWORK_RECV_STATUS_OKAY; @@ -775,16 +775,18 @@ NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_AUTH_REQUEST(Pa } } -NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_AUTH_COMPLETED(Packet &) +NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_ENABLE_ENCRYPTION(Packet &) { if (this->status != STATUS_AUTH_GAME || this->authentication_handler == nullptr) return NETWORK_RECV_STATUS_MALFORMED_PACKET; - Debug(net, 9, "Client::Receive_SERVER_AUTH_COMPLETED()"); + Debug(net, 9, "Client::Receive_SERVER_ENABLE_ENCRYPTION()"); + this->receive_encryption_handler = this->authentication_handler->CreateServerToClientEncryptionHandler(); + this->send_encryption_handler = this->authentication_handler->CreateClientToServerEncryptionHandler(); this->authentication_handler = nullptr; - Debug(net, 9, "Client::status = AUTHENTICATED"); - this->status = STATUS_AUTHENTICATED; + Debug(net, 9, "Client::status = ENCRYPTED"); + this->status = STATUS_ENCRYPTED; return this->SendIdentify(); } @@ -798,7 +800,7 @@ class CompanyPasswordRequest : public NetworkAuthenticationPasswordRequest { NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_NEED_COMPANY_PASSWORD(Packet &p) { - if (this->status < STATUS_AUTHENTICATED || this->status >= STATUS_AUTH_COMPANY) return NETWORK_RECV_STATUS_MALFORMED_PACKET; + if (this->status < STATUS_ENCRYPTED || this->status >= STATUS_AUTH_COMPANY) return NETWORK_RECV_STATUS_MALFORMED_PACKET; Debug(net, 9, "Client::status = AUTH_COMPANY"); this->status = STATUS_AUTH_COMPANY; @@ -819,7 +821,7 @@ NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_NEED_COMPANY_PA NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_WELCOME(Packet &p) { - if (this->status < STATUS_AUTHENTICATED || this->status >= STATUS_AUTHORIZED) return NETWORK_RECV_STATUS_MALFORMED_PACKET; + if (this->status < STATUS_ENCRYPTED || this->status >= STATUS_AUTHORIZED) return NETWORK_RECV_STATUS_MALFORMED_PACKET; Debug(net, 9, "Client::status = AUTHORIZED"); this->status = STATUS_AUTHORIZED; diff --git a/src/network/network_client.h b/src/network/network_client.h index ac80393651..8d68b8d8ba 100644 --- a/src/network/network_client.h +++ b/src/network/network_client.h @@ -25,7 +25,7 @@ private: STATUS_INACTIVE, ///< The client is not connected nor active. STATUS_JOIN, ///< We are trying to join a server. STATUS_AUTH_GAME, ///< Last action was requesting game (server) password. - STATUS_AUTHENTICATED, ///< The game authentication has completed. + STATUS_ENCRYPTED, ///< The game authentication has completed and from here on the connection to the server is encrypted. STATUS_NEWGRFS_CHECK, ///< Last action was checking NewGRFs. STATUS_AUTH_COMPANY, ///< Last action was requesting company password. STATUS_AUTHORIZED, ///< The client is authorized at the server. @@ -47,7 +47,7 @@ protected: NetworkRecvStatus Receive_SERVER_ERROR(Packet &p) override; NetworkRecvStatus Receive_SERVER_CLIENT_INFO(Packet &p) override; NetworkRecvStatus Receive_SERVER_AUTH_REQUEST(Packet &p) override; - NetworkRecvStatus Receive_SERVER_AUTH_COMPLETED(Packet &p) override; + NetworkRecvStatus Receive_SERVER_ENABLE_ENCRYPTION(Packet &p) override; NetworkRecvStatus Receive_SERVER_NEED_COMPANY_PASSWORD(Packet &p) override; NetworkRecvStatus Receive_SERVER_WELCOME(Packet &p) override; NetworkRecvStatus Receive_SERVER_WAIT(Packet &p) override; diff --git a/src/network/network_server.cpp b/src/network/network_server.cpp index 6c7d4192ee..3a792504dd 100644 --- a/src/network/network_server.cpp +++ b/src/network/network_server.cpp @@ -457,15 +457,15 @@ NetworkRecvStatus ServerNetworkGameSocketHandler::SendAuthRequest() return NETWORK_RECV_STATUS_OKAY; } -/** Notify the client that the authentication has completed. */ -NetworkRecvStatus ServerNetworkGameSocketHandler::SendAuthCompleted() +/** Notify the client that the authentication has completed and tell that for the remainder of this socket encryption is enabled. */ +NetworkRecvStatus ServerNetworkGameSocketHandler::SendEnableEncryption() { - Debug(net, 9, "client[{}] SendAuthCompleted()", this->client_id); + Debug(net, 9, "client[{}] SendEnableEncryption()", this->client_id); /* Invalid packet when status is anything but STATUS_AUTH_GAME. */ if (this->status != STATUS_AUTH_GAME) return this->CloseConnection(NETWORK_RECV_STATUS_MALFORMED_PACKET); - auto p = std::make_unique(this, PACKET_SERVER_AUTH_COMPLETED); + auto p = std::make_unique(this, PACKET_SERVER_ENABLE_ENCRYPTION); this->SendPacket(std::move(p)); return NETWORK_RECV_STATUS_OKAY; } @@ -999,9 +999,11 @@ NetworkRecvStatus ServerNetworkGameSocketHandler::Receive_CLIENT_AUTH_RESPONSE(P return this->SendError(GetErrorForAuthenticationMethod(authentication_method)); } - NetworkRecvStatus status = this->SendAuthCompleted(); + NetworkRecvStatus status = this->SendEnableEncryption(); if (status != NETWORK_RECV_STATUS_OKAY) return status; + this->receive_encryption_handler = this->authentication_handler->CreateClientToServerEncryptionHandler(); + this->send_encryption_handler = this->authentication_handler->CreateServerToClientEncryptionHandler(); this->authentication_handler = nullptr; Debug(net, 9, "client[{}] status = IDENTIFY", this->client_id); diff --git a/src/network/network_server.h b/src/network/network_server.h index b0e21f26f6..47b8777578 100644 --- a/src/network/network_server.h +++ b/src/network/network_server.h @@ -47,7 +47,7 @@ protected: NetworkRecvStatus SendNewGRFCheck(); NetworkRecvStatus SendWelcome(); NetworkRecvStatus SendAuthRequest(); - NetworkRecvStatus SendAuthCompleted(); + NetworkRecvStatus SendEnableEncryption(); NetworkRecvStatus SendNeedCompanyPassword(); public: diff --git a/src/tests/test_network_crypto.cpp b/src/tests/test_network_crypto.cpp index 7258c09150..34cbefaf27 100644 --- a/src/tests/test_network_crypto.cpp +++ b/src/tests/test_network_crypto.cpp @@ -20,15 +20,21 @@ static_assert(NETWORK_SECRET_KEY_LENGTH >= X25519_KEY_SIZE * 2 + 1); class MockNetworkSocketHandler : public NetworkSocketHandler { +public: + MockNetworkSocketHandler(std::unique_ptr &&receive = {}, std::unique_ptr &&send = {}) + { + this->receive_encryption_handler = std::move(receive); + this->send_encryption_handler = std::move(send); + } }; static MockNetworkSocketHandler mock_socket_handler; -static Packet CreatePacketForReading(Packet &source) +static std::tuple CreatePacketForReading(Packet &source, MockNetworkSocketHandler *socket_handler) { source.PrepareToSend(); - Packet dest(&mock_socket_handler, COMPAT_MTU, source.Size()); + Packet dest(socket_handler, COMPAT_MTU, source.Size()); auto transfer_in = [](Packet &source, char *dest_data, size_t length) { auto transfer_out = [](char *dest_data, const char *source_data, size_t length) { @@ -39,9 +45,9 @@ static Packet CreatePacketForReading(Packet &source) }; dest.TransferIn(transfer_in, source); - dest.PrepareToRead(); + bool valid = dest.PrepareToRead(); dest.Recv_uint8(); // Ignore the type - return dest; + return { dest, valid }; } class TestPasswordRequestHandler : public NetworkAuthenticationPasswordRequestHandler { @@ -60,13 +66,16 @@ static void TestAuthentication(NetworkAuthenticationServerHandler &server, Netwo Packet request(&mock_socket_handler, PacketType{}); server.SendRequest(request); - request = CreatePacketForReading(request); + bool valid; + std::tie(request, valid) = CreatePacketForReading(request, &mock_socket_handler); + CHECK(valid); CHECK(client.ReceiveRequest(request) == expected_request_result); Packet response(&mock_socket_handler, PacketType{}); client.SendResponse(response); - response = CreatePacketForReading(response); + std::tie(response, valid) = CreatePacketForReading(response, &mock_socket_handler); + CHECK(valid); CHECK(server.ReceiveResponse(response) == expected_response_result); } @@ -200,3 +209,62 @@ TEST_CASE("Authentication_Combined") TestAuthentication(*server, *client, NetworkAuthenticationServerHandler::AUTHENTICATED, NetworkAuthenticationClientHandler::READY_FOR_RESPONSE); } } + + +static void CheckEncryption(MockNetworkSocketHandler *sending_socket_handler, MockNetworkSocketHandler *receiving_socket_handler) +{ + PacketType sent_packet_type{ 1 }; + uint64_t sent_value = 0x1234567890ABCDEF; + std::set encrypted_packet_types; + + for (int i = 0; i < 10; i++) { + Packet request(sending_socket_handler, sent_packet_type); + request.Send_uint64(sent_value); + + auto [response, valid] = CreatePacketForReading(request, receiving_socket_handler); + CHECK(valid); + CHECK(response.Recv_uint64() == sent_value); + + encrypted_packet_types.insert(request.GetPacketType()); + } + /* + * Check whether it looks like encryption has happened. This is done by checking the value + * of the packet type after encryption. If after a few iterations more than one encrypted + * value has been seen, then we know that some type of encryption/scrambling is happening. + * + * Technically this check could fail erroneously when 16 subsequent encryptions yield the + * same encrypted packet type. However, with encryption that byte should have random value + * value, so the chance of this happening are tiny given enough iterations. + * Roughly in the order of 2**((iterations - 1) * 8), which with 10 iterations is in the + * one-in-sextillion (10**21) order of magnitude. + */ + CHECK(encrypted_packet_types.size() != 1); + +} + +TEST_CASE("Encryption handling") +{ + X25519KeyExchangeOnlyServerHandler server(X25519SecretKey::CreateRandom()); + X25519KeyExchangeOnlyClientHandler client(X25519SecretKey::CreateRandom()); + + TestAuthentication(server, client, NetworkAuthenticationServerHandler::AUTHENTICATED, NetworkAuthenticationClientHandler::READY_FOR_RESPONSE); + + MockNetworkSocketHandler server_socket_handler(server.CreateClientToServerEncryptionHandler(), server.CreateServerToClientEncryptionHandler()); + MockNetworkSocketHandler client_socket_handler(client.CreateServerToClientEncryptionHandler(), client.CreateClientToServerEncryptionHandler()); + + SECTION("Encyption happening client -> server") { + CheckEncryption(&client_socket_handler, &server_socket_handler); + } + + SECTION("Encyption happening server -> client") { + CheckEncryption(&server_socket_handler, &client_socket_handler); + } + + SECTION("Unencrypted packet sent causes invalid read packet") { + Packet request(&mock_socket_handler, PacketType{}); + request.Send_uint64(0); + + auto [response, valid] = CreatePacketForReading(request, &client_socket_handler); + CHECK(!valid); + } +}