From cd8dd2bafad5acaa01694d7c4e2b44c761290225 Mon Sep 17 00:00:00 2001 From: maciejmajek Date: Fri, 19 Dec 2025 15:05:31 +0100 Subject: [PATCH 1/7] feat: allow passing ros2 types in send_message --- src/rai_core/pyproject.toml | 2 +- .../rai/communication/ros2/api/base.py | 19 +++++- .../rai/communication/ros2/api/topic.py | 14 ++++- .../rai/communication/ros2/connectors/base.py | 31 ++++++--- tests/communication/ros2/test_api.py | 63 ++++++++++++++++--- tests/communication/ros2/test_connectors.py | 58 ++++++++++++++--- 6 files changed, 155 insertions(+), 32 deletions(-) diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml index 3be0b26c8..3f94640c5 100644 --- a/src/rai_core/pyproject.toml +++ b/src/rai_core/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "rai_core" -version = "2.6.0" +version = "2.7.0" description = "Core functionality for RAI framework" authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "] readme = "README.md" diff --git a/src/rai_core/rai/communication/ros2/api/base.py b/src/rai_core/rai/communication/ros2/api/base.py index 7692f5cbd..b138c7d10 100644 --- a/src/rai_core/rai/communication/ros2/api/base.py +++ b/src/rai_core/rai/communication/ros2/api/base.py @@ -36,7 +36,12 @@ from rclpy.topic_endpoint_info import TopicEndpointInfo from rosidl_parser.definition import NamespacedType from rosidl_runtime_py.import_message import import_message_from_namespaced_type -from rosidl_runtime_py.utilities import get_namespaced_type +from rosidl_runtime_py.utilities import ( + get_namespaced_type, + is_action, + is_message, + is_service, +) from rai.communication.ros2.api.conversion import import_message_from_str @@ -134,3 +139,15 @@ def get_topic_type(self, topic: str) -> str: raise ValueError(f"Topic {topic} has multiple types: {types}") return types[0] raise ValueError(f"Topic {topic} not found") + + @staticmethod + def is_ros2_message(msg: Any) -> bool: + return is_message(msg) + + @staticmethod + def is_ros2_service(msg: Any) -> bool: + return is_service(msg) + + @staticmethod + def is_ros2_action(msg: Any) -> bool: + return is_action(msg) diff --git a/src/rai_core/rai/communication/ros2/api/topic.py b/src/rai_core/rai/communication/ros2/api/topic.py index cb55002f6..88ab01a09 100644 --- a/src/rai_core/rai/communication/ros2/api/topic.py +++ b/src/rai_core/rai/communication/ros2/api/topic.py @@ -34,6 +34,7 @@ from rai.communication.ros2.api.base import ( BaseROS2API, + IROS2Message, ) from rai.communication.ros2.api.conversion import import_message_from_str @@ -140,8 +141,8 @@ def get_topic_names_and_types( def publish( self, topic: str, - msg_content: Dict[str, Any], - msg_type: str, + msg_content: IROS2Message | Dict[str, Any], + msg_type: str | None = None, *, auto_qos_matching: bool = True, qos_profile: Optional[QoSProfile] = None, @@ -162,7 +163,14 @@ def publish( topic, auto_qos_matching, qos_profile, for_publisher=True ) - msg = self.build_ros2_msg(msg_type, msg_content) + if self.is_ros2_message(msg_content): + msg = msg_content + elif isinstance(msg_content, dict) and msg_type is not None: + msg = self.build_ros2_msg(msg_type, msg_content) + elif isinstance(msg_content, dict) and msg_type is None: + raise ValueError("msg_type must be provided if msg_content is a dict") + else: + raise ValueError(f"Invalid message content type: {type(msg_content)}") publisher = self._get_or_create_publisher(topic, type(msg), qos_profile) publisher.publish(msg) diff --git a/src/rai_core/rai/communication/ros2/connectors/base.py b/src/rai_core/rai/communication/ros2/connectors/base.py index 2bf5145a5..aeed2aac6 100644 --- a/src/rai_core/rai/communication/ros2/connectors/base.py +++ b/src/rai_core/rai/communication/ros2/connectors/base.py @@ -16,7 +16,17 @@ import time import uuid from functools import partial -from typing import Any, Callable, Dict, Final, List, Literal, Optional, Tuple, TypeVar +from typing import ( + Any, + Callable, + Dict, + Final, + List, + Literal, + Optional, + Tuple, + TypeVar, +) import rclpy import rclpy.executors @@ -31,6 +41,7 @@ from rai.communication import BaseConnector from rai.communication.ros2.api import ( + IROS2Message, ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI, @@ -240,10 +251,10 @@ def get_actions_names_and_types(self) -> List[Tuple[str, List[str]]]: def send_message( self, - message: T, + message: T | IROS2Message, target: str, *, - msg_type: str, + msg_type: str | None = None, auto_qos_matching: bool = True, qos_profile: Optional[QoSProfile] = None, **kwargs: Any, @@ -252,12 +263,12 @@ def send_message( Parameters ---------- - message : T - The message to send. + message : T | IROS2Message + The message to send. Can be a subclass of ROS2Message (payload is a dict) or any ROS2 message. target : str The target topic name. - msg_type : str - The ROS2 message type. + msg_type : str | None, optional + The ROS2 message type. If None, the message type will be inferred from the message content. Must be provided if msg_content is a ROS2Message subclass. auto_qos_matching : bool, optional Whether to automatically match QoS profiles, by default True. qos_profile : Optional[QoSProfile], optional @@ -265,9 +276,13 @@ def send_message( **kwargs : Any Additional keyword arguments. """ + if isinstance(message, ROS2Message): # T class + msg_content = message.payload + else: # An actual ROS 2 message + msg_content = message self._topic_api.publish( topic=target, - msg_content=message.payload, + msg_content=msg_content, msg_type=msg_type, auto_qos_matching=auto_qos_matching, qos_profile=qos_profile, diff --git a/tests/communication/ros2/test_api.py b/tests/communication/ros2/test_api.py index 64974fb76..9765371ea 100644 --- a/tests/communication/ros2/test_api.py +++ b/tests/communication/ros2/test_api.py @@ -22,6 +22,13 @@ import pytest from action_msgs.msg import GoalStatus from action_msgs.srv import CancelGoal +from geometry_msgs.msg import ( + Point, + Pose, + PoseWithCovariance, + PoseWithCovarianceStamped, + Quaternion, +) from nav2_msgs.action import NavigateToPose from rai.communication.ros2.api import ( ROS2ActionAPI, @@ -35,6 +42,7 @@ ) from rclpy.executors import MultiThreadedExecutor from rclpy.node import Node +from std_msgs.msg import Header, String from std_srvs.srv import SetBool from .helpers import ( @@ -51,12 +59,39 @@ _ = ros_setup # Explicitly use the fixture to prevent pytest warnings +@pytest.mark.parametrize( + "message_content,msg_type,actual_type", + [ + ({"data": "Hello, ROS2!"}, "std_msgs/msg/String", String), + (String(data="Hello, ROS2!"), None, String), + (String(), None, String), + (Pose(), None, Pose), + (PoseWithCovarianceStamped(), None, PoseWithCovarianceStamped), + ( + PoseWithCovarianceStamped( + header=Header(), + pose=PoseWithCovariance( + pose=Pose( + position=Point(x=1.0, y=2.0, z=3.0), + orientation=Quaternion(x=0.1, y=0.2, z=0.3, w=0.4), + ) + ), + ), + None, + PoseWithCovarianceStamped, + ), + ], +) def test_ros2_single_message_publish( - ros_setup: None, request: pytest.FixtureRequest + ros_setup: None, + request: pytest.FixtureRequest, + message_content: Any, + msg_type: str | None, + actual_type: type, ) -> None: topic_name = f"{request.node.originalname}_topic" # type: ignore node_name = f"{request.node.originalname}_node" # type: ignore - message_receiver = MessageSubscriber(topic_name) + message_receiver = MessageSubscriber(topic_name, actual_type) node = Node(node_name) executors, threads = multi_threaded_spinner([message_receiver, node]) @@ -64,12 +99,12 @@ def test_ros2_single_message_publish( topic_api = ROS2TopicAPI(node) topic_api.publish( topic_name, - {"data": "Hello, ROS2!"}, - msg_type="std_msgs/msg/String", + message_content, + msg_type=msg_type, ) - time.sleep(1) + time.sleep(0.1) assert len(message_receiver.received_messages) == 1 - assert message_receiver.received_messages[0].data == "Hello, ROS2!" + assert isinstance(message_receiver.received_messages[0], actual_type) finally: shutdown_executors_and_threads(executors, threads) @@ -116,8 +151,18 @@ def test_ros2_single_message_publish_wrong_msg_content( shutdown_executors_and_threads(executors, threads) +@pytest.mark.parametrize( + "message_content,msg_type", + [ + ({"data": "Hello, ROS2!"}, "std_msgs/msg/String"), + (String(data="Hello, ROS2!"), None), + ], +) def test_ros2_single_message_publish_wrong_qos_setup( - ros_setup: None, request: pytest.FixtureRequest + ros_setup: None, + request: pytest.FixtureRequest, + message_content: Any, + msg_type: str | None, ) -> None: topic_name = f"{request.node.originalname}_topic" # type: ignore node_name = f"{request.node.originalname}_node" # type: ignore @@ -130,8 +175,8 @@ def test_ros2_single_message_publish_wrong_qos_setup( with pytest.raises(ValueError): topic_api.publish( topic_name, - {"data": "Hello, ROS2!"}, - msg_type="std_msgs/msg/String", + message_content, + msg_type=msg_type, auto_qos_matching=False, qos_profile=None, ) diff --git a/tests/communication/ros2/test_connectors.py b/tests/communication/ros2/test_connectors.py index 346a318bd..b37867972 100644 --- a/tests/communication/ros2/test_connectors.py +++ b/tests/communication/ros2/test_connectors.py @@ -19,6 +19,14 @@ from unittest.mock import MagicMock import pytest +from builtin_interfaces.msg import Time +from geometry_msgs.msg import ( + Point, + Pose, + PoseWithCovariance, + PoseWithCovarianceStamped, + Quaternion, +) from nav2_msgs.action import NavigateToPose from PIL import Image from pydub import AudioSegment @@ -33,7 +41,7 @@ MutuallyExclusiveCallbackGroup, ReentrantCallbackGroup, ) -from std_msgs.msg import String +from std_msgs.msg import Header, String from std_srvs.srv import SetBool from .helpers import ( @@ -52,21 +60,51 @@ _ = ros_setup # Explicitly use the fixture to prevent pytest warnings -def test_ros2_connector_send_message(ros_setup: None, request: pytest.FixtureRequest): +@pytest.mark.parametrize( + "message_content,msg_type,actual_type", + [ + (ROS2Message(payload={"data": "Hello, ROS2!"}), "std_msgs/msg/String", String), + (String(data="Hello, ROS2!"), None, String), + (String(), None, String), + (Pose(), None, Pose), + (PoseWithCovarianceStamped(), None, PoseWithCovarianceStamped), + ( + PoseWithCovarianceStamped( + header=Header( + stamp=Time(sec=1, nanosec=100000000), + frame_id="test_frame", + ), + pose=PoseWithCovariance( + pose=Pose( + position=Point(x=1.0, y=2.0, z=3.0), + orientation=Quaternion(x=0.1, y=0.2, z=0.3, w=0.4), + ), + covariance=[0.0] * 36, + ), + ), + None, + PoseWithCovarianceStamped, + ), + ], +) +def test_ros2_connector_send_message( + ros_setup: None, + request: pytest.FixtureRequest, + message_content: ROS2Message, + msg_type: str | None, + actual_type: type, +): topic_name = f"{request.node.originalname}_topic" # type: ignore - message_receiver = MessageSubscriber(topic_name) + message_receiver = MessageSubscriber(topic_name, actual_type) executors, threads = multi_threaded_spinner([message_receiver]) connector = ROS2Connector() try: - message = ROS2Message( - payload={"data": "Hello, world!"}, - metadata={"msg_type": "std_msgs/msg/String"}, - ) connector.send_message( - message=message, target=topic_name, msg_type="std_msgs/msg/String" + message=message_content, target=topic_name, msg_type=msg_type ) - time.sleep(1) # wait for the message to be received - assert message_receiver.received_messages == [String(data="Hello, world!")] + time.sleep(0.1) # wait for the message to be received + assert len(message_receiver.received_messages) == 1 + assert isinstance(message_receiver.received_messages[0], actual_type) finally: connector.shutdown() shutdown_executors_and_threads(executors, threads) From fad9c01b7c5a78a2e1f539a89f6170a69e501402 Mon Sep 17 00:00:00 2001 From: maciejmajek Date: Fri, 19 Dec 2025 15:28:37 +0100 Subject: [PATCH 2/7] test: extend test suite --- tests/communication/ros2/test_api.py | 114 +++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/tests/communication/ros2/test_api.py b/tests/communication/ros2/test_api.py index 9765371ea..8a9747e08 100644 --- a/tests/communication/ros2/test_api.py +++ b/tests/communication/ros2/test_api.py @@ -35,6 +35,7 @@ ROS2ServiceAPI, ROS2TopicAPI, ) +from rai.communication.ros2.api.base import BaseROS2API from rclpy.callback_groups import ( CallbackGroup, MutuallyExclusiveCallbackGroup, @@ -59,6 +60,63 @@ _ = ros_setup # Explicitly use the fixture to prevent pytest warnings +@pytest.mark.parametrize( + "entity,is_message,is_service,is_action", + [ + ({"data": "Hello, ROS2!"}, False, False, False), + ({}, False, False, False), + ("", False, False, False), + ("data: Hello, ROS2!", False, False, False), + (None, False, False, False), + (String(), True, False, False), + (Pose(), True, False, False), + (PoseWithCovarianceStamped(), True, False, False), + ( + PoseWithCovarianceStamped( + header=Header(), + pose=PoseWithCovariance( + pose=Pose( + position=Point(x=1.0, y=2.0, z=3.0), + orientation=Quaternion(x=0.1, y=0.2, z=0.3, w=0.4), + ) + ), + ), + True, + False, + False, + ), + (SetBool.Request(data=True), True, False, False), + ( + SetBool.Response(success=True, message="Test service called"), + True, + False, + False, + ), + (SetBool, False, True, False), + ( + NavigateToPose.Goal( + pose=Pose( + position=Point(x=1.0, y=2.0, z=3.0), + orientation=Quaternion(x=0.1, y=0.2, z=0.3, w=0.4), + ) + ), + True, + False, + False, + ), + (NavigateToPose.Result(success=True), True, False, False), + (NavigateToPose.Feedback(feedback="Test feedback"), True, False, False), + (NavigateToPose, False, False, True), + ], +) +def test_is_message_type( + ros_setup: None, entity: Any, is_message: bool, is_service: bool, is_action: bool +) -> None: + assert is_message == BaseROS2API.is_ros2_message(entity) + assert is_service == BaseROS2API.is_ros2_service(entity) + assert is_action == BaseROS2API.is_ros2_action(entity) + + @pytest.mark.parametrize( "message_content,msg_type,actual_type", [ @@ -184,6 +242,62 @@ def test_ros2_single_message_publish_wrong_qos_setup( shutdown_executors_and_threads(executors, threads) +def test_ros2_single_message_dict_no_type( + ros_setup: None, request: pytest.FixtureRequest +) -> None: + topic_name = f"{request.node.originalname}_topic" # type: ignore + node_name = f"{request.node.originalname}_node" # type: ignore + message_receiver = MessageSubscriber(topic_name) + node = Node(node_name) + executors, threads = multi_threaded_spinner([message_receiver, node]) + + try: + topic_api = ROS2TopicAPI(node) + with pytest.raises(ValueError): + topic_api.publish( + topic_name, + {"data": "Hello, ROS2!"}, + msg_type=None, + ) + finally: + shutdown_executors_and_threads(executors, threads) + + +@pytest.mark.parametrize( + "message_content,msg_type", + [ + ((), "std_msgs/msg/String"), + ((), None), + (None, "std_msgs/msg/String"), + (None, None), + ("data: Hello, ROS2!", "std_msgs/msg/String"), + ("data: Hello, ROS2!", None), + ], +) +def test_ros2_single_message_invalid_type( + ros_setup: None, + request: pytest.FixtureRequest, + message_content: Any, + msg_type: str | None, +) -> None: + topic_name = f"{request.node.originalname}_topic" # type: ignore + node_name = f"{request.node.originalname}_node" # type: ignore + message_receiver = MessageSubscriber(topic_name) + node = Node(node_name) + executors, threads = multi_threaded_spinner([message_receiver, node]) + + try: + topic_api = ROS2TopicAPI(node) + with pytest.raises(ValueError): + topic_api.publish( + topic_name, + message_content, + msg_type=msg_type, + ) + finally: + shutdown_executors_and_threads(executors, threads) + + def invoke_set_bool_service( service_name: str, service_api: ROS2ServiceAPI, reuse_client: bool = True ): From 6807cf19773eed57304914f33c77e66b426854ac Mon Sep 17 00:00:00 2001 From: maciejmajek Date: Fri, 19 Dec 2025 15:32:28 +0100 Subject: [PATCH 3/7] fix: tests on humble --- tests/communication/ros2/test_api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/communication/ros2/test_api.py b/tests/communication/ros2/test_api.py index 8a9747e08..1d3d1b74b 100644 --- a/tests/communication/ros2/test_api.py +++ b/tests/communication/ros2/test_api.py @@ -25,6 +25,7 @@ from geometry_msgs.msg import ( Point, Pose, + PoseStamped, PoseWithCovariance, PoseWithCovarianceStamped, Quaternion, @@ -95,7 +96,7 @@ (SetBool, False, True, False), ( NavigateToPose.Goal( - pose=Pose( + pose=PoseStamped( position=Point(x=1.0, y=2.0, z=3.0), orientation=Quaternion(x=0.1, y=0.2, z=0.3, w=0.4), ) From bfd7cc5627fda579fc5b28847d3a9c979143e8a2 Mon Sep 17 00:00:00 2001 From: maciejmajek Date: Fri, 19 Dec 2025 15:43:40 +0100 Subject: [PATCH 4/7] fix: tests on humble --- tests/communication/ros2/test_api.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/communication/ros2/test_api.py b/tests/communication/ros2/test_api.py index 1d3d1b74b..35322c909 100644 --- a/tests/communication/ros2/test_api.py +++ b/tests/communication/ros2/test_api.py @@ -97,8 +97,11 @@ ( NavigateToPose.Goal( pose=PoseStamped( - position=Point(x=1.0, y=2.0, z=3.0), - orientation=Quaternion(x=0.1, y=0.2, z=0.3, w=0.4), + header=Header(), + pose=Pose( + position=Point(x=1.0, y=2.0, z=3.0), + orientation=Quaternion(x=0.1, y=0.2, z=0.3, w=0.4), + ), ) ), True, From e9bd7c87b573ee5e2291eb8439a598286b4f2493 Mon Sep 17 00:00:00 2001 From: maciejmajek Date: Fri, 19 Dec 2025 15:46:38 +0100 Subject: [PATCH 5/7] fix: tests on humble --- tests/communication/ros2/test_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/communication/ros2/test_api.py b/tests/communication/ros2/test_api.py index 35322c909..b2b7d7d21 100644 --- a/tests/communication/ros2/test_api.py +++ b/tests/communication/ros2/test_api.py @@ -108,8 +108,8 @@ False, False, ), - (NavigateToPose.Result(success=True), True, False, False), - (NavigateToPose.Feedback(feedback="Test feedback"), True, False, False), + (NavigateToPose.Result(), True, False, False), + (NavigateToPose.Feedback(), True, False, False), (NavigateToPose, False, False, True), ], ) From 1e5a46f80389774cc0d9322201df11c6c0b4a225 Mon Sep 17 00:00:00 2001 From: maciejmajek Date: Fri, 19 Dec 2025 15:47:05 +0100 Subject: [PATCH 6/7] docs: update --- docs/API_documentation/connectors/ROS_2_Connectors.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/API_documentation/connectors/ROS_2_Connectors.md b/docs/API_documentation/connectors/ROS_2_Connectors.md index e4c8c7969..22daaa3bc 100644 --- a/docs/API_documentation/connectors/ROS_2_Connectors.md +++ b/docs/API_documentation/connectors/ROS_2_Connectors.md @@ -34,12 +34,19 @@ The `ROS2Connector` is the main interface for publishing, subscribing, and calli ```python from rai.communication.ros2.connectors import ROS2Connector +from std_msgs.msg import String connector = ROS2Connector() -# Send a message to a topic +# Send a raw ROS 2 message (msg_type is inferred) connector.send_message( - message=my_msg, # ROS2Message + message=String(data="Hello"), + target="/my_topic" +) + +# Send a message using a dictionary (msg_type is required) +connector.send_message( + message={"data": "Hello"}, target="/my_topic", msg_type="std_msgs/msg/String" ) From 864f99c2bf6336a2fd428f32fb9c70374e3d96f0 Mon Sep 17 00:00:00 2001 From: maciejmajek Date: Fri, 19 Dec 2025 15:48:30 +0100 Subject: [PATCH 7/7] docs: update --- docs/API_documentation/connectors/ROS_2_Connectors.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/API_documentation/connectors/ROS_2_Connectors.md b/docs/API_documentation/connectors/ROS_2_Connectors.md index 22daaa3bc..6579d3a0a 100644 --- a/docs/API_documentation/connectors/ROS_2_Connectors.md +++ b/docs/API_documentation/connectors/ROS_2_Connectors.md @@ -33,7 +33,7 @@ The `ROS2Connector` is the main interface for publishing, subscribing, and calli ### Example Usage ```python -from rai.communication.ros2.connectors import ROS2Connector +from rai.communication.ros2.connectors import ROS2Connector, ROS2Message from std_msgs.msg import String connector = ROS2Connector() @@ -46,7 +46,7 @@ connector.send_message( # Send a message using a dictionary (msg_type is required) connector.send_message( - message={"data": "Hello"}, + message=ROS2Message(payload={"data": "Hello"}), target="/my_topic", msg_type="std_msgs/msg/String" )