1515import threading
1616import time
1717import uuid
18- from typing import Any , Callable , Dict , List , Literal , Optional , Tuple
18+ from collections import OrderedDict
19+ from typing import Any , Callable , Dict , List , Literal , Optional , Tuple , Union , cast
1920
21+ import numpy as np
2022import rclpy
2123import rclpy .executors
2224import rclpy .node
2325import rclpy .time
26+ import rosidl_runtime_py .convert
27+ from cv_bridge import CvBridge
28+ from PIL import Image
29+ from pydub import AudioSegment
2430from rclpy .duration import Duration
2531from rclpy .executors import MultiThreadedExecutor
2632from rclpy .node import Node
2733from rclpy .qos import QoSProfile
34+ from sensor_msgs .msg import Image as ROS2Image
2835from tf2_ros import Buffer , LookupException , TransformListener , TransformStamped
2936
37+ import rai_interfaces .msg
3038from rai .communication import (
3139 ARIConnector ,
3240 ARIMessage ,
4149 ROS2TopicAPI ,
4250 TopicConfig ,
4351)
52+ from rai_interfaces .msg import HRIMessage as ROS2HRIMessage_
53+ from rai_interfaces .msg ._audio_message import (
54+ AudioMessage as ROS2HRIMessage__Audio ,
55+ )
4456
4557
4658class ROS2ARIMessage (ARIMessage ):
@@ -200,26 +212,95 @@ class ROS2HRIMessage(HRIMessage):
200212 def __init__ (self , payload : HRIPayload , message_author : Literal ["ai" , "human" ]):
201213 super ().__init__ (payload , message_author )
202214
215+ @classmethod
216+ def from_ros2 (
217+ cls , msg : rai_interfaces .msg .HRIMessage , message_author : Literal ["ai" , "human" ]
218+ ):
219+ cv_bridge = CvBridge ()
220+ images = [
221+ cv_bridge .imgmsg_to_cv2 (img_msg , "rgb8" )
222+ for img_msg in cast (List [ROS2Image ], msg .images )
223+ ]
224+ pil_images = [Image .fromarray (img ) for img in images ]
225+ audio_segments = [
226+ AudioSegment (
227+ data = audio_msg .audio ,
228+ frame_rate = audio_msg .sample_rate ,
229+ sample_width = 2 , # bytes, int16
230+ channels = audio_msg .channels ,
231+ )
232+ for audio_msg in msg .audios
233+ ]
234+ return ROS2HRIMessage (
235+ payload = HRIPayload (text = msg .text , images = pil_images , audios = audio_segments ),
236+ message_author = message_author ,
237+ )
238+
239+ def to_ros2_dict (self ) -> OrderedDict [str , Any ]:
240+ cv_bridge = CvBridge ()
241+ assert isinstance (self .payload , HRIPayload )
242+ img_msgs = [
243+ cv_bridge .cv2_to_imgmsg (np .array (img ), "rgb8" )
244+ for img in self .payload .images
245+ ]
246+ audio_msgs = [
247+ ROS2HRIMessage__Audio (
248+ audio = audio .raw_data ,
249+ sample_rate = audio .frame_rate ,
250+ channels = audio .channels ,
251+ )
252+ for audio in self .payload .audios
253+ ]
254+
255+ return cast (
256+ OrderedDict [str , Any ],
257+ rosidl_runtime_py .convert .message_to_ordereddict (
258+ ROS2HRIMessage_ (
259+ text = self .payload .text ,
260+ images = img_msgs ,
261+ audios = audio_msgs ,
262+ )
263+ ),
264+ )
265+
203266
204267class ROS2HRIConnector (HRIConnector [ROS2HRIMessage ]):
205268 def __init__ (
206269 self ,
207270 node_name : str = f"rai_ros2_hri_connector_{ str (uuid .uuid4 ())[- 12 :]} " ,
208- targets : List [Tuple [str , TopicConfig ]] = [],
209- sources : List [Tuple [str , TopicConfig ]] = [],
271+ targets : List [Union [ str , Tuple [str , TopicConfig ] ]] = [],
272+ sources : List [Union [ str , Tuple [str , TopicConfig ] ]] = [],
210273 ):
211- configured_targets = [target [0 ] for target in targets ]
212- configured_sources = [source [0 ] for source in sources ]
274+ configured_targets = [
275+ target [0 ] if isinstance (target , tuple ) else target for target in targets
276+ ]
277+ configured_sources = [
278+ source [0 ] if isinstance (source , tuple ) else source for source in sources
279+ ]
213280
214- self ._configure_publishers (targets )
215- self ._configure_subscribers (sources )
281+ _targets = [
282+ target
283+ if isinstance (target , tuple )
284+ else (target , TopicConfig (is_subscriber = False ))
285+ for target in targets
286+ ]
287+ _sources = [
288+ source
289+ if isinstance (source , tuple )
290+ else (source , TopicConfig (is_subscriber = True ))
291+ for source in sources
292+ ]
216293
217- super ().__init__ (configured_targets , configured_sources )
218294 self ._node = Node (node_name )
219295 self ._topic_api = ConfigurableROS2TopicAPI (self ._node )
220296 self ._service_api = ROS2ServiceAPI (self ._node )
221297 self ._actions_api = ROS2ActionAPI (self ._node )
222298
299+ self ._configure_publishers (_targets )
300+ self ._configure_subscribers (_sources )
301+
302+ super ().__init__ (configured_targets , configured_sources )
303+
223304 self ._executor = MultiThreadedExecutor ()
224305 self ._executor .add_node (self ._node )
225306 self ._thread = threading .Thread (target = self ._executor .spin )
@@ -236,7 +317,7 @@ def _configure_subscribers(self, sources: List[Tuple[str, TopicConfig]]):
236317 def send_message (self , message : ROS2HRIMessage , target : str , ** kwargs ):
237318 self ._topic_api .publish_configured (
238319 topic = target ,
239- msg_content = message .payload ,
320+ msg_content = message .to_ros2_dict () ,
240321 )
241322
242323 def receive_message (
@@ -249,16 +330,12 @@ def receive_message(
249330 auto_topic_type : bool = True ,
250331 ** kwargs : Any ,
251332 ) -> ROS2HRIMessage :
252- if msg_type != "std_msgs/msg/String" :
253- raise ValueError ("ROS2HRIConnector only supports receiving sting messages" )
254333 msg = self ._topic_api .receive (
255334 topic = source ,
256335 timeout_sec = timeout_sec ,
257- msg_type = msg_type ,
258336 auto_topic_type = auto_topic_type ,
259337 )
260- payload = HRIPayload (msg .data )
261- return ROS2HRIMessage (payload = payload , message_author = message_author )
338+ return ROS2HRIMessage .from_ros2 (msg , message_author )
262339
263340 def service_call (
264341 self , message : ROS2HRIMessage , target : str , timeout_sec : float , ** kwargs : Any
@@ -284,3 +361,10 @@ def terminate_action(self, action_handle: str, **kwargs: Any):
284361 raise NotImplementedError (
285362 f"{ self .__class__ .__name__ } doesn't support action calls"
286363 )
364+
365+ def shutdown (self ):
366+ self ._executor .shutdown ()
367+ self ._thread .join ()
368+ self ._actions_api .shutdown ()
369+ self ._topic_api .shutdown ()
370+ self ._node .destroy_node ()
0 commit comments