Skip to content

Performance gap with tf.data.Dataset on MNIST #1164

@rademacher-p

Description

@rademacher-p

I'm experiencing a serious performance gap between tf.data.Dataset using TFRecord and grain w/ ArrayRecord. Simply trying to iterate over MNIST with comparable speed. Snippet below exemplifies, Grain seems 5-10x slower.

Trying to make comparison fair, so I've disabled caching and prefetch for TF. I've tried adding a variety of Grain functionality to improve speed, including mp_prefetch, varying num_threads, using CacheIterDataset and ThreadPrefetchIterDataset, etc.

FYI, I've also found one TF option that slows the performance down to Grain-level: num_parallel_calls_for_decode in tfds.ReadConfig.

import time
from pathlib import Path

import grain.python as grain
import tensorflow as tf
import tensorflow_datasets as tfds
import tqdm

tf.config.set_visible_devices([], device_type="GPU")


n_epoch = 10
batch_size = 32


def make_tf_dataset():
    ds = tfds.load(
        "mnist",
        split="train",
        data_dir=Path.home() / "datasets/tfrecord",
        read_config=tfds.ReadConfig(
            options=tf.data.Options(), try_autocache=False, skip_prefetch=True
        ),
    )

    ds = ds.repeat(n_epoch)
    ds = ds.batch(batch_size)
    ds = tfds.as_numpy(ds)

    return ds


def make_grain_dataset():
    source = tfds.data_source(
        "mnist", split="train", data_dir=Path.home() / "datasets/array_record"
    )
    ds = grain.MapDataset.source(source)

    ds = ds.repeat(n_epoch)
    ds = ds.batch(batch_size)
    ds = ds.to_iter_dataset(grain.ReadOptions(num_threads=8))

    return ds


def profile(ds):
    t0 = time.perf_counter()
    for _ in tqdm.tqdm(iter(ds)):
        pass
    t1 = time.perf_counter()
    return t1 - t0


if __name__ == "__main__":
    t_tf = profile(make_tf_dataset())
    print(f"TF: {t_tf:.3f} s")

    t_grain = profile(make_grain_dataset())
    print(f"Grain: {t_grain:.3f} s")

I keep trying to make Grain work for me every few months (since #569), but with no avail, and it seems like other performance issues are going largely unresolved. I love JAX and want to switch to Grain, but trying to get it to work is making me ☹️

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions