Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions src/adf_core_python/core/agent/info/world_info.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Optional
from typing import Any
from warnings import deprecated

from rcrscore.entities import EntityID
from rcrscore.entities.area import Area
Expand Down Expand Up @@ -39,7 +40,7 @@ def set_change_set(self, change_set: ChangeSet) -> None:
"""
self._change_set = change_set

def get_entity(self, entity_id: EntityID) -> Optional[Entity]:
def get_entity(self, entity_id: EntityID) -> Entity | None:
"""
Get the entity

Expand All @@ -50,7 +51,7 @@ def get_entity(self, entity_id: EntityID) -> Optional[Entity]:

Returns
-------
Optional[Entity]
Entity | None
Entity
"""
return self._world_model.get_entity(entity_id)
Expand Down Expand Up @@ -179,8 +180,8 @@ def get_distance(self, entity_id1: EntityID, entity_id2: EntityID) -> float:
ValueError
If one or both entities are invalid or the location is invalid
"""
entity1: Optional[Entity] = self.get_entity(entity_id1)
entity2: Optional[Entity] = self.get_entity(entity_id2)
entity1: Entity | None = self.get_entity(entity_id1)
entity2: Entity | None = self.get_entity(entity_id2)
if entity1 is None or entity2 is None:
raise ValueError(
f"One or both entities are invalid: entity_id1={entity_id1}, entity_id2={entity_id2}, entity1={entity1}, entity2={entity2}"
Expand All @@ -204,6 +205,9 @@ def get_distance(self, entity_id1: EntityID, entity_id2: EntityID) -> float:

return distance

@deprecated(
"get_entity_position is deprecated, use get_entity_position_entity_id or get_entity_position_entity instead."
)
def get_entity_position(self, entity_id: EntityID) -> EntityID | None:
"""
Get the entity position
Expand All @@ -215,7 +219,7 @@ def get_entity_position(self, entity_id: EntityID) -> EntityID | None:

Returns
-------
EntityID
EntityID | None
Entity position

Raises
Expand All @@ -234,6 +238,50 @@ def get_entity_position(self, entity_id: EntityID) -> EntityID | None:
return entity.get_position()
raise ValueError(f"Invalid entity type: entity_id={entity_id}, entity={entity}")

def get_entity_position_entity_id(self, entity_id: EntityID) -> EntityID | None:
"""
Get the entity position EntityID

Parameters
----------
entity_id : EntityID
Entity ID

Returns
-------
EntityID | None
Entity position EntityID
"""
entity = self.get_entity(entity_id)
if entity is None:
return None
if isinstance(entity, Area):
return entity.get_entity_id()
if isinstance(entity, Human):
return entity.get_position()
if isinstance(entity, Blockade):
return entity.get_position()
return None

def get_entity_position_entity(self, entity_id: EntityID) -> Entity | None:
"""
Get the entity position Entity

Parameters
----------
entity_id : EntityID
Entity ID

Returns
-------
Entity | None
Entity position Entity
"""
position_entity_id = self.get_entity_position_entity_id(entity_id)
if position_entity_id is None:
return None
return self.get_entity(position_entity_id)

def get_change_set(self) -> ChangeSet:
"""
Get the change set
Expand Down
Loading