@@ -251,7 +251,7 @@ async def _primary_on_connect(self, channel: RpcChannel):
251251
252252 def subscribe (self , topic : Topic , callback : Coroutine ):
253253 """
254- Subscribe for events (prior to starting the client)
254+ Subscribe for events (before and after starting the client)
255255 @see fastapi_websocket_pubsub/rpc_event_methods.py :: RpcEventServerMethods.subscribe
256256
257257 Args:
@@ -260,18 +260,68 @@ def subscribe(self, topic: Topic, callback: Coroutine):
260260 'hello' or a complex path 'a/b/c/d' .
261261 Note: You can use ALL_TOPICS (event_notifier.ALL_TOPICS) to subscribe to all topics
262262 callback (Coroutine): the function to call upon relevant event publishing
263+
264+ Returns:
265+ Coroutine: awaitable task to subscribe to topic if connected.
263266 """
264- # TODO: add support for post connection subscriptions
265- if not self .is_ready ():
266- self ._topics .add (topic )
267- # init to empty list if no entry
268- callbacks = self ._callbacks [topic ] = self ._callbacks .get (topic , [])
269- # add callback to callbacks list of the topic
270- callbacks .append (callback )
267+ topic_is_new = topic not in self ._topics
268+ self ._topics .add (topic )
269+ # init to empty list if no entry
270+ callbacks = self ._callbacks [topic ] = self ._callbacks .get (topic , [])
271+ # add callback to callbacks list of the topic
272+ callbacks .append (callback )
273+ if topic_is_new and self .is_ready ():
274+ return self ._rpc_channel .other .subscribe (topics = [topic ])
271275 else :
272- raise PubSubClientInvalidStateException (
273- "Client already connected and subscribed"
274- )
276+ # If we can't return an RPC call future then we need
277+ # to supply something else to not fail when the
278+ # calling code awaits the result of this function.
279+ future = asyncio .Future ()
280+ future .set_result (None )
281+ return future
282+
283+ def unsubscribe (self , topic : Topic ):
284+ """
285+ Unsubscribe for events
286+
287+ Args:
288+ topic (Topic): the identifier of the event topic to be unsubscribed.
289+ Note: You can use ALL_TOPICS (event_notifier.ALL_TOPICS) to unsubscribe all topics
290+
291+ Returns:
292+ Coroutine: awaitable task to subscribe to topic if connected.
293+ """
294+ # Create none-future which can be safely awaited
295+ # but which also will not give warnings
296+ # if it isn't awaited. This is returned
297+ # on code paths which do not make RPC calls.
298+ none_future = asyncio .Future ()
299+ none_future .set_result (None )
300+
301+ # Topics to potentially make RPC calls about
302+ topics = list (self ._topics ) if topic is ALL_TOPICS else [topic ]
303+
304+ # Handle ALL_TOPICS or specific topics
305+ if topic is ALL_TOPICS and not self ._topics :
306+ logger .warning (f"Cannot unsubscribe 'ALL_TOPICS'. No topics are subscribed." )
307+ return none_future
308+ elif topic is not ALL_TOPICS and topic not in self ._topics :
309+ logger .warning (f"Cannot unsubscribe topic '{ topic } ' which is not subscribed." )
310+ return none_future
311+ elif topic is ALL_TOPICS and self ._topics :
312+ logger .debug (f"Unsubscribing all topics: { self ._topics } " )
313+ # remove all topics and callbacks
314+ self ._topics .clear ()
315+ self ._callbacks .clear ()
316+ elif topic is not ALL_TOPICS and topic in self ._topics :
317+ logger .debug (f"Unsubscribing topic '{ topic } '" )
318+ self ._topics .remove (topic )
319+ self ._callbacks .pop (topic , None )
320+
321+ if self .is_ready ():
322+ return self ._rpc_channel .other .unsubscribe (topics = topics )
323+ else :
324+ return none_future
275325
276326 async def publish (
277327 self , topics : TopicList , data = None , sync = True , notifier_id = None
0 commit comments