anemoi-datasets 0.4.5__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/create.py +3 -2
  3. anemoi/datasets/commands/inspect.py +1 -1
  4. anemoi/datasets/commands/publish.py +30 -0
  5. anemoi/datasets/create/__init__.py +72 -35
  6. anemoi/datasets/create/check.py +6 -0
  7. anemoi/datasets/create/config.py +4 -3
  8. anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py +57 -0
  9. anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py +57 -0
  10. anemoi/datasets/create/functions/filters/rename.py +2 -3
  11. anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py +54 -0
  12. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py +59 -0
  13. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py +115 -0
  14. anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py +390 -0
  15. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +77 -0
  16. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +55 -0
  17. anemoi/datasets/create/functions/sources/__init__.py +7 -1
  18. anemoi/datasets/create/functions/sources/accumulations.py +2 -0
  19. anemoi/datasets/create/functions/sources/grib.py +87 -2
  20. anemoi/datasets/create/functions/sources/hindcasts.py +14 -73
  21. anemoi/datasets/create/functions/sources/mars.py +9 -3
  22. anemoi/datasets/create/functions/sources/xarray/__init__.py +6 -1
  23. anemoi/datasets/create/functions/sources/xarray/coordinates.py +6 -1
  24. anemoi/datasets/create/functions/sources/xarray/field.py +20 -5
  25. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +16 -16
  26. anemoi/datasets/create/functions/sources/xarray/flavour.py +126 -12
  27. anemoi/datasets/create/functions/sources/xarray/grid.py +106 -17
  28. anemoi/datasets/create/functions/sources/xarray/metadata.py +6 -12
  29. anemoi/datasets/create/functions/sources/xarray/time.py +1 -5
  30. anemoi/datasets/create/functions/sources/xarray/variable.py +10 -10
  31. anemoi/datasets/create/input/__init__.py +69 -0
  32. anemoi/datasets/create/input/action.py +123 -0
  33. anemoi/datasets/create/input/concat.py +92 -0
  34. anemoi/datasets/create/input/context.py +59 -0
  35. anemoi/datasets/create/input/data_sources.py +71 -0
  36. anemoi/datasets/create/input/empty.py +42 -0
  37. anemoi/datasets/create/input/filter.py +76 -0
  38. anemoi/datasets/create/input/function.py +122 -0
  39. anemoi/datasets/create/input/join.py +57 -0
  40. anemoi/datasets/create/input/misc.py +85 -0
  41. anemoi/datasets/create/input/pipe.py +33 -0
  42. anemoi/datasets/create/input/repeated_dates.py +217 -0
  43. anemoi/datasets/create/input/result.py +413 -0
  44. anemoi/datasets/create/input/step.py +99 -0
  45. anemoi/datasets/create/{template.py → input/template.py} +0 -42
  46. anemoi/datasets/create/persistent.py +1 -1
  47. anemoi/datasets/create/statistics/__init__.py +1 -1
  48. anemoi/datasets/create/utils.py +3 -0
  49. anemoi/datasets/create/zarr.py +4 -2
  50. anemoi/datasets/data/dataset.py +11 -1
  51. anemoi/datasets/data/debug.py +5 -1
  52. anemoi/datasets/data/masked.py +2 -2
  53. anemoi/datasets/data/rescale.py +147 -0
  54. anemoi/datasets/data/stores.py +20 -7
  55. anemoi/datasets/dates/__init__.py +113 -30
  56. anemoi/datasets/dates/groups.py +92 -19
  57. anemoi/datasets/fields.py +66 -0
  58. anemoi/datasets/utils/fields.py +47 -0
  59. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/METADATA +10 -19
  60. anemoi_datasets-0.5.5.dist-info/RECORD +121 -0
  61. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/WHEEL +1 -1
  62. anemoi/datasets/create/input.py +0 -1065
  63. anemoi_datasets-0.4.5.dist-info/RECORD +0 -96
  64. /anemoi/datasets/create/{trace.py → input/trace.py} +0 -0
  65. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/LICENSE +0 -0
  66. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/entry_points.txt +0 -0
  67. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/top_level.txt +0 -0
@@ -1,1065 +0,0 @@
1
- # (C) Copyright 2023 ECMWF.
2
- #
3
- # This software is licensed under the terms of the Apache Licence Version 2.0
4
- # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
- # In applying this licence, ECMWF does not waive the privileges and immunities
6
- # granted to it by virtue of its status as an intergovernmental organisation
7
- # nor does it submit to any jurisdiction.
8
- #
9
- import datetime
10
- import itertools
11
- import logging
12
- import math
13
- import time
14
- from collections import defaultdict
15
- from copy import deepcopy
16
- from functools import cached_property
17
- from functools import wraps
18
-
19
- import numpy as np
20
- from anemoi.utils.humanize import seconds_to_human
21
- from anemoi.utils.humanize import shorten_list
22
- from earthkit.data.core.fieldlist import FieldList
23
- from earthkit.data.core.fieldlist import MultiFieldList
24
- from earthkit.data.core.order import build_remapping
25
-
26
- from anemoi.datasets.dates import Dates
27
-
28
- from .functions import import_function
29
- from .template import Context
30
- from .template import notify_result
31
- from .template import resolve
32
- from .template import substitute
33
- from .trace import trace
34
- from .trace import trace_datasource
35
- from .trace import trace_select
36
-
37
- LOG = logging.getLogger(__name__)
38
-
39
-
40
- def parse_function_name(name):
41
-
42
- if name.endswith("h") and name[:-1].isdigit():
43
-
44
- if "-" in name:
45
- name, delta = name.split("-")
46
- sign = -1
47
-
48
- elif "+" in name:
49
- name, delta = name.split("+")
50
- sign = 1
51
-
52
- else:
53
- return name, None
54
-
55
- assert delta[-1] == "h", (name, delta)
56
- delta = sign * int(delta[:-1])
57
- return name, delta
58
-
59
- return name, None
60
-
61
-
62
- def time_delta_to_string(delta):
63
- assert isinstance(delta, datetime.timedelta), delta
64
- seconds = delta.total_seconds()
65
- hours = int(seconds // 3600)
66
- assert hours * 3600 == seconds, delta
67
- hours = abs(hours)
68
-
69
- if seconds > 0:
70
- return f"plus_{hours}h"
71
- if seconds == 0:
72
- return ""
73
- if seconds < 0:
74
- return f"minus_{hours}h"
75
-
76
-
77
- def is_function(name, kind):
78
- name, delta = parse_function_name(name) # noqa
79
- try:
80
- import_function(name, kind)
81
- return True
82
- except ImportError as e:
83
- print(e)
84
- return False
85
-
86
-
87
- def assert_fieldlist(method):
88
- @wraps(method)
89
- def wrapper(self, *args, **kwargs):
90
- result = method(self, *args, **kwargs)
91
- assert isinstance(result, FieldList), type(result)
92
- return result
93
-
94
- return wrapper
95
-
96
-
97
- def assert_is_fieldlist(obj):
98
- assert isinstance(obj, FieldList), type(obj)
99
-
100
-
101
- def _data_request(data):
102
- date = None
103
- params_levels = defaultdict(set)
104
- params_steps = defaultdict(set)
105
-
106
- area = grid = None
107
-
108
- for field in data:
109
- try:
110
- if date is None:
111
- date = field.datetime()["valid_time"]
112
-
113
- if field.datetime()["valid_time"] != date:
114
- continue
115
-
116
- as_mars = field.metadata(namespace="mars")
117
- if not as_mars:
118
- continue
119
- step = as_mars.get("step")
120
- levtype = as_mars.get("levtype", "sfc")
121
- param = as_mars["param"]
122
- levelist = as_mars.get("levelist", None)
123
- area = field.mars_area
124
- grid = field.mars_grid
125
-
126
- if levelist is None:
127
- params_levels[levtype].add(param)
128
- else:
129
- params_levels[levtype].add((param, levelist))
130
-
131
- if step:
132
- params_steps[levtype].add((param, step))
133
- except Exception:
134
- LOG.error(f"Error in retrieving metadata (cannot build data request info) for {field}", exc_info=True)
135
-
136
- def sort(old_dic):
137
- new_dic = {}
138
- for k, v in old_dic.items():
139
- new_dic[k] = sorted(list(v))
140
- return new_dic
141
-
142
- params_steps = sort(params_steps)
143
- params_levels = sort(params_levels)
144
-
145
- return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid)
146
-
147
-
148
- class Action:
149
- def __init__(self, context, action_path, /, *args, **kwargs):
150
- if "args" in kwargs and "kwargs" in kwargs:
151
- """We have:
152
- args = []
153
- kwargs = {args: [...], kwargs: {...}}
154
- move the content of kwargs to args and kwargs.
155
- """
156
- assert len(kwargs) == 2, (args, kwargs)
157
- assert not args, (args, kwargs)
158
- args = kwargs.pop("args")
159
- kwargs = kwargs.pop("kwargs")
160
-
161
- assert isinstance(context, ActionContext), type(context)
162
- self.context = context
163
- self.kwargs = kwargs
164
- self.args = args
165
- self.action_path = action_path
166
-
167
- @classmethod
168
- def _short_str(cls, x):
169
- x = str(x)
170
- if len(x) < 1000:
171
- return x
172
- return x[:1000] + "..."
173
-
174
- def __repr__(self, *args, _indent_="\n", _inline_="", **kwargs):
175
- more = ",".join([str(a)[:5000] for a in args])
176
- more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()])
177
-
178
- more = more[:5000]
179
- txt = f"{self.__class__.__name__}: {_inline_}{_indent_}{more}"
180
- if _indent_:
181
- txt = txt.replace("\n", "\n ")
182
- return txt
183
-
184
- def select(self, dates, **kwargs):
185
- self._raise_not_implemented()
186
-
187
- def _raise_not_implemented(self):
188
- raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
189
-
190
- def _trace_select(self, dates):
191
- return f"{self.__class__.__name__}({shorten(dates)})"
192
-
193
-
194
- def shorten(dates):
195
- if isinstance(dates, (list, tuple)):
196
- dates = [d.isoformat() for d in dates]
197
- if len(dates) > 5:
198
- return f"{dates[0]}...{dates[-1]}"
199
- return dates
200
-
201
-
202
- class Result:
203
- empty = False
204
- _coords_already_built = False
205
-
206
- def __init__(self, context, action_path, dates):
207
- assert isinstance(context, ActionContext), type(context)
208
- assert isinstance(action_path, list), action_path
209
-
210
- self.context = context
211
- self.dates = dates
212
- self.action_path = action_path
213
-
214
- @property
215
- @trace_datasource
216
- def datasource(self):
217
- self._raise_not_implemented()
218
-
219
- @property
220
- def data_request(self):
221
- """Returns a dictionary with the parameters needed to retrieve the data."""
222
- return _data_request(self.datasource)
223
-
224
- def get_cube(self):
225
- trace("🧊", f"getting cube from {self.__class__.__name__}")
226
- ds = self.datasource
227
-
228
- remapping = self.context.remapping
229
- order_by = self.context.order_by
230
- flatten_grid = self.context.flatten_grid
231
- start = time.time()
232
- LOG.debug("Sorting dataset %s %s", dict(order_by), remapping)
233
- assert order_by, order_by
234
-
235
- patches = {"number": {None: 0}}
236
-
237
- try:
238
- cube = ds.cube(
239
- order_by,
240
- remapping=remapping,
241
- flatten_values=flatten_grid,
242
- patches=patches,
243
- )
244
- cube = cube.squeeze()
245
- LOG.debug(f"Sorting done in {seconds_to_human(time.time()-start)}.")
246
- except ValueError:
247
- self.explain(ds, order_by, remapping=remapping, patches=patches)
248
- # raise ValueError(f"Error in {self}")
249
- exit(1)
250
-
251
- if LOG.isEnabledFor(logging.DEBUG):
252
- LOG.debug("Cube shape: %s", cube)
253
- for k, v in cube.user_coords.items():
254
- LOG.debug(" %s %s", k, shorten_list(v, max_length=10))
255
-
256
- return cube
257
-
258
- def explain(self, ds, *args, remapping, patches):
259
-
260
- METADATA = (
261
- "date",
262
- "time",
263
- "step",
264
- "hdate",
265
- "valid_datetime",
266
- "levtype",
267
- "levelist",
268
- "number",
269
- "level",
270
- "shortName",
271
- "paramId",
272
- "variable",
273
- )
274
-
275
- # We redo the logic here
276
- print()
277
- print("❌" * 40)
278
- print()
279
- if len(args) == 1 and isinstance(args[0], (list, tuple)):
280
- args = args[0]
281
-
282
- # print("Executing", self.action_path)
283
- # print("Dates:", compress_dates(self.dates))
284
-
285
- names = []
286
- for a in args:
287
- if isinstance(a, str):
288
- names.append(a)
289
- elif isinstance(a, dict):
290
- names += list(a.keys())
291
-
292
- print(f"Building a {len(names)}D hypercube using", names)
293
- ds = ds.order_by(*args, remapping=remapping, patches=patches)
294
- user_coords = ds.unique_values(*names, remapping=remapping, patches=patches, progress_bar=False)
295
-
296
- print()
297
- print("Number of unique values found for each coordinate:")
298
- for k, v in user_coords.items():
299
- print(f" {k:20}:", len(v), shorten_list(v, max_length=10))
300
- print()
301
- user_shape = tuple(len(v) for k, v in user_coords.items())
302
- print("Shape of the hypercube :", user_shape)
303
- print(
304
- "Number of expected fields :", math.prod(user_shape), "=", " x ".join([str(i) for i in user_shape])
305
- )
306
- print("Number of fields in the dataset :", len(ds))
307
- print("Difference :", abs(len(ds) - math.prod(user_shape)))
308
- print()
309
-
310
- remapping = build_remapping(remapping, patches)
311
- expected = set(itertools.product(*user_coords.values()))
312
- extra = set()
313
-
314
- if math.prod(user_shape) > len(ds):
315
- print(f"This means that all the fields in the datasets do not exists for all combinations of {names}.")
316
-
317
- for f in ds:
318
- metadata = remapping(f.metadata)
319
- key = tuple(metadata(n, default=None) for n in names)
320
- if key in expected:
321
- expected.remove(key)
322
- else:
323
- extra.add(key)
324
-
325
- print("Missing fields:")
326
- print()
327
- for i, f in enumerate(sorted(expected)):
328
- print(" ", f)
329
- if i >= 9 and len(expected) > 10:
330
- print("...", len(expected) - i - 1, "more")
331
- break
332
-
333
- print("Extra fields:")
334
- print()
335
- for i, f in enumerate(sorted(extra)):
336
- print(" ", f)
337
- if i >= 9 and len(extra) > 10:
338
- print("...", len(extra) - i - 1, "more")
339
- break
340
-
341
- print()
342
- print("Missing values:")
343
- per_name = defaultdict(set)
344
- for e in expected:
345
- for n, v in zip(names, e):
346
- per_name[n].add(v)
347
-
348
- for n, v in per_name.items():
349
- print(" ", n, len(v), shorten_list(sorted(v), max_length=10))
350
- print()
351
-
352
- print("Extra values:")
353
- per_name = defaultdict(set)
354
- for e in extra:
355
- for n, v in zip(names, e):
356
- per_name[n].add(v)
357
-
358
- for n, v in per_name.items():
359
- print(" ", n, len(v), shorten_list(sorted(v), max_length=10))
360
- print()
361
-
362
- print("To solve this issue, you can:")
363
- print(
364
- " - Provide a better selection, like 'step: 0' or 'level: 1000' to "
365
- "reduce the number of selected fields."
366
- )
367
- print(
368
- " - Split the 'input' part in smaller sections using 'join', "
369
- "making sure that each section represent a full hypercube."
370
- )
371
-
372
- else:
373
- print(f"More fields in dataset that expected for {names}. " "This means that some fields are duplicated.")
374
- duplicated = defaultdict(list)
375
- for f in ds:
376
- # print(f.metadata(namespace="default"))
377
- metadata = remapping(f.metadata)
378
- key = tuple(metadata(n, default=None) for n in names)
379
- duplicated[key].append(f)
380
-
381
- print("Duplicated fields:")
382
- print()
383
- duplicated = {k: v for k, v in duplicated.items() if len(v) > 1}
384
- for i, (k, v) in enumerate(sorted(duplicated.items())):
385
- print(" ", k)
386
- for f in v:
387
- x = {k: f.metadata(k, default=None) for k in METADATA if f.metadata(k, default=None) is not None}
388
- print(" ", f, x)
389
- if i >= 9 and len(duplicated) > 10:
390
- print("...", len(duplicated) - i - 1, "more")
391
- break
392
-
393
- print()
394
- print("To solve this issue, you can:")
395
- print(" - Provide a better selection, like 'step: 0' or 'level: 1000'")
396
- print(" - Change the way 'param' is computed using 'variable_naming' " "in the 'build' section.")
397
-
398
- print()
399
- print("❌" * 40)
400
- print()
401
- exit(1)
402
-
403
- def __repr__(self, *args, _indent_="\n", **kwargs):
404
- more = ",".join([str(a)[:5000] for a in args])
405
- more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()])
406
-
407
- dates = " no-dates"
408
- if self.dates is not None:
409
- dates = f" {len(self.dates)} dates"
410
- dates += " ("
411
- dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.dates)
412
- if len(dates) > 100:
413
- dates = dates[:100] + "..."
414
- dates += ")"
415
-
416
- more = more[:5000]
417
- txt = f"{self.__class__.__name__}:{dates}{_indent_}{more}"
418
- if _indent_:
419
- txt = txt.replace("\n", "\n ")
420
- return txt
421
-
422
- def _raise_not_implemented(self):
423
- raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
424
-
425
- def _trace_datasource(self, *args, **kwargs):
426
- return f"{self.__class__.__name__}({shorten(self.dates)})"
427
-
428
- def build_coords(self):
429
- if self._coords_already_built:
430
- return
431
- from_data = self.get_cube().user_coords
432
- from_config = self.context.order_by
433
-
434
- keys_from_config = list(from_config.keys())
435
- keys_from_data = list(from_data.keys())
436
- assert keys_from_data == keys_from_config, f"Critical error: {keys_from_data=} != {keys_from_config=}. {self=}"
437
-
438
- variables_key = list(from_config.keys())[1]
439
- ensembles_key = list(from_config.keys())[2]
440
-
441
- if isinstance(from_config[variables_key], (list, tuple)):
442
- assert all([v == w for v, w in zip(from_data[variables_key], from_config[variables_key])]), (
443
- from_data[variables_key],
444
- from_config[variables_key],
445
- )
446
-
447
- self._variables = from_data[variables_key] # "param_level"
448
- self._ensembles = from_data[ensembles_key] # "number"
449
-
450
- first_field = self.datasource[0]
451
- grid_points = first_field.grid_points()
452
-
453
- lats, lons = grid_points
454
-
455
- assert len(lats) == len(lons), (len(lats), len(lons), first_field)
456
- assert len(lats) == math.prod(first_field.shape), (len(lats), first_field.shape, first_field)
457
-
458
- north = np.amax(lats)
459
- south = np.amin(lats)
460
- east = np.amax(lons)
461
- west = np.amin(lons)
462
-
463
- assert -90 <= south <= north <= 90, (south, north, first_field)
464
- assert (-180 <= west <= east <= 180) or (0 <= west <= east <= 360), (
465
- west,
466
- east,
467
- first_field,
468
- )
469
-
470
- grid_values = list(range(len(grid_points[0])))
471
-
472
- self._grid_points = grid_points
473
- self._resolution = first_field.resolution
474
- self._grid_values = grid_values
475
- self._field_shape = first_field.shape
476
- self._proj_string = first_field.proj_string if hasattr(first_field, "proj_string") else None
477
-
478
- @property
479
- def variables(self):
480
- self.build_coords()
481
- return self._variables
482
-
483
- @property
484
- def ensembles(self):
485
- self.build_coords()
486
- return self._ensembles
487
-
488
- @property
489
- def resolution(self):
490
- self.build_coords()
491
- return self._resolution
492
-
493
- @property
494
- def grid_values(self):
495
- self.build_coords()
496
- return self._grid_values
497
-
498
- @property
499
- def grid_points(self):
500
- self.build_coords()
501
- return self._grid_points
502
-
503
- @property
504
- def field_shape(self):
505
- self.build_coords()
506
- return self._field_shape
507
-
508
- @property
509
- def proj_string(self):
510
- self.build_coords()
511
- return self._proj_string
512
-
513
- @cached_property
514
- def shape(self):
515
- return [
516
- len(self.dates),
517
- len(self.variables),
518
- len(self.ensembles),
519
- len(self.grid_values),
520
- ]
521
-
522
- @cached_property
523
- def coords(self):
524
- return {
525
- "dates": self.dates,
526
- "variables": self.variables,
527
- "ensembles": self.ensembles,
528
- "values": self.grid_values,
529
- }
530
-
531
-
532
- class EmptyResult(Result):
533
- empty = True
534
-
535
- def __init__(self, context, action_path, dates):
536
- super().__init__(context, action_path + ["empty"], dates)
537
-
538
- @cached_property
539
- @assert_fieldlist
540
- @trace_datasource
541
- def datasource(self):
542
- from earthkit.data import from_source
543
-
544
- return from_source("empty")
545
-
546
- @property
547
- def variables(self):
548
- return []
549
-
550
-
551
- def _flatten(ds):
552
- if isinstance(ds, MultiFieldList):
553
- return [_tidy(f) for s in ds._indexes for f in _flatten(s)]
554
- return [ds]
555
-
556
-
557
- def _tidy(ds, indent=0):
558
- if isinstance(ds, MultiFieldList):
559
-
560
- sources = [s for s in _flatten(ds) if len(s) > 0]
561
- if len(sources) == 1:
562
- return sources[0]
563
- return MultiFieldList(sources)
564
- return ds
565
-
566
-
567
- class FunctionResult(Result):
568
- def __init__(self, context, action_path, dates, action):
569
- super().__init__(context, action_path, dates)
570
- assert isinstance(action, Action), type(action)
571
- self.action = action
572
-
573
- self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs))
574
-
575
- def _trace_datasource(self, *args, **kwargs):
576
- return f"{self.action.name}({shorten(self.dates)})"
577
-
578
- @cached_property
579
- @assert_fieldlist
580
- @notify_result
581
- @trace_datasource
582
- def datasource(self):
583
- args, kwargs = resolve(self.context, (self.args, self.kwargs))
584
-
585
- try:
586
- return _tidy(self.action.function(FunctionContext(self), self.dates, *args, **kwargs))
587
- except Exception:
588
- LOG.error(f"Error in {self.action.function.__name__}", exc_info=True)
589
- raise
590
-
591
- def __repr__(self):
592
- try:
593
- return f"{self.action.name}({shorten(self.dates)})"
594
- except Exception:
595
- return f"{self.__class__.__name__}(unitialised)"
596
-
597
- @property
598
- def function(self):
599
- raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
600
-
601
-
602
- class JoinResult(Result):
603
- def __init__(self, context, action_path, dates, results, **kwargs):
604
- super().__init__(context, action_path, dates)
605
- self.results = [r for r in results if not r.empty]
606
-
607
- @cached_property
608
- @assert_fieldlist
609
- @notify_result
610
- @trace_datasource
611
- def datasource(self):
612
- ds = EmptyResult(self.context, self.action_path, self.dates).datasource
613
- for i in self.results:
614
- ds += i.datasource
615
- return _tidy(ds)
616
-
617
- def __repr__(self):
618
- content = "\n".join([str(i) for i in self.results])
619
- return super().__repr__(content)
620
-
621
-
622
- class DateShiftAction(Action):
623
- def __init__(self, context, action_path, delta, **kwargs):
624
- super().__init__(context, action_path, **kwargs)
625
-
626
- if isinstance(delta, str):
627
- if delta[0] == "-":
628
- delta, sign = int(delta[1:]), -1
629
- else:
630
- delta, sign = int(delta), 1
631
- delta = datetime.timedelta(hours=sign * delta)
632
- assert isinstance(delta, int), delta
633
- delta = datetime.timedelta(hours=delta)
634
- self.delta = delta
635
-
636
- self.content = action_factory(kwargs, context, self.action_path + ["shift"])
637
-
638
- @trace_select
639
- def select(self, dates):
640
- shifted_dates = [d + self.delta for d in dates]
641
- result = self.content.select(shifted_dates)
642
- return UnShiftResult(self.context, self.action_path, dates, result, action=self)
643
-
644
- def __repr__(self):
645
- return super().__repr__(f"{self.delta}\n{self.content}")
646
-
647
-
648
- class UnShiftResult(Result):
649
- def __init__(self, context, action_path, dates, result, action):
650
- super().__init__(context, action_path, dates)
651
- # dates are the actual requested dates
652
- # result does not have the same dates
653
- self.action = action
654
- self.result = result
655
-
656
- def _trace_datasource(self, *args, **kwargs):
657
- return f"{self.action.delta}({shorten(self.dates)})"
658
-
659
- @cached_property
660
- @assert_fieldlist
661
- @notify_result
662
- @trace_datasource
663
- def datasource(self):
664
- from earthkit.data.indexing.fieldlist import FieldArray
665
-
666
- class DateShiftedField:
667
- def __init__(self, field, delta):
668
- self.field = field
669
- self.delta = delta
670
-
671
- def metadata(self, key):
672
- value = self.field.metadata(key)
673
- if key == "param":
674
- return value + "_" + time_delta_to_string(self.delta)
675
- if key == "valid_datetime":
676
- dt = datetime.datetime.fromisoformat(value)
677
- new_dt = dt - self.delta
678
- new_value = new_dt.isoformat()
679
- return new_value
680
- if key in ["date", "time", "step", "hdate"]:
681
- raise NotImplementedError(f"metadata {key} not implemented when shifting dates")
682
- return value
683
-
684
- def __getattr__(self, name):
685
- return getattr(self.field, name)
686
-
687
- ds = self.result.datasource
688
- ds = FieldArray([DateShiftedField(fs, self.action.delta) for fs in ds])
689
- return _tidy(ds)
690
-
691
-
692
- class FunctionAction(Action):
693
- def __init__(self, context, action_path, _name, **kwargs):
694
- super().__init__(context, action_path, **kwargs)
695
- self.name = _name
696
-
697
- @trace_select
698
- def select(self, dates):
699
- return FunctionResult(self.context, self.action_path, dates, action=self)
700
-
701
- @property
702
- def function(self):
703
- # name, delta = parse_function_name(self.name)
704
- return import_function(self.name, "sources")
705
-
706
- def __repr__(self):
707
- content = ""
708
- content += ",".join([self._short_str(a) for a in self.args])
709
- content += " ".join([self._short_str(f"{k}={v}") for k, v in self.kwargs.items()])
710
- content = self._short_str(content)
711
- return super().__repr__(_inline_=content, _indent_=" ")
712
-
713
- def _trace_select(self, dates):
714
- return f"{self.name}({shorten(dates)})"
715
-
716
-
717
- class PipeAction(Action):
718
- def __init__(self, context, action_path, *configs):
719
- super().__init__(context, action_path, *configs)
720
- assert len(configs) > 1, configs
721
- current = action_factory(configs[0], context, action_path + ["0"])
722
- for i, c in enumerate(configs[1:]):
723
- current = step_factory(c, context, action_path + [str(i + 1)], previous_step=current)
724
- self.last_step = current
725
-
726
- @trace_select
727
- def select(self, dates):
728
- return self.last_step.select(dates)
729
-
730
- def __repr__(self):
731
- return super().__repr__(self.last_step)
732
-
733
-
734
- class StepResult(Result):
735
- def __init__(self, context, action_path, dates, action, upstream_result):
736
- super().__init__(context, action_path, dates)
737
- assert isinstance(upstream_result, Result), type(upstream_result)
738
- self.upstream_result = upstream_result
739
- self.action = action
740
-
741
- @property
742
- @notify_result
743
- @trace_datasource
744
- def datasource(self):
745
- raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
746
-
747
-
748
- class StepAction(Action):
749
- result_class = None
750
-
751
- def __init__(self, context, action_path, previous_step, *args, **kwargs):
752
- super().__init__(context, action_path, *args, **kwargs)
753
- self.previous_step = previous_step
754
-
755
- @trace_select
756
- def select(self, dates):
757
- return self.result_class(
758
- self.context,
759
- self.action_path,
760
- dates,
761
- self,
762
- self.previous_step.select(dates),
763
- )
764
-
765
- def __repr__(self):
766
- return super().__repr__(self.previous_step, _inline_=str(self.kwargs))
767
-
768
-
769
- class StepFunctionResult(StepResult):
770
- @cached_property
771
- @assert_fieldlist
772
- @notify_result
773
- @trace_datasource
774
- def datasource(self):
775
- try:
776
- return _tidy(
777
- self.action.function(
778
- FunctionContext(self),
779
- self.upstream_result.datasource,
780
- *self.action.args[1:],
781
- **self.action.kwargs,
782
- )
783
- )
784
-
785
- except Exception:
786
- LOG.error(f"Error in {self.action.name}", exc_info=True)
787
- raise
788
-
789
- def _trace_datasource(self, *args, **kwargs):
790
- return f"{self.action.name}({shorten(self.dates)})"
791
-
792
-
793
- class FilterStepResult(StepResult):
794
- @property
795
- @notify_result
796
- @assert_fieldlist
797
- @trace_datasource
798
- def datasource(self):
799
- ds = self.upstream_result.datasource
800
- ds = ds.sel(**self.action.kwargs)
801
- return _tidy(ds)
802
-
803
-
804
- class FilterStepAction(StepAction):
805
- result_class = FilterStepResult
806
-
807
-
808
- class FunctionStepAction(StepAction):
809
- result_class = StepFunctionResult
810
-
811
- def __init__(self, context, action_path, previous_step, *args, **kwargs):
812
- super().__init__(context, action_path, previous_step, *args, **kwargs)
813
- self.name = args[0]
814
- self.function = import_function(self.name, "filters")
815
-
816
-
817
- class ConcatResult(Result):
818
- def __init__(self, context, action_path, dates, results, **kwargs):
819
- super().__init__(context, action_path, dates)
820
- self.results = [r for r in results if not r.empty]
821
-
822
- @cached_property
823
- @assert_fieldlist
824
- @notify_result
825
- @trace_datasource
826
- def datasource(self):
827
- ds = EmptyResult(self.context, self.action_path, self.dates).datasource
828
- for i in self.results:
829
- ds += i.datasource
830
- return _tidy(ds)
831
-
832
- @property
833
- def variables(self):
834
- """Check that all the results objects have the same variables."""
835
- variables = None
836
- for f in self.results:
837
- if f.empty:
838
- continue
839
- if variables is None:
840
- variables = f.variables
841
- assert variables == f.variables, (variables, f.variables)
842
- assert variables is not None, self.results
843
- return variables
844
-
845
- def __repr__(self):
846
- content = "\n".join([str(i) for i in self.results])
847
- return super().__repr__(content)
848
-
849
-
850
- class DataSourcesResult(Result):
851
- def __init__(self, context, action_path, dates, input_result, sources_results):
852
- super().__init__(context, action_path, dates)
853
- # result is the main input result
854
- self.input_result = input_result
855
- # sources_results is the list of the sources_results
856
- self.sources_results = sources_results
857
-
858
- @cached_property
859
- def datasource(self):
860
- for i in self.sources_results:
861
- # for each result trigger the datasource to be computed
862
- # and saved in context
863
- self.context.notify_result(i.action_path[:-1], i.datasource)
864
- # then return the input result
865
- # which can use the datasources of the included results
866
- return _tidy(self.input_result.datasource)
867
-
868
-
869
- class DataSourcesAction(Action):
870
- def __init__(self, context, action_path, sources, input):
871
- super().__init__(context, ["data_sources"], *sources)
872
- if isinstance(sources, dict):
873
- configs = [(str(k), c) for k, c in sources.items()]
874
- elif isinstance(sources, list):
875
- configs = [(str(i), c) for i, c in enumerate(sources)]
876
- else:
877
- raise ValueError(f"Invalid data_sources, expecting list or dict, got {type(sources)}: {sources}")
878
-
879
- self.sources = [action_factory(config, context, ["data_sources"] + [a_path]) for a_path, config in configs]
880
- self.input = action_factory(input, context, ["input"])
881
-
882
- def select(self, dates):
883
- sources_results = [a.select(dates) for a in self.sources]
884
- return DataSourcesResult(
885
- self.context,
886
- self.action_path,
887
- dates,
888
- self.input.select(dates),
889
- sources_results,
890
- )
891
-
892
- def __repr__(self):
893
- content = "\n".join([str(i) for i in self.sources])
894
- return super().__repr__(content)
895
-
896
-
897
- class ConcatAction(Action):
898
- def __init__(self, context, action_path, *configs):
899
- super().__init__(context, action_path, *configs)
900
- parts = []
901
- for i, cfg in enumerate(configs):
902
- if "dates" not in cfg:
903
- raise ValueError(f"Missing 'dates' in {cfg}")
904
- cfg = deepcopy(cfg)
905
- dates_cfg = cfg.pop("dates")
906
- assert isinstance(dates_cfg, dict), dates_cfg
907
- filtering_dates = Dates.from_config(**dates_cfg)
908
- action = action_factory(cfg, context, action_path + [str(i)])
909
- parts.append((filtering_dates, action))
910
- self.parts = parts
911
-
912
- def __repr__(self):
913
- content = "\n".join([str(i) for i in self.parts])
914
- return super().__repr__(content)
915
-
916
- @trace_select
917
- def select(self, dates):
918
- results = []
919
- for filtering_dates, action in self.parts:
920
- newdates = sorted(set(dates) & set(filtering_dates))
921
- if newdates:
922
- results.append(action.select(newdates))
923
- if not results:
924
- return EmptyResult(self.context, self.action_path, dates)
925
-
926
- return ConcatResult(self.context, self.action_path, dates, results)
927
-
928
-
929
- class JoinAction(Action):
930
- def __init__(self, context, action_path, *configs):
931
- super().__init__(context, action_path, *configs)
932
- self.actions = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)]
933
-
934
- def __repr__(self):
935
- content = "\n".join([str(i) for i in self.actions])
936
- return super().__repr__(content)
937
-
938
- @trace_select
939
- def select(self, dates):
940
- results = [a.select(dates) for a in self.actions]
941
- return JoinResult(self.context, self.action_path, dates, results)
942
-
943
-
944
- def action_factory(config, context, action_path):
945
- assert isinstance(context, Context), (type, context)
946
- if not isinstance(config, dict):
947
- raise ValueError(f"Invalid input config {config}")
948
- if len(config) != 1:
949
- raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}")
950
-
951
- config = deepcopy(config)
952
- key = list(config.keys())[0]
953
-
954
- if isinstance(config[key], list):
955
- args, kwargs = config[key], {}
956
- if isinstance(config[key], dict):
957
- args, kwargs = [], config[key]
958
-
959
- cls = {
960
- # "date_shift": DateShiftAction,
961
- # "date_filter": DateFilterAction,
962
- "data_sources": DataSourcesAction,
963
- "concat": ConcatAction,
964
- "join": JoinAction,
965
- "pipe": PipeAction,
966
- "function": FunctionAction,
967
- }.get(key)
968
-
969
- if cls is None:
970
- if not is_function(key, "sources"):
971
- raise ValueError(f"Unknown action '{key}' in {config}")
972
- cls = FunctionAction
973
- args = [key] + args
974
-
975
- return cls(context, action_path + [key], *args, **kwargs)
976
-
977
-
978
- def step_factory(config, context, action_path, previous_step):
979
- assert isinstance(context, Context), (type, context)
980
- if not isinstance(config, dict):
981
- raise ValueError(f"Invalid input config {config}")
982
-
983
- config = deepcopy(config)
984
- assert len(config) == 1, config
985
-
986
- key = list(config.keys())[0]
987
- cls = dict(
988
- filter=FilterStepAction,
989
- # rename=RenameAction,
990
- # remapping=RemappingAction,
991
- ).get(key)
992
-
993
- if isinstance(config[key], list):
994
- args, kwargs = config[key], {}
995
-
996
- if isinstance(config[key], dict):
997
- args, kwargs = [], config[key]
998
-
999
- if isinstance(config[key], str):
1000
- args, kwargs = [config[key]], {}
1001
-
1002
- if cls is None:
1003
- if not is_function(key, "filters"):
1004
- raise ValueError(f"Unknown step {key}")
1005
- cls = FunctionStepAction
1006
- args = [key] + args
1007
- # print("========", args)
1008
-
1009
- return cls(context, action_path, previous_step, *args, **kwargs)
1010
-
1011
-
1012
- class FunctionContext:
1013
- """A FunctionContext is passed to all functions, it will be used to pass information
1014
- to the functions from the other actions and filters and results.
1015
- """
1016
-
1017
- def __init__(self, owner):
1018
- self.owner = owner
1019
- self.use_grib_paramid = owner.context.use_grib_paramid
1020
-
1021
- def trace(self, emoji, *args):
1022
- trace(emoji, *args)
1023
-
1024
-
1025
- class ActionContext(Context):
1026
- def __init__(self, /, order_by, flatten_grid, remapping, use_grib_paramid):
1027
- super().__init__()
1028
- self.order_by = order_by
1029
- self.flatten_grid = flatten_grid
1030
- self.remapping = build_remapping(remapping)
1031
- self.use_grib_paramid = use_grib_paramid
1032
-
1033
-
1034
- class InputBuilder:
1035
- def __init__(self, config, data_sources, **kwargs):
1036
- self.kwargs = kwargs
1037
-
1038
- config = deepcopy(config)
1039
- if data_sources:
1040
- config = dict(
1041
- data_sources=dict(
1042
- sources=data_sources,
1043
- input=config,
1044
- )
1045
- )
1046
- self.config = config
1047
- self.action_path = ["input"]
1048
-
1049
- @trace_select
1050
- def select(self, dates):
1051
- """This changes the context."""
1052
- context = ActionContext(**self.kwargs)
1053
- action = action_factory(self.config, context, self.action_path)
1054
- return action.select(dates)
1055
-
1056
- def __repr__(self):
1057
- context = ActionContext(**self.kwargs)
1058
- a = action_factory(self.config, context, self.action_path)
1059
- return repr(a)
1060
-
1061
- def _trace_select(self, dates):
1062
- return f"InputBuilder({shorten(dates)})"
1063
-
1064
-
1065
- build_input = InputBuilder