diff --git a/src/network/core/packet.cpp b/src/network/core/packet.cpp index 65f2f89a60..f350c431da 100644 --- a/src/network/core/packet.cpp +++ b/src/network/core/packet.cpp @@ -54,12 +54,16 @@ Packet::Packet(NetworkSocketHandler *cs, PacketType type, size_t limit) : pos(0) void Packet::ResetState(PacketType type) { this->buffer.clear(); + this->tx_packet_type = type; /* Allocate space for the the size so we can write that in just before sending the packet. */ size_t size = EncodedLengthOfPacketSize(); if (cs != nullptr && cs->send_encryption_handler != nullptr) { /* Allocate some space for the message authentication code of the encryption. */ size += cs->send_encryption_handler->MACSize(); + this->encyption_pending = true; + } else { + this->encyption_pending = false; } assert(this->CanWriteToPacket(size)); this->buffer.resize(size, 0); @@ -70,7 +74,7 @@ void Packet::ResetState(PacketType type) /** * Writes the packet size from the raw packet from packet->size */ -void Packet::PrepareToSend() +void Packet::PrepareForSendQueue() { /* Prevent this to be called twice and for packets that have been received. */ assert(this->buffer[0] == 0 && this->buffer[1] == 0); @@ -78,17 +82,19 @@ void Packet::PrepareToSend() this->buffer[0] = GB(this->Size(), 0, 8); this->buffer[1] = GB(this->Size(), 8, 8); - if (cs != nullptr && cs->send_encryption_handler != nullptr) { - size_t offset = EncodedLengthOfPacketSize(); - size_t mac_size = cs->send_encryption_handler->MACSize(); - size_t message_offset = offset + mac_size; - cs->send_encryption_handler->Encrypt(std::span(&this->buffer[offset], mac_size), std::span(&this->buffer[message_offset], this->buffer.size() - message_offset)); - } - - this->pos = 0; // We start reading from here + this->pos = 0; // We start reading from here this->buffer.shrink_to_fit(); } +void Packet::PreSendEncryption() +{ + this->encyption_pending = false; + size_t offset = EncodedLengthOfPacketSize(); + size_t mac_size = cs->send_encryption_handler->MACSize(); + size_t message_offset = offset + mac_size; + cs->send_encryption_handler->Encrypt(std::span(&this->buffer[offset], mac_size), std::span(&this->buffer[message_offset], this->buffer.size() - message_offset)); +} + /** * Is it safe to write to the packet, i.e. didn't we run over the buffer? * @param bytes_to_write The amount of bytes we want to try to write. diff --git a/src/network/core/packet.h b/src/network/core/packet.h index 1a8211aaf0..19eb0feaaa 100644 --- a/src/network/core/packet.h +++ b/src/network/core/packet.h @@ -51,6 +51,10 @@ struct Packet : public BufferSerialisationHelper, public BufferDeseriali private: /** The current read/write position in the packet */ PacketSize pos; + /** Whether encryption is required for this packet */ + bool encyption_pending = false; + /** Packet type, for transmitted packets */ + PacketType tx_packet_type; /** The buffer of this packet. */ std::vector buffer; /** The limit for the packet size. */ @@ -59,6 +63,8 @@ private: /** Socket we're associated with. */ NetworkSocketHandler *cs; + void PreSendEncryption(); + public: struct ReadTag{}; Packet(ReadTag tag, NetworkSocketHandler *cs, size_t limit, size_t initial_read_size = EncodedLengthOfPacketSize()); @@ -66,8 +72,21 @@ public: void ResetState(PacketType type); + void PrepareForSendQueue(); + + inline void CheckPendingPreSendEncryption() + { + if (this->encyption_pending) { + this->PreSendEncryption(); + } + } + /* Sending/writing of packets */ - void PrepareToSend(); + inline void PrepareToSend() + { + this->PrepareForSendQueue(); + this->CheckPendingPreSendEncryption(); + } std::vector &GetSerialisationBuffer() { return this->buffer; } size_t GetSerialisationLimit() const { return this->limit; } @@ -88,6 +107,7 @@ public: size_t Size() const; [[nodiscard]] bool PrepareToRead(); PacketType GetPacketType() const; + PacketType GetTransmitPacketType() const { return this->tx_packet_type; } bool CanReadFromPacket(size_t bytes_to_read, bool close_connection = false); diff --git a/src/network/core/tcp.cpp b/src/network/core/tcp.cpp index 7793e2bb1a..e748eb8165 100644 --- a/src/network/core/tcp.cpp +++ b/src/network/core/tcp.cpp @@ -69,7 +69,7 @@ void NetworkTCPSocketHandler::SendPacket(std::unique_ptr packet) { assert(packet != nullptr); - packet->PrepareToSend(); + packet->PrepareForSendQueue(); this->packet_queue.push_back(std::move(packet)); } @@ -84,11 +84,11 @@ void NetworkTCPSocketHandler::SendPrependPacket(std::unique_ptr packet, { assert(packet != nullptr); - packet->PrepareToSend(); + packet->PrepareForSendQueue(); if (queue_after_packet_type >= 0) { for (auto iter = this->packet_queue.begin(); iter != this->packet_queue.end(); ++iter) { - if ((*iter)->GetPacketType() == queue_after_packet_type) { + if ((*iter)->GetTransmitPacketType() == queue_after_packet_type) { ++iter; this->packet_queue.insert(iter, std::move(packet)); return; @@ -131,6 +131,7 @@ SendPacketsState NetworkTCPSocketHandler::SendPackets(bool closing_down) while (!this->packet_queue.empty()) { Packet &p = *this->packet_queue.front(); + p.CheckPendingPreSendEncryption(); ssize_t res = p.TransferOut(send, this->sock, 0); if (res == -1) { NetworkError err = NetworkError::GetLast(); diff --git a/src/network/core/tcp_game.cpp b/src/network/core/tcp_game.cpp index 1c47f12dcf..9573e78ec2 100644 --- a/src/network/core/tcp_game.cpp +++ b/src/network/core/tcp_game.cpp @@ -295,7 +295,7 @@ std::string NetworkGameSocketHandler::GetDebugInfo() const { return ""; } void NetworkGameSocketHandler::LogSentPacket(const Packet &pkt) { - PacketGameType type = (PacketGameType)pkt.GetPacketType(); + PacketGameType type = (PacketGameType)pkt.GetTransmitPacketType(); DEBUG(net, 5, "[tcp/game] sent packet type %d (%s) to client %d, %s", type, GetPacketGameTypeName(type), this->client_id, this->GetDebugInfo().c_str()); } diff --git a/src/network/network_server.cpp b/src/network/network_server.cpp index c4a84de80c..1dca8f30d5 100644 --- a/src/network/network_server.cpp +++ b/src/network/network_server.cpp @@ -136,7 +136,7 @@ struct PacketWriter : SaveFilter { } bool last_packet = false; for (auto &p : this->packets) { - if (p->GetPacketType() == PACKET_SERVER_MAP_DONE) last_packet = true; + if (p->GetTransmitPacketType() == PACKET_SERVER_MAP_DONE) last_packet = true; this->cs->SendPacket(std::move(p)); }