diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 01b34e76..92acbec4 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -891,7 +891,7 @@ def _add_metadata_and_validate( # It also just changes the state of the series, so it is not a big deal. if isinstance(data[c].dtype, CategoricalDtype) and not data[c].cat.known: try: - data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories) + data[c] = data[c].cat.set_categories(data[c].compute().cat.categories) except ValueError: logger.info(f"Column `{c}` contains unknown categories. Consider casting it.") diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 124933f4..1e82b698 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -830,3 +830,31 @@ def test_warning_on_large_chunks(): assert len(w) == 1, "Warning should be raised for large chunk size" assert issubclass(w[-1].category, UserWarning) assert "Detected chunks larger than:" in str(w[-1].message) + + +def test_categories_on_partitioned_dataframe(sdata_blobs: SpatialData): + df = sdata_blobs["blobs_points"].compute() + df["genes"] = RNG.choice([f"gene_{i}" for i in range(200)], len(df)) + N_PARTITIONS = 200 + ddf = dd.from_pandas(df, npartitions=N_PARTITIONS) + ddf["genes"] = ddf["genes"].astype("category") + + df["genes"] = df["genes"].astype("category") + df_parsed = PointsModel.parse(df, npartitions=N_PARTITIONS) + ddf_parsed = PointsModel.parse(ddf, npartitions=N_PARTITIONS) + + assert df["genes"].equals(df_parsed["genes"].compute()) + assert df["genes"].cat.categories.equals(df_parsed["genes"].compute().cat.categories) + + assert np.array_equal(df["genes"].to_numpy(), ddf_parsed["genes"].compute().to_numpy()) + assert set(df["genes"].cat.categories.tolist()) == set(ddf_parsed["genes"].compute().cat.categories.tolist()) + + # two behavior to investigate later/report to dask (they originate in dask) + # TODO: df['genes'].cat.categories has dtype 'object', while ddf_parsed['genes'].compute().cat.categories has dtype + # 'string' + # this problem should disappear after pandas 3.0 is released + assert df["genes"].cat.categories.dtype == "object" + assert ddf_parsed["genes"].compute().cat.categories.dtype == "string" + + # TODO: the list of categories are not preserving the order + assert df["genes"].cat.categories.tolist() != ddf_parsed["genes"].compute().cat.categories.tolist()