From 74cc2cbce1352644db4b8107e9024df2c3d27fc9 Mon Sep 17 00:00:00 2001 From: salex Date: Tue, 30 Dec 2025 11:51:45 -0800 Subject: [PATCH] Fix WindowShuffleDatasetIterator checkpoint restoration bug This fixes a bug in _WindowShuffleDatasetIterator where the _init flag was not included in get_state()/set_state(), causing incorrect behavior when restoring checkpoints to a fresh iterator. The bug manifested as: - When creating a fresh iterator and restoring a checkpoint to it, the _init flag would remain True (from initialization) - On the next window fill after restoration, _maybe_update_window_index() would see _init=True and not increment window_index - This caused the same window_index to be used twice, leading to incorrect shuffle seeds and data mismatch The fix: - Include _init in the state dict returned by get_state() - Restore _init in set_state() with backwards compatibility for old checkpoints (defaults to False if not present) Added test test_checkpoint_restore_on_fresh_iterator that: - Creates an iterator and checkpoints partway through a window - Restores the checkpoint to a fresh iterator (not the same instance) - Verifies data and window_index match between original and restored runs --- .../python/dataset/transformations/shuffle.py | 2 + .../dataset/transformations/shuffle_test.py | 77 +++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/grain/_src/python/dataset/transformations/shuffle.py b/grain/_src/python/dataset/transformations/shuffle.py index 3d78dbc0..f70d18ee 100644 --- a/grain/_src/python/dataset/transformations/shuffle.py +++ b/grain/_src/python/dataset/transformations/shuffle.py @@ -248,6 +248,7 @@ def get_state(self): window_index=self._window_index, pos_in_window=self._pos_in_window, parent_exhausted=self._parent_exhausted, + init=self._init, ) def set_state(self, state): @@ -256,6 +257,7 @@ def set_state(self, state): self._window_index = state["window_index"] self._pos_in_window = state["pos_in_window"] self._parent_exhausted = state["parent_exhausted"] + self._init = state.get("init", False) self._fill_and_shuffle_window() # Removed previously processed elements from the window. for _ in range(min(self._pos_in_window, len(self._window))): diff --git a/grain/_src/python/dataset/transformations/shuffle_test.py b/grain/_src/python/dataset/transformations/shuffle_test.py index 54a78e78..8e7441ab 100644 --- a/grain/_src/python/dataset/transformations/shuffle_test.py +++ b/grain/_src/python/dataset/transformations/shuffle_test.py @@ -246,6 +246,83 @@ def test_element_spec(self): self.assertEqual(spec.dtype, np.int64) self.assertEqual(spec.shape, ()) + def test_checkpoint_restore_on_fresh_iterator(self): + """Test that checkpoint restore works correctly on a fresh iterator. + + This test verifies that when we restore a checkpoint to a fresh iterator + (not the same instance that created the checkpoint), the _init flag is + properly handled. Without the fix, the fresh iterator has _init=True, which + causes the window_index to not increment on the first window fill after + restoration, leading to incorrect shuffling and state mismatch. + """ + window_size = 10 + num_elements = 1000 + seed = 42 + num_full_windows_before_checkpoint = 5 + num_elements_to_verify = 20 + + # Checkpoint position: halfway through a window to trigger the bug + # (need to be partway through a window so next() triggers window refill) + checkpoint_position = ( + num_full_windows_before_checkpoint * window_size + window_size // 2 + ) + + # Create dataset with enough elements to span multiple windows + ds = dataset.MapDataset.range(num_elements).to_iter_dataset() + ds = shuffle.WindowShuffleIterDataset( + ds, window_size=window_size, seed=seed + ) + + # Original continuous run: consume to checkpoint position + it1 = iter(ds) + for _ in range(checkpoint_position): + next(it1) + checkpoint_state = it1.get_state() + + # Continue and record data for verification + elements_after_checkpoint_original = [ + next(it1) for _ in range(num_elements_to_verify) + ] + state_after_verification_original = it1.get_state() + + # Now simulate checkpoint restore from a fresh iterator + ds2 = dataset.MapDataset.range(num_elements).to_iter_dataset() + ds2 = shuffle.WindowShuffleIterDataset( + ds2, window_size=window_size, seed=seed + ) + it2 = iter(ds2) + + # Restore state at checkpoint position + it2.set_state(checkpoint_state) + + # Continue from checkpoint and verify data matches + elements_after_checkpoint_restored = [ + next(it2) for _ in range(num_elements_to_verify) + ] + state_after_verification_restored = it2.get_state() + + # Verify data matches + self.assertEqual( + elements_after_checkpoint_original, + elements_after_checkpoint_restored, + msg=( + "Data mismatch after checkpoint restore! This indicates the" + " window_index bug." + ), + ) + + # Verify window_index matches + self.assertEqual( + state_after_verification_original["window_index"], + state_after_verification_restored["window_index"], + msg=( + "window_index mismatch: " + f"original={state_after_verification_original['window_index']}, " + f"restored={state_after_verification_restored['window_index']}. " + "This indicates the _init flag was not reset properly." + ), + ) + if __name__ == "__main__": absltest.main()