Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def __call__(
worker_count: int,
start_profiling_event: synchronize.Event | None = None,
stop_profiling_event: synchronize.Event | None = None,
profiling_timeout: Any | None = None,
stats_out_queue: queues.Queue | None = None,
) -> Iterator[tuple[T, Optional[dict[str, Any]]]]:
if worker_count > 1:
Expand Down Expand Up @@ -644,6 +645,7 @@ def __init__(
)
self._start_profiling_event = mp.get_context("spawn").Event()
self._stop_profiling_event = mp.get_context("spawn").Event()
self._profiling_timeout = mp.get_context("spawn").Value("i", -1)

self._state: dict[str, dict[str, Any] | int] = {
_WORKERS_STATE: workers_state,
Expand Down Expand Up @@ -751,6 +753,7 @@ def _create_iterator_context(self) -> grain_pool.MultiProcessIterator[T]:
self._worker_init_fn,
self._start_profiling_event,
self._stop_profiling_event,
self._profiling_timeout,
self._stats_in_queues,
)

Expand Down
14 changes: 14 additions & 0 deletions grain/_src/python/grain_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __call__(
worker_count: int,
start_profiling_event: synchronize.Event | None = None,
stop_profiling_event: synchronize.Event | None = None,
profiling_timeout: Any | None = None,
stats_out_queue: queues.Queue | None = None,
) -> Iterator[T]:
"""Returns a generator of elements."""
Expand Down Expand Up @@ -188,6 +189,7 @@ def _initialize_and_get_element_producer(
worker_count: int,
start_profiling_event: synchronize.Event,
stop_profiling_event: synchronize.Event,
profiling_timeout: Any,
stats_out_queue: queues.Queue,
) -> Iterator[Any]:
"""Unpickles the element producer from the args queue and closes the queue."""
Expand All @@ -214,6 +216,7 @@ def _initialize_and_get_element_producer(
worker_count=worker_count,
start_profiling_event=start_profiling_event,
stop_profiling_event=stop_profiling_event,
profiling_timeout=profiling_timeout,
stats_out_queue=stats_out_queue,
)
# args_queue has only a single argument and thus can be safely closed.
Expand All @@ -229,6 +232,7 @@ def _worker_loop(
termination_event: synchronize.Event,
start_profiling_event: synchronize.Event,
stop_profiling_event: synchronize.Event,
profiling_timeout: Any,
worker_index: int,
worker_count: int,
enable_profiling: bool,
Expand All @@ -250,6 +254,7 @@ def _worker_loop(
worker_count=worker_count,
start_profiling_event=start_profiling_event,
stop_profiling_event=stop_profiling_event,
profiling_timeout=profiling_timeout,
stats_out_queue=stats_out_queue,
)
profiling_enabled = enable_profiling and worker_index == 0
Expand Down Expand Up @@ -339,6 +344,7 @@ def __init__(
termination_event: threading.Event | None = None,
start_profiling_event: synchronize.Event | None = None,
stop_profiling_event: synchronize.Event | None = None,
profiling_timeout: Any | None = None,
options: MultiprocessingOptions,
worker_init_fn: Callable[[int, int], None] | None = None,
stats_in_queues: tuple[queues.Queue, ...] | None = None,
Expand All @@ -356,6 +362,7 @@ def __init__(
all workers are done processing data. GrainPool will not set this event.
start_profiling_event: Event to start prism profiling.
stop_profiling_event: Event to stop prism profiling.
profiling_timeout: Shared value for profiling timeout.
options: Options for multiprocessing. See MultiprocessingOptions.
worker_init_fn: Function to run in each worker process before the element
producer. The function takes two arguments: the current worker index and
Expand Down Expand Up @@ -409,6 +416,7 @@ def __init__(
termination_event=self._workers_termination_event,
start_profiling_event=start_profiling_event,
stop_profiling_event=stop_profiling_event,
profiling_timeout=profiling_timeout,
worker_index=worker_index,
worker_count=options.num_workers,
enable_profiling=options.enable_profiling,
Expand Down Expand Up @@ -614,6 +622,7 @@ def _process_elements_in_grain_pool(
termination_event: threading.Event,
start_profiling_event: synchronize.Event | None,
stop_profiling_event: synchronize.Event | None,
profiling_timeout: Any | None,
worker_index_to_start_reading: int,
worker_init_fn: Callable[[int, int], None] | None,
stats_in_queues: tuple[queues.Queue, ...] | None,
Expand All @@ -633,6 +642,7 @@ def read_thread_should_stop():
termination_event=termination_event,
start_profiling_event=start_profiling_event,
stop_profiling_event=stop_profiling_event,
profiling_timeout=profiling_timeout,
options=multiprocessing_options,
worker_init_fn=worker_init_fn,
stats_in_queues=stats_in_queues,
Expand Down Expand Up @@ -691,6 +701,7 @@ def __init__(
worker_init_fn: Callable[[int, int], None] | None = None,
start_profiling_event: synchronize.Event | None = None,
stop_profiling_event: synchronize.Event | None = None,
profiling_timeout: Any | None = None,
stats_in_queues: tuple[queues.Queue, ...] | None = None,
):
"""Initializes MultiProcessIterator.
Expand All @@ -706,6 +717,7 @@ def __init__(
the total worker count.
start_profiling_event: Event to start prism profiling.
stop_profiling_event: Event to stop prism profiling.
profiling_timeout: Shared value for profiling timeout.
stats_in_queues: Queues to send execution summaries from worker processes
to the main process.
"""
Expand All @@ -720,6 +732,7 @@ def __init__(
self._stats_in_queues = stats_in_queues
self._start_profiling_event = start_profiling_event
self._stop_profiling_event = stop_profiling_event
self._profiling_timeout = profiling_timeout

def __del__(self):
if self._reader_thread:
Expand Down Expand Up @@ -749,6 +762,7 @@ def start_prefetch(self) -> None:
termination_event=self._termination_event,
start_profiling_event=self._start_profiling_event,
stop_profiling_event=self._stop_profiling_event,
profiling_timeout=self._profiling_timeout,
worker_index_to_start_reading=self._last_worker_index + 1,
worker_init_fn=self._worker_init_fn,
stats_in_queues=self._stats_in_queues,
Expand Down