anemoi-datasets 0.5.12__py3-none-any.whl → 0.5.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/create.py +4 -2
  3. anemoi/datasets/create/__init__.py +22 -6
  4. anemoi/datasets/create/check.py +1 -1
  5. anemoi/datasets/create/functions/__init__.py +15 -1
  6. anemoi/datasets/create/functions/filters/orog_to_z.py +58 -0
  7. anemoi/datasets/create/functions/filters/sum.py +71 -0
  8. anemoi/datasets/create/functions/filters/wz_to_w.py +79 -0
  9. anemoi/datasets/create/functions/sources/accumulations.py +7 -2
  10. anemoi/datasets/create/functions/sources/eccc_fstd.py +16 -0
  11. anemoi/datasets/create/functions/sources/mars.py +5 -1
  12. anemoi/datasets/create/functions/sources/xarray/__init__.py +3 -3
  13. anemoi/datasets/create/functions/sources/xarray/field.py +5 -1
  14. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +10 -1
  15. anemoi/datasets/create/functions/sources/xarray/metadata.py +5 -11
  16. anemoi/datasets/create/functions/sources/xarray/patch.py +44 -0
  17. anemoi/datasets/create/functions/sources/xarray/time.py +15 -0
  18. anemoi/datasets/create/functions/sources/xarray/variable.py +18 -2
  19. anemoi/datasets/create/input/repeated_dates.py +18 -0
  20. anemoi/datasets/create/input/result.py +1 -1
  21. anemoi/datasets/create/statistics/__init__.py +7 -4
  22. anemoi/datasets/create/utils.py +4 -0
  23. anemoi/datasets/data/complement.py +164 -0
  24. anemoi/datasets/data/dataset.py +68 -5
  25. anemoi/datasets/data/ensemble.py +55 -0
  26. anemoi/datasets/data/join.py +1 -2
  27. anemoi/datasets/data/merge.py +3 -0
  28. anemoi/datasets/data/misc.py +34 -1
  29. anemoi/datasets/grids.py +29 -10
  30. {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/METADATA +2 -2
  31. {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/RECORD +35 -29
  32. {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/WHEEL +1 -1
  33. {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/LICENSE +0 -0
  34. {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/entry_points.txt +0 -0
  35. {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/top_level.txt +0 -0
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.5.12'
16
- __version_tuple__ = version_tuple = (0, 5, 12)
15
+ __version__ = version = '0.5.14'
16
+ __version_tuple__ = version_tuple = (0, 5, 14)
@@ -83,11 +83,12 @@ class Create(Command):
83
83
  task("load", options)
84
84
  task("finalise", options)
85
85
 
86
- task("patch", options)
87
-
88
86
  task("init_additions", options)
89
87
  task("run_additions", options)
90
88
  task("finalise_additions", options)
89
+
90
+ task("patch", options)
91
+
91
92
  task("cleanup", options)
92
93
  task("verify", options)
93
94
 
@@ -153,6 +154,7 @@ class Create(Command):
153
154
 
154
155
  with ExecutorClass(max_workers=1) as executor:
155
156
  executor.submit(task, "finalise-additions", options).result()
157
+ executor.submit(task, "patch", options).result()
156
158
  executor.submit(task, "cleanup", options).result()
157
159
  executor.submit(task, "verify", options).result()
158
160
 
@@ -79,7 +79,10 @@ def json_tidy(o):
79
79
  )
80
80
  return o.isoformat()
81
81
 
82
- raise TypeError(repr(o) + " is not JSON serializable")
82
+ if isinstance(o, (np.float32, np.float64)):
83
+ return float(o)
84
+
85
+ raise TypeError(f"{repr(o)} is not JSON serializable {type(o)}")
83
86
 
84
87
 
85
88
  def build_statistics_dates(dates, start, end):
@@ -596,6 +599,8 @@ class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
596
599
  # There is one cube to load for each result.
597
600
  dates = list(result.group_of_dates)
598
601
 
602
+ LOG.debug(f"Loading cube for {len(dates)} dates")
603
+
599
604
  cube = result.get_cube()
600
605
  shape = cube.extended_user_shape
601
606
  dates_in_data = cube.user_coords["valid_datetime"]
@@ -622,10 +627,14 @@ class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
622
627
 
623
628
  check_shape(cube, dates, dates_in_data)
624
629
 
625
- def check_dates_in_data(lst, lst2):
626
- lst2 = [np.datetime64(_) for _ in lst2]
627
- lst = [np.datetime64(_) for _ in lst]
628
- assert lst == lst2, ("Dates in data are not the requested ones:", lst, lst2)
630
+ def check_dates_in_data(dates_in_data, requested_dates):
631
+ requested_dates = [np.datetime64(_) for _ in requested_dates]
632
+ dates_in_data = [np.datetime64(_) for _ in dates_in_data]
633
+ assert dates_in_data == requested_dates, (
634
+ "Dates in data are not the requested ones:",
635
+ dates_in_data,
636
+ requested_dates,
637
+ )
629
638
 
630
639
  check_dates_in_data(dates_in_data, dates)
631
640
 
@@ -638,12 +647,14 @@ class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
638
647
  indexes = dates_to_indexes(self.dates, dates_in_data)
639
648
 
640
649
  array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes)
650
+ LOG.info(f"Loading array shape={shape}, indexes={len(indexes)}")
641
651
  self.load_cube(cube, array)
642
652
 
643
653
  stats = compute_statistics(array.cache, self.variables_names, allow_nans=self._get_allow_nans())
644
654
  self.tmp_statistics.write(indexes, stats, dates=dates_in_data)
645
-
655
+ LOG.info("Flush data array")
646
656
  array.flush()
657
+ LOG.info("Flushed data array")
647
658
 
648
659
  def _get_allow_nans(self):
649
660
  config = self.main_config
@@ -736,6 +747,11 @@ class AdditionsMixin:
736
747
  if not self.delta.total_seconds() % frequency.total_seconds() == 0:
737
748
  LOG.debug(f"Delta {self.delta} is not a multiple of frequency {frequency}. Skipping.")
738
749
  return True
750
+
751
+ if self.dataset.zarr_metadata.get("build", {}).get("additions", None) is False:
752
+ LOG.warning(f"Additions are disabled for {self.path} in the recipe.")
753
+ return True
754
+
739
755
  return False
740
756
 
741
757
  @cached_property
@@ -58,7 +58,7 @@ class DatasetName:
58
58
  raise ValueError(self.error_message)
59
59
 
60
60
  def _parse(self, name):
61
- pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h)-v(\d+)-?([a-zA-Z0-9-]+)?$"
61
+ pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h|\d+m)-v(\d+)-?([a-zA-Z0-9-]+)?$"
62
62
  match = re.match(pattern, name)
63
63
 
64
64
  if not match:
@@ -22,6 +22,7 @@ def assert_is_fieldlist(obj):
22
22
  def import_function(name, kind):
23
23
 
24
24
  from anemoi.transform.filters import filter_registry
25
+ from anemoi.transform.sources import source_registry
25
26
 
26
27
  name = name.replace("-", "_")
27
28
 
@@ -45,7 +46,20 @@ def import_function(name, kind):
45
46
  if filter_registry.lookup(name, return_none=True):
46
47
 
47
48
  def proc(context, data, *args, **kwargs):
48
- return filter_registry.create(name, *args, **kwargs)(data)
49
+ filter = filter_registry.create(name, *args, **kwargs)
50
+ filter.context = context
51
+ # filter = filter_registry.create(context, name, *args, **kwargs)
52
+ return filter.forward(data)
53
+
54
+ return proc
55
+
56
+ if kind == "sources":
57
+ if source_registry.lookup(name, return_none=True):
58
+
59
+ def proc(context, data, *args, **kwargs):
60
+ source = source_registry.create(name, *args, **kwargs)
61
+ # source = source_registry.create(context, name, *args, **kwargs)
62
+ return source.forward(data)
49
63
 
50
64
  return proc
51
65
 
@@ -0,0 +1,58 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ from collections import defaultdict
12
+
13
+ from earthkit.data.indexing.fieldlist import FieldArray
14
+
15
+
16
+ class NewDataField:
17
+ def __init__(self, field, data, new_name):
18
+ self.field = field
19
+ self.data = data
20
+ self.new_name = new_name
21
+
22
+ def to_numpy(self, *args, **kwargs):
23
+ return self.data
24
+
25
+ def metadata(self, key=None, **kwargs):
26
+ if key is None:
27
+ return self.field.metadata(**kwargs)
28
+
29
+ value = self.field.metadata(key, **kwargs)
30
+ if key == "param":
31
+ return self.new_name
32
+ return value
33
+
34
+ def __getattr__(self, name):
35
+ return getattr(self.field, name)
36
+
37
+
38
+ def execute(context, input, orog, z="z"):
39
+ """Convert orography [m] to z (geopotential height)"""
40
+ result = FieldArray()
41
+
42
+ processed_fields = defaultdict(dict)
43
+
44
+ for f in input:
45
+ key = f.metadata(namespace="mars")
46
+ param = key.pop("param")
47
+ if param == orog:
48
+ key = tuple(key.items())
49
+
50
+ if param in processed_fields[key]:
51
+ raise ValueError(f"Duplicate field {param} for {key}")
52
+
53
+ output = f.to_numpy(flatten=True) * 9.80665
54
+ result.append(NewDataField(f, output, z))
55
+ else:
56
+ result.append(f)
57
+
58
+ return result
@@ -0,0 +1,71 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ from collections import defaultdict
12
+
13
+ from earthkit.data.indexing.fieldlist import FieldArray
14
+
15
+
16
+ class NewDataField:
17
+ def __init__(self, field, data, new_name):
18
+ self.field = field
19
+ self.data = data
20
+ self.new_name = new_name
21
+
22
+ def to_numpy(self, *args, **kwargs):
23
+ return self.data
24
+
25
+ def metadata(self, key=None, **kwargs):
26
+ if key is None:
27
+ return self.field.metadata(**kwargs)
28
+
29
+ value = self.field.metadata(key, **kwargs)
30
+ if key == "param":
31
+ return self.new_name
32
+ return value
33
+
34
+ def __getattr__(self, name):
35
+ return getattr(self.field, name)
36
+
37
+
38
+ def execute(context, input, params, output):
39
+ """Computes the sum over a set of variables"""
40
+ result = FieldArray()
41
+
42
+ needed_fields = defaultdict(dict)
43
+
44
+ for f in input:
45
+ key = f.metadata(namespace="mars")
46
+ param = key.pop("param")
47
+ if param in params:
48
+ key = tuple(key.items())
49
+
50
+ if param in needed_fields[key]:
51
+ raise ValueError(f"Duplicate field {param} for {key}")
52
+
53
+ needed_fields[key][param] = f
54
+ else:
55
+ result.append(f)
56
+
57
+ for keys, values in needed_fields.items():
58
+
59
+ if len(values) != len(params):
60
+ raise ValueError("Missing fields")
61
+
62
+ s = None
63
+ for k, v in values.items():
64
+ c = v.to_numpy(flatten=True)
65
+ if s is None:
66
+ s = c
67
+ else:
68
+ s += c
69
+ result.append(NewDataField(values[list(values.keys())[0]], s, output))
70
+
71
+ return result
@@ -0,0 +1,79 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ from collections import defaultdict
12
+
13
+ from earthkit.data.indexing.fieldlist import FieldArray
14
+
15
+
16
+ class NewDataField:
17
+ def __init__(self, field, data, new_name):
18
+ self.field = field
19
+ self.data = data
20
+ self.new_name = new_name
21
+
22
+ def to_numpy(self, *args, **kwargs):
23
+ return self.data
24
+
25
+ def metadata(self, key=None, **kwargs):
26
+ if key is None:
27
+ return self.field.metadata(**kwargs)
28
+
29
+ value = self.field.metadata(key, **kwargs)
30
+ if key == "param":
31
+ return self.new_name
32
+ return value
33
+
34
+ def __getattr__(self, name):
35
+ return getattr(self.field, name)
36
+
37
+
38
+ def execute(context, input, wz, t, w="w"):
39
+ """Convert geometric vertical velocity (m/s) to vertical velocity (Pa / s)"""
40
+ result = FieldArray()
41
+
42
+ params = (wz, t)
43
+ pairs = defaultdict(dict)
44
+
45
+ for f in input:
46
+ key = f.metadata(namespace="mars")
47
+ param = key.pop("param")
48
+ if param in params:
49
+ key = tuple(key.items())
50
+
51
+ if param in pairs[key]:
52
+ raise ValueError(f"Duplicate field {param} for {key}")
53
+
54
+ pairs[key][param] = f
55
+ if param == t:
56
+ result.append(f)
57
+ else:
58
+ result.append(f)
59
+
60
+ for keys, values in pairs.items():
61
+
62
+ if len(values) != 2:
63
+ raise ValueError("Missing fields")
64
+
65
+ wz_pl = values[wz].to_numpy(flatten=True)
66
+ t_pl = values[t].to_numpy(flatten=True)
67
+ pressure = keys[4][1] * 100 # TODO: REMOVE HARDCODED INDICES
68
+
69
+ w_pl = wz_to_w(wz_pl, t_pl, pressure)
70
+ result.append(NewDataField(values[wz], w_pl, w))
71
+
72
+ return result
73
+
74
+
75
+ def wz_to_w(wz, t, pressure):
76
+ g = 9.81
77
+ Rd = 287.058
78
+
79
+ return -wz * g * pressure / (t * Rd)
@@ -253,6 +253,7 @@ def _compute_accumulations(
253
253
  data_accumulation_period=None,
254
254
  patch=_identity,
255
255
  base_times=None,
256
+ use_cdsapi_dataset=None,
256
257
  ):
257
258
  adjust_step = isinstance(user_accumulation_period, int)
258
259
 
@@ -311,7 +312,9 @@ def _compute_accumulations(
311
312
 
312
313
  requests.append(patch(r))
313
314
 
314
- ds = mars(context, dates, *requests, request_already_using_valid_datetime=True)
315
+ ds = mars(
316
+ context, dates, *requests, request_already_using_valid_datetime=True, use_cdsapi_dataset=use_cdsapi_dataset
317
+ )
315
318
 
316
319
  accumulations = {}
317
320
  for a in [AccumulationClass(out, frequency=frequency, **r) for r in requests]:
@@ -366,7 +369,7 @@ def _scda(request):
366
369
  return request
367
370
 
368
371
 
369
- def accumulations(context, dates, **request):
372
+ def accumulations(context, dates, use_cdsapi_dataset=None, **request):
370
373
  _to_list(request["param"])
371
374
  class_ = request.get("class", "od")
372
375
  stream = request.get("stream", "oper")
@@ -379,6 +382,7 @@ def accumulations(context, dates, **request):
379
382
  KWARGS = {
380
383
  ("od", "oper"): dict(patch=_scda),
381
384
  ("od", "elda"): dict(base_times=(6, 18)),
385
+ ("od", "enfo"): dict(base_times=(0, 6, 12, 18)),
382
386
  ("ea", "oper"): dict(data_accumulation_period=1, base_times=(6, 18)),
383
387
  ("ea", "enda"): dict(data_accumulation_period=3, base_times=(6, 18)),
384
388
  ("rr", "oper"): dict(base_times=(0, 3, 6, 9, 12, 15, 18, 21)),
@@ -394,6 +398,7 @@ def accumulations(context, dates, **request):
394
398
  dates,
395
399
  request,
396
400
  user_accumulation_period=user_accumulation_period,
401
+ use_cdsapi_dataset=use_cdsapi_dataset,
397
402
  **kwargs,
398
403
  )
399
404
 
@@ -0,0 +1,16 @@
1
+ # (C) Copyright 2025 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ from .xarray import load_many
12
+
13
+
14
+ def execute(context, dates, path, *args, **kwargs):
15
+ options = {"engine": "fstd"}
16
+ return load_many("🍁", context, dates, path, *args, options=options, **kwargs)
@@ -246,6 +246,7 @@ def mars(
246
246
  *requests,
247
247
  request_already_using_valid_datetime=False,
248
248
  date_key="date",
249
+ use_cdsapi_dataset=None,
249
250
  **kwargs,
250
251
  ):
251
252
 
@@ -305,7 +306,10 @@ def mars(
305
306
  f"⚠️ Unknown key {k}={v} in MARS request. Did you mean '{did_you_mean(k, MARS_KEYS)}' ?"
306
307
  )
307
308
  try:
308
- ds = ds + from_source("mars", **r)
309
+ if use_cdsapi_dataset:
310
+ ds = ds + from_source("cds", use_cdsapi_dataset, r)
311
+ else:
312
+ ds = ds + from_source("mars", **r)
309
313
  except Exception as e:
310
314
  if "File is empty:" not in str(e):
311
315
  raise
@@ -29,7 +29,7 @@ def check(what, ds, paths, **kwargs):
29
29
  raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, {what}s={paths})")
30
30
 
31
31
 
32
- def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs):
32
+ def load_one(emoji, context, dates, dataset, *, options={}, flavour=None, patch=None, **kwargs):
33
33
  import xarray as xr
34
34
 
35
35
  """
@@ -54,10 +54,10 @@ def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs)
54
54
  else:
55
55
  data = xr.open_dataset(dataset, **options)
56
56
 
57
- fs = XarrayFieldList.from_xarray(data, flavour)
57
+ fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
58
58
 
59
59
  if len(dates) == 0:
60
- return fs.sel(**kwargs)
60
+ result = fs.sel(**kwargs)
61
61
  else:
62
62
  result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])
63
63
 
@@ -92,6 +92,10 @@ class XArrayField(Field):
92
92
  def grid_points(self):
93
93
  return self.owner.grid_points()
94
94
 
95
+ def to_latlon(self, flatten=True):
96
+ assert flatten
97
+ return dict(lat=self.latitudes, lon=self.longitudes)
98
+
95
99
  @property
96
100
  def resolution(self):
97
101
  return None
@@ -120,6 +124,6 @@ class XArrayField(Field):
120
124
  def __repr__(self):
121
125
  return repr(self._metadata)
122
126
 
123
- def _values(self):
127
+ def _values(self, dtype=None):
124
128
  # we don't use .values as this will download the data
125
129
  return self.selection
@@ -16,6 +16,7 @@ from earthkit.data.core.fieldlist import FieldList
16
16
 
17
17
  from .field import EmptyFieldList
18
18
  from .flavour import CoordinateGuesser
19
+ from .patch import patch_dataset
19
20
  from .time import Time
20
21
  from .variable import FilteredVariable
21
22
  from .variable import Variable
@@ -49,7 +50,11 @@ class XarrayFieldList(FieldList):
49
50
  raise IndexError(k)
50
51
 
51
52
  @classmethod
52
- def from_xarray(cls, ds, flavour=None):
53
+ def from_xarray(cls, ds, *, flavour=None, patch=None):
54
+
55
+ if patch is not None:
56
+ ds = patch_dataset(ds, patch)
57
+
53
58
  variables = []
54
59
 
55
60
  if isinstance(flavour, str):
@@ -83,6 +88,8 @@ class XarrayFieldList(FieldList):
83
88
  _skip_attr(variable, "bounds")
84
89
  _skip_attr(variable, "grid_mapping")
85
90
 
91
+ LOG.debug("Xarray data_vars: %s", ds.data_vars)
92
+
86
93
  # Select only geographical variables
87
94
  for name in ds.data_vars:
88
95
 
@@ -97,6 +104,7 @@ class XarrayFieldList(FieldList):
97
104
  c = guess.guess(ds[coord], coord)
98
105
  assert c, f"Could not guess coordinate for {coord}"
99
106
  if coord not in variable.dims:
107
+ LOG.debug("%s: coord=%s (not a dimension): dims=%s", variable, coord, variable.dims)
100
108
  c.is_dim = False
101
109
  coordinates.append(c)
102
110
 
@@ -104,6 +112,7 @@ class XarrayFieldList(FieldList):
104
112
  assert grid_coords <= 2
105
113
 
106
114
  if grid_coords < 2:
115
+ LOG.debug("Skipping %s (not 2D): %s", variable, [(c, c.is_grid, c.is_dim) for c in coordinates])
107
116
  continue
108
117
 
109
118
  v = Variable(
@@ -24,6 +24,7 @@ class _MDMapping:
24
24
  def __init__(self, variable):
25
25
  self.variable = variable
26
26
  self.time = variable.time
27
+ # Aliases
27
28
  self.mapping = dict(param="variable")
28
29
  for c in variable.coordinates:
29
30
  for v in c.mars_names:
@@ -34,7 +35,6 @@ class _MDMapping:
34
35
  return self.mapping.get(key, key)
35
36
 
36
37
  def from_user(self, kwargs):
37
- print("from_user", kwargs, self)
38
38
  return {self._from_user(k): v for k, v in kwargs.items()}
39
39
 
40
40
  def __repr__(self):
@@ -81,22 +81,16 @@ class XArrayMetadata(RawMetadata):
81
81
  def _valid_datetime(self):
82
82
  return self._get("valid_datetime")
83
83
 
84
- def _get(self, key, **kwargs):
84
+ def get(self, key, astype=None, **kwargs):
85
85
 
86
86
  if key in self._d:
87
+ if astype is not None:
88
+ return astype(self._d[key])
87
89
  return self._d[key]
88
90
 
89
- if key.startswith("mars."):
90
- key = key[5:]
91
- if key not in self.MARS_KEYS:
92
- if kwargs.get("raise_on_missing", False):
93
- raise KeyError(f"Invalid key '{key}' in namespace='mars'")
94
- else:
95
- return kwargs.get("default", None)
96
-
97
91
  key = self._mapping._from_user(key)
98
92
 
99
- return super()._get(key, **kwargs)
93
+ return super().get(key, astype=astype, **kwargs)
100
94
 
101
95
 
102
96
  class XArrayFieldGeography(Geography):
@@ -0,0 +1,44 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ import logging
12
+
13
+ LOG = logging.getLogger(__name__)
14
+
15
+
16
+ def patch_attributes(ds, attributes):
17
+ for name, value in attributes.items():
18
+ variable = ds[name]
19
+ variable.attrs.update(value)
20
+
21
+ return ds
22
+
23
+
24
+ def patch_coordinates(ds, coordinates):
25
+ for name in coordinates:
26
+ ds = ds.assign_coords({name: ds[name]})
27
+
28
+ return ds
29
+
30
+
31
+ PATCHES = {
32
+ "attributes": patch_attributes,
33
+ "coordinates": patch_coordinates,
34
+ }
35
+
36
+
37
+ def patch_dataset(ds, patch):
38
+ for what, values in patch.items():
39
+ if what not in PATCHES:
40
+ raise ValueError(f"Unknown patch type {what!r}")
41
+
42
+ ds = PATCHES[what](ds, values)
43
+
44
+ return ds
@@ -62,12 +62,18 @@ class Time:
62
62
 
63
63
  raise NotImplementedError(f"{len(date_coordinate)=} {len(time_coordinate)=} {len(step_coordinate)=}")
64
64
 
65
+ def select_valid_datetime(self, variable):
66
+ raise NotImplementedError(f"{self.__class__.__name__}.select_valid_datetime()")
67
+
65
68
 
66
69
  class Constant(Time):
67
70
 
68
71
  def fill_time_metadata(self, coords_values, metadata):
69
72
  return None
70
73
 
74
+ def select_valid_datetime(self, variable):
75
+ return None
76
+
71
77
 
72
78
  class Analysis(Time):
73
79
 
@@ -83,6 +89,9 @@ class Analysis(Time):
83
89
 
84
90
  return valid_datetime
85
91
 
92
+ def select_valid_datetime(self, variable):
93
+ return self.time_coordinate_name
94
+
86
95
 
87
96
  class ForecastFromValidTimeAndStep(Time):
88
97
 
@@ -116,6 +125,9 @@ class ForecastFromValidTimeAndStep(Time):
116
125
 
117
126
  return valid_datetime
118
127
 
128
+ def select_valid_datetime(self, variable):
129
+ return self.time_coordinate_name
130
+
119
131
 
120
132
  class ForecastFromValidTimeAndBaseTime(Time):
121
133
 
@@ -138,6 +150,9 @@ class ForecastFromValidTimeAndBaseTime(Time):
138
150
 
139
151
  return valid_datetime
140
152
 
153
+ def select_valid_datetime(self, variable):
154
+ return self.time_coordinate_name
155
+
141
156
 
142
157
  class ForecastFromBaseTimeAndDate(Time):
143
158