diff --git a/src/adf_core_python/core/agent/info/world_info.py b/src/adf_core_python/core/agent/info/world_info.py index cc74d83..3106663 100644 --- a/src/adf_core_python/core/agent/info/world_info.py +++ b/src/adf_core_python/core/agent/info/world_info.py @@ -1,4 +1,5 @@ -from typing import Any, Optional +from typing import Any +from warnings import deprecated from rcrscore.entities import EntityID from rcrscore.entities.area import Area @@ -39,7 +40,7 @@ def set_change_set(self, change_set: ChangeSet) -> None: """ self._change_set = change_set - def get_entity(self, entity_id: EntityID) -> Optional[Entity]: + def get_entity(self, entity_id: EntityID) -> Entity | None: """ Get the entity @@ -50,7 +51,7 @@ def get_entity(self, entity_id: EntityID) -> Optional[Entity]: Returns ------- - Optional[Entity] + Entity | None Entity """ return self._world_model.get_entity(entity_id) @@ -179,8 +180,8 @@ def get_distance(self, entity_id1: EntityID, entity_id2: EntityID) -> float: ValueError If one or both entities are invalid or the location is invalid """ - entity1: Optional[Entity] = self.get_entity(entity_id1) - entity2: Optional[Entity] = self.get_entity(entity_id2) + entity1: Entity | None = self.get_entity(entity_id1) + entity2: Entity | None = self.get_entity(entity_id2) if entity1 is None or entity2 is None: raise ValueError( f"One or both entities are invalid: entity_id1={entity_id1}, entity_id2={entity_id2}, entity1={entity1}, entity2={entity2}" @@ -204,6 +205,9 @@ def get_distance(self, entity_id1: EntityID, entity_id2: EntityID) -> float: return distance + @deprecated( + "get_entity_position is deprecated, use get_entity_position_entity_id or get_entity_position_entity instead." + ) def get_entity_position(self, entity_id: EntityID) -> EntityID | None: """ Get the entity position @@ -215,7 +219,7 @@ def get_entity_position(self, entity_id: EntityID) -> EntityID | None: Returns ------- - EntityID + EntityID | None Entity position Raises @@ -234,6 +238,50 @@ def get_entity_position(self, entity_id: EntityID) -> EntityID | None: return entity.get_position() raise ValueError(f"Invalid entity type: entity_id={entity_id}, entity={entity}") + def get_entity_position_entity_id(self, entity_id: EntityID) -> EntityID | None: + """ + Get the entity position EntityID + + Parameters + ---------- + entity_id : EntityID + Entity ID + + Returns + ------- + EntityID | None + Entity position EntityID + """ + entity = self.get_entity(entity_id) + if entity is None: + return None + if isinstance(entity, Area): + return entity.get_entity_id() + if isinstance(entity, Human): + return entity.get_position() + if isinstance(entity, Blockade): + return entity.get_position() + return None + + def get_entity_position_entity(self, entity_id: EntityID) -> Entity | None: + """ + Get the entity position Entity + + Parameters + ---------- + entity_id : EntityID + Entity ID + + Returns + ------- + Entity | None + Entity position Entity + """ + position_entity_id = self.get_entity_position_entity_id(entity_id) + if position_entity_id is None: + return None + return self.get_entity(position_entity_id) + def get_change_set(self) -> ChangeSet: """ Get the change set