From b37209c3afd79d3d9674d7fb7154415659d9fbd4 Mon Sep 17 00:00:00 2001 From: Jonathan G Rennison Date: Sat, 11 Jun 2022 22:15:19 +0100 Subject: [PATCH] Tracerestrict: Programs maintain a list of signals referencing them --- src/saveload/tracerestrict_sl.cpp | 2 +- src/tracerestrict.cpp | 64 +++++++++++++++++++++++++------ src/tracerestrict.h | 36 ++++++++++++++--- 3 files changed, 84 insertions(+), 18 deletions(-) diff --git a/src/saveload/tracerestrict_sl.cpp b/src/saveload/tracerestrict_sl.cpp index c422e0b5bb..52dae2a4fc 100644 --- a/src/saveload/tracerestrict_sl.cpp +++ b/src/saveload/tracerestrict_sl.cpp @@ -211,7 +211,7 @@ void AfterLoadTraceRestrict() { for (TraceRestrictMapping::iterator iter = _tracerestrictprogram_mapping.begin(); iter != _tracerestrictprogram_mapping.end(); ++iter) { - _tracerestrictprogram_pool.Get(iter->second.program_id)->IncrementRefCount(); + _tracerestrictprogram_pool.Get(iter->second.program_id)->IncrementRefCount(iter->first); } } diff --git a/src/tracerestrict.cpp b/src/tracerestrict.cpp index e7e0301a7a..f1543422b9 100644 --- a/src/tracerestrict.cpp +++ b/src/tracerestrict.cpp @@ -818,12 +818,58 @@ void TraceRestrictProgram::Execute(const Train* v, const TraceRestrictProgramInp assert(condstack.empty()); } +void TraceRestrictProgram::ClearRefIds() +{ + if (this->refcount > 4) free(this->ref_ids.ptr_ref_ids.buffer); +} + +/** + * Increment ref count, only use when creating a mapping + */ +void TraceRestrictProgram::IncrementRefCount(TraceRestrictRefId ref_id) +{ + if (this->refcount >= 4) { + if (this->refcount == 4) { + /* Transition from inline to allocated mode */ + + TraceRestrictRefId *ptr = MallocT(8); + MemCpyT(ptr, this->ref_ids.inline_ref_ids, 4); + this->ref_ids.ptr_ref_ids.buffer = ptr; + this->ref_ids.ptr_ref_ids.elem_capacity = 8; + } else if (this->refcount == this->ref_ids.ptr_ref_ids.elem_capacity) { + // grow buffer + this->ref_ids.ptr_ref_ids.elem_capacity *= 2; + this->ref_ids.ptr_ref_ids.buffer = ReallocT(this->ref_ids.ptr_ref_ids.buffer, this->ref_ids.ptr_ref_ids.elem_capacity); + } + this->ref_ids.ptr_ref_ids.buffer[this->refcount] = ref_id; + } else { + this->ref_ids.inline_ref_ids[this->refcount] = ref_id; + } + this->refcount++; +} + /** * Decrement ref count, only use when removing a mapping */ -void TraceRestrictProgram::DecrementRefCount() { +void TraceRestrictProgram::DecrementRefCount(TraceRestrictRefId ref_id) { assert(this->refcount > 0); + if (this->refcount >= 2) { + TraceRestrictRefId *data = this->GetRefIdsPtr(); + for (uint i = 0; i < this->refcount - 1; i++) { + if (data[i] == ref_id) { + data[i] = data[this->refcount - 1]; + break; + } + } + } this->refcount--; + if (this->refcount == 4) { + /* Transition from allocated to inline mode */ + + TraceRestrictRefId *ptr = this->ref_ids.ptr_ref_ids.buffer; + MemCpyT(this->ref_ids.inline_ref_ids, ptr, 4); + free(ptr); + } if (this->refcount == 0) { delete this; } @@ -1227,10 +1273,10 @@ void TraceRestrictCreateProgramMapping(TraceRestrictRefId ref, TraceRestrictProg if (!insert_result.second) { // value was not inserted, there is an existing mapping // unref the existing mapping before updating it - _tracerestrictprogram_pool.Get(insert_result.first->second.program_id)->DecrementRefCount(); + _tracerestrictprogram_pool.Get(insert_result.first->second.program_id)->DecrementRefCount(ref); insert_result.first->second = prog->index; } - prog->IncrementRefCount(); + prog->IncrementRefCount(ref); TileIndex tile = GetTraceRestrictRefIdTileIndex(ref); Track track = GetTraceRestrictRefIdTrack(ref); @@ -1254,7 +1300,7 @@ bool TraceRestrictRemoveProgramMapping(TraceRestrictRefId ref) // do this before decrementing the refcount bool remove_other_mapping = prog->refcount == 2 && prog->items.empty(); - prog->DecrementRefCount(); + prog->DecrementRefCount(ref); _tracerestrictprogram_mapping.erase(iter); TileIndex tile = GetTraceRestrictRefIdTileIndex(ref); @@ -1264,14 +1310,7 @@ bool TraceRestrictRemoveProgramMapping(TraceRestrictRefId ref) YapfNotifyTrackLayoutChange(tile, track); if (remove_other_mapping) { - TraceRestrictProgramID id = prog->index; - for (TraceRestrictMapping::iterator rm_iter = _tracerestrictprogram_mapping.begin(); - rm_iter != _tracerestrictprogram_mapping.end(); ++rm_iter) { - if (rm_iter->second.program_id == id) { - TraceRestrictRemoveProgramMapping(rm_iter->first); - break; - } - } + TraceRestrictRemoveProgramMapping(const_cast(prog)->GetRefIdsPtr()[0]); } return true; } else { @@ -1743,6 +1782,7 @@ CommandCost CmdProgramSignalTraceRestrictProgMgmt(TileIndex tile, DoCommandFlag // allocation failed return CMD_ERROR; } + prog->items.reserve(prog->items.size() + source_prog->items.size()); // this is in case prog == source_prog prog->items.insert(prog->items.end(), source_prog->items.begin(), source_prog->items.end()); // append prog->Validate(); diff --git a/src/tracerestrict.h b/src/tracerestrict.h index e4ef9cb081..f1aff5da09 100644 --- a/src/tracerestrict.h +++ b/src/tracerestrict.h @@ -479,17 +479,43 @@ struct TraceRestrictProgram : TraceRestrictProgramPool::PoolItem<&_tracerestrict uint32 refcount; TraceRestrictProgramActionsUsedFlags actions_used_flags; +private: + + struct ptr_buffer { + TraceRestrictRefId *buffer; + uint32 elem_capacity; + }; + union refid_list_union { + TraceRestrictRefId inline_ref_ids[4]; + ptr_buffer ptr_ref_ids; + + // Actual construction/destruction done by struct TraceRestrictProgram + refid_list_union() {} + ~refid_list_union() {} + }; + refid_list_union ref_ids; + + void ClearRefIds(); + + inline TraceRestrictRefId *GetRefIdsPtr() { return this->refcount <= 4 ? this->ref_ids.inline_ref_ids : this->ref_ids.ptr_ref_ids.buffer; }; + +public: + TraceRestrictProgram() : refcount(0), actions_used_flags(static_cast(0)) { } + ~TraceRestrictProgram() + { + this->ClearRefIds(); + } + void Execute(const Train *v, const TraceRestrictProgramInput &input, TraceRestrictProgramResult &out) const; - /** - * Increment ref count, only use when creating a mapping - */ - void IncrementRefCount() { refcount++; } + inline const TraceRestrictRefId *GetRefIdsPtr() const { return const_cast(this)->GetRefIdsPtr(); } - void DecrementRefCount(); + void IncrementRefCount(TraceRestrictRefId ref_id); + + void DecrementRefCount(TraceRestrictRefId ref_id); static CommandCost Validate(const std::vector &items, TraceRestrictProgramActionsUsedFlags &actions_used_flags);