import grain.python as grain
from grain._src.python.dataset.transformations import shuffle as grain_shuffle
ds = grain.MapDataset.source(grain.RangeDataSource(start=0, stop=10, step=1))
ds = ds.to_iter_dataset()
ds = grain_shuffle.WindowShuffleIterDataset(ds, window_size=5, seed=42)
it1 = iter(ds)
checkpoint = it1.get_state()
original = list(it1)
it2 = iter(ds)
it2.set_state(checkpoint)
restored = list(it2)
assert original == restored, (original, restored)