Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Operations on `SpatialData` objects.
.. autofunction:: match_element_to_table
.. autofunction:: match_table_to_element
.. autofunction:: match_sdata_to_table
.. autofunction:: filter_by_table_query
.. autofunction:: concatenate
.. autofunction:: transform
.. autofunction:: rasterize
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
"datatree": ("https://datatree.readthedocs.io/en/latest/", None),
"dask": ("https://docs.dask.org/en/latest/", None),
"shapely": ("https://shapely.readthedocs.io/en/stable", None),
"annsel": ("https://annsel.readthedocs.io/en/latest/", None),
}


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ license = {file = "LICENSE"}
readme = "README.md"
dependencies = [
"anndata>=0.9.1",
"annsel>=0.1.2",
"click",
"dask-image",
"dask>=2025.2.0",
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"match_element_to_table",
"match_table_to_element",
"match_sdata_to_table",
"filter_by_table_query",
"SpatialData",
"get_extent",
"get_centroids",
Expand Down Expand Up @@ -57,6 +58,7 @@
from spatialdata._core.operations.vectorize import to_circles, to_polygons
from spatialdata._core.query._utils import get_bounding_box_corners
from spatialdata._core.query.relational_query import (
filter_by_table_query,
get_element_annotators,
get_element_instances,
get_values,
Expand Down
87 changes: 87 additions & 0 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import pandas as pd
from anndata import AnnData
from annsel.core.typing import Predicates
from dask.dataframe import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from xarray import DataArray, DataTree
Expand Down Expand Up @@ -650,6 +651,11 @@ def join_spatialelement_table(
ValueError
If an incorrect value is given for `match_rows`.

Notes
-----
For a graphical representation of the join operations, see the
`Tables tutorial <https://spatialdata.scverse.org/en/stable/tutorials/notebooks/notebooks/examples/tables.html>`_.

See Also
--------
match_element_to_table : Function to match elements to a table.
Expand Down Expand Up @@ -733,6 +739,11 @@ def match_table_to_element(sdata: SpatialData, element_name: str, table_name: st
-------
Table with the rows matching the instances of the element

Notes
-----
For a graphical representation of the join operations, see the
`Tables tutorial <https://spatialdata.scverse.org/en/stable/tutorials/notebooks/notebooks/examples/tables.html>`_.

See Also
--------
match_element_to_table : Function to match a spatial element to a table.
Expand Down Expand Up @@ -763,6 +774,11 @@ def match_element_to_table(
-------
A tuple containing the joined elements as a dictionary and the joined table as an AnnData object.

Notes
-----
For a graphical representation of the join operations, see the
`Tables tutorial <https://spatialdata.scverse.org/en/stable/tutorials/notebooks/notebooks/examples/tables.html>`_.

See Also
--------
match_table_to_element : Function to match a table to a spatial element.
Expand Down Expand Up @@ -795,6 +811,10 @@ def match_sdata_to_table(
how
The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right".

Notes
-----
For a graphical representation of the join operations, see the
`Tables tutorial <https://spatialdata.scverse.org/en/stable/tutorials/notebooks/notebooks/examples/tables.html>`_.
"""
if table is None:
table = sdata[table_name]
Expand All @@ -813,6 +833,73 @@ def match_sdata_to_table(
return SpatialData.init_from_elements(filtered_elements | {table_name: filtered_table})


def filter_by_table_query(
sdata: SpatialData,
table_name: str,
filter_tables: bool = True,
element_names: list[str] | None = None,
obs_expr: Predicates | None = None,
var_expr: Predicates | None = None,
x_expr: Predicates | None = None,
obs_names_expr: Predicates | None = None,
var_names_expr: Predicates | None = None,
layer: str | None = None,
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right",
) -> SpatialData:
"""Filter the SpatialData object based on a set of table queries.

Parameters
----------
sdata
The SpatialData object to filter.
table_name
The name of the table to filter the SpatialData object by.
filter_tables
If True (default), the table is filtered to only contain rows that are annotating regions
contained within the element_names.
element_names
The names of the elements to filter the SpatialData object by.
obs_expr
A Predicate or an iterable of `annsel` `Predicates` to filter :attr:`anndata.AnnData.obs` by.
var_expr
A Predicate or an iterable of `annsel` `Predicates` to filter :attr:`anndata.AnnData.var` by.
x_expr
A Predicate or an iterable of `annsel` `Predicates` to filter :attr:`anndata.AnnData.X` by.
obs_names_expr
A Predicate or an iterable of `annsel` `Predicates` to filter :attr:`anndata.AnnData.obs_names` by.
var_names_expr
A Predicate or an iterable of `annsel` `Predicates` to filter :attr:`anndata.AnnData.var_names` by.
layer
The layer of the :class:`anndata.AnnData` to filter the SpatialData object by, only used with `x_expr`.
how
The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right".

Returns
-------
The filtered SpatialData object.

Notes
-----
You can also use :func:`spatialdata.SpatialData.filter_by_table_query` with the convenience that `sdata` is the
current `SpatialData` object.

For a graphical representation of the join operations, see the
`Tables tutorial <https://spatialdata.scverse.org/en/stable/tutorials/notebooks/notebooks/examples/tables.html>`_.

For more examples on table queries, see the
`Table queries tutorial <https://spatialdata.scverse.org/en/stable/tutorials/notebooks/notebooks/examples/table_queries.html>`_.
"""
sdata_subset: SpatialData = (
sdata.subset(element_names=element_names, filter_tables=filter_tables) if element_names else sdata
)

filtered_table: AnnData = sdata_subset.tables[table_name].an.filter(
obs=obs_expr, var=var_expr, x=x_expr, obs_names=obs_names_expr, var_names=var_names_expr, layer=layer
)

return match_sdata_to_table(sdata=sdata_subset, table_name=table_name, table=filtered_table, how=how)


@dataclass
class _ValueOrigin:
origin: str
Expand Down
36 changes: 36 additions & 0 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pandas as pd
import zarr
from anndata import AnnData
from annsel.core.typing import Predicates
from dask.dataframe import DataFrame as DaskDataFrame
from dask.dataframe import Scalar, read_parquet
from geopandas import GeoDataFrame
Expand Down Expand Up @@ -2408,6 +2409,41 @@ def attrs(self, value: Mapping[Any, Any]) -> None:
else:
self._attrs = dict(value)

def filter_by_table_query(
self,
table_name: str,
filter_tables: bool = True,
element_names: list[str] | None = None,
obs_expr: Predicates | None = None,
var_expr: Predicates | None = None,
x_expr: Predicates | None = None,
obs_names_expr: Predicates | None = None,
var_names_expr: Predicates | None = None,
layer: str | None = None,
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right",
) -> SpatialData:
"""
Filter the SpatialData object based on a set of table queries.

Please see
:func:`query.relational_query.filter_by_table_query` for the complete docstring.
"""
from spatialdata._core.query.relational_query import filter_by_table_query

return filter_by_table_query(
self,
table_name=table_name,
filter_tables=filter_tables,
element_names=element_names,
obs_expr=obs_expr,
var_expr=var_expr,
x_expr=x_expr,
obs_names_expr=obs_names_expr,
var_names_expr=var_names_expr,
layer=layer,
how=how,
)


class QueryManager:
"""Perform queries on SpatialData objects."""
Expand Down
138 changes: 138 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,141 @@ def adata_labels() -> AnnData:
"tensor_copy": rng.integers(0, blobs.shape[0], size=(n_obs_labels, 2)),
}
return generate_adata(n_var, obs_labels, obsm_labels, uns_labels)


@pytest.fixture()
def complex_sdata() -> SpatialData:
"""
Create a complex SpatialData object with multiple data types for comprehensive testing.

Contains:
- Images (2D and 3D)
- Labels (2D and 3D)
- Shapes (polygons and circles)
- Points
- Multiple tables with different annotations
- Categorical and numerical values in both obs and var

Returns
-------
SpatialData
A complex SpatialData object for testing.
"""
RNG = np.random.default_rng(seed=SEED)

# Get basic components using existing functions
images = _get_images()
labels = _get_labels()
shapes = _get_shapes()
points = _get_points()

# Create tables with enhanced var data
n_var = 10

# Table 1: Basic table annotating labels2d
obs1 = pd.DataFrame(
{
"region": pd.Categorical(["labels2d"] * 50),
"instance_id": range(1, 51), # Skip background (0)
"cell_type": pd.Categorical(RNG.choice(["T cell", "B cell", "Macrophage"], size=50)),
"size": RNG.uniform(10, 100, size=50),
}
)

var1 = pd.DataFrame(
{
"feature_type": pd.Categorical(["gene", "protein", "gene", "protein", "gene"] * 2),
"importance": RNG.uniform(0, 10, size=n_var),
"is_marker": RNG.choice([True, False], size=n_var),
},
index=[f"feature_{i}" for i in range(n_var)],
)

X1 = RNG.normal(size=(50, n_var))
uns1 = {
"spatialdata_attrs": {
"region": "labels2d",
"region_key": "region",
"instance_key": "instance_id",
}
}

table1 = AnnData(X=X1, obs=obs1, var=var1, uns=uns1)

# Table 2: Annotating both polygons and circles from shapes
n_polygons = len(shapes["poly"])
n_circles = len(shapes["circles"])
total_items = n_polygons + n_circles

obs2 = pd.DataFrame(
{
"region": pd.Categorical(["poly"] * n_polygons + ["circles"] * n_circles),
"instance_id": np.concatenate([range(n_polygons), range(n_circles)]),
"category": pd.Categorical(RNG.choice(["A", "B", "C"], size=total_items)),
"value": RNG.normal(size=total_items),
"count": RNG.poisson(10, size=total_items),
}
)

var2 = pd.DataFrame(
{
"feature_type": pd.Categorical(
["feature_type1", "feature_type2", "feature_type1", "feature_type2", "feature_type1"] * 2
),
"score": RNG.exponential(2, size=n_var),
"detected": RNG.choice([True, False], p=[0.7, 0.3], size=n_var),
},
index=[f"metric_{i}" for i in range(n_var)],
)

X2 = RNG.normal(size=(total_items, n_var))
uns2 = {
"spatialdata_attrs": {
"region": ["poly", "circles"],
"region_key": "region",
"instance_key": "instance_id",
}
}

table2 = AnnData(X=X2, obs=obs2, var=var2, uns=uns2)

# Table 3: Orphan table not annotating any elements
obs3 = pd.DataFrame(
{
"cluster": pd.Categorical(RNG.choice(["cluster_1", "cluster_2", "cluster_3"], size=40)),
"sample": pd.Categorical(["sample_A"] * 20 + ["sample_B"] * 20),
"qc_pass": RNG.choice([True, False], p=[0.8, 0.2], size=40),
}
)

var3 = pd.DataFrame(
{
"feature_type": pd.Categorical(["gene", "protein", "gene", "protein", "gene"] * 2),
"mean_expression": RNG.uniform(0, 20, size=n_var),
"variance": RNG.gamma(2, 2, size=n_var),
},
index=[f"feature_{i}" for i in range(n_var)],
)

X3 = RNG.normal(size=(40, n_var))
table3 = AnnData(X=X3, obs=obs3, var=var3)

# Create additional coordinate system in one of the shapes for testing
# Modified copy of circles with an additional coordinate system
circles_alt_coords = shapes["circles"].copy()
circles_alt_coords["coordinate_system"] = "alt_system"

# Add everything to a SpatialData object
sdata = SpatialData(
images=images,
labels=labels,
shapes={**shapes, "circles_alt_coords": circles_alt_coords},
points=points,
tables={"labels_table": table1, "shapes_table": table2, "orphan_table": table3},
)

# Add layers to tables for testing layer-specific operations
sdata.tables["labels_table"].layers["scaled"] = sdata.tables["labels_table"].X * 2
sdata.tables["labels_table"].layers["log"] = np.log1p(np.abs(sdata.tables["labels_table"].X))

return sdata
Loading
Loading