anemoi-datasets 0.5.0__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/inspect.py +1 -1
  3. anemoi/datasets/commands/publish.py +30 -0
  4. anemoi/datasets/create/__init__.py +42 -3
  5. anemoi/datasets/create/check.py +6 -0
  6. anemoi/datasets/create/functions/filters/rename.py +2 -3
  7. anemoi/datasets/create/functions/sources/__init__.py +7 -1
  8. anemoi/datasets/create/functions/sources/accumulations.py +2 -0
  9. anemoi/datasets/create/functions/sources/grib.py +1 -1
  10. anemoi/datasets/create/functions/sources/xarray/__init__.py +7 -2
  11. anemoi/datasets/create/functions/sources/xarray/coordinates.py +12 -1
  12. anemoi/datasets/create/functions/sources/xarray/field.py +13 -4
  13. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +16 -16
  14. anemoi/datasets/create/functions/sources/xarray/flavour.py +130 -13
  15. anemoi/datasets/create/functions/sources/xarray/grid.py +106 -17
  16. anemoi/datasets/create/functions/sources/xarray/metadata.py +3 -11
  17. anemoi/datasets/create/functions/sources/xarray/time.py +1 -5
  18. anemoi/datasets/create/functions/sources/xarray/variable.py +10 -10
  19. anemoi/datasets/create/input/__init__.py +69 -0
  20. anemoi/datasets/create/input/action.py +123 -0
  21. anemoi/datasets/create/input/concat.py +92 -0
  22. anemoi/datasets/create/input/context.py +59 -0
  23. anemoi/datasets/create/input/data_sources.py +71 -0
  24. anemoi/datasets/create/input/empty.py +42 -0
  25. anemoi/datasets/create/input/filter.py +76 -0
  26. anemoi/datasets/create/input/function.py +122 -0
  27. anemoi/datasets/create/input/join.py +57 -0
  28. anemoi/datasets/create/input/misc.py +85 -0
  29. anemoi/datasets/create/input/pipe.py +33 -0
  30. anemoi/datasets/create/input/repeated_dates.py +217 -0
  31. anemoi/datasets/create/input/result.py +413 -0
  32. anemoi/datasets/create/input/step.py +99 -0
  33. anemoi/datasets/create/{template.py → input/template.py} +0 -42
  34. anemoi/datasets/create/statistics/__init__.py +1 -1
  35. anemoi/datasets/create/zarr.py +4 -2
  36. anemoi/datasets/dates/__init__.py +1 -0
  37. anemoi/datasets/dates/groups.py +12 -4
  38. anemoi/datasets/fields.py +66 -0
  39. anemoi/datasets/utils/fields.py +47 -0
  40. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/METADATA +1 -1
  41. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/RECORD +46 -30
  42. anemoi/datasets/create/input.py +0 -1087
  43. /anemoi/datasets/create/{trace.py → input/trace.py} +0 -0
  44. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/LICENSE +0 -0
  45. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/WHEEL +0 -0
  46. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/entry_points.txt +0 -0
  47. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/top_level.txt +0 -0
@@ -1,1087 +0,0 @@
1
- # (C) Copyright 2023 ECMWF.
2
- #
3
- # This software is licensed under the terms of the Apache Licence Version 2.0
4
- # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
- # In applying this licence, ECMWF does not waive the privileges and immunities
6
- # granted to it by virtue of its status as an intergovernmental organisation
7
- # nor does it submit to any jurisdiction.
8
- #
9
- import datetime
10
- import itertools
11
- import logging
12
- import math
13
- import time
14
- from collections import defaultdict
15
- from copy import deepcopy
16
- from functools import cached_property
17
- from functools import wraps
18
-
19
- import numpy as np
20
- from anemoi.utils.humanize import seconds_to_human
21
- from anemoi.utils.humanize import shorten_list
22
- from earthkit.data.core.fieldlist import FieldList
23
- from earthkit.data.core.fieldlist import MultiFieldList
24
- from earthkit.data.core.order import build_remapping
25
-
26
- from anemoi.datasets.dates import DatesProvider
27
-
28
- from .functions import import_function
29
- from .template import Context
30
- from .template import notify_result
31
- from .template import resolve
32
- from .template import substitute
33
- from .trace import trace
34
- from .trace import trace_datasource
35
- from .trace import trace_select
36
-
37
- LOG = logging.getLogger(__name__)
38
-
39
-
40
- def parse_function_name(name):
41
-
42
- if name.endswith("h") and name[:-1].isdigit():
43
-
44
- if "-" in name:
45
- name, delta = name.split("-")
46
- sign = -1
47
-
48
- elif "+" in name:
49
- name, delta = name.split("+")
50
- sign = 1
51
-
52
- else:
53
- return name, None
54
-
55
- assert delta[-1] == "h", (name, delta)
56
- delta = sign * int(delta[:-1])
57
- return name, delta
58
-
59
- return name, None
60
-
61
-
62
- def time_delta_to_string(delta):
63
- assert isinstance(delta, datetime.timedelta), delta
64
- seconds = delta.total_seconds()
65
- hours = int(seconds // 3600)
66
- assert hours * 3600 == seconds, delta
67
- hours = abs(hours)
68
-
69
- if seconds > 0:
70
- return f"plus_{hours}h"
71
- if seconds == 0:
72
- return ""
73
- if seconds < 0:
74
- return f"minus_{hours}h"
75
-
76
-
77
- def is_function(name, kind):
78
- name, _ = parse_function_name(name)
79
- try:
80
- import_function(name, kind)
81
- return True
82
- except ImportError as e:
83
- print(e)
84
- return False
85
-
86
-
87
- def assert_fieldlist(method):
88
- @wraps(method)
89
- def wrapper(self, *args, **kwargs):
90
- result = method(self, *args, **kwargs)
91
- assert isinstance(result, FieldList), type(result)
92
- return result
93
-
94
- return wrapper
95
-
96
-
97
- def assert_is_fieldlist(obj):
98
- assert isinstance(obj, FieldList), type(obj)
99
-
100
-
101
- def _data_request(data):
102
- date = None
103
- params_levels = defaultdict(set)
104
- params_steps = defaultdict(set)
105
-
106
- area = grid = None
107
-
108
- for field in data:
109
- try:
110
- if date is None:
111
- date = field.datetime()["valid_time"]
112
-
113
- if field.datetime()["valid_time"] != date:
114
- continue
115
-
116
- as_mars = field.metadata(namespace="mars")
117
- if not as_mars:
118
- continue
119
- step = as_mars.get("step")
120
- levtype = as_mars.get("levtype", "sfc")
121
- param = as_mars["param"]
122
- levelist = as_mars.get("levelist", None)
123
- area = field.mars_area
124
- grid = field.mars_grid
125
-
126
- if levelist is None:
127
- params_levels[levtype].add(param)
128
- else:
129
- params_levels[levtype].add((param, levelist))
130
-
131
- if step:
132
- params_steps[levtype].add((param, step))
133
- except Exception:
134
- LOG.error(f"Error in retrieving metadata (cannot build data request info) for {field}", exc_info=True)
135
-
136
- def sort(old_dic):
137
- new_dic = {}
138
- for k, v in old_dic.items():
139
- new_dic[k] = sorted(list(v))
140
- return new_dic
141
-
142
- params_steps = sort(params_steps)
143
- params_levels = sort(params_levels)
144
-
145
- return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid)
146
-
147
-
148
- class Action:
149
- def __init__(self, context, action_path, /, *args, **kwargs):
150
- if "args" in kwargs and "kwargs" in kwargs:
151
- """We have:
152
- args = []
153
- kwargs = {args: [...], kwargs: {...}}
154
- move the content of kwargs to args and kwargs.
155
- """
156
- assert len(kwargs) == 2, (args, kwargs)
157
- assert not args, (args, kwargs)
158
- args = kwargs.pop("args")
159
- kwargs = kwargs.pop("kwargs")
160
-
161
- assert isinstance(context, ActionContext), type(context)
162
- self.context = context
163
- self.kwargs = kwargs
164
- self.args = args
165
- self.action_path = action_path
166
-
167
- @classmethod
168
- def _short_str(cls, x):
169
- x = str(x)
170
- if len(x) < 1000:
171
- return x
172
- return x[:1000] + "..."
173
-
174
- def __repr__(self, *args, _indent_="\n", _inline_="", **kwargs):
175
- more = ",".join([str(a)[:5000] for a in args])
176
- more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()])
177
-
178
- more = more[:5000]
179
- txt = f"{self.__class__.__name__}: {_inline_}{_indent_}{more}"
180
- if _indent_:
181
- txt = txt.replace("\n", "\n ")
182
- return txt
183
-
184
- def select(self, dates, **kwargs):
185
- self._raise_not_implemented()
186
-
187
- def _raise_not_implemented(self):
188
- raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
189
-
190
- def _trace_select(self, dates):
191
- return f"{self.__class__.__name__}({shorten(dates)})"
192
-
193
-
194
- def shorten(dates):
195
- if isinstance(dates, (list, tuple)):
196
- dates = [d.isoformat() for d in dates]
197
- if len(dates) > 5:
198
- return f"{dates[0]}...{dates[-1]}"
199
- return dates
200
-
201
-
202
- class Result:
203
- empty = False
204
- _coords_already_built = False
205
-
206
- def __init__(self, context, action_path, dates):
207
- from anemoi.datasets.dates.groups import GroupOfDates
208
-
209
- assert isinstance(dates, GroupOfDates), dates
210
-
211
- assert isinstance(context, ActionContext), type(context)
212
- assert isinstance(action_path, list), action_path
213
-
214
- self.context = context
215
- self.group_of_dates = dates
216
- self.action_path = action_path
217
-
218
- @property
219
- @trace_datasource
220
- def datasource(self):
221
- self._raise_not_implemented()
222
-
223
- @property
224
- def data_request(self):
225
- """Returns a dictionary with the parameters needed to retrieve the data."""
226
- return _data_request(self.datasource)
227
-
228
- def get_cube(self):
229
- trace("🧊", f"getting cube from {self.__class__.__name__}")
230
- ds = self.datasource
231
-
232
- remapping = self.context.remapping
233
- order_by = self.context.order_by
234
- flatten_grid = self.context.flatten_grid
235
- start = time.time()
236
- LOG.debug("Sorting dataset %s %s", dict(order_by), remapping)
237
- assert order_by, order_by
238
-
239
- patches = {"number": {None: 0}}
240
-
241
- try:
242
- cube = ds.cube(
243
- order_by,
244
- remapping=remapping,
245
- flatten_values=flatten_grid,
246
- patches=patches,
247
- )
248
- cube = cube.squeeze()
249
- LOG.debug(f"Sorting done in {seconds_to_human(time.time()-start)}.")
250
- except ValueError:
251
- self.explain(ds, order_by, remapping=remapping, patches=patches)
252
- # raise ValueError(f"Error in {self}")
253
- exit(1)
254
-
255
- if LOG.isEnabledFor(logging.DEBUG):
256
- LOG.debug("Cube shape: %s", cube)
257
- for k, v in cube.user_coords.items():
258
- LOG.debug(" %s %s", k, shorten_list(v, max_length=10))
259
-
260
- return cube
261
-
262
- def explain(self, ds, *args, remapping, patches):
263
-
264
- METADATA = (
265
- "date",
266
- "time",
267
- "step",
268
- "hdate",
269
- "valid_datetime",
270
- "levtype",
271
- "levelist",
272
- "number",
273
- "level",
274
- "shortName",
275
- "paramId",
276
- "variable",
277
- )
278
-
279
- # We redo the logic here
280
- print()
281
- print("❌" * 40)
282
- print()
283
- if len(args) == 1 and isinstance(args[0], (list, tuple)):
284
- args = args[0]
285
-
286
- # print("Executing", self.action_path)
287
- # print("Dates:", compress_dates(self.dates))
288
-
289
- names = []
290
- for a in args:
291
- if isinstance(a, str):
292
- names.append(a)
293
- elif isinstance(a, dict):
294
- names += list(a.keys())
295
-
296
- print(f"Building a {len(names)}D hypercube using", names)
297
- ds = ds.order_by(*args, remapping=remapping, patches=patches)
298
- user_coords = ds.unique_values(*names, remapping=remapping, patches=patches, progress_bar=False)
299
-
300
- print()
301
- print("Number of unique values found for each coordinate:")
302
- for k, v in user_coords.items():
303
- print(f" {k:20}:", len(v), shorten_list(v, max_length=10))
304
- print()
305
- user_shape = tuple(len(v) for k, v in user_coords.items())
306
- print("Shape of the hypercube :", user_shape)
307
- print(
308
- "Number of expected fields :", math.prod(user_shape), "=", " x ".join([str(i) for i in user_shape])
309
- )
310
- print("Number of fields in the dataset :", len(ds))
311
- print("Difference :", abs(len(ds) - math.prod(user_shape)))
312
- print()
313
-
314
- remapping = build_remapping(remapping, patches)
315
- expected = set(itertools.product(*user_coords.values()))
316
- extra = set()
317
-
318
- if math.prod(user_shape) > len(ds):
319
- print(f"This means that all the fields in the datasets do not exists for all combinations of {names}.")
320
-
321
- for f in ds:
322
- metadata = remapping(f.metadata)
323
- key = tuple(metadata(n, default=None) for n in names)
324
- if key in expected:
325
- expected.remove(key)
326
- else:
327
- extra.add(key)
328
-
329
- print("Missing fields:")
330
- print()
331
- for i, f in enumerate(sorted(expected)):
332
- print(" ", f)
333
- if i >= 9 and len(expected) > 10:
334
- print("...", len(expected) - i - 1, "more")
335
- break
336
-
337
- print("Extra fields:")
338
- print()
339
- for i, f in enumerate(sorted(extra)):
340
- print(" ", f)
341
- if i >= 9 and len(extra) > 10:
342
- print("...", len(extra) - i - 1, "more")
343
- break
344
-
345
- print()
346
- print("Missing values:")
347
- per_name = defaultdict(set)
348
- for e in expected:
349
- for n, v in zip(names, e):
350
- per_name[n].add(v)
351
-
352
- for n, v in per_name.items():
353
- print(" ", n, len(v), shorten_list(sorted(v), max_length=10))
354
- print()
355
-
356
- print("Extra values:")
357
- per_name = defaultdict(set)
358
- for e in extra:
359
- for n, v in zip(names, e):
360
- per_name[n].add(v)
361
-
362
- for n, v in per_name.items():
363
- print(" ", n, len(v), shorten_list(sorted(v), max_length=10))
364
- print()
365
-
366
- print("To solve this issue, you can:")
367
- print(
368
- " - Provide a better selection, like 'step: 0' or 'level: 1000' to "
369
- "reduce the number of selected fields."
370
- )
371
- print(
372
- " - Split the 'input' part in smaller sections using 'join', "
373
- "making sure that each section represent a full hypercube."
374
- )
375
-
376
- else:
377
- print(f"More fields in dataset that expected for {names}. " "This means that some fields are duplicated.")
378
- duplicated = defaultdict(list)
379
- for f in ds:
380
- # print(f.metadata(namespace="default"))
381
- metadata = remapping(f.metadata)
382
- key = tuple(metadata(n, default=None) for n in names)
383
- duplicated[key].append(f)
384
-
385
- print("Duplicated fields:")
386
- print()
387
- duplicated = {k: v for k, v in duplicated.items() if len(v) > 1}
388
- for i, (k, v) in enumerate(sorted(duplicated.items())):
389
- print(" ", k)
390
- for f in v:
391
- x = {k: f.metadata(k, default=None) for k in METADATA if f.metadata(k, default=None) is not None}
392
- print(" ", f, x)
393
- if i >= 9 and len(duplicated) > 10:
394
- print("...", len(duplicated) - i - 1, "more")
395
- break
396
-
397
- print()
398
- print("To solve this issue, you can:")
399
- print(" - Provide a better selection, like 'step: 0' or 'level: 1000'")
400
- print(" - Change the way 'param' is computed using 'variable_naming' " "in the 'build' section.")
401
-
402
- print()
403
- print("❌" * 40)
404
- print()
405
- exit(1)
406
-
407
- def __repr__(self, *args, _indent_="\n", **kwargs):
408
- more = ",".join([str(a)[:5000] for a in args])
409
- more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()])
410
-
411
- dates = " no-dates"
412
- if self.group_of_dates is not None:
413
- dates = f" {len(self.group_of_dates)} dates"
414
- dates += " ("
415
- dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.group_of_dates)
416
- if len(dates) > 100:
417
- dates = dates[:100] + "..."
418
- dates += ")"
419
-
420
- more = more[:5000]
421
- txt = f"{self.__class__.__name__}:{dates}{_indent_}{more}"
422
- if _indent_:
423
- txt = txt.replace("\n", "\n ")
424
- return txt
425
-
426
- def _raise_not_implemented(self):
427
- raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
428
-
429
- def _trace_datasource(self, *args, **kwargs):
430
- return f"{self.__class__.__name__}({self.group_of_dates})"
431
-
432
- def build_coords(self):
433
- if self._coords_already_built:
434
- return
435
- from_data = self.get_cube().user_coords
436
- from_config = self.context.order_by
437
-
438
- keys_from_config = list(from_config.keys())
439
- keys_from_data = list(from_data.keys())
440
- assert keys_from_data == keys_from_config, f"Critical error: {keys_from_data=} != {keys_from_config=}. {self=}"
441
-
442
- variables_key = list(from_config.keys())[1]
443
- ensembles_key = list(from_config.keys())[2]
444
-
445
- if isinstance(from_config[variables_key], (list, tuple)):
446
- assert all([v == w for v, w in zip(from_data[variables_key], from_config[variables_key])]), (
447
- from_data[variables_key],
448
- from_config[variables_key],
449
- )
450
-
451
- self._variables = from_data[variables_key] # "param_level"
452
- self._ensembles = from_data[ensembles_key] # "number"
453
-
454
- first_field = self.datasource[0]
455
- grid_points = first_field.grid_points()
456
-
457
- lats, lons = grid_points
458
-
459
- assert len(lats) == len(lons), (len(lats), len(lons), first_field)
460
- assert len(lats) == math.prod(first_field.shape), (len(lats), first_field.shape, first_field)
461
-
462
- north = np.amax(lats)
463
- south = np.amin(lats)
464
- east = np.amax(lons)
465
- west = np.amin(lons)
466
-
467
- assert -90 <= south <= north <= 90, (south, north, first_field)
468
- assert (-180 <= west <= east <= 180) or (0 <= west <= east <= 360), (
469
- west,
470
- east,
471
- first_field,
472
- )
473
-
474
- grid_values = list(range(len(grid_points[0])))
475
-
476
- self._grid_points = grid_points
477
- self._resolution = first_field.resolution
478
- self._grid_values = grid_values
479
- self._field_shape = first_field.shape
480
- self._proj_string = first_field.proj_string if hasattr(first_field, "proj_string") else None
481
-
482
- @property
483
- def variables(self):
484
- self.build_coords()
485
- return self._variables
486
-
487
- @property
488
- def ensembles(self):
489
- self.build_coords()
490
- return self._ensembles
491
-
492
- @property
493
- def resolution(self):
494
- self.build_coords()
495
- return self._resolution
496
-
497
- @property
498
- def grid_values(self):
499
- self.build_coords()
500
- return self._grid_values
501
-
502
- @property
503
- def grid_points(self):
504
- self.build_coords()
505
- return self._grid_points
506
-
507
- @property
508
- def field_shape(self):
509
- self.build_coords()
510
- return self._field_shape
511
-
512
- @property
513
- def proj_string(self):
514
- self.build_coords()
515
- return self._proj_string
516
-
517
- @cached_property
518
- def shape(self):
519
- return [
520
- len(self.group_of_dates),
521
- len(self.variables),
522
- len(self.ensembles),
523
- len(self.grid_values),
524
- ]
525
-
526
- @cached_property
527
- def coords(self):
528
- return {
529
- "dates": list(self.group_of_dates),
530
- "variables": self.variables,
531
- "ensembles": self.ensembles,
532
- "values": self.grid_values,
533
- }
534
-
535
-
536
- class EmptyResult(Result):
537
- empty = True
538
-
539
- def __init__(self, context, action_path, dates):
540
- super().__init__(context, action_path + ["empty"], dates)
541
-
542
- @cached_property
543
- @assert_fieldlist
544
- @trace_datasource
545
- def datasource(self):
546
- from earthkit.data import from_source
547
-
548
- return from_source("empty")
549
-
550
- @property
551
- def variables(self):
552
- return []
553
-
554
-
555
- def _flatten(ds):
556
- if isinstance(ds, MultiFieldList):
557
- return [_tidy(f) for s in ds._indexes for f in _flatten(s)]
558
- return [ds]
559
-
560
-
561
- def _tidy(ds, indent=0):
562
- if isinstance(ds, MultiFieldList):
563
-
564
- sources = [s for s in _flatten(ds) if len(s) > 0]
565
- if len(sources) == 1:
566
- return sources[0]
567
- return MultiFieldList(sources)
568
- return ds
569
-
570
-
571
- class FunctionResult(Result):
572
- def __init__(self, context, action_path, dates, action):
573
- super().__init__(context, action_path, dates)
574
- assert isinstance(action, Action), type(action)
575
- self.action = action
576
-
577
- self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs))
578
-
579
- def _trace_datasource(self, *args, **kwargs):
580
- return f"{self.action.name}({self.group_of_dates})"
581
-
582
- @cached_property
583
- @assert_fieldlist
584
- @notify_result
585
- @trace_datasource
586
- def datasource(self):
587
- args, kwargs = resolve(self.context, (self.args, self.kwargs))
588
-
589
- try:
590
- return _tidy(
591
- self.action.function(
592
- FunctionContext(self),
593
- list(self.group_of_dates), # Will provide a list of datetime objects
594
- *args,
595
- **kwargs,
596
- )
597
- )
598
- except Exception:
599
- LOG.error(f"Error in {self.action.function.__name__}", exc_info=True)
600
- raise
601
-
602
- def __repr__(self):
603
- try:
604
- return f"{self.action.name}({self.group_of_dates})"
605
- except Exception:
606
- return f"{self.__class__.__name__}(unitialised)"
607
-
608
- @property
609
- def function(self):
610
- raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
611
-
612
-
613
- class JoinResult(Result):
614
- def __init__(self, context, action_path, dates, results, **kwargs):
615
- super().__init__(context, action_path, dates)
616
- self.results = [r for r in results if not r.empty]
617
-
618
- @cached_property
619
- @assert_fieldlist
620
- @notify_result
621
- @trace_datasource
622
- def datasource(self):
623
- ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource
624
- for i in self.results:
625
- ds += i.datasource
626
- return _tidy(ds)
627
-
628
- def __repr__(self):
629
- content = "\n".join([str(i) for i in self.results])
630
- return super().__repr__(content)
631
-
632
-
633
- class DateShiftAction(Action):
634
- def __init__(self, context, action_path, delta, **kwargs):
635
- super().__init__(context, action_path, **kwargs)
636
-
637
- if isinstance(delta, str):
638
- if delta[0] == "-":
639
- delta, sign = int(delta[1:]), -1
640
- else:
641
- delta, sign = int(delta), 1
642
- delta = datetime.timedelta(hours=sign * delta)
643
- assert isinstance(delta, int), delta
644
- delta = datetime.timedelta(hours=delta)
645
- self.delta = delta
646
-
647
- self.content = action_factory(kwargs, context, self.action_path + ["shift"])
648
-
649
- @trace_select
650
- def select(self, dates):
651
- shifted_dates = [d + self.delta for d in dates]
652
- result = self.content.select(shifted_dates)
653
- return UnShiftResult(self.context, self.action_path, dates, result, action=self)
654
-
655
- def __repr__(self):
656
- return super().__repr__(f"{self.delta}\n{self.content}")
657
-
658
-
659
- class UnShiftResult(Result):
660
- def __init__(self, context, action_path, dates, result, action):
661
- super().__init__(context, action_path, dates)
662
- # dates are the actual requested dates
663
- # result does not have the same dates
664
- self.action = action
665
- self.result = result
666
-
667
- def _trace_datasource(self, *args, **kwargs):
668
- return f"{self.action.delta}({shorten(self.dates)})"
669
-
670
- @cached_property
671
- @assert_fieldlist
672
- @notify_result
673
- @trace_datasource
674
- def datasource(self):
675
- from earthkit.data.indexing.fieldlist import FieldArray
676
-
677
- class DateShiftedField:
678
- def __init__(self, field, delta):
679
- self.field = field
680
- self.delta = delta
681
-
682
- def metadata(self, key):
683
- value = self.field.metadata(key)
684
- if key == "param":
685
- return value + "_" + time_delta_to_string(self.delta)
686
- if key == "valid_datetime":
687
- dt = datetime.datetime.fromisoformat(value)
688
- new_dt = dt - self.delta
689
- new_value = new_dt.isoformat()
690
- return new_value
691
- if key in ["date", "time", "step", "hdate"]:
692
- raise NotImplementedError(f"metadata {key} not implemented when shifting dates")
693
- return value
694
-
695
- def __getattr__(self, name):
696
- return getattr(self.field, name)
697
-
698
- ds = self.result.datasource
699
- ds = FieldArray([DateShiftedField(fs, self.action.delta) for fs in ds])
700
- return _tidy(ds)
701
-
702
-
703
- class FunctionAction(Action):
704
- def __init__(self, context, action_path, _name, **kwargs):
705
- super().__init__(context, action_path, **kwargs)
706
- self.name = _name
707
-
708
- @trace_select
709
- def select(self, dates):
710
- return FunctionResult(self.context, self.action_path, dates, action=self)
711
-
712
- @property
713
- def function(self):
714
- # name, delta = parse_function_name(self.name)
715
- return import_function(self.name, "sources")
716
-
717
- def __repr__(self):
718
- content = ""
719
- content += ",".join([self._short_str(a) for a in self.args])
720
- content += " ".join([self._short_str(f"{k}={v}") for k, v in self.kwargs.items()])
721
- content = self._short_str(content)
722
- return super().__repr__(_inline_=content, _indent_=" ")
723
-
724
- def _trace_select(self, dates):
725
- return f"{self.name}({shorten(dates)})"
726
-
727
-
728
- class PipeAction(Action):
729
- def __init__(self, context, action_path, *configs):
730
- super().__init__(context, action_path, *configs)
731
- assert len(configs) > 1, configs
732
- current = action_factory(configs[0], context, action_path + ["0"])
733
- for i, c in enumerate(configs[1:]):
734
- current = step_factory(c, context, action_path + [str(i + 1)], previous_step=current)
735
- self.last_step = current
736
-
737
- @trace_select
738
- def select(self, dates):
739
- return self.last_step.select(dates)
740
-
741
- def __repr__(self):
742
- return super().__repr__(self.last_step)
743
-
744
-
745
- class StepResult(Result):
746
- def __init__(self, context, action_path, dates, action, upstream_result):
747
- super().__init__(context, action_path, dates)
748
- assert isinstance(upstream_result, Result), type(upstream_result)
749
- self.upstream_result = upstream_result
750
- self.action = action
751
-
752
- @property
753
- @notify_result
754
- @trace_datasource
755
- def datasource(self):
756
- raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
757
-
758
-
759
- class StepAction(Action):
760
- result_class = None
761
-
762
- def __init__(self, context, action_path, previous_step, *args, **kwargs):
763
- super().__init__(context, action_path, *args, **kwargs)
764
- self.previous_step = previous_step
765
-
766
- @trace_select
767
- def select(self, dates):
768
- return self.result_class(
769
- self.context,
770
- self.action_path,
771
- dates,
772
- self,
773
- self.previous_step.select(dates),
774
- )
775
-
776
- def __repr__(self):
777
- return super().__repr__(self.previous_step, _inline_=str(self.kwargs))
778
-
779
-
780
- class StepFunctionResult(StepResult):
781
- @cached_property
782
- @assert_fieldlist
783
- @notify_result
784
- @trace_datasource
785
- def datasource(self):
786
- try:
787
- return _tidy(
788
- self.action.function(
789
- FunctionContext(self),
790
- self.upstream_result.datasource,
791
- *self.action.args[1:],
792
- **self.action.kwargs,
793
- )
794
- )
795
-
796
- except Exception:
797
- LOG.error(f"Error in {self.action.name}", exc_info=True)
798
- raise
799
-
800
- def _trace_datasource(self, *args, **kwargs):
801
- return f"{self.action.name}({shorten(self.dates)})"
802
-
803
-
804
- class FilterStepResult(StepResult):
805
- @property
806
- @notify_result
807
- @assert_fieldlist
808
- @trace_datasource
809
- def datasource(self):
810
- ds = self.upstream_result.datasource
811
- ds = ds.sel(**self.action.kwargs)
812
- return _tidy(ds)
813
-
814
-
815
- class FilterStepAction(StepAction):
816
- result_class = FilterStepResult
817
-
818
-
819
- class FunctionStepAction(StepAction):
820
- result_class = StepFunctionResult
821
-
822
- def __init__(self, context, action_path, previous_step, *args, **kwargs):
823
- super().__init__(context, action_path, previous_step, *args, **kwargs)
824
- self.name = args[0]
825
- self.function = import_function(self.name, "filters")
826
-
827
-
828
- class ConcatResult(Result):
829
- def __init__(self, context, action_path, dates, results, **kwargs):
830
- super().__init__(context, action_path, dates)
831
- self.results = [r for r in results if not r.empty]
832
-
833
- @cached_property
834
- @assert_fieldlist
835
- @notify_result
836
- @trace_datasource
837
- def datasource(self):
838
- ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource
839
- for i in self.results:
840
- ds += i.datasource
841
- return _tidy(ds)
842
-
843
- @property
844
- def variables(self):
845
- """Check that all the results objects have the same variables."""
846
- variables = None
847
- for f in self.results:
848
- if f.empty:
849
- continue
850
- if variables is None:
851
- variables = f.variables
852
- assert variables == f.variables, (variables, f.variables)
853
- assert variables is not None, self.results
854
- return variables
855
-
856
- def __repr__(self):
857
- content = "\n".join([str(i) for i in self.results])
858
- return super().__repr__(content)
859
-
860
-
861
- class DataSourcesResult(Result):
862
- def __init__(self, context, action_path, dates, input_result, sources_results):
863
- super().__init__(context, action_path, dates)
864
- # result is the main input result
865
- self.input_result = input_result
866
- # sources_results is the list of the sources_results
867
- self.sources_results = sources_results
868
-
869
- @cached_property
870
- def datasource(self):
871
- for i in self.sources_results:
872
- # for each result trigger the datasource to be computed
873
- # and saved in context
874
- self.context.notify_result(i.action_path[:-1], i.datasource)
875
- # then return the input result
876
- # which can use the datasources of the included results
877
- return _tidy(self.input_result.datasource)
878
-
879
-
880
- class DataSourcesAction(Action):
881
- def __init__(self, context, action_path, sources, input):
882
- super().__init__(context, ["data_sources"], *sources)
883
- if isinstance(sources, dict):
884
- configs = [(str(k), c) for k, c in sources.items()]
885
- elif isinstance(sources, list):
886
- configs = [(str(i), c) for i, c in enumerate(sources)]
887
- else:
888
- raise ValueError(f"Invalid data_sources, expecting list or dict, got {type(sources)}: {sources}")
889
-
890
- self.sources = [action_factory(config, context, ["data_sources"] + [a_path]) for a_path, config in configs]
891
- self.input = action_factory(input, context, ["input"])
892
-
893
- def select(self, dates):
894
- sources_results = [a.select(dates) for a in self.sources]
895
- return DataSourcesResult(
896
- self.context,
897
- self.action_path,
898
- dates,
899
- self.input.select(dates),
900
- sources_results,
901
- )
902
-
903
- def __repr__(self):
904
- content = "\n".join([str(i) for i in self.sources])
905
- return super().__repr__(content)
906
-
907
-
908
- class ConcatAction(Action):
909
- def __init__(self, context, action_path, *configs):
910
- super().__init__(context, action_path, *configs)
911
- parts = []
912
- for i, cfg in enumerate(configs):
913
- if "dates" not in cfg:
914
- raise ValueError(f"Missing 'dates' in {cfg}")
915
- cfg = deepcopy(cfg)
916
- dates_cfg = cfg.pop("dates")
917
- assert isinstance(dates_cfg, dict), dates_cfg
918
- filtering_dates = DatesProvider.from_config(**dates_cfg)
919
- action = action_factory(cfg, context, action_path + [str(i)])
920
- parts.append((filtering_dates, action))
921
- self.parts = parts
922
-
923
- def __repr__(self):
924
- content = "\n".join([str(i) for i in self.parts])
925
- return super().__repr__(content)
926
-
927
- @trace_select
928
- def select(self, dates):
929
- from anemoi.datasets.dates.groups import GroupOfDates
930
-
931
- results = []
932
- for filtering_dates, action in self.parts:
933
- newdates = GroupOfDates(sorted(set(dates) & set(filtering_dates)), dates.provider)
934
- if newdates:
935
- results.append(action.select(newdates))
936
- if not results:
937
- return EmptyResult(self.context, self.action_path, dates)
938
-
939
- return ConcatResult(self.context, self.action_path, dates, results)
940
-
941
-
942
- class JoinAction(Action):
943
- def __init__(self, context, action_path, *configs):
944
- super().__init__(context, action_path, *configs)
945
- self.actions = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)]
946
-
947
- def __repr__(self):
948
- content = "\n".join([str(i) for i in self.actions])
949
- return super().__repr__(content)
950
-
951
- @trace_select
952
- def select(self, dates):
953
- results = [a.select(dates) for a in self.actions]
954
- return JoinResult(self.context, self.action_path, dates, results)
955
-
956
-
957
- def action_factory(config, context, action_path):
958
- assert isinstance(context, Context), (type, context)
959
- if not isinstance(config, dict):
960
- raise ValueError(f"Invalid input config {config}")
961
- if len(config) != 1:
962
- raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}")
963
-
964
- config = deepcopy(config)
965
- key = list(config.keys())[0]
966
-
967
- if isinstance(config[key], list):
968
- args, kwargs = config[key], {}
969
- elif isinstance(config[key], dict):
970
- args, kwargs = [], config[key]
971
- else:
972
- raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}")
973
-
974
- cls = {
975
- # "date_shift": DateShiftAction,
976
- # "date_filter": DateFilterAction,
977
- "data_sources": DataSourcesAction,
978
- "concat": ConcatAction,
979
- "join": JoinAction,
980
- "pipe": PipeAction,
981
- "function": FunctionAction,
982
- }.get(key)
983
-
984
- if cls is None:
985
- if not is_function(key, "sources"):
986
- raise ValueError(f"Unknown action '{key}' in {config}")
987
- cls = FunctionAction
988
- args = [key] + args
989
-
990
- return cls(context, action_path + [key], *args, **kwargs)
991
-
992
-
993
- def step_factory(config, context, action_path, previous_step):
994
- assert isinstance(context, Context), (type, context)
995
- if not isinstance(config, dict):
996
- raise ValueError(f"Invalid input config {config}")
997
-
998
- config = deepcopy(config)
999
- assert len(config) == 1, config
1000
-
1001
- key = list(config.keys())[0]
1002
- cls = dict(
1003
- filter=FilterStepAction,
1004
- # rename=RenameAction,
1005
- # remapping=RemappingAction,
1006
- ).get(key)
1007
-
1008
- if isinstance(config[key], list):
1009
- args, kwargs = config[key], {}
1010
-
1011
- if isinstance(config[key], dict):
1012
- args, kwargs = [], config[key]
1013
-
1014
- if isinstance(config[key], str):
1015
- args, kwargs = [config[key]], {}
1016
-
1017
- if cls is None:
1018
- if not is_function(key, "filters"):
1019
- raise ValueError(f"Unknown step {key}")
1020
- cls = FunctionStepAction
1021
- args = [key] + args
1022
- # print("========", args)
1023
-
1024
- return cls(context, action_path, previous_step, *args, **kwargs)
1025
-
1026
-
1027
- class FunctionContext:
1028
- """A FunctionContext is passed to all functions, it will be used to pass information
1029
- to the functions from the other actions and filters and results.
1030
- """
1031
-
1032
- def __init__(self, owner):
1033
- self.owner = owner
1034
- self.use_grib_paramid = owner.context.use_grib_paramid
1035
-
1036
- def trace(self, emoji, *args):
1037
- trace(emoji, *args)
1038
-
1039
- def info(self, *args, **kwargs):
1040
- LOG.info(*args, **kwargs)
1041
-
1042
- @property
1043
- def dates_provider(self):
1044
- return self.owner.group_of_dates.provider
1045
-
1046
-
1047
- class ActionContext(Context):
1048
- def __init__(self, /, order_by, flatten_grid, remapping, use_grib_paramid):
1049
- super().__init__()
1050
- self.order_by = order_by
1051
- self.flatten_grid = flatten_grid
1052
- self.remapping = build_remapping(remapping)
1053
- self.use_grib_paramid = use_grib_paramid
1054
-
1055
-
1056
- class InputBuilder:
1057
- def __init__(self, config, data_sources, **kwargs):
1058
- self.kwargs = kwargs
1059
-
1060
- config = deepcopy(config)
1061
- if data_sources:
1062
- config = dict(
1063
- data_sources=dict(
1064
- sources=data_sources,
1065
- input=config,
1066
- )
1067
- )
1068
- self.config = config
1069
- self.action_path = ["input"]
1070
-
1071
- @trace_select
1072
- def select(self, dates):
1073
- """This changes the context."""
1074
- context = ActionContext(**self.kwargs)
1075
- action = action_factory(self.config, context, self.action_path)
1076
- return action.select(dates)
1077
-
1078
- def __repr__(self):
1079
- context = ActionContext(**self.kwargs)
1080
- a = action_factory(self.config, context, self.action_path)
1081
- return repr(a)
1082
-
1083
- def _trace_select(self, dates):
1084
- return f"InputBuilder({shorten(dates)})"
1085
-
1086
-
1087
- build_input = InputBuilder