Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def _add_metadata_and_validate(
# It also just changes the state of the series, so it is not a big deal.
if isinstance(data[c].dtype, CategoricalDtype) and not data[c].cat.known:
try:
data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories)
data[c] = data[c].cat.set_categories(data[c].compute().cat.categories)
except ValueError:
logger.info(f"Column `{c}` contains unknown categories. Consider casting it.")

Expand Down
28 changes: 28 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,3 +830,31 @@ def test_warning_on_large_chunks():
assert len(w) == 1, "Warning should be raised for large chunk size"
assert issubclass(w[-1].category, UserWarning)
assert "Detected chunks larger than:" in str(w[-1].message)


def test_categories_on_partitioned_dataframe(sdata_blobs: SpatialData):
df = sdata_blobs["blobs_points"].compute()
df["genes"] = RNG.choice([f"gene_{i}" for i in range(200)], len(df))
N_PARTITIONS = 200
ddf = dd.from_pandas(df, npartitions=N_PARTITIONS)
ddf["genes"] = ddf["genes"].astype("category")

df["genes"] = df["genes"].astype("category")
df_parsed = PointsModel.parse(df, npartitions=N_PARTITIONS)
ddf_parsed = PointsModel.parse(ddf, npartitions=N_PARTITIONS)

assert df["genes"].equals(df_parsed["genes"].compute())
assert df["genes"].cat.categories.equals(df_parsed["genes"].compute().cat.categories)

assert np.array_equal(df["genes"].to_numpy(), ddf_parsed["genes"].compute().to_numpy())
assert set(df["genes"].cat.categories.tolist()) == set(ddf_parsed["genes"].compute().cat.categories.tolist())

# two behavior to investigate later/report to dask (they originate in dask)
# TODO: df['genes'].cat.categories has dtype 'object', while ddf_parsed['genes'].compute().cat.categories has dtype
# 'string'
# this problem should disappear after pandas 3.0 is released
assert df["genes"].cat.categories.dtype == "object"
assert ddf_parsed["genes"].compute().cat.categories.dtype == "string"

# TODO: the list of categories are not preserving the order
assert df["genes"].cat.categories.tolist() != ddf_parsed["genes"].compute().cat.categories.tolist()
Loading