diff --git a/src/network/core/packet.cpp b/src/network/core/packet.cpp index 6be5fef11f..76039fa79e 100644 --- a/src/network/core/packet.cpp +++ b/src/network/core/packet.cpp @@ -19,45 +19,41 @@ #include "../../safeguards.h" /** - * Create a packet that is used to read from a network socket - * @param cs the socket handler associated with the socket we are reading from + * Create a packet that is used to read from a network socket. + * @param cs The socket handler associated with the socket we are reading from. + * @param initial_read_size The initial amount of data to transfer from the socket into the + * packet. This defaults to just the required bytes to determine the + * packet's size. That default is the wanted for streams such as TCP + * as you do not want to read data of the next packet yet. For UDP + * you need to read the whole packet at once otherwise you might + * loose some the data of the packet, so there you pass the maximum + * size for the packet you expect from the network. */ -Packet::Packet(NetworkSocketHandler *cs) +Packet::Packet(NetworkSocketHandler *cs, size_t initial_read_size) : pos(0) { assert(cs != nullptr); - this->cs = cs; - this->pos = 0; // We start reading from here - this->size = 0; - this->buffer = MallocT(SHRT_MAX); + this->cs = cs; + this->buffer.resize(initial_read_size); } /** * Creates a packet to send * @param type of the packet to send */ -Packet::Packet(PacketType type) +Packet::Packet(PacketType type) : pos(0), cs(nullptr) { - this->buffer = MallocT(SHRT_MAX); this->ResetState(type); } -/** - * Free the buffer of this packet. - */ -Packet::~Packet() -{ - free(this->buffer); -} - void Packet::ResetState(PacketType type) { - this->cs = nullptr; + this->cs = nullptr; + this->buffer.clear(); - /* Skip the size so we can write that in before sending the packet */ - this->pos = 0; - this->size = sizeof(PacketSize); - this->buffer[this->size++] = type; + /* Allocate space for the the size so we can write that in just before sending the packet. */ + this->Send_uint16(0); + this->Send_uint8(type); } /** @@ -67,10 +63,21 @@ void Packet::PrepareToSend() { assert(this->cs == nullptr); - this->buffer[0] = GB(this->size, 0, 8); - this->buffer[1] = GB(this->size, 8, 8); + this->buffer[0] = GB(this->Size(), 0, 8); + this->buffer[1] = GB(this->Size(), 8, 8); this->pos = 0; // We start reading from here + this->buffer.shrink_to_fit(); +} + +/** + * Is it safe to write to the packet, i.e. didn't we run over the buffer? + * @param bytes_to_write The amount of bytes we want to try to write. + * @return True iff the given amount of bytes can be written to the packet. + */ +bool Packet::CanWriteToPacket(size_t bytes_to_write) +{ + return this->Size() + bytes_to_write < SHRT_MAX; } /* @@ -100,8 +107,8 @@ void Packet::Send_bool(bool data) */ void Packet::Send_uint8(uint8 data) { - assert(this->size < SHRT_MAX - sizeof(data)); - this->buffer[this->size++] = data; + assert(this->CanWriteToPacket(sizeof(data))); + this->buffer.emplace_back(data); } /** @@ -110,9 +117,11 @@ void Packet::Send_uint8(uint8 data) */ void Packet::Send_uint16(uint16 data) { - assert(this->size < SHRT_MAX - sizeof(data)); - this->buffer[this->size++] = GB(data, 0, 8); - this->buffer[this->size++] = GB(data, 8, 8); + assert(this->CanWriteToPacket(sizeof(data))); + this->buffer.insert(this->buffer.end(), { + (uint8)GB(data, 0, 8), + (uint8)GB(data, 8, 8), + }); } /** @@ -121,11 +130,13 @@ void Packet::Send_uint16(uint16 data) */ void Packet::Send_uint32(uint32 data) { - assert(this->size < SHRT_MAX - sizeof(data)); - this->buffer[this->size++] = GB(data, 0, 8); - this->buffer[this->size++] = GB(data, 8, 8); - this->buffer[this->size++] = GB(data, 16, 8); - this->buffer[this->size++] = GB(data, 24, 8); + assert(this->CanWriteToPacket(sizeof(data))); + this->buffer.insert(this->buffer.end(), { + (uint8)GB(data, 0, 8), + (uint8)GB(data, 8, 8), + (uint8)GB(data, 16, 8), + (uint8)GB(data, 24, 8), + }); } /** @@ -134,15 +145,17 @@ void Packet::Send_uint32(uint32 data) */ void Packet::Send_uint64(uint64 data) { - assert(this->size < SHRT_MAX - sizeof(data)); - this->buffer[this->size++] = GB(data, 0, 8); - this->buffer[this->size++] = GB(data, 8, 8); - this->buffer[this->size++] = GB(data, 16, 8); - this->buffer[this->size++] = GB(data, 24, 8); - this->buffer[this->size++] = GB(data, 32, 8); - this->buffer[this->size++] = GB(data, 40, 8); - this->buffer[this->size++] = GB(data, 48, 8); - this->buffer[this->size++] = GB(data, 56, 8); + assert(this->CanWriteToPacket(sizeof(data))); + this->buffer.insert(this->buffer.end(), { + (uint8)GB(data, 0, 8), + (uint8)GB(data, 8, 8), + (uint8)GB(data, 16, 8), + (uint8)GB(data, 24, 8), + (uint8)GB(data, 32, 8), + (uint8)GB(data, 40, 8), + (uint8)GB(data, 48, 8), + (uint8)GB(data, 56, 8), + }); } /** @@ -153,9 +166,24 @@ void Packet::Send_uint64(uint64 data) void Packet::Send_string(const char *data) { assert(data != nullptr); - /* The <= *is* valid due to the fact that we are comparing sizes and not the index. */ - assert(this->size + strlen(data) + 1 <= SHRT_MAX); - while ((this->buffer[this->size++] = *data++) != '\0') {} + /* Length of the string + 1 for the '\0' termination. */ + assert(this->CanWriteToPacket(strlen(data) + 1)); + while (this->buffer.emplace_back(*data++) != '\0') {} +} + +/** + * Send as many of the bytes as possible in the packet. This can mean + * that it is possible that not all bytes are sent. To cope with this + * the function returns the amount of bytes that were actually sent. + * @param begin The begin of the buffer to send. + * @param end The end of the buffer to send. + * @return The number of bytes that were added to this packet. + */ +size_t Packet::Send_bytes(const byte *begin, const byte *end) +{ + size_t amount = std::min(end - begin, SHRT_MAX - this->Size()); + this->buffer.insert(this->buffer.end(), begin, begin + amount); + return amount; } /** @@ -165,9 +193,8 @@ void Packet::Send_string(const char *data) void Packet::Send_binary(const char *data, const size_t size) { assert(data != nullptr); - assert(this->size + size <= SHRT_MAX); - memcpy(&this->buffer[this->size], data, size); - this->size += (PacketSize) size; + assert(this->CanWriteToPacket(size)); + this->buffer.insert(this->buffer.end(), data, data + size); } @@ -179,19 +206,21 @@ void Packet::Send_binary(const char *data, const size_t size) /** - * Is it safe to read from the packet, i.e. didn't we run over the buffer ? - * @param bytes_to_read The amount of bytes we want to try to read. - * @param non_fatal True if a false return value is considered non-fatal, and will not raise an error. + * Is it safe to read from the packet, i.e. didn't we run over the buffer? + * In case \c close_connection is true, the connection will be closed when one would + * overrun the buffer. When it is false, the connection remains untouched. + * @param bytes_to_read The amount of bytes we want to try to read. + * @param close_connection Whether to close the connection if one cannot read that amount. * @return True if that is safe, otherwise false. */ -bool Packet::CanReadFromPacket(uint bytes_to_read, bool non_fatal) +bool Packet::CanReadFromPacket(size_t bytes_to_read, bool close_connection) { /* Don't allow reading from a quit client/client who send bad data */ if (this->cs->HasClientQuit()) return false; /* Check if variable is within packet-size */ - if (this->pos + bytes_to_read > this->size) { - if (!non_fatal) this->cs->NetworkSocketHandler::CloseConnection(); + if (this->pos + bytes_to_read > this->Size()) { + if (close_connection) this->cs->NetworkSocketHandler::CloseConnection(); return false; } @@ -199,13 +228,50 @@ bool Packet::CanReadFromPacket(uint bytes_to_read, bool non_fatal) } /** - * Reads the packet size from the raw packet and stores it in the packet->size + * Check whether the packet, given the position of the "write" pointer, has read + * enough of the packet to contain its size. + * @return True iff there is enough data in the packet to contain the packet's size. */ -void Packet::ReadRawPacketSize() +bool Packet::HasPacketSizeData() const +{ + return this->pos >= sizeof(PacketSize); +} + +/** + * Get the number of bytes in the packet. + * When sending a packet this is the size of the data up to that moment. + * When receiving a packet (before PrepareToRead) this is the allocated size for the data to be read. + * When reading a packet (after PrepareToRead) this is the full size of the packet. + * @return The packet's size. + */ +size_t Packet::Size() const +{ + return this->buffer.size(); +} + +size_t Packet::ReadRawPacketSize() const +{ + return (size_t)this->buffer[0] + ((size_t)this->buffer[1] << 8); +} + +/** + * Reads the packet size from the raw packet and stores it in the packet->size + * @return True iff the packet size seems plausible. + */ +bool Packet::ParsePacketSize() { assert(this->cs != nullptr); - this->size = (PacketSize)this->buffer[0]; - this->size += (PacketSize)this->buffer[1] << 8; + size_t size = (size_t)this->buffer[0]; + size += (size_t)this->buffer[1] << 8; + + /* If the size of the packet is less than the bytes required for the size and type of + * the packet, or more than the allowed limit, then something is wrong with the packet. + * In those cases the packet can generally be regarded as containing garbage data. */ + if (size < sizeof(PacketSize) + sizeof(PacketType)) return false; + + this->buffer.resize(size); + this->pos = sizeof(PacketSize); + return true; } /** @@ -213,12 +279,20 @@ void Packet::ReadRawPacketSize() */ void Packet::PrepareToRead() { - this->ReadRawPacketSize(); - /* Put the position on the right place */ this->pos = sizeof(PacketSize); } +/** + * Get the \c PacketType from this packet. + * @return The packet type. + */ +PacketType Packet::GetPacketType() const +{ + assert(this->Size() >= sizeof(PacketSize) + sizeof(PacketType)); + return static_cast(buffer[sizeof(PacketSize)]); +} + /** * Read a boolean from the packet. * @return The read data. @@ -236,7 +310,7 @@ uint8 Packet::Recv_uint8() { uint8 n; - if (!this->CanReadFromPacket(sizeof(n))) return 0; + if (!this->CanReadFromPacket(sizeof(n), true)) return 0; n = this->buffer[this->pos++]; return n; @@ -250,7 +324,7 @@ uint16 Packet::Recv_uint16() { uint16 n; - if (!this->CanReadFromPacket(sizeof(n))) return 0; + if (!this->CanReadFromPacket(sizeof(n), true)) return 0; n = (uint16)this->buffer[this->pos++]; n += (uint16)this->buffer[this->pos++] << 8; @@ -265,7 +339,7 @@ uint32 Packet::Recv_uint32() { uint32 n; - if (!this->CanReadFromPacket(sizeof(n))) return 0; + if (!this->CanReadFromPacket(sizeof(n), true)) return 0; n = (uint32)this->buffer[this->pos++]; n += (uint32)this->buffer[this->pos++] << 8; @@ -282,7 +356,7 @@ uint64 Packet::Recv_uint64() { uint64 n; - if (!this->CanReadFromPacket(sizeof(n))) return 0; + if (!this->CanReadFromPacket(sizeof(n), true)) return 0; n = (uint64)this->buffer[this->pos++]; n += (uint64)this->buffer[this->pos++] << 8; @@ -311,13 +385,13 @@ void Packet::Recv_string(char *buffer, size_t size, StringValidationSettings set if (cs->HasClientQuit()) return; pos = this->pos; - while (--size > 0 && pos < this->size && (*buffer++ = this->buffer[pos++]) != '\0') {} + while (--size > 0 && pos < this->Size() && (*buffer++ = this->buffer[pos++]) != '\0') {} - if (size == 0 || pos == this->size) { + if (size == 0 || pos == this->Size()) { *buffer = '\0'; /* If size was sooner to zero then the string in the stream * skip till the \0, so than packet can be read out correctly for the rest */ - while (pos < this->size && this->buffer[pos] != '\0') pos++; + while (pos < this->Size() && this->buffer[pos] != '\0') pos++; pos++; } this->pos = pos; @@ -325,6 +399,15 @@ void Packet::Recv_string(char *buffer, size_t size, StringValidationSettings set str_validate(bufp, last, settings); } +/** + * Get the amount of bytes that are still available for the Transfer functions. + * @return The number of bytes that still have to be transfered. + */ +size_t Packet::RemainingBytesToTransfer() const +{ + return this->Size() - this->pos; +} + /** * Reads a string till it finds a '\0' in the stream. * @param buffer The buffer to put the data into. @@ -335,8 +418,8 @@ void Packet::Recv_string(std::string &buffer, StringValidationSettings settings) /* Don't allow reading from a closed socket */ if (cs->HasClientQuit()) return; - size_t length = ttd_strnlen((const char *)(this->buffer + this->pos), this->size - this->pos - 1); - buffer.assign((const char *)(this->buffer + this->pos), length); + size_t length = ttd_strnlen((const char *)(this->buffer.data() + this->pos), this->Size() - this->pos - 1); + buffer.assign((const char *)(this->buffer.data() + this->pos), length); this->pos += length + 1; str_validate_inplace(buffer, settings); } @@ -348,7 +431,7 @@ void Packet::Recv_string(std::string &buffer, StringValidationSettings settings) */ void Packet::Recv_binary(char *buffer, size_t size) { - if (!this->CanReadFromPacket(size)) return; + if (!this->CanReadFromPacket(size, true)) return; memcpy(buffer, &this->buffer[this->pos], size); this->pos += (PacketSize) size; @@ -361,7 +444,7 @@ void Packet::Recv_binary(char *buffer, size_t size) */ void Packet::Recv_binary(std::string &buffer, size_t size) { - if (!this->CanReadFromPacket(size)) return; + if (!this->CanReadFromPacket(size, true)) return; buffer.assign((const char *) &this->buffer[this->pos], size); this->pos += (PacketSize) size; diff --git a/src/network/core/packet.h b/src/network/core/packet.h index 4999cf274a..58637fc043 100644 --- a/src/network/core/packet.h +++ b/src/network/core/packet.h @@ -12,10 +12,12 @@ #ifndef NETWORK_CORE_PACKET_H #define NETWORK_CORE_PACKET_H +#include "os_abstraction.h" #include "config.h" #include "core.h" #include "../../string_type.h" #include +#include typedef uint16 PacketSize; ///< Size of the whole packet. typedef uint8 PacketType; ///< Identifier for the packet @@ -39,44 +41,43 @@ typedef uint8 PacketType; ///< Identifier for the packet * (year % 4 == 0) and ((year % 100 != 0) or (year % 400 == 0)) */ struct Packet { - /** - * The size of the whole packet for received packets. For packets - * that will be sent, the value is filled in just before the - * actual transmission. - */ - PacketSize size; +private: /** The current read/write position in the packet */ PacketSize pos; - /** The buffer of this packet, of basically variable length up to SHRT_MAX. */ - byte *buffer; + /** The buffer of this packet. */ + std::vector buffer; -private: /** Socket we're associated with. */ NetworkSocketHandler *cs; public: - Packet(NetworkSocketHandler *cs); + Packet(NetworkSocketHandler *cs, size_t initial_read_size = sizeof(PacketSize)); Packet(PacketType type); - ~Packet(); void ResetState(PacketType type); /* Sending/writing of packets */ void PrepareToSend(); - void Send_bool (bool data); - void Send_uint8 (uint8 data); - void Send_uint16(uint16 data); - void Send_uint32(uint32 data); - void Send_uint64(uint64 data); - void Send_string(const char *data); - void Send_binary(const char *data, const size_t size); + bool CanWriteToPacket(size_t bytes_to_write); + void Send_bool (bool data); + void Send_uint8 (uint8 data); + void Send_uint16(uint16 data); + void Send_uint32(uint32 data); + void Send_uint64(uint64 data); + void Send_string(const char *data); + size_t Send_bytes (const byte *begin, const byte *end); + void Send_binary(const char *data, const size_t size); /* Reading/receiving of packets */ - void ReadRawPacketSize(); + size_t ReadRawPacketSize() const; + bool HasPacketSizeData() const; + bool ParsePacketSize(); + size_t Size() const; void PrepareToRead(); + PacketType GetPacketType() const; - bool CanReadFromPacket (uint bytes_to_read, bool non_fatal = false); + bool CanReadFromPacket(size_t bytes_to_read, bool close_connection = false); bool Recv_bool (); uint8 Recv_uint8 (); uint16 Recv_uint16(); @@ -86,6 +87,108 @@ public: void Recv_string(std::string &buffer, StringValidationSettings settings = SVS_REPLACE_WITH_QUESTION_MARK); void Recv_binary(char *buffer, size_t size); void Recv_binary(std::string &buffer, size_t size); + + size_t RemainingBytesToTransfer() const; + + const byte *GetBufferData() const { return this->buffer.data(); } + PacketSize GetRawPos() const { return this->pos; } + void ReserveBuffer(size_t size) { this->buffer.reserve(size); } + + /** + * Transfer data from the packet to the given function. It starts reading at the + * position the last transfer stopped. + * See Packet::TransferIn for more information about transferring data to functions. + * @param transfer_function The function to pass the buffer as second parameter and the + * amount to write as third parameter. It returns the amount that + * was written or -1 upon errors. + * @param limit The maximum amount of bytes to transfer. + * @param destination The first parameter of the transfer function. + * @param args The fourth and further parameters to the transfer function, if any. + * @return The return value of the transfer_function. + */ + template < + typename A = size_t, ///< The type for the amount to be passed, so it can be cast to the right type. + typename F, ///< The type of the function. + typename D, ///< The type of the destination. + typename ... Args> ///< The types of the remaining arguments to the function. + ssize_t TransferOutWithLimit(F transfer_function, size_t limit, D destination, Args&& ... args) + { + size_t amount = std::min(this->RemainingBytesToTransfer(), limit); + if (amount == 0) return 0; + + assert(this->pos < this->buffer.size()); + assert(this->pos + amount <= this->buffer.size()); + /* Making buffer a char means casting a lot in the Recv/Send functions. */ + const char *output_buffer = reinterpret_cast(this->buffer.data() + this->pos); + ssize_t bytes = transfer_function(destination, output_buffer, static_cast(amount), std::forward(args)...); + if (bytes > 0) this->pos += bytes; + return bytes; + } + + /** + * Transfer data from the packet to the given function. It starts reading at the + * position the last transfer stopped. + * See Packet::TransferIn for more information about transferring data to functions. + * @param transfer_function The function to pass the buffer as second parameter and the + * amount to write as third parameter. It returns the amount that + * was written or -1 upon errors. + * @param destination The first parameter of the transfer function. + * @param args The fourth and further parameters to the transfer function, if any. + * @tparam A The type for the amount to be passed, so it can be cast to the right type. + * @tparam F The type of the transfer_function. + * @tparam D The type of the destination. + * @tparam Args The types of the remaining arguments to the function. + * @return The return value of the transfer_function. + */ + template + ssize_t TransferOut(F transfer_function, D destination, Args&& ... args) + { + return TransferOutWithLimit(transfer_function, std::numeric_limits::max(), destination, std::forward(args)...); + } + + /** + * Transfer data from the given function into the packet. It starts writing at the + * position the last transfer stopped. + * + * Examples of functions that can be used to transfer data into a packet are TCP's + * recv and UDP's recvfrom functions. They will directly write their data into the + * packet without an intermediate buffer. + * Examples of functions that can be used to transfer data from a packet are TCP's + * send and UDP's sendto functions. They will directly read the data from the packet's + * buffer without an intermediate buffer. + * These are functions are special in a sense as even though the packet can send or + * receive an amount of data, those functions can say they only processed a smaller + * amount, so special handling is required to keep the position pointers correct. + * Most of these transfer functions are in the form function(source, buffer, amount, ...), + * so the template of this function will assume that as the base parameter order. + * + * This will attempt to write all the remaining bytes into the packet. It updates the + * position based on how many bytes were actually written by the called transfer_function. + * @param transfer_function The function to pass the buffer as second parameter and the + * amount to read as third parameter. It returns the amount that + * was read or -1 upon errors. + * @param source The first parameter of the transfer function. + * @param args The fourth and further parameters to the transfer function, if any. + * @tparam A The type for the amount to be passed, so it can be cast to the right type. + * @tparam F The type of the transfer_function. + * @tparam S The type of the source. + * @tparam Args The types of the remaining arguments to the function. + * @return The return value of the transfer_function. + */ + template + ssize_t TransferIn(F transfer_function, S source, Args&& ... args) + { + size_t amount = this->RemainingBytesToTransfer(); + if (amount == 0) return 0; + + assert(this->pos < this->buffer.size()); + assert(this->pos + amount <= this->buffer.size()); + /* Making buffer a char means casting a lot in the Recv/Send functions. */ + char *input_buffer = reinterpret_cast(this->buffer.data() + this->pos); + ssize_t bytes = transfer_function(source, input_buffer, static_cast(amount), std::forward(args)...); + if (bytes > 0) this->pos += bytes; + return bytes; + } }; #endif /* NETWORK_CORE_PACKET_H */ diff --git a/src/network/core/tcp.cpp b/src/network/core/tcp.cpp index 4198ef7217..379900c287 100644 --- a/src/network/core/tcp.cpp +++ b/src/network/core/tcp.cpp @@ -58,11 +58,6 @@ void NetworkTCPSocketHandler::SendPacket(std::unique_ptr packet) packet->PrepareToSend(); - /* Reallocate the packet as in 99+% of the times we send at most 25 bytes and - * keeping the other 1400+ bytes wastes memory, especially when someone tries - * to do a denial of service attack! */ - if (packet->size < ((SHRT_MAX * 2) / 3)) packet->buffer = ReallocT(packet->buffer, packet->size); - this->packet_queue.push_back(std::move(packet)); } @@ -86,7 +81,7 @@ SendPacketsState NetworkTCPSocketHandler::SendPackets(bool closing_down) while (!this->packet_queue.empty()) { Packet *p = this->packet_queue.front().get(); - res = send(this->sock, (const char*)p->buffer + p->pos, p->size - p->pos, 0); + res = p->TransferOut(send, this->sock, 0); if (res == -1) { int err = GET_LAST_ERROR(); if (err != EWOULDBLOCK) { @@ -105,10 +100,8 @@ SendPacketsState NetworkTCPSocketHandler::SendPackets(bool closing_down) return SPS_CLOSED; } - p->pos += res; - /* Is this packet sent? */ - if (p->pos == p->size) { + if (p->RemainingBytesToTransfer() == 0) { /* Go to the next packet */ this->packet_queue.pop_front(); } else { @@ -136,10 +129,9 @@ std::unique_ptr NetworkTCPSocketHandler::ReceivePacket() Packet *p = this->packet_recv.get(); /* Read packet size */ - if (p->pos < sizeof(PacketSize)) { - while (p->pos < sizeof(PacketSize)) { - /* Read the size of the packet */ - res = recv(this->sock, (char*)p->buffer + p->pos, sizeof(PacketSize) - p->pos, 0); + if (!p->HasPacketSizeData()) { + while (p->RemainingBytesToTransfer() != 0) { + res = p->TransferIn(recv, this->sock, 0); if (res == -1) { int err = GET_LAST_ERROR(); if (err != EWOULDBLOCK) { @@ -156,16 +148,18 @@ std::unique_ptr NetworkTCPSocketHandler::ReceivePacket() this->CloseConnection(); return nullptr; } - p->pos += res; } - /* Read the packet size from the received packet */ - p->ReadRawPacketSize(); + /* Parse the size in the received packet and if not valid, close the connection. */ + if (!p->ParsePacketSize()) { + this->CloseConnection(); + return nullptr; + } } /* Read rest of packet */ - while (p->pos < p->size) { - res = recv(this->sock, (char*)p->buffer + p->pos, p->size - p->pos, 0); + while (p->RemainingBytesToTransfer() != 0) { + res = p->TransferIn(recv, this->sock, 0); if (res == -1) { int err = GET_LAST_ERROR(); if (err != EWOULDBLOCK) { @@ -182,8 +176,6 @@ std::unique_ptr NetworkTCPSocketHandler::ReceivePacket() this->CloseConnection(); return nullptr; } - - p->pos += res; } diff --git a/src/network/core/tcp_listen.h b/src/network/core/tcp_listen.h index 1f073aa735..53a3d57cc9 100644 --- a/src/network/core/tcp_listen.h +++ b/src/network/core/tcp_listen.h @@ -63,7 +63,7 @@ public: DEBUG(net, 1, "[%s] Banned ip tried to join (%s), refused", Tsocket::GetName(), entry.c_str()); - if (send(s, (const char*)p.buffer, p.size, 0) < 0) { + if (p.TransferOut(send, s, 0) < 0) { DEBUG(net, 0, "send failed with error %d", GET_LAST_ERROR()); } closesocket(s); @@ -80,7 +80,7 @@ public: Packet p(Tfull_packet); p.PrepareToSend(); - if (send(s, (const char*)p.buffer, p.size, 0) < 0) { + if (p.TransferOut(send, s, 0) < 0) { DEBUG(net, 0, "send failed with error %d", GET_LAST_ERROR()); } closesocket(s); diff --git a/src/network/core/udp.cpp b/src/network/core/udp.cpp index 77010dbb6c..875de1ee3c 100644 --- a/src/network/core/udp.cpp +++ b/src/network/core/udp.cpp @@ -88,24 +88,25 @@ void NetworkUDPSocketHandler::SendPacket(Packet *p, NetworkAddress *recv, bool a const uint MTU = short_mtu ? SEND_MTU_SHORT : SEND_MTU; - if (p->size > MTU) { + if (p->Size() > MTU) { p->PrepareToSend(); uint64 token = this->fragment_token++; const uint PAYLOAD_MTU = MTU - (1 + 2 + 8 + 1 + 1 + 2); - const uint8 frag_count = (p->size + PAYLOAD_MTU - 1) / PAYLOAD_MTU; + const size_t packet_size = p->Size(); + const uint8 frag_count = (packet_size + PAYLOAD_MTU - 1) / PAYLOAD_MTU; Packet frag(PACKET_UDP_EX_MULTI); uint8 current_frag = 0; uint16 offset = 0; - while (offset < p->size) { - uint16 payload_size = std::min(PAYLOAD_MTU, p->size - offset); + while (offset < packet_size) { + uint16 payload_size = std::min(PAYLOAD_MTU, packet_size - offset); frag.Send_uint64(token); frag.Send_uint8 (current_frag); frag.Send_uint8 (frag_count); frag.Send_uint16 (payload_size); - frag.Send_binary((const char *) p->buffer + offset, payload_size); + frag.Send_binary((const char *) p->GetBufferData() + offset, payload_size); current_frag++; offset += payload_size; this->SendPacket(&frag, recv, all, broadcast, short_mtu); @@ -134,8 +135,8 @@ void NetworkUDPSocketHandler::SendPacket(Packet *p, NetworkAddress *recv, bool a } /* Send the buffer */ - int res = sendto(s.second, (const char*)p->buffer, p->size, 0, (const struct sockaddr *)send.GetAddress(), send.GetAddressLength()); - DEBUG(net, 7, "[udp] sendto(%s)", NetworkAddressDumper().GetAddressAsString(&send)); + ssize_t res = p->TransferOut(sendto, s.second, 0, (const struct sockaddr *)send.GetAddress(), send.GetAddressLength()); + DEBUG(net, 7, "[udp] sendto(%s)", NetworkAddressDumper().GetAddressAsString(&send)); /* Check for any errors, but ignore it otherwise */ if (res == -1) DEBUG(net, 1, "[udp] sendto(%s) failed with: %i", NetworkAddressDumper().GetAddressAsString(&send), GET_LAST_ERROR()); @@ -154,12 +155,12 @@ void NetworkUDPSocketHandler::ReceivePackets() struct sockaddr_storage client_addr; memset(&client_addr, 0, sizeof(client_addr)); - Packet p(this); + Packet p(this, SEND_MTU); socklen_t client_len = sizeof(client_addr); /* Try to receive anything */ SetNonBlocking(s.second); // Some OSes seem to lose the non-blocking status of the socket - int nbytes = recvfrom(s.second, (char*)p.buffer, SEND_MTU, 0, (struct sockaddr *)&client_addr, &client_len); + ssize_t nbytes = p.TransferIn(recvfrom, s.second, 0, (struct sockaddr *)&client_addr, &client_len); /* Did we get the bytes for the base header of the packet? */ if (nbytes <= 0) break; // No data, i.e. no packet @@ -169,14 +170,14 @@ void NetworkUDPSocketHandler::ReceivePackets() #endif NetworkAddress address(client_addr, client_len); - p.PrepareToRead(); /* If the size does not match the packet must be corrupted. * Otherwise it will be marked as corrupted later on. */ - if (nbytes != p.size) { - DEBUG(net, 1, "received a packet with mismatching size from %s, (%u, %u)", NetworkAddressDumper().GetAddressAsString(&address), nbytes, p.size); + if (!p.ParsePacketSize() || (size_t)nbytes != p.Size()) { + DEBUG(net, 1, "received a packet with mismatching size from %s, (%u, %u)", NetworkAddressDumper().GetAddressAsString(&address), (uint)nbytes, (uint)p.Size()); continue; } + p.PrepareToRead(); /* Handle the packet */ this->HandleUDPPacket(&p, &address); @@ -493,7 +494,7 @@ void NetworkUDPSocketHandler::Receive_EX_MULTI(Packet *p, NetworkAddress *client time_t cur_time = time(nullptr); auto add_to_fragment = [&](FragmentSet &fs) { - fs.fragments[index].assign((const char *) p->buffer + p->pos, payload_size); + fs.fragments[index].assign((const char *) p->GetBufferData() + p->GetRawPos(), payload_size); uint total_payload = 0; for (auto &frag : fs.fragments) { @@ -505,17 +506,19 @@ void NetworkUDPSocketHandler::Receive_EX_MULTI(Packet *p, NetworkAddress *client DEBUG(net, 6, "[udp] merged multi-part packet from %s: " OTTD_PRINTFHEX64 ", %u bytes", NetworkAddressDumper().GetAddressAsString(client_addr), token, total_payload); - Packet merged(this); + Packet merged(this, 0); + merged.ReserveBuffer(total_payload); for (auto &frag : fs.fragments) { merged.Send_binary(frag.data(), frag.size()); } + merged.ParsePacketSize(); merged.PrepareToRead(); /* If the size does not match the packet must be corrupted. * Otherwise it will be marked as corrupted later on. */ - if (total_payload != merged.size) { + if (total_payload != merged.ReadRawPacketSize()) { DEBUG(net, 1, "received an extended packet with mismatching size from %s, (%u, %u)", - NetworkAddressDumper().GetAddressAsString(client_addr), total_payload, merged.size); + NetworkAddressDumper().GetAddressAsString(client_addr), (uint)total_payload, (uint)merged.ReadRawPacketSize()); } else { this->HandleUDPPacket(&merged, client_addr); } diff --git a/src/network/network_admin.cpp b/src/network/network_admin.cpp index a788a22836..1646adb2ea 100644 --- a/src/network/network_admin.cpp +++ b/src/network/network_admin.cpp @@ -613,7 +613,7 @@ NetworkRecvStatus ServerNetworkAdminSocketHandler::SendCmdNames() /* Should SEND_MTU be exceeded, start a new packet * (magic 5: 1 bool "more data" and one uint16 "command id", one * byte for string '\0' termination and 1 bool "no more data" */ - if (p->size + strlen(cmdname) + 5 >= SEND_MTU) { + if (p->CanWriteToPacket(strlen(cmdname) + 5)) { p->Send_bool(false); this->SendPacket(p); diff --git a/src/network/network_client.cpp b/src/network/network_client.cpp index 1a1cfbb43a..1f6cee6eda 100644 --- a/src/network/network_client.cpp +++ b/src/network/network_client.cpp @@ -40,7 +40,6 @@ /* This file handles all the client-commands */ - /** Read some packets, and when do use that data as initial load filter. */ struct PacketReader : LoadFilter { static const size_t CHUNK = 32 * 1024; ///< 32 KiB chunks of memory. @@ -64,35 +63,38 @@ struct PacketReader : LoadFilter { } } + /** + * Simple wrapper around fwrite to be able to pass it to Packet's TransferOut. + * @param destination The reader to add the data to. + * @param source The buffer to read data from. + * @param amount The number of bytes to copy. + * @return The number of bytes that were copied. + */ + static inline ssize_t TransferOutMemCopy(PacketReader *destination, const char *source, size_t amount) + { + memcpy(destination->buf, source, amount); + destination->buf += amount; + destination->written_bytes += amount; + return amount; + } + /** * Add a packet to this buffer. * @param p The packet to add. */ - void AddPacket(const Packet *p) + void AddPacket(Packet *p) { assert(this->read_bytes == 0); - - size_t in_packet = p->size - p->pos; - size_t to_write = std::min(this->bufe - this->buf, in_packet); - const byte *pbuf = p->buffer + p->pos; - - this->written_bytes += in_packet; - if (to_write != 0) { - memcpy(this->buf, pbuf, to_write); - this->buf += to_write; - } + p->TransferOutWithLimit(TransferOutMemCopy, this->bufe - this->buf, this); /* Did everything fit in the current chunk, then we're done. */ - if (to_write == in_packet) return; + if (p->RemainingBytesToTransfer() == 0) return; /* Allocate a new chunk and add the remaining data. */ - pbuf += to_write; - to_write = in_packet - to_write; this->blocks.push_back(this->buf = CallocT(CHUNK)); this->bufe = this->buf + CHUNK; - memcpy(this->buf, pbuf, to_write); - this->buf += to_write; + p->TransferOutWithLimit(TransferOutMemCopy, this->bufe - this->buf, this); } size_t Read(byte *rbuf, size_t size) override @@ -559,7 +561,7 @@ NetworkRecvStatus ClientNetworkGameSocketHandler::SendDesyncLog(const std::strin { for (size_t offset = 0; offset < log.size();) { Packet *p = new Packet(PACKET_CLIENT_DESYNC_LOG); - size_t size = std::min(log.size() - offset, SHRT_MAX - 2 - p->size); + size_t size = std::min(log.size() - offset, SHRT_MAX - 2 - p->Size()); p->Send_uint16(size); p->Send_binary(log.data() + offset, size); my_client->SendPacket(p); @@ -1056,7 +1058,7 @@ NetworkRecvStatus ClientNetworkGameSocketHandler::Receive_SERVER_FRAME(Packet *p } #endif /* Receive the token. */ - if (p->pos != p->size) this->token = p->Recv_uint8(); + if (p->CanReadFromPacket(sizeof(uint8))) this->token = p->Recv_uint8(); DEBUG(net, 5, "Received FRAME %d", _frame_counter_server); diff --git a/src/network/network_command.cpp b/src/network/network_command.cpp index 12a4da5f28..1307077082 100644 --- a/src/network/network_command.cpp +++ b/src/network/network_command.cpp @@ -328,7 +328,7 @@ const char *NetworkGameSocketHandler::ReceiveCommand(Packet *p, CommandPacket *c if (cp->binary_length == 0) { p->Recv_string(cp->text, (!_network_server && GetCommandFlags(cp->cmd) & CMD_STR_CTRL) != 0 ? SVS_ALLOW_CONTROL_CODE | SVS_REPLACE_WITH_QUESTION_MARK : SVS_REPLACE_WITH_QUESTION_MARK); } else { - if ((p->pos + (PacketSize) cp->binary_length + /* callback index */ 1) > p->size) return "invalid binary data length"; + if (!p->CanReadFromPacket(cp->binary_length + /* callback index */ 1)) return "invalid binary data length"; if (cp->binary_length > MAX_CMD_TEXT_LENGTH) return "over-size binary data length"; p->Recv_binary(cp->text, cp->binary_length); } diff --git a/src/network/network_content.cpp b/src/network/network_content.cpp index 8ac6bdc270..fc52caad28 100644 --- a/src/network/network_content.cpp +++ b/src/network/network_content.cpp @@ -466,6 +466,18 @@ static bool GunzipFile(const ContentInfo *ci) #endif /* defined(WITH_ZLIB) */ } +/** + * Simple wrapper around fwrite to be able to pass it to Packet's TransferOut. + * @param file The file to write data to. + * @param buffer The buffer to write to the file. + * @param amount The number of bytes to write. + * @return The number of bytes that were written. + */ +static inline ssize_t TransferOutFWrite(FILE *file, const char *buffer, size_t amount) +{ + return fwrite(buffer, 1, amount, file); +} + bool ClientNetworkContentSocketHandler::Receive_SERVER_CONTENT(Packet *p) { if (this->curFile == nullptr) { @@ -483,8 +495,8 @@ bool ClientNetworkContentSocketHandler::Receive_SERVER_CONTENT(Packet *p) } } else { /* We have a file opened, thus are downloading internal content */ - size_t toRead = (size_t)(p->size - p->pos); - if (fwrite(p->buffer + p->pos, 1, toRead, this->curFile) != toRead) { + size_t toRead = p->RemainingBytesToTransfer(); + if (toRead != 0 && (size_t)p->TransferOut(TransferOutFWrite, this->curFile) != toRead) { DeleteWindowById(WC_NETWORK_STATUS_WINDOW, WN_NETWORK_STATUS_WINDOW_CONTENT_DOWNLOAD); ShowErrorMessage(STR_CONTENT_ERROR_COULD_NOT_DOWNLOAD, STR_CONTENT_ERROR_COULD_NOT_DOWNLOAD_FILE_NOT_WRITABLE, WL_ERROR); this->Close(); diff --git a/src/network/network_server.cpp b/src/network/network_server.cpp index 6492e3d423..c180f49025 100644 --- a/src/network/network_server.cpp +++ b/src/network/network_server.cpp @@ -167,12 +167,10 @@ struct PacketWriter : SaveFilter { byte *bufe = buf + size; while (buf != bufe) { - size_t to_write = std::min(SHRT_MAX - this->current->size, bufe - buf); - memcpy(this->current->buffer + this->current->size, buf, to_write); - this->current->size += (PacketSize)to_write; - buf += to_write; + size_t written = this->current->Send_bytes(buf, bufe); + buf += written; - if (this->current->size == SHRT_MAX) { + if (!this->current->CanWriteToPacket(1)) { this->AppendQueue(); if (buf != bufe) this->current.reset(new Packet(PACKET_SERVER_MAP_DATA)); } @@ -248,7 +246,7 @@ std::unique_ptr ServerNetworkGameSocketHandler::ReceivePacket() /* We can receive a packet, so try that and if needed account for * the amount of received data. */ std::unique_ptr p = this->NetworkTCPSocketHandler::ReceivePacket(); - if (p != nullptr) this->receive_limit -= p->size; + if (p != nullptr) this->receive_limit -= p->Size(); return p; } @@ -485,7 +483,7 @@ NetworkRecvStatus ServerNetworkGameSocketHandler::SendDesyncLog(const std::strin { for (size_t offset = 0; offset < log.size();) { Packet *p = new Packet(PACKET_SERVER_DESYNC_LOG); - size_t size = std::min(log.size() - offset, SHRT_MAX - 2 - p->size); + size_t size = std::min(log.size() - offset, SHRT_MAX - 2 - p->Size()); p->Send_uint16(size); p->Send_binary(log.data() + offset, size); this->SendPacket(p); @@ -669,7 +667,7 @@ NetworkRecvStatus ServerNetworkGameSocketHandler::SendMap() has_packets = false; break; } - last_packet = p->buffer[2] == PACKET_SERVER_MAP_DONE; + last_packet = p->GetPacketType() == PACKET_SERVER_MAP_DONE; this->SendPacket(std::move(p)); @@ -1283,7 +1281,7 @@ NetworkRecvStatus ServerNetworkGameSocketHandler::Receive_CLIENT_DESYNC_LOG(Pack this->desync_log.resize(this->desync_log.size() + size); p->Recv_binary(this->desync_log.data() + this->desync_log.size() - size, size); DEBUG(net, 2, "Received %u bytes of client desync log", size); - this->receive_limit += p->size; + this->receive_limit += p->Size(); return NETWORK_RECV_STATUS_OKAY; } @@ -1950,7 +1948,7 @@ void NetworkServer_Tick(bool send_frame) for (NetworkClientSocket *cs : NetworkClientSocket::Iterate()) { /* We allow a number of bytes per frame, but only to the burst amount * to be available for packet receiving at any particular time. */ - cs->receive_limit = std::min(cs->receive_limit + _settings_client.network.bytes_per_frame, + cs->receive_limit = std::min(cs->receive_limit + _settings_client.network.bytes_per_frame, _settings_client.network.bytes_per_frame_burst); /* Check if the speed of the client is what we can expect from a client */ diff --git a/src/network/network_server.h b/src/network/network_server.h index b05a060e7c..21bb2b1695 100644 --- a/src/network/network_server.h +++ b/src/network/network_server.h @@ -71,7 +71,7 @@ public: uint32 last_token_frame; ///< The last frame we received the right token ClientStatus status; ///< Status of this client CommandQueue outgoing_queue; ///< The command-queue awaiting delivery - int receive_limit; ///< Amount of bytes that we can receive at this moment + size_t receive_limit; ///< Amount of bytes that we can receive at this moment uint32 server_hash_bits; ///< Server password hash entropy bits uint32 rcon_hash_bits; ///< Rcon password hash entropy bits uint32 settings_hash_bits; ///< Settings password hash entropy bits diff --git a/src/network/network_udp.cpp b/src/network/network_udp.cpp index 794a7201f3..9e9e52a9fd 100644 --- a/src/network/network_udp.cpp +++ b/src/network/network_udp.cpp @@ -214,7 +214,7 @@ void ServerNetworkUDPSocketHandler::Receive_CLIENT_FIND_SERVER(Packet *p, Networ strecpy(ngi.server_revision, _openttd_revision, lastof(ngi.server_revision)); strecpy(ngi.short_server_revision, _openttd_revision, lastof(ngi.short_server_revision)); - if (p->CanReadFromPacket(8, true) && p->Recv_uint32() == FIND_SERVER_EXTENDED_TOKEN) { + if (p->CanReadFromPacket(8) && p->Recv_uint32() == FIND_SERVER_EXTENDED_TOKEN) { this->Reply_CLIENT_FIND_SERVER_extended(p, client_addr, &ngi); return; } @@ -261,23 +261,23 @@ void ServerNetworkUDPSocketHandler::Receive_CLIENT_DETAIL_INFO(Packet *p, Networ static const uint MIN_CI_SIZE = 54; uint max_cname_length = NETWORK_COMPANY_NAME_LENGTH; - if (Company::GetNumItems() * (MIN_CI_SIZE + NETWORK_COMPANY_NAME_LENGTH) >= (uint)SEND_MTU - packet.size) { + if (!packet.CanWriteToPacket(Company::GetNumItems() * (MIN_CI_SIZE + NETWORK_COMPANY_NAME_LENGTH))) { /* Assume we can at least put the company information in the packets. */ - assert(Company::GetNumItems() * MIN_CI_SIZE < (uint)SEND_MTU - packet.size); + assert(packet.CanWriteToPacket(Company::GetNumItems() * MIN_CI_SIZE)); /* At this moment the company names might not fit in the * packet. Check whether that is really the case. */ for (;;) { - int free = SEND_MTU - packet.size; + size_t required = 0; for (const Company *company : Company::Iterate()) { char company_name[NETWORK_COMPANY_NAME_LENGTH]; SetDParam(0, company->index); GetString(company_name, STR_COMPANY_NAME, company_name + max_cname_length - 1); - free -= MIN_CI_SIZE; - free -= (int)strlen(company_name); + required += MIN_CI_SIZE; + required += strlen(company_name); } - if (free >= 0) break; + if (packet.CanWriteToPacket(required)) break; /* Try again, with slightly shorter strings. */ assert(max_cname_length > 0);