diff --git a/src/core/serialisation.cpp b/src/core/serialisation.cpp index 7fed2cf1a8..3458dc73f5 100644 --- a/src/core/serialisation.cpp +++ b/src/core/serialisation.cpp @@ -9,6 +9,7 @@ #include "../stdafx.h" #include "serialisation.hpp" +#include "../string_func_extra.h" /** * Is it safe to write to the packet, i.e. didn't we run over the buffer? @@ -135,3 +136,8 @@ void BufferSend_binary(std::vector &buffer, size_t limit, const char *data assert(BufferCanWriteToPacket(buffer, limit, size)); buffer.insert(buffer.end(), data, data + size); } + +void BufferRecvStringValidate(std::string &buffer, StringValidationSettings settings) +{ + StrMakeValidInPlace(buffer, settings); +} diff --git a/src/core/serialisation.hpp b/src/core/serialisation.hpp index 1431824689..a3a10f01c9 100644 --- a/src/core/serialisation.hpp +++ b/src/core/serialisation.hpp @@ -11,6 +11,8 @@ #define SERIALISATION_HPP #include "bitmath_func.hpp" +#include "../string_type.h" +#include "../string_func.h" #include #include @@ -75,4 +77,188 @@ struct BufferSerialisationHelper { } }; +void BufferRecvStringValidate(std::string &buffer, StringValidationSettings settings); + +template +struct BufferDeserialisationHelper { +private: + const byte *GetBuffer() + { + return static_cast(this)->GetDeserialisationBuffer(); + } + + size_t GetBufferSize() + { + return static_cast(this)->GetDeserialisationBufferSize(); + } + +public: + bool CanRecvBytes(size_t bytes_to_read, bool raise_error = true) + { + return static_cast(this)->CanDeserialiseBytes(bytes_to_read, raise_error); + } + + /** + * Read a boolean from the packet. + * @return The read data. + */ + bool Recv_bool() + { + return this->Recv_uint8() != 0; + } + + /** + * Read a 8 bits integer from the packet. + * @return The read data. + */ + uint8 Recv_uint8() + { + uint8 n; + + if (!this->CanRecvBytes(sizeof(n), true)) return 0; + + auto &pos = static_cast(this)->GetDeserialisationPosition(); + + n = this->GetBuffer()[pos++]; + return n; + } + + /** + * Read a 16 bits integer from the packet. + * @return The read data. + */ + uint16 Recv_uint16() + { + uint16 n; + + if (!this->CanRecvBytes(sizeof(n), true)) return 0; + + auto &pos = static_cast(this)->GetDeserialisationPosition(); + + n = (uint16)this->GetBuffer()[pos++]; + n += (uint16)this->GetBuffer()[pos++] << 8; + return n; + } + + /** + * Read a 32 bits integer from the packet. + * @return The read data. + */ + uint32 Recv_uint32() + { + uint32 n; + + if (!this->CanRecvBytes(sizeof(n), true)) return 0; + + auto &pos = static_cast(this)->GetDeserialisationPosition(); + + n = (uint32)this->GetBuffer()[pos++]; + n += (uint32)this->GetBuffer()[pos++] << 8; + n += (uint32)this->GetBuffer()[pos++] << 16; + n += (uint32)this->GetBuffer()[pos++] << 24; + return n; + } + + /** + * Read a 64 bits integer from the packet. + * @return The read data. + */ + uint64 Recv_uint64() + { + uint64 n; + + if (!this->CanRecvBytes(sizeof(n), true)) return 0; + + auto &pos = static_cast(this)->GetDeserialisationPosition(); + + n = (uint64)this->GetBuffer()[pos++]; + n += (uint64)this->GetBuffer()[pos++] << 8; + n += (uint64)this->GetBuffer()[pos++] << 16; + n += (uint64)this->GetBuffer()[pos++] << 24; + n += (uint64)this->GetBuffer()[pos++] << 32; + n += (uint64)this->GetBuffer()[pos++] << 40; + n += (uint64)this->GetBuffer()[pos++] << 48; + n += (uint64)this->GetBuffer()[pos++] << 56; + return n; + } + + /** + * Reads characters (bytes) from the packet until it finds a '\0', or reaches a + * maximum of \c length characters. + * When the '\0' has not been reached in the first \c length read characters, + * more characters are read from the packet until '\0' has been reached. However, + * these characters will not end up in the returned string. + * The length of the returned string will be at most \c length - 1 characters. + * @param length The maximum length of the string including '\0'. + * @param settings The string validation settings. + * @return The validated string. + */ + std::string Recv_string(size_t length, StringValidationSettings settings = SVS_REPLACE_WITH_QUESTION_MARK) + { + assert(length > 1); + + /* Both loops with Recv_uint8 terminate when reading past the end of the + * packet as Recv_uint8 then closes the connection and returns 0. */ + std::string str; + char character; + while (--length > 0 && (character = this->Recv_uint8()) != '\0') str.push_back(character); + + if (length == 0) { + /* The string in the packet was longer. Read until the termination. */ + while (this->Recv_uint8() != '\0') {} + } + + BufferRecvStringValidate(str, settings); + return str; + } + + /** + * Reads a string till it finds a '\0' in the stream. + * @param buffer The buffer to put the data into. + * @param settings The string validation settings. + */ + void Recv_string(std::string &buffer, StringValidationSettings settings = SVS_REPLACE_WITH_QUESTION_MARK) + { + /* Don't allow reading from a closed socket */ + if (!this->CanRecvBytes(0, false)) return; + + auto &pos = static_cast(this)->GetDeserialisationPosition(); + + size_t length = ttd_strnlen((const char *)(this->GetBuffer() + pos), this->GetBufferSize() - pos - 1); + buffer.assign((const char *)(this->GetBuffer() + pos), length); + pos += (decltype(pos))length + 1; + BufferRecvStringValidate(buffer, settings); + } + + /** + * Reads binary data. + * @param buffer The buffer to put the data into. + * @param size The size of the data. + */ + void Recv_binary(char *buffer, size_t size) + { + if (!this->CanRecvBytes(size, true)) return; + + auto &pos = static_cast(this)->GetDeserialisationPosition(); + + memcpy(buffer, &this->GetBuffer()[pos], size); + pos += (decltype(pos)) size; + } + + /** + * Reads binary data. + * @param buffer The buffer to put the data into. + * @param size The size of the data. + */ + void Recv_binary(std::string &buffer, size_t size) + { + if (!this->CanRecvBytes(size, true)) return; + + auto &pos = static_cast(this)->GetDeserialisationPosition(); + + buffer.assign((const char *) &this->GetBuffer()[pos], size); + pos += (decltype(pos)) size; + } +}; + #endif /* SERIALISATION_HPP */ diff --git a/src/network/core/packet.cpp b/src/network/core/packet.cpp index 64ecfee2c4..fd252ab2d5 100644 --- a/src/network/core/packet.cpp +++ b/src/network/core/packet.cpp @@ -182,111 +182,6 @@ PacketType Packet::GetPacketType() const return static_cast(buffer[sizeof(PacketSize)]); } -/** - * Read a boolean from the packet. - * @return The read data. - */ -bool Packet::Recv_bool() -{ - return this->Recv_uint8() != 0; -} - -/** - * Read a 8 bits integer from the packet. - * @return The read data. - */ -uint8 Packet::Recv_uint8() -{ - uint8 n; - - if (!this->CanReadFromPacket(sizeof(n), true)) return 0; - - n = this->buffer[this->pos++]; - return n; -} - -/** - * Read a 16 bits integer from the packet. - * @return The read data. - */ -uint16 Packet::Recv_uint16() -{ - uint16 n; - - if (!this->CanReadFromPacket(sizeof(n), true)) return 0; - - n = (uint16)this->buffer[this->pos++]; - n += (uint16)this->buffer[this->pos++] << 8; - return n; -} - -/** - * Read a 32 bits integer from the packet. - * @return The read data. - */ -uint32 Packet::Recv_uint32() -{ - uint32 n; - - if (!this->CanReadFromPacket(sizeof(n), true)) return 0; - - n = (uint32)this->buffer[this->pos++]; - n += (uint32)this->buffer[this->pos++] << 8; - n += (uint32)this->buffer[this->pos++] << 16; - n += (uint32)this->buffer[this->pos++] << 24; - return n; -} - -/** - * Read a 64 bits integer from the packet. - * @return The read data. - */ -uint64 Packet::Recv_uint64() -{ - uint64 n; - - if (!this->CanReadFromPacket(sizeof(n), true)) return 0; - - n = (uint64)this->buffer[this->pos++]; - n += (uint64)this->buffer[this->pos++] << 8; - n += (uint64)this->buffer[this->pos++] << 16; - n += (uint64)this->buffer[this->pos++] << 24; - n += (uint64)this->buffer[this->pos++] << 32; - n += (uint64)this->buffer[this->pos++] << 40; - n += (uint64)this->buffer[this->pos++] << 48; - n += (uint64)this->buffer[this->pos++] << 56; - return n; -} - -/** - * Reads characters (bytes) from the packet until it finds a '\0', or reaches a - * maximum of \c length characters. - * When the '\0' has not been reached in the first \c length read characters, - * more characters are read from the packet until '\0' has been reached. However, - * these characters will not end up in the returned string. - * The length of the returned string will be at most \c length - 1 characters. - * @param length The maximum length of the string including '\0'. - * @param settings The string validation settings. - * @return The validated string. - */ -std::string Packet::Recv_string(size_t length, StringValidationSettings settings) -{ - assert(length > 1); - - /* Both loops with Recv_uint8 terminate when reading past the end of the - * packet as Recv_uint8 then closes the connection and returns 0. */ - std::string str; - char character; - while (--length > 0 && (character = this->Recv_uint8()) != '\0') str.push_back(character); - - if (length == 0) { - /* The string in the packet was longer. Read until the termination. */ - while (this->Recv_uint8() != '\0') {} - } - - return StrMakeValid(str, settings); -} - /** * Get the amount of bytes that are still available for the Transfer functions. * @return The number of bytes that still have to be transfered. @@ -296,44 +191,3 @@ size_t Packet::RemainingBytesToTransfer() const return this->Size() - this->pos; } -/** - * Reads a string till it finds a '\0' in the stream. - * @param buffer The buffer to put the data into. - * @param settings The string validation settings. - */ -void Packet::Recv_string(std::string &buffer, StringValidationSettings settings) -{ - /* Don't allow reading from a closed socket */ - if (cs->HasClientQuit()) return; - - size_t length = ttd_strnlen((const char *)(this->buffer.data() + this->pos), this->Size() - this->pos - 1); - buffer.assign((const char *)(this->buffer.data() + this->pos), length); - this->pos += (PacketSize)length + 1; - StrMakeValidInPlace(buffer, settings); -} - -/** - * Reads binary data. - * @param buffer The buffer to put the data into. - * @param size The size of the data. - */ -void Packet::Recv_binary(char *buffer, size_t size) -{ - if (!this->CanReadFromPacket(size, true)) return; - - memcpy(buffer, &this->buffer[this->pos], size); - this->pos += (PacketSize) size; -} - -/** - * Reads binary data. - * @param buffer The buffer to put the data into. - * @param size The size of the data. - */ -void Packet::Recv_binary(std::string &buffer, size_t size) -{ - if (!this->CanReadFromPacket(size, true)) return; - - buffer.assign((const char *) &this->buffer[this->pos], size); - this->pos += (PacketSize) size; -} diff --git a/src/network/core/packet.h b/src/network/core/packet.h index bae98d65fb..8f1bbc2cb1 100644 --- a/src/network/core/packet.h +++ b/src/network/core/packet.h @@ -43,7 +43,7 @@ typedef uint8 PacketType; ///< Identifier for the packet * - years that are leap years in the 'days since X' to 'date' calculations: * (year % 4 == 0) and ((year % 100 != 0) or (year % 400 == 0)) */ -struct Packet : public BufferSerialisationHelper { +struct Packet : public BufferSerialisationHelper, public BufferDeserialisationHelper { private: /** The current read/write position in the packet */ PacketSize pos; @@ -67,6 +67,11 @@ public: std::vector &GetSerialisationBuffer() { return this->buffer; } size_t GetSerialisationLimit() const { return this->limit; } + const byte *GetDeserialisationBuffer() const { return this->buffer.data(); } + size_t GetDeserialisationBufferSize() const { return this->buffer.size(); } + PacketSize &GetDeserialisationPosition() { return this->pos; } + bool CanDeserialiseBytes(size_t bytes_to_read, bool raise_error) { return this->CanReadFromPacket(bytes_to_read, raise_error); } + bool CanWriteToPacket(size_t bytes_to_write); /* Reading/receiving of packets */ @@ -77,16 +82,7 @@ public: void PrepareToRead(); PacketType GetPacketType() const; - bool CanReadFromPacket(size_t bytes_to_read, bool close_connection = false); - bool Recv_bool (); - uint8 Recv_uint8 (); - uint16 Recv_uint16(); - uint32 Recv_uint32(); - uint64 Recv_uint64(); - std::string Recv_string(size_t length, StringValidationSettings settings = SVS_REPLACE_WITH_QUESTION_MARK); - void Recv_string(std::string &buffer, StringValidationSettings settings = SVS_REPLACE_WITH_QUESTION_MARK); - void Recv_binary(char *buffer, size_t size); - void Recv_binary(std::string &buffer, size_t size); + bool CanReadFromPacket(size_t bytes_to_read, bool close_connection = false); size_t RemainingBytesToTransfer() const;