From 9f1f77457abe4910feaa1566df3f6e84acfc5f67 Mon Sep 17 00:00:00 2001 From: maciejmajek Date: Tue, 16 Dec 2025 21:35:38 +0100 Subject: [PATCH] feat: initial obs implementation --- src/rai_core/rai/agents/base.py | 41 +++++- .../rai/communication/base_connector.py | 43 ++++-- .../rai/communication/ros2/connectors/base.py | 7 +- src/rai_core/rai/observability/__init__.py | 27 ++++ src/rai_core/rai/observability/builder.py | 98 +++++++++++++ src/rai_core/rai/observability/meta.py | 121 +++++++++++++++ src/rai_core/rai/observability/sink.py | 138 ++++++++++++++++++ test_obs.py | 98 +++++++++++++ 8 files changed, 557 insertions(+), 16 deletions(-) create mode 100644 src/rai_core/rai/observability/__init__.py create mode 100644 src/rai_core/rai/observability/builder.py create mode 100644 src/rai_core/rai/observability/meta.py create mode 100644 src/rai_core/rai/observability/sink.py create mode 100644 test_obs.py diff --git a/src/rai_core/rai/agents/base.py b/src/rai_core/rai/agents/base.py index f4325dca2..e993f4c6d 100644 --- a/src/rai_core/rai/agents/base.py +++ b/src/rai_core/rai/agents/base.py @@ -12,14 +12,49 @@ # See the License for the specific language governing permissions and # limitations under the License. + import logging from abc import ABC, abstractmethod +from typing import Optional + +from rai.communication.base_connector import BaseConnector +from rai.observability import ObservabilitySink, build_sink_from_env class BaseAgent(ABC): - def __init__(self): - """Initializes a new agent instance and sets up logging with the class name.""" - self.logger = logging.getLogger(self.__class__.__name__) + def __init__( + self, + name: Optional[str] = None, + observability_sink: Optional[ObservabilitySink] = None, + observability_endpoint: Optional[str] = None, + ): + """Initializes a new agent instance, logger, and optional observability sink.""" + self.name = name or self.__class__.__name__ + self.logger = logging.getLogger(self.name) + self.observability_sink = observability_sink or build_sink_from_env( + endpoint=observability_endpoint + ) + + def attach_connectors(self, *connectors: object) -> None: + """Annotate connectors with agent context without changing their constructors.""" + for conn in connectors: + try: + setattr(conn, "agent_name", self.name) + except Exception: + # Best effort; do not raise + continue + + def __setattr__(self, name: str, value: object) -> None: + """Automatically inject agent context into assigned connectors.""" + # Use super().__setattr__ first to ensure the attribute is set + super().__setattr__(name, value) + + # Then inspect and inject if it looks like a connector + # We avoid importing BaseConnector to prevent circular imports, use duck typing + if isinstance(value, BaseConnector): + value.agent_name = self.name + value.connector_name = value.__class__.__name__ + value.observability_sink = self.observability_sink @abstractmethod def run(self): diff --git a/src/rai_core/rai/communication/base_connector.py b/src/rai_core/rai/communication/base_connector.py index 51fd2f044..0094981fb 100644 --- a/src/rai_core/rai/communication/base_connector.py +++ b/src/rai_core/rai/communication/base_connector.py @@ -12,24 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. + import logging import time from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from typing import ( - Any, - Callable, - Dict, - Generic, - Optional, - Type, - TypeVar, - get_args, -) +from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar, get_args from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field +from rai.observability import ObservabilityMeta, ObservabilitySink, build_sink_from_env + class ConnectorException(Exception): """Base exception for all connector exceptions.""" @@ -70,8 +64,24 @@ class ParametrizedCallback(BaseModel, Generic[T]): id: str = Field(default_factory=lambda: str(uuid4())) -class BaseConnector(Generic[T]): - def __init__(self, callback_max_workers: int = 4): +class BaseConnector(Generic[T], metaclass=ObservabilityMeta): + __observability_methods__ = ( + "send_message", + "receive_message", + "service_call", + "call_service", + "create_service", + "create_action", + "start_action", + "terminate_action", + ) + + def __init__( + self, + callback_max_workers: int = 4, + observability_sink: Optional[ObservabilitySink] = None, + observability_endpoint: Optional[str] = None, + ): self.callback_max_workers = callback_max_workers self.logger = logging.getLogger(self.__class__.__name__) self.registered_callbacks: Dict[str, Dict[str, ParametrizedCallback[T]]] = ( @@ -82,6 +92,15 @@ def __init__(self, callback_max_workers: int = 4): max_workers=self.callback_max_workers ) + # Storing agent_name in a connector is useful for observability purposes, + # but storing such high level information in a low level class may be problematic. + # Raised by @Juliaj + self.agent_name: str | None = None + self.connector_name = self.__class__.__name__ + self.observability_sink = observability_sink or build_sink_from_env( + endpoint=observability_endpoint + ) + if not hasattr(self, "__orig_bases__"): self.__orig_bases__ = {} raise ConnectorException( diff --git a/src/rai_core/rai/communication/ros2/connectors/base.py b/src/rai_core/rai/communication/ros2/connectors/base.py index 2bf5145a5..7ff6c60fc 100644 --- a/src/rai_core/rai/communication/ros2/connectors/base.py +++ b/src/rai_core/rai/communication/ros2/connectors/base.py @@ -104,6 +104,8 @@ def __init__( destroy_subscribers: bool = False, executor_type: Literal["single_threaded", "multi_threaded"] = "multi_threaded", use_sim_time: bool = False, + observability_sink=None, + observability_endpoint: str | None = None, ): """Initialize the ROS2BaseConnector. @@ -121,7 +123,10 @@ def __init__( ValueError If an invalid executor type is provided. """ - super().__init__() + super().__init__( + observability_sink=observability_sink, + observability_endpoint=observability_endpoint, + ) if not rclpy.ok(): rclpy.init() diff --git a/src/rai_core/rai/observability/__init__.py b/src/rai_core/rai/observability/__init__.py new file mode 100644 index 000000000..dfc997f72 --- /dev/null +++ b/src/rai_core/rai/observability/__init__.py @@ -0,0 +1,27 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .builder import build_sink_from_env +from .meta import EVENT_SCHEMA_VERSION, ObservabilityMeta +from .sink import BufferedSink, LoggingSink, NoOpSink, ObservabilitySink + +__all__ = [ + "EVENT_SCHEMA_VERSION", + "BufferedSink", + "LoggingSink", + "NoOpSink", + "ObservabilityMeta", + "ObservabilitySink", + "build_sink_from_env", +] diff --git a/src/rai_core/rai/observability/builder.py b/src/rai_core/rai/observability/builder.py new file mode 100644 index 000000000..aea1ef430 --- /dev/null +++ b/src/rai_core/rai/observability/builder.py @@ -0,0 +1,98 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Callable, Mapping, Optional +from urllib.parse import urlparse + +from .sink import ( + BufferedSink, + LoggingSink, + NoOpSink, + ObservabilitySink, + StdoutSink, + default_buffer_size, +) + +LOGGER = logging.getLogger("ObservabilityBuilder") + + +def _make_logging_sink(_endpoint: str) -> ObservabilitySink: + return LoggingSink(logger=logging.getLogger("ObservabilityLoggingSink")) + + +def _make_stdout_sink(_endpoint: str) -> ObservabilitySink: + return StdoutSink() + + +DEFAULT_FACTORY: Mapping[str, Callable[[str], ObservabilitySink]] = { + "ws": _make_stdout_sink, # visible by default for local debugging + "wss": _make_logging_sink, + "tcp": _make_logging_sink, + "http": _make_logging_sink, + "https": _make_logging_sink, + "file": _make_logging_sink, +} + + +def build_sink_from_env( + endpoint: Optional[str] = None, + buffer_size: Optional[int] = None, + factory: Mapping[str, Callable[[str], ObservabilitySink]] = DEFAULT_FACTORY, +) -> ObservabilitySink: + """Build an observability sink from configuration. + + If no endpoint is provided or parsing fails, falls back to NoOpSink. + """ + raw_target = endpoint or os.getenv("RAI_OBS_ENDPOINT") + if not raw_target: + return NoOpSink() + + # Accept scheme-less values like "ws" to mean "use the ws factory". + parsed = urlparse(raw_target) + if not parsed.scheme and raw_target in factory: + target_scheme = raw_target + target_full = raw_target + else: + target_scheme = parsed.scheme + target_full = raw_target + + if not target_scheme: + LOGGER.debug( + "Observability endpoint missing scheme and not recognized: %s; using NoOpSink", + raw_target, + ) + return NoOpSink() + + factory_fn = factory.get(target_scheme) + if not factory_fn: + LOGGER.debug( + "Observability endpoint scheme not recognized (%s), using NoOpSink", + target_scheme, + ) + return NoOpSink() + + try: + sink = factory_fn(target_full) + except Exception as exc: # pragma: no cover - defensive + LOGGER.debug("Failed to create sink for %s: %s", target_full, exc) + return NoOpSink() + + buf_size = buffer_size if buffer_size is not None else default_buffer_size() + if buf_size and buf_size > 0: + return BufferedSink( + sink, maxlen=buf_size, logger=logging.getLogger("BufferedSink") + ) + return sink diff --git a/src/rai_core/rai/observability/meta.py b/src/rai_core/rai/observability/meta.py new file mode 100644 index 000000000..19850b2d2 --- /dev/null +++ b/src/rai_core/rai/observability/meta.py @@ -0,0 +1,121 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import time +from typing import Any, Callable, Dict + +from .sink import NoOpSink + +EVENT_SCHEMA_VERSION = "v1" + + +def _extract_target( + fn_name: str, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> dict[str, Any]: + """Extract common fields like target/source from known method signatures.""" + fields: dict[str, Any] = {} + if fn_name in { + "send_message", + "service_call", + "call_service", + "start_action", + "terminate_action", + }: + # target is usually the second positional argument or a kwarg named target + if "target" in kwargs: + fields["target"] = kwargs["target"] + elif len(args) >= 2: + fields["target"] = args[1] + if fn_name in {"receive_message"}: + # source is usually the first positional argument or kwarg named source + if "source" in kwargs: + fields["source"] = kwargs["source"] + elif len(args) >= 1: + fields["source"] = args[0] + if fn_name in {"create_service", "create_action"}: + if "service_name" in kwargs: + fields["target"] = kwargs["service_name"] + elif "action_name" in kwargs: + fields["target"] = kwargs["action_name"] + elif len(args) >= 1: + fields["target"] = args[0] + return fields + + +def _timed_handler(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + started = time.time() + try: + return fn(self, *args, **kwargs) + finally: + sink = getattr(self, "observability_sink", None) or NoOpSink() + connector_name = getattr(self, "connector_name", None) + agent_name = getattr(self, "agent_name", None) + try: + event = { + "schema_version": EVENT_SCHEMA_VERSION, + "event_type": fn.__name__, + "phase": "close", + "latency_ms": (time.time() - started) * 1000.0, + "component": agent_name, + "connector_name": connector_name, + } + if agent_name: + event["agent_name"] = agent_name + event.update(_extract_target(fn.__name__, args, kwargs)) + sink.record(event) + except Exception: + # Best-effort: never raise into caller. + pass + + +HANDLERS: Dict[str, Dict[str, Callable[..., Any]]] = { + EVENT_SCHEMA_VERSION: { + "send_message": _timed_handler, + "receive_message": _timed_handler, + "service_call": _timed_handler, + "call_service": _timed_handler, + "create_service": _timed_handler, + "create_action": _timed_handler, + "start_action": _timed_handler, + "terminate_action": _timed_handler, + } +} + +DEFAULT_METHODS = tuple(HANDLERS[EVENT_SCHEMA_VERSION].keys()) + + +class ObservabilityMeta(type): + """Metaclass that wraps selected methods with observability handlers.""" + + def __new__(mcls, name, bases, attrs): + cls = super().__new__(mcls, name, bases, attrs) + methods = getattr(cls, "__observability_methods__", DEFAULT_METHODS) + schema_version = getattr( + cls, "__observability_schema_version__", EVENT_SCHEMA_VERSION + ) + handler_map = HANDLERS.get(schema_version, {}) + + for method_name in methods: + fn = getattr(cls, method_name, None) + handler = handler_map.get(method_name) + if not fn or not handler: + continue + + @functools.wraps(fn) + def wrapper(self, *args, __fn=fn, __handler=handler, **kwargs): + return __handler(self, __fn, *args, **kwargs) + + setattr(cls, method_name, wrapper) + return cls diff --git a/src/rai_core/rai/observability/sink.py b/src/rai_core/rai/observability/sink.py new file mode 100644 index 000000000..235934fe6 --- /dev/null +++ b/src/rai_core/rai/observability/sink.py @@ -0,0 +1,138 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import threading +from collections import deque +from typing import Any, Protocol + + +class ObservabilitySink(Protocol): + """Minimal interface for observability sinks.""" + + def record(self, event: dict[str, Any]) -> None: ... + + def start(self) -> None: # pragma: no cover - default no-op + return None + + def flush(self) -> None: # pragma: no cover - default no-op + return None + + def stop(self) -> None: # pragma: no cover - default no-op + return None + + +class NoOpSink: + """Sink that does nothing.""" + + def record(self, event: dict[str, Any]) -> None: + return None + + def start(self) -> None: + return None + + def flush(self) -> None: + return None + + def stop(self) -> None: + return None + + +class LoggingSink: + """Simple sink that logs events (placeholder for real transports).""" + + def __init__(self, logger: logging.Logger | None = None) -> None: + self.logger = logger or logging.getLogger("ObservabilityLoggingSink") + + def record(self, event: dict[str, Any]) -> None: + # Log at debug to minimize noise; callable is non-blocking/low cost. + self.logger.debug("observability event: %s", event) + + +class StdoutSink: + """Sink that prints events to stdout; handy for local debugging.""" + + def __init__(self, prefix: str = "observability") -> None: + self.prefix = prefix + + def record(self, event: dict[str, Any]) -> None: + print(f"{self.prefix}: {event}") + + +class BufferedSink: + """Wraps another sink, buffers on failure, and drops oldest when full.""" + + def __init__( + self, + inner: ObservabilitySink, + maxlen: int = 256, + logger: logging.Logger | None = None, + ) -> None: + self.inner = inner + self.maxlen = maxlen + self.logger = logger or logging.getLogger("BufferedSink") + self._lock = threading.Lock() + self._buf: deque[dict[str, Any]] = deque() + + def record(self, event: dict[str, Any]) -> None: + with self._lock: + # Try to flush buffered events first. + self._flush_locked() + self._record_one_locked(event) + + def _record_one_locked(self, event: dict[str, Any]) -> None: + try: + self.inner.record(event) + except Exception: + if len(self._buf) >= self.maxlen: + self._buf.popleft() + self.logger.debug("BufferedSink dropping oldest event (buffer full)") + self._buf.append(event) + + def _flush_locked(self) -> None: + pending = list(self._buf) + self._buf.clear() + for ev in pending: + try: + self.inner.record(ev) + except Exception: + # Re-queue remaining events and stop to avoid tight retry loops. + self._buf.extend(pending[pending.index(ev) :]) + break + + def flush(self) -> None: + with self._lock: + self._flush_locked() + + def start(self) -> None: + if hasattr(self.inner, "start"): + try: + self.inner.start() + except Exception: + pass + + def stop(self) -> None: + if hasattr(self.inner, "stop"): + try: + self.inner.stop() + except Exception: + pass + + +def default_buffer_size() -> int: + try: + return int(os.getenv("RAI_OBS_BUFFER_SIZE", "256")) + except Exception: + return 256 diff --git a/test_obs.py b/test_obs.py new file mode 100644 index 000000000..4df1337e5 --- /dev/null +++ b/test_obs.py @@ -0,0 +1,98 @@ +import logging +import os +import time + +from rai.agents.base import BaseAgent +from rai.communication.ros2 import ROS2Connector, ROS2Context +from rai.communication.ros2.messages import ROS2Message +from std_srvs.srv import Trigger_Response + +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(levelname)s %(name)s %(message)s", +) + + +endpoint = os.getenv("RAI_OBS_ENDPOINT") + + +class DemoAgent(BaseAgent): + def __init__(self): + super().__init__(name="DemoAgent", observability_endpoint=endpoint) + self.nav2_connector = ROS2Connector(node_name="Nav2Agent") + self.moveit2_connector = ROS2Connector(node_name="Moveit2Agent") + self.user_task_connector = ROS2Connector(node_name="UserTaskAgent") + self.rai_gdino_connector = ROS2Connector(node_name="RAIGroundingDino") + self.rai_gsam_connector = ROS2Connector(node_name="RAIGroundedSam") + + def setup(self): + _ = self.nav2_connector.create_service( + service_name="navigate_to", + on_request=lambda x, y: Trigger_Response( + success=True, message="Navigation completed" + ), + service_type="std_srvs/srv/Trigger", + ) + _ = self.moveit2_connector.create_service( + service_name="execute_trajectory", + on_request=lambda x, y: Trigger_Response( + success=True, message="Trajectory execution completed" + ), + service_type="std_srvs/srv/Trigger", + ) + _ = self.rai_gdino_connector.create_service( + service_name="grounding_dino_classify", + on_request=lambda x, y: Trigger_Response( + success=True, message="Detection completed" + ), + service_type="std_srvs/srv/Trigger", + ) + _ = self.rai_gsam_connector.create_service( + service_name="grounded_sam_segment", + on_request=lambda x, y: Trigger_Response( + success=True, message="Segmentation completed" + ), + service_type="std_srvs/srv/Trigger", + ) + + def run(self): + pass + + def stop(self): + pass + + +with ROS2Context(): + agent = DemoAgent() + agent.setup() + + time.sleep(1.0) + for i in range(10): + _ = agent.user_task_connector.receive_message( + source="/user_task", timeout_sec=2.0 + ) + _ = agent.user_task_connector.service_call( + ROS2Message(payload={}), + target="navigate_to", + msg_type="std_srvs/srv/Trigger", + ) + _ = agent.user_task_connector.service_call( + ROS2Message(payload={}), + target="grounding_dino_classify", + msg_type="std_srvs/srv/Trigger", + ) + _ = agent.user_task_connector.service_call( + ROS2Message(payload={}), + target="grounded_sam_segment", + msg_type="std_srvs/srv/Trigger", + ) + _ = agent.user_task_connector.service_call( + ROS2Message(payload={}), + target="execute_trajectory", + msg_type="std_srvs/srv/Trigger", + ) + _ = agent.user_task_connector.send_message( + ROS2Message(payload={}), + target="/user_output", + msg_type="std_msgs/msg/String", + )