-
Notifications
You must be signed in to change notification settings - Fork 61
Description
I'm experiencing a serious performance gap between tf.data.Dataset using TFRecord and grain w/ ArrayRecord. Simply trying to iterate over MNIST with comparable speed. Snippet below exemplifies, Grain seems 5-10x slower.
Trying to make comparison fair, so I've disabled caching and prefetch for TF. I've tried adding a variety of Grain functionality to improve speed, including mp_prefetch, varying num_threads, using CacheIterDataset and ThreadPrefetchIterDataset, etc.
FYI, I've also found one TF option that slows the performance down to Grain-level: num_parallel_calls_for_decode in tfds.ReadConfig.
import time
from pathlib import Path
import grain.python as grain
import tensorflow as tf
import tensorflow_datasets as tfds
import tqdm
tf.config.set_visible_devices([], device_type="GPU")
n_epoch = 10
batch_size = 32
def make_tf_dataset():
ds = tfds.load(
"mnist",
split="train",
data_dir=Path.home() / "datasets/tfrecord",
read_config=tfds.ReadConfig(
options=tf.data.Options(), try_autocache=False, skip_prefetch=True
),
)
ds = ds.repeat(n_epoch)
ds = ds.batch(batch_size)
ds = tfds.as_numpy(ds)
return ds
def make_grain_dataset():
source = tfds.data_source(
"mnist", split="train", data_dir=Path.home() / "datasets/array_record"
)
ds = grain.MapDataset.source(source)
ds = ds.repeat(n_epoch)
ds = ds.batch(batch_size)
ds = ds.to_iter_dataset(grain.ReadOptions(num_threads=8))
return ds
def profile(ds):
t0 = time.perf_counter()
for _ in tqdm.tqdm(iter(ds)):
pass
t1 = time.perf_counter()
return t1 - t0
if __name__ == "__main__":
t_tf = profile(make_tf_dataset())
print(f"TF: {t_tf:.3f} s")
t_grain = profile(make_grain_dataset())
print(f"Grain: {t_grain:.3f} s")I keep trying to make Grain work for me every few months (since #569), but with no avail, and it seems like other performance issues are going largely unresolved. I love JAX and want to switch to Grain, but trying to get it to work is making me