@@ -150,14 +150,29 @@ extension RedisPubSubHandler {
150150
151151 private func handleUnsubscribeMessage(
152152 withSubscriptionKey subscriptionKey: String ,
153- reportedSubscriptionCount subscriptionCount: Int
153+ reportedSubscriptionCount subscriptionCount: Int ,
154+ unsubscribeFromAllKey: String
154155 ) {
155- defer { self . pendingUnsubscribes. removeValue ( forKey: subscriptionKey) ? . succeed ( subscriptionCount) }
156-
157156 guard let subscription = self . subscriptions. removeValue ( forKey: subscriptionKey) else { return }
158157
159158 subscription. onUnsubscribe ? ( subscriptionKey, subscriptionCount)
160159 subscription. type. gauge. decrement ( )
160+
161+ switch self . pendingUnsubscribes. removeValue ( forKey: subscriptionKey) {
162+ // we found a specific pattern/channel was being removed, so just fulfill the notification
163+ case let . some( promise) :
164+ promise. succeed ( subscriptionCount)
165+
166+ // if one wasn't found, this means a [p]unsubscribe all was issued
167+ case . none:
168+ // and we want to wait for the subscription count to be 0 before we resolve it's notification
169+ // this count may be from what Redis reports, or the count of subscriptions for this particular type
170+ guard
171+ subscriptionCount == 0 || self . subscriptions. count ( where: { $0. type == subscription. type } ) == 0
172+ else { return }
173+ // always report back the count according to Redis, it is the source of truth
174+ self . pendingUnsubscribes. removeValue ( forKey: unsubscribeFromAllKey) ? . succeed ( subscriptionCount)
175+ }
161176 }
162177
163178 private func handleMessage(
@@ -249,6 +264,12 @@ extension RedisPubSubHandler {
249264
250265 // we send the UNSUBSCRIBE message to Redis,
251266 // and in the response we handle the actual removal of the receiver closure
267+
268+ // if there are no channels / patterns specified,
269+ // then this is a special case of unsubscribing from all patterns / channels
270+ guard !target. values. isEmpty else {
271+ return self . unsubscribeAll ( for: target)
272+ }
252273
253274 return self . sendSubscriptionChange (
254275 subscriptionChangeKeyword: target. unsubscribeKeyword,
@@ -302,9 +323,21 @@ extension RedisPubSubHandler {
302323 return latestSubscriptionCount
303324 }
304325
305- return self . context. writeAndFlush ( self . wrapOutboundOut ( . array( command) ) )
326+ return self . context
327+ . writeAndFlush ( self . wrapOutboundOut ( . array( command) ) )
306328 . flatMap { return subscriptionCountFuture }
307329 }
330+
331+ private func unsubscribeAll( for target: RedisSubscriptionTarget ) -> EventLoopFuture < Int > {
332+ let command = [ RESPValue ( bulk: target. unsubscribeKeyword) ]
333+
334+ let promise = self . context. eventLoop. makePromise ( of: Int . self)
335+ self . pendingUnsubscribes. updateValue ( promise, forKey: target. unsubscribeAllKey)
336+
337+ return self . context
338+ . writeAndFlush ( self . wrapOutboundOut ( . array( command) ) )
339+ . flatMap { promise. futureResult }
340+ }
308341}
309342
310343// MARK: ChannelHandler
@@ -376,8 +409,19 @@ extension RedisPubSubHandler: ChannelInboundHandler {
376409 case " subscribe " , " psubscribe " :
377410 self . handleSubscribeMessage ( withSubscriptionKey: channelOrPattern, reportedSubscriptionCount: message. int!)
378411
379- case " unsubscribe " , " punsubscribe " :
380- self . handleUnsubscribeMessage ( withSubscriptionKey: channelOrPattern, reportedSubscriptionCount: message. int!)
412+ case " unsubscribe " :
413+ self . handleUnsubscribeMessage (
414+ withSubscriptionKey: channelOrPattern,
415+ reportedSubscriptionCount: message. int!,
416+ unsubscribeFromAllKey: kUnsubscribeAllChannelsKey
417+ )
418+
419+ case " punsubscribe " :
420+ self . handleUnsubscribeMessage (
421+ withSubscriptionKey: channelOrPattern,
422+ reportedSubscriptionCount: message. int!,
423+ unsubscribeFromAllKey: kUnsubscribeAllPatternsKey
424+ )
381425
382426 // if we don't have a match, fire a channel read to forward to the next handler
383427 default : context. fireChannelRead ( data)
@@ -419,6 +463,10 @@ extension RedisPubSubHandler: ChannelOutboundHandler {
419463
420464// MARK: Private Types
421465
466+ // keys used for the pendingUnsubscribes
467+ private let kUnsubscribeAllChannelsKey = " __RS_ALL_CHS "
468+ private let kUnsubscribeAllPatternsKey = " __RS_ALL_PNS "
469+
422470fileprivate enum SubscriptionType {
423471 case channel, pattern
424472
@@ -433,7 +481,7 @@ fileprivate enum SubscriptionType {
433481extension RedisPubSubHandler {
434482 private typealias PendingSubscriptionChangeQueue = [ String : EventLoopPromise < Int > ]
435483
436- private final class Subscription {
484+ fileprivate final class Subscription {
437485 let type : SubscriptionType
438486 let onMessage : RedisSubscriptionMessageReceiver
439487 var onSubscribe : RedisSubscriptionChangeHandler ? // will be set to nil after first call
@@ -460,6 +508,13 @@ extension RedisPubSubHandler {
460508// MARK: Subscription Management Helpers
461509
462510extension RedisSubscriptionTarget {
511+ fileprivate var unsubscribeAllKey : String {
512+ switch self {
513+ case . channels: return kUnsubscribeAllChannelsKey
514+ case . patterns: return kUnsubscribeAllPatternsKey
515+ }
516+ }
517+
463518 fileprivate var subscriptionType : SubscriptionType {
464519 switch self {
465520 case . channels: return . channel
@@ -480,3 +535,12 @@ extension RedisSubscriptionTarget {
480535 }
481536 }
482537}
538+
539+ extension Dictionary where Key == String , Value == RedisPubSubHandler . Subscription {
540+ func count( where isIncluded: ( Value ) -> Bool ) -> Int {
541+ self . reduce ( into: 0 ) {
542+ guard isIncluded ( $1. value) else { return }
543+ $0 += 1
544+ }
545+ }
546+ }
0 commit comments