diff --git a/livekit-rtc/livekit/rtc/audio_mixer.py b/livekit-rtc/livekit/rtc/audio_mixer.py index e2f28c6b..beab6040 100644 --- a/livekit-rtc/livekit/rtc/audio_mixer.py +++ b/livekit-rtc/livekit/rtc/audio_mixer.py @@ -50,7 +50,7 @@ def __init__( capacity (int, optional): The maximum number of mixed frames to store in the output queue. Defaults to 100. """ - self._streams: set[_Stream] = set() + self._streams: dict[_Stream, asyncio.Lock] = {} self._buffers: dict[_Stream, np.ndarray] = {} self._sample_rate: int = sample_rate self._num_channels: int = num_channels @@ -62,7 +62,7 @@ def __init__( self._ending: bool = False self._mixer_task: asyncio.Task = asyncio.create_task(self._mixer()) - def add_stream(self, stream: AsyncIterator[AudioFrame]) -> None: + def add_stream(self, stream: AsyncIterator[AudioFrame]) -> asyncio.Lock: """ Add an audio stream to the mixer. @@ -71,13 +71,17 @@ def add_stream(self, stream: AsyncIterator[AudioFrame]) -> None: Args: stream (AsyncIterator[AudioFrame]): An async iterator that produces AudioFrame objects. + + Returns: + asyncio.Lock: A lock that can be used to synchronize access to the stream. """ if self._ending: raise RuntimeError("Cannot add stream after mixer has been closed") - self._streams.add(stream) + self._streams[stream] = asyncio.Lock() if stream not in self._buffers: self._buffers[stream] = np.empty((0, self._num_channels), dtype=np.int16) + return self._streams[stream] def remove_stream(self, stream: AsyncIterator[AudioFrame]) -> None: """ @@ -88,7 +92,7 @@ def remove_stream(self, stream: AsyncIterator[AudioFrame]) -> None: Args: stream (AsyncIterator[AudioFrame]): The audio stream to remove. """ - self._streams.discard(stream) + self._streams.pop(stream, None) self._buffers.pop(stream, None) def __aiter__(self) -> "AudioMixer": @@ -133,9 +137,10 @@ async def _mixer(self) -> None: tasks = [ self._get_contribution( stream, + lock, self._buffers.get(stream, np.empty((0, self._num_channels), dtype=np.int16)), ) - for stream in list(self._streams) + for stream, lock in self._streams.items() ] results = await asyncio.gather(*tasks, return_exceptions=True) contributions = [] @@ -169,15 +174,18 @@ async def _mixer(self) -> None: await self._queue.put(None) async def _get_contribution( - self, stream: AsyncIterator[AudioFrame], buf: np.ndarray + self, stream: AsyncIterator[AudioFrame], lock: asyncio.Lock, buf: np.ndarray ) -> _Contribution: had_data = buf.shape[0] > 0 exhausted = False + + async def _get_frame() -> AudioFrame: + async with lock: + return await stream.__anext__() + while buf.shape[0] < self._chunk_size and not exhausted: try: - frame = await asyncio.wait_for( - stream.__anext__(), timeout=self._stream_timeout_ms / 1000 - ) + frame = await asyncio.wait_for(_get_frame(), timeout=self._stream_timeout_ms / 1000) except asyncio.TimeoutError: logger.warning(f"AudioMixer: stream {stream} timeout, ignoring") break