diff --git a/src/spatialdata/_core/operations/transform.py b/src/spatialdata/_core/operations/transform.py index 8340f23e..84258aae 100644 --- a/src/spatialdata/_core/operations/transform.py +++ b/src/spatialdata/_core/operations/transform.py @@ -381,6 +381,18 @@ def _( transformed_dask, raster_translation_single_scale = _transform_raster( data=xdata.data, axes=xdata.dims, transformation=composed, **kwargs ) + + # if a scale in the transformed data has zero shape, we skip it + if 0 in transformed_dask.shape: + if k == "scale0": + raise ValueError( + "The transformation leads to zero shaped data even at the highest resolution level. " + "Check the scaling component of the transformation." + ) + # no risk of skipping a scale (e.g. scale1) but not the next ones (e.g. scale2), because once a scale + # is skipped, all the lower scales are also skipped + continue + if raster_translation is None: raster_translation = raster_translation_single_scale # we set a dummy empty dict for the transformation that will be replaced with the correct transformation for diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index 9c1c6823..1bb494fb 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -11,7 +11,7 @@ from spatialdata._core.data_extent import are_extents_equal, get_extent from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import unpad_raster -from spatialdata.models import PointsModel, ShapesModel, get_axes_names +from spatialdata.models import Image2DModel, PointsModel, ShapesModel, get_axes_names from spatialdata.transformations.operations import ( align_elements_using_landmarks, get_transformation, @@ -229,6 +229,37 @@ def test_transform_shapes(shapes: SpatialData): assert geom_almost_equals(p0["geometry"], p1["geometry"]) +def test_transform_datatree_scale_handling(): + """ + Test the cases in which the lowest and highest scale of the result of a + transformed multi-scale image would be zero shape. + """ + + test_image = Image2DModel.parse( + np.ones((1, 10, 10)), + dims=("c", "y", "x"), + scale_factors=[2, 4], + transformations={ + "cs1": Scale([0.5] * 2, axes=["y", "x"]), + "cs2": Scale([0.01] * 2, axes=["y", "x"]), + }, + ) + + # check that the transform doesn't raise an error and that it + # discards the lowest resolution level + test_image_t = transform(test_image, to_coordinate_system="cs1") + assert list(test_image.keys()) == ["scale0", "scale1", "scale2"] + assert list(test_image_t.keys()) == ["scale0", "scale1"] + + # check that a ValueError is raised when no resolution level + # is left after the transformation + with pytest.raises( + ValueError, + match="The transformation leads to zero shaped data even at the highest resolution level", + ): + transform(test_image, to_coordinate_system="cs2") + + def test_map_coordinate_systems_single_path(full_sdata: SpatialData): scale = Scale([2], axes=("x",)) translation = Translation([100], axes=("x",))