Fix ownership of loaded object data

This commit is contained in:
ζeh Matt 2021-09-03 13:43:06 +03:00
parent b2d0b54d13
commit 1f4f0c015c
No known key found for this signature in database
GPG Key ID: 18CE582C71A225B0
4 changed files with 142 additions and 157 deletions

View File

@ -36,7 +36,8 @@ class ObjectManager final : public IObjectManager
{ {
private: private:
IObjectRepository& _objectRepository; IObjectRepository& _objectRepository;
std::vector<std::unique_ptr<Object>> _loadedObjects;
std::vector<Object*> _loadedObjects;
std::array<std::vector<ObjectEntryIndex>, RIDE_TYPE_COUNT> _rideTypeToObjectMap; std::array<std::vector<ObjectEntryIndex>, RIDE_TYPE_COUNT> _rideTypeToObjectMap;
// Used to return a safe empty vector back from GetAllRideEntries, can be removed when std::span is available // Used to return a safe empty vector back from GetAllRideEntries, can be removed when std::span is available
@ -63,7 +64,7 @@ public:
{ {
return nullptr; return nullptr;
} }
return _loadedObjects[index].get(); return _loadedObjects[index];
} }
Object* GetLoadedObject(ObjectType objectType, size_t index) override Object* GetLoadedObject(ObjectType objectType, size_t index) override
@ -82,13 +83,11 @@ public:
Object* GetLoadedObject(const ObjectEntryDescriptor& entry) override Object* GetLoadedObject(const ObjectEntryDescriptor& entry) override
{ {
Object* loadedObject = nullptr;
const ObjectRepositoryItem* ori = _objectRepository.FindObject(entry); const ObjectRepositoryItem* ori = _objectRepository.FindObject(entry);
if (ori != nullptr) if (ori == nullptr)
{ return nullptr;
loadedObject = ori->LoadedObject;
} return ori->LoadedObject.get();
return loadedObject;
} }
ObjectEntryIndex GetLoadedObjectEntryIndex(const Object* object) override ObjectEntryIndex GetLoadedObjectEntryIndex(const Object* object) override
@ -142,7 +141,7 @@ public:
const ObjectRepositoryItem* ori = _objectRepository.FindObject(&entry); const ObjectRepositoryItem* ori = _objectRepository.FindObject(&entry);
if (ori != nullptr) if (ori != nullptr)
{ {
Object* loadedObject = ori->LoadedObject; Object* loadedObject = ori->LoadedObject.get();
if (loadedObject != nullptr) if (loadedObject != nullptr)
{ {
UnloadObject(loadedObject); UnloadObject(loadedObject);
@ -162,7 +161,7 @@ public:
{ {
for (auto& object : _loadedObjects) for (auto& object : _loadedObjects)
{ {
UnloadObject(object.get()); UnloadObject(object);
} }
UpdateSceneryGroupIndexes(); UpdateSceneryGroupIndexes();
ResetTypeToRideEntryIndexMap(); ResetTypeToRideEntryIndexMap();
@ -333,12 +332,13 @@ private:
Object* RepositoryItemToObject(const ObjectRepositoryItem* ori, std::optional<int32_t> slot = {}) Object* RepositoryItemToObject(const ObjectRepositoryItem* ori, std::optional<int32_t> slot = {})
{ {
Object* loadedObject = nullptr; if (ori == nullptr)
if (ori != nullptr) return nullptr;
{
loadedObject = ori->LoadedObject; Object* loadedObject = ori->LoadedObject.get();
if (loadedObject == nullptr) if (loadedObject != nullptr)
{ return loadedObject;
ObjectType objectType = ori->ObjectEntry.GetType(); ObjectType objectType = ori->ObjectEntry.GetType();
if (slot) if (slot)
{ {
@ -361,14 +361,13 @@ private:
{ {
_loadedObjects.resize(*slot + 1); _loadedObjects.resize(*slot + 1);
} }
loadedObject = object.get(); loadedObject = object;
_loadedObjects[*slot] = std::move(object); _loadedObjects[*slot] = object;
UpdateSceneryGroupIndexes(); UpdateSceneryGroupIndexes();
ResetTypeToRideEntryIndexMap(); ResetTypeToRideEntryIndexMap();
} }
} }
}
}
return loadedObject; return loadedObject;
} }
@ -396,8 +395,7 @@ private:
Guard::ArgumentNotNull(object, GUARD_LINE); Guard::ArgumentNotNull(object, GUARD_LINE);
auto result = std::numeric_limits<size_t>().max(); auto result = std::numeric_limits<size_t>().max();
auto it = std::find_if( auto it = std::find_if(_loadedObjects.begin(), _loadedObjects.end(), [object](auto& obj) { return obj == object; });
_loadedObjects.begin(), _loadedObjects.end(), [object](auto& obj) { return obj.get() == object; });
if (it != _loadedObjects.end()) if (it != _loadedObjects.end())
{ {
result = std::distance(_loadedObjects.begin(), it); result = std::distance(_loadedObjects.begin(), it);
@ -405,7 +403,7 @@ private:
return result; return result;
} }
void SetNewLoadedObjectList(std::vector<std::unique_ptr<Object>>&& newLoadedObjects) void SetNewLoadedObjectList(std::vector<Object*>&& newLoadedObjects)
{ {
if (newLoadedObjects.empty()) if (newLoadedObjects.empty())
{ {
@ -420,8 +418,9 @@ private:
void UnloadObject(Object* object) void UnloadObject(Object* object)
{ {
if (object != nullptr) if (object == nullptr)
{ return;
object->Unload(); object->Unload();
// TODO try to prevent doing a repository search // TODO try to prevent doing a repository search
@ -433,17 +432,10 @@ private:
// Because it's possible to have the same loaded object for multiple // Because it's possible to have the same loaded object for multiple
// slots, we have to make sure find and set all of them to nullptr // slots, we have to make sure find and set all of them to nullptr
for (auto& obj : _loadedObjects) std::replace(_loadedObjects.begin(), _loadedObjects.end(), object, static_cast<Object*>(nullptr));
{
if (obj.get() == object)
{
obj = nullptr;
}
}
}
} }
void UnloadObjectsExcept(const std::vector<std::unique_ptr<Object>>& newLoadedObjects) void UnloadObjectsExcept(const std::vector<Object*>& newLoadedObjects)
{ {
// Build a hash set for quick checking // Build a hash set for quick checking
auto exceptSet = std::unordered_set<Object*>(); auto exceptSet = std::unordered_set<Object*>();
@ -451,7 +443,7 @@ private:
{ {
if (object != nullptr) if (object != nullptr)
{ {
exceptSet.insert(object.get()); exceptSet.insert(object);
} }
} }
@ -463,9 +455,9 @@ private:
if (object != nullptr) if (object != nullptr)
{ {
totalObjectsLoaded++; totalObjectsLoaded++;
if (exceptSet.find(object.get()) == exceptSet.end()) if (exceptSet.find(object) == exceptSet.end())
{ {
UnloadObject(object.get()); UnloadObject(object);
numObjectsUnloaded++; numObjectsUnloaded++;
} }
} }
@ -478,43 +470,45 @@ private:
{ {
for (auto& loadedObject : _loadedObjects) for (auto& loadedObject : _loadedObjects)
{ {
if (loadedObject != nullptr) // The list can contain unused slots, skip them.
{ if (loadedObject == nullptr)
continue;
switch (loadedObject->GetObjectType()) switch (loadedObject->GetObjectType())
{ {
case ObjectType::SmallScenery: case ObjectType::SmallScenery:
{ {
auto* sceneryEntry = static_cast<SmallSceneryEntry*>(loadedObject->GetLegacyData()); auto* sceneryEntry = static_cast<SmallSceneryEntry*>(loadedObject->GetLegacyData());
sceneryEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject.get()); sceneryEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject);
break; break;
} }
case ObjectType::LargeScenery: case ObjectType::LargeScenery:
{ {
auto* sceneryEntry = static_cast<LargeSceneryEntry*>(loadedObject->GetLegacyData()); auto* sceneryEntry = static_cast<LargeSceneryEntry*>(loadedObject->GetLegacyData());
sceneryEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject.get()); sceneryEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject);
break; break;
} }
case ObjectType::Walls: case ObjectType::Walls:
{ {
auto* wallEntry = static_cast<WallSceneryEntry*>(loadedObject->GetLegacyData()); auto* wallEntry = static_cast<WallSceneryEntry*>(loadedObject->GetLegacyData());
wallEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject.get()); wallEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject);
break; break;
} }
case ObjectType::Banners: case ObjectType::Banners:
{ {
auto* bannerEntry = static_cast<BannerSceneryEntry*>(loadedObject->GetLegacyData()); auto* bannerEntry = static_cast<BannerSceneryEntry*>(loadedObject->GetLegacyData());
bannerEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject.get()); bannerEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject);
break; break;
} }
case ObjectType::PathBits: case ObjectType::PathBits:
{ {
auto* pathBitEntry = static_cast<PathBitEntry*>(loadedObject->GetLegacyData()); auto* pathBitEntry = static_cast<PathBitEntry*>(loadedObject->GetLegacyData());
pathBitEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject.get()); pathBitEntry->scenery_tab_id = GetPrimarySceneryGroupEntryIndex(loadedObject);
break; break;
} }
case ObjectType::SceneryGroup: case ObjectType::SceneryGroup:
{ {
auto sgObject = dynamic_cast<SceneryGroupObject*>(loadedObject.get()); auto sgObject = dynamic_cast<SceneryGroupObject*>(loadedObject);
sgObject->UpdateEntryIndexes(); sgObject->UpdateEntryIndexes();
break; break;
} }
@ -523,7 +517,6 @@ private:
break; break;
} }
} }
}
// HACK Scenery window will lose its tabs after changing the scenery group indexing // HACK Scenery window will lose its tabs after changing the scenery group indexing
// for now just close it, but it will be better to later tell it to invalidate the tabs // for now just close it, but it will be better to later tell it to invalidate the tabs
@ -583,7 +576,7 @@ private:
} }
else else
{ {
auto loadedObject = ori->LoadedObject; auto* loadedObject = ori->LoadedObject.get();
if (loadedObject == nullptr) if (loadedObject == nullptr)
{ {
auto object = _objectRepository.LoadObject(ori); auto object = _objectRepository.LoadObject(ori);
@ -651,61 +644,51 @@ private:
} }
} }
std::vector<std::unique_ptr<Object>> LoadObjects( std::vector<Object*> LoadObjects(std::vector<const ObjectRepositoryItem*>& requiredObjects, size_t* outNewObjectsLoaded)
std::vector<const ObjectRepositoryItem*>& requiredObjects, size_t* outNewObjectsLoaded)
{ {
std::vector<std::unique_ptr<Object>> objects; std::vector<Object*> objects;
std::vector<Object*> loadedObjects; std::vector<Object*> newLoadedObjects;
std::vector<rct_object_entry> badObjects; std::vector<rct_object_entry> badObjects;
objects.resize(OBJECT_ENTRY_COUNT); objects.resize(OBJECT_ENTRY_COUNT);
loadedObjects.reserve(OBJECT_ENTRY_COUNT); newLoadedObjects.reserve(OBJECT_ENTRY_COUNT);
// Read objects // Read objects
std::mutex commonMutex; std::mutex commonMutex;
ParallelFor(requiredObjects, [this, &commonMutex, requiredObjects, &objects, &badObjects, &loadedObjects](size_t i) { ParallelFor(requiredObjects, [this, &commonMutex, requiredObjects, &objects, &badObjects, &newLoadedObjects](size_t i) {
auto requiredObject = requiredObjects[i]; auto requiredObject = requiredObjects[i];
std::unique_ptr<Object> object; Object* object = nullptr;
if (requiredObject != nullptr) if (requiredObject != nullptr)
{ {
auto loadedObject = requiredObject->LoadedObject; auto loadedObject = requiredObject->LoadedObject.get();
if (loadedObject == nullptr) if (loadedObject == nullptr)
{ {
// Object requires to be loaded, if the object successfully loads it will register it // Object requires to be loaded, if the object successfully loads it will register it
// as a loaded object otherwise placed into the badObjects list. // as a loaded object otherwise placed into the badObjects list.
object = _objectRepository.LoadObject(requiredObject); auto newObject = _objectRepository.LoadObject(requiredObject);
std::lock_guard<std::mutex> guard(commonMutex); std::lock_guard<std::mutex> guard(commonMutex);
if (object == nullptr) if (newObject == nullptr)
{ {
badObjects.push_back(requiredObject->ObjectEntry); badObjects.push_back(requiredObject->ObjectEntry);
ReportObjectLoadProblem(&requiredObject->ObjectEntry); ReportObjectLoadProblem(&requiredObject->ObjectEntry);
} }
else else
{ {
loadedObjects.push_back(object.get()); object = newObject.get();
newLoadedObjects.push_back(object);
// Connect the ori to the registered object // Connect the ori to the registered object
_objectRepository.RegisterLoadedObject(requiredObject, object.get()); _objectRepository.RegisterLoadedObject(requiredObject, std::move(newObject));
} }
} }
else else
{ {
// The object is already loaded, given that the new list will be used as the next loaded object list, object = loadedObject;
// we can move the element out safely. This is required as the resulting list must contain all loaded
// objects and not just the newly loaded ones.
std::lock_guard<std::mutex> guard(commonMutex);
auto it = std::find_if(_loadedObjects.begin(), _loadedObjects.end(), [loadedObject](const auto& obj) {
return obj.get() == loadedObject;
});
if (it != _loadedObjects.end())
{
object = std::move(*it);
} }
} }
} objects[i] = object;
objects[i] = std::move(object);
}); });
// Load objects // Load objects
for (auto obj : loadedObjects) for (auto obj : newLoadedObjects)
{ {
obj->Load(); obj->Load();
} }
@ -713,7 +696,7 @@ private:
if (!badObjects.empty()) if (!badObjects.empty())
{ {
// Unload all the new objects we loaded // Unload all the new objects we loaded
for (auto object : loadedObjects) for (auto object : newLoadedObjects)
{ {
UnloadObject(object); UnloadObject(object);
} }
@ -722,28 +705,30 @@ private:
if (outNewObjectsLoaded != nullptr) if (outNewObjectsLoaded != nullptr)
{ {
*outNewObjectsLoaded = loadedObjects.size(); *outNewObjectsLoaded = newLoadedObjects.size();
} }
return objects; return objects;
} }
std::unique_ptr<Object> GetOrLoadObject(const ObjectRepositoryItem* ori) Object* GetOrLoadObject(const ObjectRepositoryItem* ori)
{
std::unique_ptr<Object> object;
auto loadedObject = ori->LoadedObject;
if (loadedObject == nullptr)
{ {
auto* loadedObject = ori->LoadedObject.get();
if (loadedObject != nullptr)
return loadedObject;
// Try to load object // Try to load object
object = _objectRepository.LoadObject(ori); auto object = _objectRepository.LoadObject(ori);
if (object != nullptr) if (object != nullptr)
{ {
loadedObject = object.get();
object->Load(); object->Load();
// Connect the ori to the registered object // Connect the ori to the registered object
_objectRepository.RegisterLoadedObject(ori, object.get()); _objectRepository.RegisterLoadedObject(ori, std::move(object));
} }
}
return object; return loadedObject;
} }
void ResetTypeToRideEntryIndexMap() void ResetTypeToRideEntryIndexMap()

View File

@ -266,18 +266,18 @@ public:
} }
} }
void RegisterLoadedObject(const ObjectRepositoryItem* ori, Object* object) override void RegisterLoadedObject(const ObjectRepositoryItem* ori, std::unique_ptr<Object>&& object) override
{ {
ObjectRepositoryItem* item = &_items[ori->Id]; ObjectRepositoryItem* item = &_items[ori->Id];
Guard::Assert(item->LoadedObject == nullptr, GUARD_LINE); Guard::Assert(item->LoadedObject == nullptr, GUARD_LINE);
item->LoadedObject = object; item->LoadedObject = std::move(object);
} }
void UnregisterLoadedObject(const ObjectRepositoryItem* ori, Object* object) override void UnregisterLoadedObject(const ObjectRepositoryItem* ori, Object* object) override
{ {
ObjectRepositoryItem* item = &_items[ori->Id]; ObjectRepositoryItem* item = &_items[ori->Id];
if (item->LoadedObject == object) if (item->LoadedObject.get() == object)
{ {
item->LoadedObject = nullptr; item->LoadedObject = nullptr;
} }

View File

@ -43,7 +43,7 @@ struct ObjectRepositoryItem
std::string Name; std::string Name;
std::vector<std::string> Authors; std::vector<std::string> Authors;
std::vector<ObjectSourceGame> Sources; std::vector<ObjectSourceGame> Sources;
Object* LoadedObject{}; std::shared_ptr<Object> LoadedObject{};
struct struct
{ {
uint8_t RideFlags; uint8_t RideFlags;
@ -82,7 +82,7 @@ struct IObjectRepository
[[nodiscard]] virtual const ObjectRepositoryItem* FindObject(const ObjectEntryDescriptor& oed) const abstract; [[nodiscard]] virtual const ObjectRepositoryItem* FindObject(const ObjectEntryDescriptor& oed) const abstract;
[[nodiscard]] virtual std::unique_ptr<Object> LoadObject(const ObjectRepositoryItem* ori) abstract; [[nodiscard]] virtual std::unique_ptr<Object> LoadObject(const ObjectRepositoryItem* ori) abstract;
virtual void RegisterLoadedObject(const ObjectRepositoryItem* ori, Object* object) abstract; virtual void RegisterLoadedObject(const ObjectRepositoryItem* ori, std::unique_ptr<Object>&& object) abstract;
virtual void UnregisterLoadedObject(const ObjectRepositoryItem* ori, Object* object) abstract; virtual void UnregisterLoadedObject(const ObjectRepositoryItem* ori, Object* object) abstract;
virtual void AddObject(const rct_object_entry* objectEntry, const void* data, size_t dataSize) abstract; virtual void AddObject(const rct_object_entry* objectEntry, const void* data, size_t dataSize) abstract;

View File

@ -81,7 +81,7 @@ void SceneryGroupObject::UpdateEntryIndexes()
if (ori->LoadedObject == nullptr) if (ori->LoadedObject == nullptr)
continue; continue;
auto entryIndex = objectManager.GetLoadedObjectEntryIndex(ori->LoadedObject); auto entryIndex = objectManager.GetLoadedObjectEntryIndex(ori->LoadedObject.get());
Guard::Assert(entryIndex != OBJECT_ENTRY_INDEX_NULL, GUARD_LINE); Guard::Assert(entryIndex != OBJECT_ENTRY_INDEX_NULL, GUARD_LINE);
auto sceneryType = ori->ObjectEntry.GetSceneryType(); auto sceneryType = ori->ObjectEntry.GetSceneryType();