Skip to content

Commit b7f593c

Browse files
authored
Add unsubscribe method so client can unsubscribe from topics (#90)
* Add unsubscribe method so client can unsubscribe from topics. * Add support for post connection subscriptions * Add unsubscribe support * Add test case concerning subscribe/unsubscribe
1 parent 05d97fb commit b7f593c

File tree

4 files changed

+122
-11
lines changed

4 files changed

+122
-11
lines changed

fastapi_websocket_pubsub/pub_sub_client.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ async def _primary_on_connect(self, channel: RpcChannel):
251251

252252
def subscribe(self, topic: Topic, callback: Coroutine):
253253
"""
254-
Subscribe for events (prior to starting the client)
254+
Subscribe for events (before and after starting the client)
255255
@see fastapi_websocket_pubsub/rpc_event_methods.py :: RpcEventServerMethods.subscribe
256256
257257
Args:
@@ -260,18 +260,68 @@ def subscribe(self, topic: Topic, callback: Coroutine):
260260
'hello' or a complex path 'a/b/c/d' .
261261
Note: You can use ALL_TOPICS (event_notifier.ALL_TOPICS) to subscribe to all topics
262262
callback (Coroutine): the function to call upon relevant event publishing
263+
264+
Returns:
265+
Coroutine: awaitable task to subscribe to topic if connected.
263266
"""
264-
# TODO: add support for post connection subscriptions
265-
if not self.is_ready():
266-
self._topics.add(topic)
267-
# init to empty list if no entry
268-
callbacks = self._callbacks[topic] = self._callbacks.get(topic, [])
269-
# add callback to callbacks list of the topic
270-
callbacks.append(callback)
267+
topic_is_new = topic not in self._topics
268+
self._topics.add(topic)
269+
# init to empty list if no entry
270+
callbacks = self._callbacks[topic] = self._callbacks.get(topic, [])
271+
# add callback to callbacks list of the topic
272+
callbacks.append(callback)
273+
if topic_is_new and self.is_ready():
274+
return self._rpc_channel.other.subscribe(topics=[topic])
271275
else:
272-
raise PubSubClientInvalidStateException(
273-
"Client already connected and subscribed"
274-
)
276+
# If we can't return an RPC call future then we need
277+
# to supply something else to not fail when the
278+
# calling code awaits the result of this function.
279+
future = asyncio.Future()
280+
future.set_result(None)
281+
return future
282+
283+
def unsubscribe(self, topic: Topic):
284+
"""
285+
Unsubscribe for events
286+
287+
Args:
288+
topic (Topic): the identifier of the event topic to be unsubscribed.
289+
Note: You can use ALL_TOPICS (event_notifier.ALL_TOPICS) to unsubscribe all topics
290+
291+
Returns:
292+
Coroutine: awaitable task to subscribe to topic if connected.
293+
"""
294+
# Create none-future which can be safely awaited
295+
# but which also will not give warnings
296+
# if it isn't awaited. This is returned
297+
# on code paths which do not make RPC calls.
298+
none_future = asyncio.Future()
299+
none_future.set_result(None)
300+
301+
# Topics to potentially make RPC calls about
302+
topics = list(self._topics) if topic is ALL_TOPICS else [topic]
303+
304+
# Handle ALL_TOPICS or specific topics
305+
if topic is ALL_TOPICS and not self._topics:
306+
logger.warning(f"Cannot unsubscribe 'ALL_TOPICS'. No topics are subscribed.")
307+
return none_future
308+
elif topic is not ALL_TOPICS and topic not in self._topics:
309+
logger.warning(f"Cannot unsubscribe topic '{topic}' which is not subscribed.")
310+
return none_future
311+
elif topic is ALL_TOPICS and self._topics:
312+
logger.debug(f"Unsubscribing all topics: {self._topics}")
313+
# remove all topics and callbacks
314+
self._topics.clear()
315+
self._callbacks.clear()
316+
elif topic is not ALL_TOPICS and topic in self._topics:
317+
logger.debug(f"Unsubscribing topic '{topic}'")
318+
self._topics.remove(topic)
319+
self._callbacks.pop(topic, None)
320+
321+
if self.is_ready():
322+
return self._rpc_channel.other.unsubscribe(topics=topics)
323+
else:
324+
return none_future
275325

276326
async def publish(
277327
self, topics: TopicList, data=None, sync=True, notifier_id=None

fastapi_websocket_pubsub/pub_sub_server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ async def subscribe(
9191
) -> List[Subscription]:
9292
return await self.notifier.subscribe(self._subscriber_id, topics, callback)
9393

94+
async def unsubscribe(
95+
self, topics: Union[TopicList, ALL_TOPICS]) -> List[Subscription]:
96+
return await self.notifier.unsubscribe(self._subscriber_id, topics)
97+
9498
async def publish(self, topics: Union[TopicList, Topic], data=None):
9599
"""
96100
Publish events to subscribres of given topics currently connected to the endpoint

fastapi_websocket_pubsub/rpc_event_methods.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@ async def callback(subscription: Subscription, data):
4949
"Failed to subscribe to RPC events notifier", topics)
5050
return False
5151

52+
async def unsubscribe(self, topics: TopicList = []) -> bool:
53+
"""
54+
provided by the server so that the client can unsubscribe topics.
55+
"""
56+
for topic in topics.copy():
57+
if topic not in self.event_notifier._topics:
58+
self.logger.warning(f"Cannot unsubscribe topic '{topic}' which is not subscribed.")
59+
topics.remove(topic)
60+
# We'll use the remote channel id as our subscriber id
61+
sub_id = await self._get_channel_id_()
62+
await self.event_notifier.unsubscribe(sub_id, topics)
63+
return True
64+
5265
async def publish(self, topics: TopicList = [], data=None, sync=True, notifier_id=None) -> bool:
5366
"""
5467
Publish an event through the server to subscribers

tests/basic_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,47 @@ async def on_event(data, topic):
133133
assert published.result
134134
# wait for finish trigger
135135
await asyncio.wait_for(finish.wait(), 5)
136+
137+
138+
@pytest.mark.asyncio
139+
async def test_pub_sub_unsub(server):
140+
"""
141+
Check client can unsubscribe topic and subscribe again.
142+
"""
143+
# finish trigger
144+
finish = asyncio.Event()
145+
async with PubSubClient() as client:
146+
147+
async def on_event(data, topic):
148+
assert data == DATA
149+
finish.set()
150+
151+
# subscribe for the event
152+
client.subscribe(EVENT_TOPIC, on_event)
153+
# start listentining
154+
client.start_client(uri)
155+
# wait for the client to be ready to receive events
156+
await client.wait_until_ready()
157+
# trigger the server via an HTTP route
158+
requests.get(trigger_url)
159+
# wait for finish trigger
160+
await asyncio.wait_for(finish.wait(), 5)
161+
assert finish.is_set()
162+
163+
# unsubscribe and see that we don't get a message
164+
finish.clear()
165+
await client.unsubscribe(EVENT_TOPIC)
166+
requests.get(trigger_url)
167+
# wait for finish trigger which isn't coming
168+
with pytest.raises(asyncio.TimeoutError) as excinfo:
169+
await asyncio.wait_for(finish.wait(), 5)
170+
assert not finish.is_set()
171+
172+
# subscribe again and observe that we get the trigger
173+
finish.clear()
174+
await client.subscribe(EVENT_TOPIC, on_event)
175+
# trigger the server via an HTTP route
176+
requests.get(trigger_url)
177+
# wait for finish trigger
178+
await asyncio.wait_for(finish.wait(), 5)
179+
assert finish.is_set()

0 commit comments

Comments
 (0)