-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Use cumsum from flox #10987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Use cumsum from flox #10987
Changes from all commits
776bc5a
ae27632
a5f9326
50ccca4
f55531e
06ac372
31244e6
dd47536
e867f12
88e0ebc
181d4a3
a82ec39
6c6abed
24c3f1d
d8d0eaa
55ff46a
33d1360
c97ae98
06b52ae
84f9b44
2978877
0a9adee
ae9a3d8
c056d1f
d4873b9
21cbde2
4aebc47
f4cab24
23d9d50
9b64db2
928b158
130f98e
5a3e754
d912cda
3bc8dc7
ec8ffd6
b0cf8c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| from packaging.version import Version | ||
|
|
||
| from xarray.computation import ops | ||
| from xarray.computation.apply_ufunc import apply_ufunc | ||
| from xarray.computation.arithmetic import ( | ||
| DataArrayGroupbyArithmetic, | ||
| DatasetGroupbyArithmetic, | ||
|
|
@@ -1028,6 +1029,26 @@ def _maybe_unstack(self, obj): | |
|
|
||
| return obj | ||
|
|
||
| def _parse_dim(self, dim: Dims) -> tuple[Hashable, ...]: | ||
| parsed_dim: tuple[Hashable, ...] | ||
| if isinstance(dim, str): | ||
| parsed_dim = (dim,) | ||
| elif dim is None: | ||
| parsed_dim_list = list() | ||
| # preserve order | ||
| for dim_ in itertools.chain( | ||
| *(grouper.codes.dims for grouper in self.groupers) | ||
| ): | ||
| if dim_ not in parsed_dim_list: | ||
| parsed_dim_list.append(dim_) | ||
| parsed_dim = tuple(parsed_dim_list) | ||
| elif dim is ...: | ||
| parsed_dim = tuple(self._original_obj.dims) | ||
| else: | ||
| parsed_dim = tuple(dim) | ||
|
|
||
| return parsed_dim | ||
|
|
||
| def _flox_reduce( | ||
| self, | ||
| dim: Dims, | ||
|
|
@@ -1088,22 +1109,7 @@ def _flox_reduce( | |
| # set explicitly to avoid unnecessarily accumulating count | ||
| kwargs["min_count"] = 0 | ||
|
|
||
| parsed_dim: tuple[Hashable, ...] | ||
| if isinstance(dim, str): | ||
| parsed_dim = (dim,) | ||
| elif dim is None: | ||
| parsed_dim_list = list() | ||
| # preserve order | ||
| for dim_ in itertools.chain( | ||
| *(grouper.codes.dims for grouper in self.groupers) | ||
| ): | ||
| if dim_ not in parsed_dim_list: | ||
| parsed_dim_list.append(dim_) | ||
| parsed_dim = tuple(parsed_dim_list) | ||
| elif dim is ...: | ||
| parsed_dim = tuple(obj.dims) | ||
| else: | ||
| parsed_dim = tuple(dim) | ||
| parsed_dim = self._parse_dim(dim) | ||
|
|
||
| # Do this so we raise the same error message whether flox is present or not. | ||
| # Better to control it here than in flox. | ||
|
|
@@ -1202,6 +1208,58 @@ def _flox_reduce( | |
|
|
||
| return result | ||
|
|
||
| def _flox_scan( | ||
| self, | ||
| dim: Dims, | ||
| *, | ||
| func: str, | ||
| skipna: bool | None = None, | ||
| keep_attrs: bool | None = None, | ||
| **kwargs: Any, | ||
| ) -> T_Xarray: | ||
| from flox import groupby_scan | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| parsed_dim = self._parse_dim(dim) | ||
| obj = self._original_obj.transpose(..., *parsed_dim) | ||
| axis = range(-len(parsed_dim), 0) | ||
| codes = tuple(g.codes for g in self.groupers) | ||
|
|
||
| def wrapper(array, *by, func: str, skipna: bool | None, **kwargs): | ||
| if skipna or (skipna is None and array.dtype.kind in "cfO"): | ||
| if "nan" not in func: | ||
| func = f"nan{func}" | ||
|
|
||
| return groupby_scan(array, *codes, func=func, **kwargs) | ||
|
|
||
| actual = apply_ufunc( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, this is the way. eventually I'd like the |
||
| wrapper, | ||
| obj, | ||
| *codes, | ||
| # input_core_dims=input_core_dims, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we don't need this because we just want the full array forwarded |
||
| # for xarray's test_groupby_duplicate_coordinate_labels | ||
| # exclude_dims=set(dim_tuple), | ||
| # output_core_dims=[output_core_dims], | ||
|
Comment on lines
+1239
to
+1241
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please delete |
||
| dask="allowed", | ||
| # dask_gufunc_kwargs=dict( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please delete. |
||
| # output_sizes=output_sizes, | ||
| # output_dtypes=[dtype] if dtype is not None else None, | ||
| # ), | ||
| keep_attrs=( | ||
| _get_keep_attrs(default=True) if keep_attrs is None else keep_attrs | ||
| ), | ||
| kwargs=dict( | ||
| func=func, | ||
| skipna=skipna, | ||
| expected_groups=None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be the same as |
||
| axis=axis, | ||
| dtype=None, | ||
| method=None, | ||
| engine=None, | ||
|
Comment on lines
+1255
to
+1257
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These we should grab from kwargs and forward just like |
||
| ), | ||
| ) | ||
|
|
||
| return actual | ||
|
|
||
| def fillna(self, value: Any) -> T_Xarray: | ||
| """Fill missing values in this object by group. | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| from xarray import DataArray, Dataset, Variable, date_range | ||
| from xarray.core.groupby import _consolidate_slices | ||
| from xarray.core.types import InterpOptions, ResampleCompatible | ||
| from xarray.core.utils import module_available | ||
| from xarray.groupers import ( | ||
| BinGrouper, | ||
| EncodedGroups, | ||
|
|
@@ -2566,9 +2567,13 @@ def test_groupby_cumsum() -> None: | |
| "group_id": ds.group_id, | ||
| }, | ||
| ) | ||
| # TODO: Remove drop_vars when GH6528 is fixed | ||
| # when Dataset.cumsum propagates indexes, and the group variable? | ||
| assert_identical(expected.drop_vars(["x", "group_id"]), actual) | ||
|
|
||
| if xr.get_options()["use_flox"] and module_available("flox", minversion="0.10.5"): | ||
| assert_identical(expected, actual) | ||
| else: | ||
| # TODO: Remove drop_vars when GH6528 is fixed | ||
| # when Dataset.cumsum propagates indexes, and the group variable? | ||
| assert_identical(expected.drop_vars(["x", "group_id"]), actual) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keeping this until min_version of flox is
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not ok imo. I think it might be fixed by simply propagating coordinates in the non-flox branch of the templated code. Might be easy |
||
|
|
||
| actual = ds.foo.groupby("group_id").cumsum(dim="x") | ||
| expected.coords["group_id"] = ds.group_id | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made these changes manually now.
I'm not getting pytest-accept to correctly fix the docstrings in _aggregations.py, it's for example not indenting correctly. I'm not sure if this is just a Windows 10 thing.