diff --git a/src/saveload/company_sl.cpp b/src/saveload/company_sl.cpp index c9a2da7d7e..92dc188e04 100644 --- a/src/saveload/company_sl.cpp +++ b/src/saveload/company_sl.cpp @@ -591,7 +591,7 @@ static void Load_PLYP() ReadBuffer::GetCurrent()->CopyBytes(buffer.data(), buffer.size()); if (crypto_unlock(buffer.data(), _network_company_password_storage_key, nonce, mac, buffer.data(), buffer.size()) == 0) { - SlLoadFromBuffer(buffer.data(), buffer.size(), [](void *) { + SlLoadFromBuffer(buffer.data(), buffer.size(), []() { _network_company_server_id.resize(SlReadUint32()); ReadBuffer::GetCurrent()->CopyBytes((uint8 *)_network_company_server_id.data(), _network_company_server_id.size()); @@ -605,7 +605,7 @@ static void Load_PLYP() } ReadBuffer::GetCurrent()->SkipBytes(SlReadByte()); // Skip padding - }, nullptr); + }); DEBUG(sl, 2, "Decrypted company passwords"); } else { DEBUG(sl, 2, "Failed to decrypt company passwords"); diff --git a/src/saveload/saveload.cpp b/src/saveload/saveload.cpp index 1ea9094960..83b5ddd960 100644 --- a/src/saveload/saveload.cpp +++ b/src/saveload/saveload.cpp @@ -2071,33 +2071,34 @@ std::vector SlSaveToVector(AutolengthProc *proc, void *arg) return std::vector(result.first, result.first + result.second); } -/** - * Run proc, loading exactly length bytes from the contents of buffer - * @param proc The callback procedure that is called - * @param arg The variable that will be used for the callback procedure - */ -void SlLoadFromBuffer(const byte *buffer, size_t length, AutolengthProc *proc, void *arg) +SlLoadFromBufferState SlLoadFromBufferSetup(const byte *buffer, size_t length) { assert(_sl.action == SLA_LOAD || _sl.action == SLA_LOAD_CHECK); - size_t old_obj_len = _sl.obj_len; + SlLoadFromBufferState state; + + state.old_obj_len = _sl.obj_len; _sl.obj_len = length; ReadBuffer *reader = ReadBuffer::GetCurrent(); - byte *old_bufp = reader->bufp; - byte *old_bufe = reader->bufe; + state.old_bufp = reader->bufp; + state.old_bufe = reader->bufe; reader->bufp = const_cast(buffer); reader->bufe = const_cast(buffer) + length; - proc(arg); + return state; +} +void SlLoadFromBufferRestore(const SlLoadFromBufferState &state, const byte *buffer, size_t length) +{ + ReadBuffer *reader = ReadBuffer::GetCurrent(); if (reader->bufp != reader->bufe || reader->bufe != buffer + length) { SlErrorCorrupt("SlLoadFromBuffer: Wrong number of bytes read"); } - _sl.obj_len = old_obj_len; - reader->bufp = old_bufp; - reader->bufe = old_bufe; + _sl.obj_len = state.old_obj_len; + reader->bufp = state.old_bufp; + reader->bufe = state.old_bufe; } /* diff --git a/src/saveload/saveload.h b/src/saveload/saveload.h index 95ac3c8697..3480619932 100644 --- a/src/saveload/saveload.h +++ b/src/saveload/saveload.h @@ -629,12 +629,32 @@ int SlIterateArray(); void SlAutolength(AutolengthProc *proc, void *arg); std::vector SlSaveToVector(AutolengthProc *proc, void *arg); -void SlLoadFromBuffer(const byte *buffer, size_t length, AutolengthProc *proc, void *arg); size_t SlGetFieldLength(); void SlSetLength(size_t length); size_t SlCalcObjMemberLength(const void *object, const SaveLoad &sld); size_t SlCalcObjLength(const void *object, const SaveLoadTable &slt); +struct SlLoadFromBufferState { + size_t old_obj_len; + byte *old_bufp; + byte *old_bufe; +}; + +/** + * Run proc, loading exactly length bytes from the contents of buffer + * @param proc The callback procedure that is called + */ +template +void SlLoadFromBuffer(const byte *buffer, size_t length, F proc) +{ + extern SlLoadFromBufferState SlLoadFromBufferSetup(const byte *buffer, size_t length); + extern void SlLoadFromBufferRestore(const SlLoadFromBufferState &state, const byte *buffer, size_t length); + + SlLoadFromBufferState state = SlLoadFromBufferSetup(buffer, length); + proc(); + SlLoadFromBufferRestore(state, buffer, length); +} + void SlGlobList(const SaveLoadTable &slt); void SlArray(void *array, size_t length, VarType conv); void SlObject(void *object, const SaveLoadTable &slt);