diff --git a/redis_timers/handler.py b/redis_timers/handler.py index 7731f7f..be3b945 100644 --- a/redis_timers/handler.py +++ b/redis_timers/handler.py @@ -1,16 +1,15 @@ import dataclasses -import typing import pydantic -from redis_timers import settings +from redis_timers import settings, types @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) class Handler[T: pydantic.BaseModel]: topic: str schema: type[T] - handler: typing.Callable[[T], typing.Coroutine[None, None, None]] + handler: types.HandlerType[T] def build_timer_key(self, timer_id: str) -> str: return f"{self.topic}{settings.TIMERS_SEPARATOR}{timer_id}" diff --git a/redis_timers/router.py b/redis_timers/router.py index c558ef2..7403ad6 100644 --- a/redis_timers/router.py +++ b/redis_timers/router.py @@ -3,6 +3,7 @@ import pydantic +from redis_timers import types from redis_timers.handler import Handler @@ -15,13 +16,8 @@ def handler[T: pydantic.BaseModel]( *, topic: str, schema: type[T], - ) -> typing.Callable[ - [typing.Callable[[T], typing.Coroutine[None, None, None]]], - typing.Callable[[T], typing.Coroutine[None, None, None]], - ]: - def _decorator( - func: typing.Callable[[T], typing.Coroutine[None, None, None]], - ) -> typing.Callable[[T], typing.Coroutine[None, None, None]]: + ) -> typing.Callable[[types.HandlerType[T]], types.HandlerType[T]]: + def _decorator(func: types.HandlerType[T]) -> types.HandlerType[T]: self.handlers.append( Handler( topic=topic, diff --git a/redis_timers/timers.py b/redis_timers/timers.py index bd21c97..7ac9476 100644 --- a/redis_timers/timers.py +++ b/redis_timers/timers.py @@ -29,6 +29,7 @@ @dataclasses.dataclass(kw_only=True, slots=True, frozen=True) class Timers: redis_client: "aioredis.Redis[str]" + context: dict[str, typing.Any] = dataclasses.field(default_factory=dict) handlers_by_topics: dict[str, Handler[typing.Any]] = dataclasses.field(default_factory=dict, init=False) def include_router(self, router: Router) -> None: @@ -72,7 +73,7 @@ async def _handle_one_timer(self, timer_key: str) -> None: logger.exception(f"Failed to parse payload, {timer_key=}, {raw_payload=}") return - await handler.handler(payload) + await handler.handler(payload, self.context) async def handle_ready_timers(self) -> None: ready_timers = await self._fetch_ready_timers() diff --git a/redis_timers/types.py b/redis_timers/types.py new file mode 100644 index 0000000..0754273 --- /dev/null +++ b/redis_timers/types.py @@ -0,0 +1,8 @@ +import typing + +import pydantic + + +SchemaType = typing.TypeVar("SchemaType", bound=pydantic.BaseModel) +ContextType = dict[str, typing.Any] +HandlerType = typing.Callable[[SchemaType, ContextType], typing.Coroutine[None, None, None]] diff --git a/tests/test_router.py b/tests/test_router.py index d84a199..19b03c9 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -17,13 +17,13 @@ def test_register_handler() -> None: router: typing.Final = Router() @router.handler(topic="test_timer", schema=SomeSchema) - async def test_handler(data: SomeSchema) -> None: ... + async def some_handler(data: SomeSchema, _: dict[str, typing.Any]) -> None: ... assert len(router.handlers) == 1 handler: typing.Final = router.handlers[0] assert handler.topic == "test_timer" assert handler.schema == SomeSchema - assert handler.handler == test_handler + assert handler.handler == some_handler def test_register_handler_multiple_handlers() -> None: @@ -31,10 +31,10 @@ def test_register_handler_multiple_handlers() -> None: expected_handlers_count: typing.Final = 2 @router.handler(topic="handler1", schema=SomeSchema) - async def handler1(data: SomeSchema) -> None: ... + async def handler1(data: SomeSchema, _: dict[str, typing.Any]) -> None: ... @router.handler(topic="handler2", schema=AnotherSchema) - async def handler2(data: AnotherSchema) -> None: ... + async def handler2(data: AnotherSchema, _: dict[str, typing.Any]) -> None: ... assert len(router.handlers) == expected_handlers_count assert router.handlers[0].topic == "handler1" diff --git a/tests/test_timers.py b/tests/test_timers.py index 3de361e..7d37d66 100644 --- a/tests/test_timers.py +++ b/tests/test_timers.py @@ -1,5 +1,6 @@ import datetime import os +import typing from collections.abc import AsyncGenerator import pydantic @@ -58,16 +59,18 @@ def timers_instance(redis_client: "aioredis.Redis[str]", handler_results: Handle router1 = Router() @router1.handler(topic="some_topic", schema=SomePayloadModel) - async def test_handler(data: SomePayloadModel) -> None: + async def some_handler(data: SomePayloadModel, context: dict[str, typing.Any]) -> None: handler_results.add_result(data) + assert context["some_key"] == "some_value" router2 = Router() @router2.handler(topic="another_topic", schema=AnotherPayloadModel) - async def another_handler(data: AnotherPayloadModel) -> None: + async def another_handler(data: AnotherPayloadModel, context: dict[str, typing.Any]) -> None: handler_results.add_result(data) + assert context["some_key"] == "some_value" - timers = Timers(redis_client=redis_client) + timers = Timers(redis_client=redis_client, context={"some_key": "some_value"}) timers.include_router(router1) timers.include_routers(router2)