Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,555 changes: 901 additions & 1,654 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ tabulate = "^0.9.0"
pytest = "^8.2.0"
pytest-timeout = "^2.3.1"
pytest-cov = "^7.0.0"

aiohttp = "^3.12.13"

[tool.poetry.group.docs]
optional = true
Expand Down
248 changes: 248 additions & 0 deletions src/rai_core/rai/communication/http/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import asyncio
import threading
from enum import IntFlag
from typing import Callable, Optional, Any
import json
import logging

import aiohttp
from aiohttp import ClientSession, ClientTimeout, web


class HTTPAPIError(Exception): ...


class HTTPConnectorMode(IntFlag):
client = 1 # 0b01
server = 2 # 0b10
client_server = 3 # 0b11 (client | server)


class HTTPAPI:
def __init__(
self,
mode: HTTPConnectorMode = HTTPConnectorMode.client,
host="localhost",
port=8080,
):
self.host = host
self.port = port
self.mode = mode

self.routes: dict[str, list[str]] = {}

self.app = web.Application()
self.runner = web.AppRunner(self.app)
self.loop = asyncio.new_event_loop()
self.client_session = None
self._thread = threading.Thread(target=self._start_loop, daemon=True)
self._started_event = threading.Event()
self.unresolved_futures = []

self.websockets: dict[str, set[web.WebSocketResponse]] = {}
self.ws_clients: dict[str, aiohttp.ClientWebSocketResponse] = {}

def _start_loop(self):
asyncio.set_event_loop(self.loop)
self.loop.run_until_complete(self._start_server())
self._started_event.set()
self.loop.run_forever()

async def _start_server(self):
if self.mode & HTTPConnectorMode.client:
self.client_session = ClientSession()
if self.mode & HTTPConnectorMode.server:
await self.runner.setup()
site = web.TCPSite(self.runner, self.host, self.port)
await site.start()
print(f"Serving on http://{self.host}:{self.port}")

def run(self):
self._thread.start()
self._started_event.wait()

def stop(self):
def shutdown():
async def _shutdown():
if (
self.mode & HTTPConnectorMode.client
and self.client_session is not None
):
await self.client_session.close()
if self.mode & HTTPConnectorMode.server:
await self.runner.cleanup()
self.loop.stop()

asyncio.run_coroutine_threadsafe(_shutdown(), self.loop)

shutdown()

def add_route(
self,
method: str,
path: str,
handler_lambda: Callable,
):
if not (self.mode & HTTPConnectorMode.server):
return

async def handler(request):
return await handler_lambda(request)

def register():
self.app.router._frozen = False
self.app.router.add_route(method.upper(), path, handler)

self.loop.call_soon_threadsafe(register)
if self.routes.get(path) is not None:
self.routes[path] = [method]
else:
self.routes[path].append(method)

def add_websocket(self, path: str, handler_lambda):
"""
In server mode:
`path` is the HTTP path (e.g. "/ws").
`handler_lambda(ws, request)` is called for each connection.

In client mode:
`path` is the full WebSocket URL (e.g. "ws://example.com/ws").
`handler_lambda(ws, msg)` is called for each incoming message.
"""
# SERVER SIDE
if self.mode & HTTPConnectorMode.server:
if path not in self.websockets:
self.websockets[path] = set()

async def ws_handler(request):
ws = web.WebSocketResponse()
await ws.prepare(request)

# register this connection
self.websockets[path].add(ws)

try:
# user handler can read/write freely, e.g.:
# async for msg in ws: ...
await handler_lambda(ws, request)
finally:
# ensure it is removed on close
self.websockets[path].discard(ws)
await ws.close()

return ws

def register_server():
self.app.router.add_get(path, ws_handler)

self.loop.call_soon_threadsafe(register_server)

# CLIENT SIDE
if self.mode & HTTPConnectorMode.client:
async def connect_client_ws():
assert self.client_session is not None, "ClientSession not initialized"
ws = await self.client_session.ws_connect(path)
self.ws_clients[path] = ws

try:
async for msg in ws:
# let user handler inspect/read messages and optionally write
await handler_lambda(ws, msg)
finally:
# clean up on close
if self.ws_clients.get(path) is ws:
del self.ws_clients[path]
await ws.close()

def start_client():
asyncio.create_task(connect_client_ws())

self.loop.call_soon_threadsafe(start_client)

def publish_websocket(
self,
path: str,
payload: Optional[str | dict],
):
"""
Send `payload` over all WebSocket connections associated with `path`.

- For server mode: broadcasts to all connected clients on that route.
- For client mode: sends to the single client WebSocket created for that URL.
"""
if payload is None:
msg = ""
elif isinstance(payload, dict):
msg = json.dumps(payload)
else:
msg = str(payload)

async def _publish():
# collect all websockets (server + client) associated with this key
server_conns = list(self.websockets.get(path, []))
client_ws = self.ws_clients.get(path)
all_conns = server_conns + ([client_ws] if client_ws is not None else [])

dead_server = []
dead_client = False

for ws in all_conns:
try:
await ws.send_str(msg)
except Exception:
# mark broken ones to be removed
if ws in server_conns:
dead_server.append(ws)
elif ws is client_ws:
dead_client = True

# cleanup broken server connections
for ws in dead_server:
self.websockets.get(path, set()).discard(ws)

# cleanup broken client connection
if dead_client and self.ws_clients.get(path) is client_ws:
del self.ws_clients[path]

asyncio.run_coroutine_threadsafe(_publish(), self.loop)

def send_request(
self,
method: str,
url: str,
timeout: Optional[float],
*,
payload: Optional[str | dict],
headers: Optional[dict],
**kwargs,
) -> tuple[str, int]:
if not (self.mode & HTTPConnectorMode.client):
raise HTTPAPIError("Tried sending request with client mode disabled!")
timeout_cfg = ClientTimeout(timeout) if timeout is not None else None
if payload is not None and "json" not in kwargs and "data" not in kwargs:
kwargs["json"] = payload # aiohttp will set appropriate headers
if headers is not None:
kwargs["headers"] = {"Content-Type": "application/json"}

if timeout_cfg:
kwargs["timeout"] = timeout_cfg

coro = self._request(method, url, **kwargs)
future = asyncio.run_coroutine_threadsafe(coro, self.loop)
self.unresolved_futures.append(future)
if timeout is None:
return "", 200
return future.result()

async def _request(self, method: str, url: str, **kwargs):
assert self.client_session is not None
async with self.client_session.request(method.upper(), url, **kwargs) as resp:
return await resp.text(), resp.status

def shutdown(self):
for future in self.unresolved_futures:
try:
future.result(timeout=0)
except Exception as e:
logging.warning(f"Background request failed or timed out: {e}")
self.unresolved_futures.clear()
Empty file.
141 changes: 141 additions & 0 deletions src/rai_core/rai/communication/http/connectors/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import time
import logging
from typing import Any, Callable, Optional, TypeVar, Dict

from rai.communication.base_connector import BaseConnector, BaseMessage
from rai.communication.http.messages import HTTPMessage
from rai.communication.http.api import HTTPAPI, HTTPAPIError, HTTPConnectorMode

T = TypeVar("T", bound=HTTPMessage)


class HTTPBaseConnector(BaseConnector[T]):
def __init__(
self,
mode: HTTPConnectorMode = HTTPConnectorMode.client,
host: str = "localhost",
port: int = 8080,
):
super().__init__()

self._api = HTTPAPI(mode, host, port)
self._api.run()
self._services = []
self.last_msg: Dict[str, T] = {}

def send_message(self, message: T, target: str, **kwargs: Optional[Any]) -> None:
if message.protocol == "http":
self._api.send_request(
message.method,
target,
None,
payload=message.payload,
headers=message.headers,
)
else:
# self._api.

def receive_message(
self, source: str, timeout_sec: float, **kwargs: Optional[Any]
) -> T:
msg = None
if self._api.routes.get(source, None) is not None:
# a GET method has already been added...
else:
def local_callback(payload: Any) -> None:
msg = payload
self._api.add_route(
"GET",
source,
self.general_callback
)

start_time = time.time()
# wait for the message to be received
while time.time() - start_time < timeout_sec:
if source in self.last_msg:
return self.last_msg[source]
time.sleep(0.1)
else:
raise TimeoutError(
f"Message from {source} not received in {timeout_sec} seconds"
)
raise NotImplementedError("This method should be implemented by the subclass.")

def _safe_callback_wrapper(self, callback: Callable[[T], None], message: T) -> None:
try:
callback(message)
except Exception as e:
self.logger.error(f"Error in callback: {str(e)}")

def general_callback(self, source: str, message: Any) -> None:
processed_message = self.general_callback_preprocessor(message)
for parametrized_callback in self.registered_callbacks.get(source, {}).values():
payload = message if parametrized_callback.raw else processed_message
self.callback_executor.submit(
self._safe_callback_wrapper, parametrized_callback.callback, payload
)

def general_callback_preprocessor(self, message: Any) -> T:
raise NotImplementedError("This method should be implemented by the subclass.")

def service_call(
self, message: T, target: str, timeout_sec: float, **kwargs: Optional[Any]
) -> BaseMessage:
payload, status = self._api.send_request(
message.method,
target,
timeout_sec,
payload=message.payload,
headers=message.headers,
)
ret = BaseMessage(payload=payload, metadata={"status": status})
return ret

def create_service(
self,
service_name: str,
on_request: Callable,
on_done: Optional[Callable] = None,
*,
method: str,
**kwargs: Optional[Any],
) -> str:
id_str = f"{method.upper()}_{service_name}"
if on_done is not None:
logging.warning(
f"not None on_done argument passed to create_service of {self.__class__}; will have no effect!"
)
if id_str in self._services:
raise HTTPAPIError(
f"Service {service_name} already has a {method.upper()} handler"
)
self._api.add_route(method, service_name, on_request)
return id_str

def create_action(
self,
action_name: str,
generate_feedback_callback: Callable,
**kwargs: Optional[Any],
) -> str:
raise NotImplementedError("This method should be implemented by the subclass.")

def start_action(
self,
action_data: Optional[T],
target: str,
on_feedback: Callable,
on_done: Callable,
timeout_sec: float,
**kwargs: Optional[Any],
) -> str:
raise NotImplementedError("This method should be implemented by the subclass.")

def terminate_action(self, action_handle: str, **kwargs: Optional[Any]) -> Any:
raise NotImplementedError("This method should be implemented by the subclass.")

def shutdown(self):
"""Shuts down the connector and releases all resources."""
self._api.shutdown()
self.callback_executor.shutdown(wait=True)
Loading
Loading