Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions grain/_src/python/dataset/transformations/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def get_state(self):
window_index=self._window_index,
pos_in_window=self._pos_in_window,
parent_exhausted=self._parent_exhausted,
init=self._init,
)

def set_state(self, state):
Expand All @@ -256,6 +257,7 @@ def set_state(self, state):
self._window_index = state["window_index"]
self._pos_in_window = state["pos_in_window"]
self._parent_exhausted = state["parent_exhausted"]
self._init = state.get("init", False)
self._fill_and_shuffle_window()
# Removed previously processed elements from the window.
for _ in range(min(self._pos_in_window, len(self._window))):
Expand Down
77 changes: 77 additions & 0 deletions grain/_src/python/dataset/transformations/shuffle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,83 @@ def test_element_spec(self):
self.assertEqual(spec.dtype, np.int64)
self.assertEqual(spec.shape, ())

def test_checkpoint_restore_on_fresh_iterator(self):
"""Test that checkpoint restore works correctly on a fresh iterator.

This test verifies that when we restore a checkpoint to a fresh iterator
(not the same instance that created the checkpoint), the _init flag is
properly handled. Without the fix, the fresh iterator has _init=True, which
causes the window_index to not increment on the first window fill after
restoration, leading to incorrect shuffling and state mismatch.
"""
window_size = 10
num_elements = 1000
seed = 42
num_full_windows_before_checkpoint = 5
num_elements_to_verify = 20

# Checkpoint position: halfway through a window to trigger the bug
# (need to be partway through a window so next() triggers window refill)
checkpoint_position = (
num_full_windows_before_checkpoint * window_size + window_size // 2
)

# Create dataset with enough elements to span multiple windows
ds = dataset.MapDataset.range(num_elements).to_iter_dataset()
ds = shuffle.WindowShuffleIterDataset(
ds, window_size=window_size, seed=seed
)

# Original continuous run: consume to checkpoint position
it1 = iter(ds)
for _ in range(checkpoint_position):
next(it1)
checkpoint_state = it1.get_state()

# Continue and record data for verification
elements_after_checkpoint_original = [
next(it1) for _ in range(num_elements_to_verify)
]
state_after_verification_original = it1.get_state()

# Now simulate checkpoint restore from a fresh iterator
ds2 = dataset.MapDataset.range(num_elements).to_iter_dataset()
ds2 = shuffle.WindowShuffleIterDataset(
ds2, window_size=window_size, seed=seed
)
it2 = iter(ds2)

# Restore state at checkpoint position
it2.set_state(checkpoint_state)

# Continue from checkpoint and verify data matches
elements_after_checkpoint_restored = [
next(it2) for _ in range(num_elements_to_verify)
]
state_after_verification_restored = it2.get_state()

# Verify data matches
self.assertEqual(
elements_after_checkpoint_original,
elements_after_checkpoint_restored,
msg=(
"Data mismatch after checkpoint restore! This indicates the"
" window_index bug."
),
)

# Verify window_index matches
self.assertEqual(
state_after_verification_original["window_index"],
state_after_verification_restored["window_index"],
msg=(
"window_index mismatch: "
f"original={state_after_verification_original['window_index']}, "
f"restored={state_after_verification_restored['window_index']}. "
"This indicates the _init flag was not reset properly."
),
)


if __name__ == "__main__":
absltest.main()