Skip to content

Commit 94255d2

Browse files
committed
add update and number-in options, reduce number of custom types replacing with custom callbacks
1 parent 6abd622 commit 94255d2

File tree

2 files changed

+153
-65
lines changed

2 files changed

+153
-65
lines changed

planet/cli/data.py

Lines changed: 123 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
"""The Planet Data CLI."""
1515
from datetime import datetime
1616
import json
17-
from typing import List
17+
from typing import List, Union
1818
from contextlib import asynccontextmanager
1919

2020
import click
2121

22-
from planet import data_filter, io, DataClient, Session
22+
from planet import data_filter, exceptions, io, DataClient, Session
2323

2424
from .cmds import coro, translate_exceptions
2525
from .io import echo_json
@@ -80,15 +80,6 @@ def parse_filter(ctx, param, value: str) -> dict:
8080
return json_value
8181

8282

83-
def assets_to_filter(ctx, param, value: str) -> dict:
84-
if value is None:
85-
return value
86-
87-
# manage assets as comma-separated names
88-
assets = [part.strip() for part in value.split(",")]
89-
return data_filter.asset_filter(assets)
90-
91-
9283
def geom_to_filter(ctx, param, value: str) -> dict:
9384
if value is None:
9485
return value
@@ -122,19 +113,15 @@ def _parse_geom(ctx, param, value: str) -> dict:
122113

123114

124115
class FieldType(click.ParamType):
116+
"""Clarify that this entry is for a field"""
125117
name = 'field'
126-
help = 'FIELD is the name of the field to filter on.'
127-
128-
def convert(self, value, param, ctx):
129-
return value
130118

131119

132120
class ComparisonType(click.ParamType):
133121
name = 'comp'
134122
valid = ['lt', 'lte', 'gt', 'gte']
135-
help = 'COMP can be lt, lte, gt, or gte.'
136123

137-
def convert(self, value, param, ctx):
124+
def convert(self, value, param, ctx) -> str:
138125
if value not in self.valid:
139126
self.fail(f'COMP ({value}) must be one of {",".join(self.valid)}',
140127
param,
@@ -145,106 +132,164 @@ def convert(self, value, param, ctx):
145132
class GTComparisonType(ComparisonType):
146133
"""Only support gt or gte comparison"""
147134
valid = ['gt', 'gte']
148-
help = 'COMP can be gt, or gte.'
149135

150136

151137
class DateTimeType(click.ParamType):
152138
name = 'datetime'
153-
help = 'DATETIME can be an RFC 3339 or ISO 8601 string.'
154139

155-
def convert(self, value, param, ctx):
140+
def convert(self, value, param, ctx) -> datetime:
156141
if isinstance(value, datetime):
157142
return value
158143
else:
159-
return io.str_to_datetime(value)
144+
try:
145+
return io.str_to_datetime(value)
146+
except exceptions.PlanetError as e:
147+
self.fail(str(e))
160148

161149

162-
class DateRangeFilter(click.Tuple):
163-
help = ('Filter by date range in field. ' +
164-
f'{FieldType.help} {ComparisonType.help} {DateTimeType.help}')
150+
class CommaSeparatedString(click.types.StringParamType):
151+
"""A list of strings that is extracted from a comma-separated string."""
165152

166-
def __init__(self) -> None:
167-
super().__init__([FieldType(), ComparisonType(), DateTimeType()])
153+
def convert(self, value, param, ctx) -> List[str]:
154+
value = super().convert(value, param, ctx)
155+
156+
if isinstance(value, list):
157+
return value
158+
else:
159+
return [part.strip() for part in value.split(",")]
160+
161+
162+
class CommaSeparatedFloat(CommaSeparatedString):
163+
"""A list of floats that is extracted from a comma-separated string."""
164+
name = 'VALUE'
168165

169166
def convert(self, value, param, ctx):
170-
vals = super().convert(value, param, ctx)
167+
values = super().convert(value, param, ctx)
168+
169+
try:
170+
return [float(v) for v in values]
171+
except ValueError:
172+
self.fail(f'Cound not convert all entries in {value} to float.')
173+
174+
175+
def assets_to_filter(ctx, param, assets: str) -> dict:
176+
if assets:
177+
# TODO: validate and normalize
178+
return data_filter.asset_filter(assets)
171179

172-
field, comp, value = vals
180+
181+
def date_range_to_filter(ctx, param, values) -> Union[List[dict], None]:
182+
183+
def _func(obj):
184+
field, comp, value = obj
173185
kwargs = {'field_name': field, comp: value}
174186
return data_filter.date_range_filter(**kwargs)
175187

188+
if values:
189+
return [_func(v) for v in values]
176190

177-
class RangeFilter(click.Tuple):
178-
help = ('Filter by number range in field. ' +
179-
f'{FieldType.help} {ComparisonType.help}')
180191

181-
def __init__(self) -> None:
182-
super().__init__([FieldType(), ComparisonType(), float])
192+
def range_to_filter(ctx, param, values) -> Union[List[dict], None]:
183193

184-
def convert(self, value, param, ctx):
185-
vals = super().convert(value, param, ctx)
186-
187-
field, comp, value = vals
194+
def _func(obj):
195+
field, comp, value = obj
188196
kwargs = {'field_name': field, comp: value}
189197
return data_filter.range_filter(**kwargs)
190198

199+
if values:
200+
return [_func(v) for v in values]
191201

192-
class UpdateFilter(click.Tuple):
193-
help = ('Filter to items with changes to a specified field value made ' +
194-
'after a specified date.' +
195-
f'{FieldType.help} {GTComparisonType.help} {DateTimeType.help}')
196-
197-
def __init__(self) -> None:
198-
super().__init__([FieldType(), GTComparisonType(), DateTimeType()])
199202

200-
def convert(self, value, param, ctx):
201-
vals = super().convert(value, param, ctx)
203+
def update_to_filter(ctx, param, values) -> Union[List[dict], None]:
202204

203-
field, comp, value = vals
205+
def _func(obj):
206+
field, comp, value = obj
204207
kwargs = {'field_name': field, comp: value}
205208
return data_filter.update_filter(**kwargs)
206209

210+
if values:
211+
return [_func(v) for v in values]
212+
213+
214+
def number_in_to_filter(ctx, param, values) -> Union[dict, None]:
215+
216+
def _func(obj):
217+
field, values = obj
218+
return data_filter.number_in_filter(field_name=field, values=values)
219+
220+
if values:
221+
return [_func(v) for v in values]
222+
207223

208224
@data.command()
209225
@click.pass_context
210226
@translate_exceptions
211227
@pretty
212228
@click.option('--asset',
213-
type=str,
229+
type=CommaSeparatedString(),
214230
default=None,
215231
callback=assets_to_filter,
216-
help='Filter to items with one or more of specified assets.')
232+
help="""Filter to items with one or more of specified assets.
233+
VALUE is a comma-separated list of entries.
234+
When multiple entries are specified an implicit 'or' logic is applied.""")
217235
@click.option('--date-range',
218-
type=DateRangeFilter(),
236+
type=click.Tuple([FieldType(), ComparisonType(),
237+
DateTimeType()]),
238+
callback=date_range_to_filter,
219239
multiple=True,
220-
help=DateRangeFilter.help)
240+
help="""Filter by date range in field.
241+
FIELD is the name of the field to filter on.
242+
COMP can be lt, lte, gt, or gte.
243+
DATETIME can be an RFC3339 or ISO 8601 string.""")
221244
@click.option('--geom',
222245
type=str,
223246
default=None,
224247
callback=geom_to_filter,
225248
help='Filter to items that overlap a given geometry.')
226-
# @click.option('--number-in',
227-
# type=RangeFilter(),
228-
# multiple=True,
229-
# help=RangeFilter.help)
249+
@click.option('--number-in',
250+
type=click.Tuple([FieldType(), CommaSeparatedFloat()]),
251+
multiple=True,
252+
callback=number_in_to_filter,
253+
help="""Filter field by numeric in.
254+
FIELD is the name of the field to filter on.
255+
VALUE is a comma-separated list of entries.
256+
When multiple entries are specified an implicit 'or' logic is applied.""")
230257
@click.option('--range',
231258
'nrange',
232-
type=RangeFilter(),
259+
type=click.Tuple([FieldType(), ComparisonType(), float]),
260+
callback=range_to_filter,
233261
multiple=True,
234-
help=RangeFilter.help)
262+
help="""Filter by date range in field.
263+
FIELD is the name of the field to filter on.
264+
COMP can be lt, lte, gt, or gte.
265+
DATETIME can be an RFC3339 or ISO 8601 string.""")
235266
# @click.option('--string-in',
236267
# type=RangeFilter(),
237268
# multiple=True,
238269
# help=RangeFilter.help)
239-
@click.option('--update',
240-
type=UpdateFilter(),
241-
multiple=True,
242-
help=UpdateFilter.help)
270+
@click.option(
271+
'--update',
272+
type=click.Tuple([FieldType(), GTComparisonType(), DateTimeType()]),
273+
callback=update_to_filter,
274+
multiple=True,
275+
help="""Filter to items with changes to a specified field value made after
276+
a specified date.
277+
FIELD is the name of the field to filter on.
278+
COMP can be gt or gte.
279+
DATETIME can be an RFC3339 or ISO 8601 string.""")
243280
@click.option('--permission',
244281
type=bool,
245282
default=True,
246283
help='Filter to assets with download permissions.')
247-
def filter(ctx, asset, date_range, geom, nrange, update, permission, pretty):
284+
def filter(ctx,
285+
asset,
286+
date_range,
287+
geom,
288+
number_in,
289+
nrange,
290+
update,
291+
permission,
292+
pretty):
248293
"""Create a structured item search filter.
249294
250295
This command provides basic functionality for specifying a filter by
@@ -255,11 +300,24 @@ def filter(ctx, asset, date_range, geom, nrange, update, permission, pretty):
255300
"""
256301
permission = data_filter.permission_filter() if permission else None
257302

303+
filter_options = (asset,
304+
date_range,
305+
geom,
306+
number_in,
307+
nrange,
308+
update,
309+
permission)
310+
258311
# options allowing multiples are broken up so one filter is created for
259312
# each time the option is specified
260-
filter_args = (asset, *date_range, geom, *nrange, *update, permission)
261-
262-
filters = [f for f in filter_args if f]
313+
# unspecified options are skipped
314+
filters = []
315+
for f in filter_options:
316+
if f:
317+
if isinstance(f, list):
318+
filters.extend(f)
319+
else:
320+
filters.append(f)
263321

264322
if filters:
265323
if len(filters) > 1:

tests/integration/test_data_cli.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,36 @@ def test_data_filter_geom(geom_fixture,
172172
assert_and_filters_equal(json.loads(result.output), expected_filt)
173173

174174

175+
# @pytest.mark.skip
176+
@respx.mock
177+
@pytest.mark.asyncio
178+
def test_data_filter_number_in(
179+
invoke,
180+
assert_and_filters_equal):
181+
runner = CliRunner()
182+
183+
result = invoke(["filter"] + '--number-in field 1'.split() +
184+
'--number-in field2 2,3.5'.split(),
185+
runner=runner)
186+
assert result.exit_code == 0
187+
188+
number_in_filter1 = {
189+
"type": "NumberInFilter", "field_name": "field", "config": [1.0]
190+
}
191+
number_in_filter2 = {
192+
"type": "NumberInFilter",
193+
"field_name": "field2",
194+
"config": [2.0, 3.5]
195+
}
196+
197+
expected_filt = {
198+
"type": "AndFilter",
199+
"config": [permission_filter, number_in_filter1, number_in_filter2]
200+
}
201+
202+
assert_and_filters_equal(json.loads(result.output), expected_filt)
203+
204+
175205
@respx.mock
176206
@pytest.mark.asyncio
177207
def test_data_filter_range(invoke, assert_and_filters_equal):

0 commit comments

Comments
 (0)