Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions src/rai_core/rai/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
from abc import ABC, abstractmethod
from typing import Optional

from rai.communication.base_connector import BaseConnector
from rai.observability import ObservabilitySink, build_sink_from_env


class BaseAgent(ABC):
def __init__(self):
"""Initializes a new agent instance and sets up logging with the class name."""
self.logger = logging.getLogger(self.__class__.__name__)
def __init__(
self,
name: Optional[str] = None,
observability_sink: Optional[ObservabilitySink] = None,
observability_endpoint: Optional[str] = None,
):
"""Initializes a new agent instance, logger, and optional observability sink."""
self.name = name or self.__class__.__name__
self.logger = logging.getLogger(self.name)
self.observability_sink = observability_sink or build_sink_from_env(
endpoint=observability_endpoint
)

def attach_connectors(self, *connectors: object) -> None:
"""Annotate connectors with agent context without changing their constructors."""
for conn in connectors:
try:
setattr(conn, "agent_name", self.name)
except Exception:
# Best effort; do not raise
continue

def __setattr__(self, name: str, value: object) -> None:
"""Automatically inject agent context into assigned connectors."""
# Use super().__setattr__ first to ensure the attribute is set
super().__setattr__(name, value)

# Then inspect and inject if it looks like a connector
# We avoid importing BaseConnector to prevent circular imports, use duck typing
if isinstance(value, BaseConnector):
value.agent_name = self.name
value.connector_name = value.__class__.__name__
value.observability_sink = self.observability_sink

@abstractmethod
def run(self):
Expand Down
43 changes: 31 additions & 12 deletions src/rai_core/rai/communication/base_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import logging
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
Callable,
Dict,
Generic,
Optional,
Type,
TypeVar,
get_args,
)
from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar, get_args
from uuid import uuid4

from pydantic import BaseModel, ConfigDict, Field

from rai.observability import ObservabilityMeta, ObservabilitySink, build_sink_from_env


class ConnectorException(Exception):
"""Base exception for all connector exceptions."""
Expand Down Expand Up @@ -70,8 +64,24 @@ class ParametrizedCallback(BaseModel, Generic[T]):
id: str = Field(default_factory=lambda: str(uuid4()))


class BaseConnector(Generic[T]):
def __init__(self, callback_max_workers: int = 4):
class BaseConnector(Generic[T], metaclass=ObservabilityMeta):
__observability_methods__ = (
"send_message",
"receive_message",
"service_call",
"call_service",
"create_service",
"create_action",
"start_action",
"terminate_action",
)

def __init__(
self,
callback_max_workers: int = 4,
observability_sink: Optional[ObservabilitySink] = None,
observability_endpoint: Optional[str] = None,
):
self.callback_max_workers = callback_max_workers
self.logger = logging.getLogger(self.__class__.__name__)
self.registered_callbacks: Dict[str, Dict[str, ParametrizedCallback[T]]] = (
Expand All @@ -82,6 +92,15 @@ def __init__(self, callback_max_workers: int = 4):
max_workers=self.callback_max_workers
)

# Storing agent_name in a connector is useful for observability purposes,
# but storing such high level information in a low level class may be problematic.
# Raised by @Juliaj
self.agent_name: str | None = None
self.connector_name = self.__class__.__name__
self.observability_sink = observability_sink or build_sink_from_env(
endpoint=observability_endpoint
)

if not hasattr(self, "__orig_bases__"):
self.__orig_bases__ = {}
raise ConnectorException(
Expand Down
7 changes: 6 additions & 1 deletion src/rai_core/rai/communication/ros2/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(
destroy_subscribers: bool = False,
executor_type: Literal["single_threaded", "multi_threaded"] = "multi_threaded",
use_sim_time: bool = False,
observability_sink=None,
observability_endpoint: str | None = None,
):
"""Initialize the ROS2BaseConnector.

Expand All @@ -121,7 +123,10 @@ def __init__(
ValueError
If an invalid executor type is provided.
"""
super().__init__()
super().__init__(
observability_sink=observability_sink,
observability_endpoint=observability_endpoint,
)

if not rclpy.ok():
rclpy.init()
Expand Down
27 changes: 27 additions & 0 deletions src/rai_core/rai/observability/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .builder import build_sink_from_env
from .meta import EVENT_SCHEMA_VERSION, ObservabilityMeta
from .sink import BufferedSink, LoggingSink, NoOpSink, ObservabilitySink

__all__ = [
"EVENT_SCHEMA_VERSION",
"BufferedSink",
"LoggingSink",
"NoOpSink",
"ObservabilityMeta",
"ObservabilitySink",
"build_sink_from_env",
]
98 changes: 98 additions & 0 deletions src/rai_core/rai/observability/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from typing import Callable, Mapping, Optional
from urllib.parse import urlparse

from .sink import (
BufferedSink,
LoggingSink,
NoOpSink,
ObservabilitySink,
StdoutSink,
default_buffer_size,
)

LOGGER = logging.getLogger("ObservabilityBuilder")


def _make_logging_sink(_endpoint: str) -> ObservabilitySink:
return LoggingSink(logger=logging.getLogger("ObservabilityLoggingSink"))


def _make_stdout_sink(_endpoint: str) -> ObservabilitySink:
return StdoutSink()


DEFAULT_FACTORY: Mapping[str, Callable[[str], ObservabilitySink]] = {
"ws": _make_stdout_sink, # visible by default for local debugging
"wss": _make_logging_sink,
"tcp": _make_logging_sink,
"http": _make_logging_sink,
"https": _make_logging_sink,
"file": _make_logging_sink,
}


def build_sink_from_env(
endpoint: Optional[str] = None,
buffer_size: Optional[int] = None,
factory: Mapping[str, Callable[[str], ObservabilitySink]] = DEFAULT_FACTORY,
) -> ObservabilitySink:
"""Build an observability sink from configuration.

If no endpoint is provided or parsing fails, falls back to NoOpSink.
"""
raw_target = endpoint or os.getenv("RAI_OBS_ENDPOINT")
if not raw_target:
return NoOpSink()

# Accept scheme-less values like "ws" to mean "use the ws factory".
parsed = urlparse(raw_target)
if not parsed.scheme and raw_target in factory:
target_scheme = raw_target
target_full = raw_target
else:
target_scheme = parsed.scheme
target_full = raw_target

if not target_scheme:
LOGGER.debug(
"Observability endpoint missing scheme and not recognized: %s; using NoOpSink",
raw_target,
)
return NoOpSink()

factory_fn = factory.get(target_scheme)
if not factory_fn:
LOGGER.debug(
"Observability endpoint scheme not recognized (%s), using NoOpSink",
target_scheme,
)
return NoOpSink()

try:
sink = factory_fn(target_full)
except Exception as exc: # pragma: no cover - defensive
LOGGER.debug("Failed to create sink for %s: %s", target_full, exc)
return NoOpSink()

buf_size = buffer_size if buffer_size is not None else default_buffer_size()
if buf_size and buf_size > 0:
return BufferedSink(
sink, maxlen=buf_size, logger=logging.getLogger("BufferedSink")
)
return sink
121 changes: 121 additions & 0 deletions src/rai_core/rai/observability/meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import time
from typing import Any, Callable, Dict

from .sink import NoOpSink

EVENT_SCHEMA_VERSION = "v1"


def _extract_target(
fn_name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> dict[str, Any]:
"""Extract common fields like target/source from known method signatures."""
fields: dict[str, Any] = {}
if fn_name in {
"send_message",
"service_call",
"call_service",
"start_action",
"terminate_action",
}:
# target is usually the second positional argument or a kwarg named target
if "target" in kwargs:
fields["target"] = kwargs["target"]
elif len(args) >= 2:
fields["target"] = args[1]
if fn_name in {"receive_message"}:
# source is usually the first positional argument or kwarg named source
if "source" in kwargs:
fields["source"] = kwargs["source"]
elif len(args) >= 1:
fields["source"] = args[0]
if fn_name in {"create_service", "create_action"}:
if "service_name" in kwargs:
fields["target"] = kwargs["service_name"]
elif "action_name" in kwargs:
fields["target"] = kwargs["action_name"]
elif len(args) >= 1:
fields["target"] = args[0]
return fields


def _timed_handler(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
started = time.time()
try:
return fn(self, *args, **kwargs)
finally:
sink = getattr(self, "observability_sink", None) or NoOpSink()
connector_name = getattr(self, "connector_name", None)
agent_name = getattr(self, "agent_name", None)
try:
event = {
"schema_version": EVENT_SCHEMA_VERSION,
"event_type": fn.__name__,
"phase": "close",
"latency_ms": (time.time() - started) * 1000.0,
"component": agent_name,
"connector_name": connector_name,
}
if agent_name:
event["agent_name"] = agent_name
event.update(_extract_target(fn.__name__, args, kwargs))
sink.record(event)
except Exception:
# Best-effort: never raise into caller.
pass


HANDLERS: Dict[str, Dict[str, Callable[..., Any]]] = {
EVENT_SCHEMA_VERSION: {
"send_message": _timed_handler,
"receive_message": _timed_handler,
"service_call": _timed_handler,
"call_service": _timed_handler,
"create_service": _timed_handler,
"create_action": _timed_handler,
"start_action": _timed_handler,
"terminate_action": _timed_handler,
}
}

DEFAULT_METHODS = tuple(HANDLERS[EVENT_SCHEMA_VERSION].keys())


class ObservabilityMeta(type):
"""Metaclass that wraps selected methods with observability handlers."""

def __new__(mcls, name, bases, attrs):
cls = super().__new__(mcls, name, bases, attrs)
methods = getattr(cls, "__observability_methods__", DEFAULT_METHODS)
schema_version = getattr(
cls, "__observability_schema_version__", EVENT_SCHEMA_VERSION
)
handler_map = HANDLERS.get(schema_version, {})

for method_name in methods:
fn = getattr(cls, method_name, None)
handler = handler_map.get(method_name)
if not fn or not handler:
continue
Copy link
Collaborator

@Juliaj Juliaj Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running unit test, this caused one of the tests to fail with timeout. This in turn revealed an issue of double-wrapping inherited methods. This can happen when a base class is created with this metaclass and a subclass is created with the same metaclass. See more details at here


@functools.wraps(fn)
def wrapper(self, *args, __fn=fn, __handler=handler, **kwargs):
return __handler(self, __fn, *args, **kwargs)

setattr(cls, method_name, wrapper)
return cls
Loading
Loading