From efb204103a9db463179205580f9a258156a3a27e Mon Sep 17 00:00:00 2001 From: Grain Team Date: Fri, 21 Nov 2025 14:08:55 -0800 Subject: [PATCH] Internal PiperOrigin-RevId: 835353176 --- grain/_src/python/dataset/dataset.py | 8 +++-- grain/_src/python/shared_memory_array_test.py | 30 ++++++++++++++----- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index a0de96248..137ef7eef 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -50,6 +50,7 @@ from collections.abc import Awaitable, Callable, Iterable, Iterator, Mapping, Sequence import functools import json +import threading from typing import Any, Generic, TypeVar, Union, cast, overload import warnings @@ -67,6 +68,7 @@ from grain._src.core import monitoring +_usage_logging_lock = threading.Lock() _api_usage_counter = monitoring.Counter( "/grain/python/lazy_dataset/api", metadata=monitoring.Metadata( @@ -358,7 +360,8 @@ def __init__(self, parents: MapDataset | Sequence[MapDataset] = ()): parents = tuple(parents) super().__init__(parents) self._parents = cast(Sequence[MapDataset], self._parents) - usage_logging.log_event("MapDataset", tag_3="PyGrain") + with _usage_logging_lock: + usage_logging.log_event("MapDataset", tag_3="PyGrain") _api_usage_counter.Increment("MapDataset") @property @@ -977,7 +980,8 @@ def __init__( self._parents = cast( Sequence[Union[MapDataset, IterDataset]], self._parents ) - usage_logging.log_event("IterDataset", tag_3="PyGrain") + with _usage_logging_lock: + usage_logging.log_event("IterDataset", tag_3="PyGrain") _api_usage_counter.Increment("IterDataset") @property diff --git a/grain/_src/python/shared_memory_array_test.py b/grain/_src/python/shared_memory_array_test.py index db48afe12..a38508fe7 100644 --- a/grain/_src/python/shared_memory_array_test.py +++ b/grain/_src/python/shared_memory_array_test.py @@ -153,21 +153,36 @@ def test_del_many_async_reuse_pool(self): ) original_close_shm_async = SharedMemoryArray.close_shm_async + # Use a semaphore to track completed async deletions. + completed_sem = threading.Semaphore(0) + # Use a thread-safe counter because mock.call_count is not thread-safe in + # free-threaded Python. + call_count = 0 + count_lock = threading.Lock() + def my_close_shm_async(shm, unlink_on_del): original_close_shm_async(shm, unlink_on_del) + with count_lock: + nonlocal call_count + call_count += 1 + completed_sem.release() with mock.patch.object( SharedMemoryArray, "close_shm_async", side_effect=my_close_shm_async - ) as mock_close_shm_async: + ): with self.subTest("first_round_of_requests"): shm_metadatas = [ _create_and_delete_shm() for _ in range(max_outstanding_requests) ] for metadata in shm_metadatas: _wait_for_deletion(metadata) - self.assertEqual( - max_outstanding_requests, mock_close_shm_async.call_count - ) + # Wait for all async deletions to complete to ensure the semaphore in + # SharedMemoryArray is released. + for _ in range(max_outstanding_requests): + completed_sem.acquire() + with count_lock: + self.assertEqual(max_outstanding_requests, call_count) + with self.subTest("second_round_of_requests"): # Do it again to make sure the pool is reused. shm_metadatas = [ @@ -175,9 +190,10 @@ def my_close_shm_async(shm, unlink_on_del): ] for metadata in shm_metadatas: _wait_for_deletion(metadata) - self.assertEqual( - 2 * max_outstanding_requests, mock_close_shm_async.call_count - ) + for _ in range(max_outstanding_requests): + completed_sem.acquire() + with count_lock: + self.assertEqual(2 * max_outstanding_requests, call_count) if __name__ == "__main__":