anemoi-datasets 0.5.7__py3-none-any.whl → 0.5.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. anemoi/datasets/__init__.py +11 -3
  2. anemoi/datasets/__main__.py +2 -3
  3. anemoi/datasets/_version.py +2 -2
  4. anemoi/datasets/commands/__init__.py +2 -3
  5. anemoi/datasets/commands/cleanup.py +9 -0
  6. anemoi/datasets/commands/compare.py +3 -3
  7. anemoi/datasets/commands/copy.py +38 -68
  8. anemoi/datasets/commands/create.py +20 -5
  9. anemoi/datasets/commands/finalise-additions.py +9 -0
  10. anemoi/datasets/commands/finalise.py +9 -0
  11. anemoi/datasets/commands/init-additions.py +9 -0
  12. anemoi/datasets/commands/init.py +9 -0
  13. anemoi/datasets/commands/inspect.py +3 -1
  14. anemoi/datasets/commands/load-additions.py +9 -0
  15. anemoi/datasets/commands/load.py +9 -0
  16. anemoi/datasets/commands/patch.py +9 -0
  17. anemoi/datasets/commands/publish.py +9 -0
  18. anemoi/datasets/commands/scan.py +9 -0
  19. anemoi/datasets/compute/__init__.py +8 -0
  20. anemoi/datasets/compute/recentre.py +3 -2
  21. anemoi/datasets/create/__init__.py +62 -12
  22. anemoi/datasets/create/check.py +4 -3
  23. anemoi/datasets/create/chunks.py +3 -2
  24. anemoi/datasets/create/config.py +5 -5
  25. anemoi/datasets/create/functions/__init__.py +22 -7
  26. anemoi/datasets/create/functions/filters/__init__.py +2 -1
  27. anemoi/datasets/create/functions/filters/empty.py +3 -2
  28. anemoi/datasets/create/functions/filters/noop.py +2 -2
  29. anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py +3 -2
  30. anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py +3 -2
  31. anemoi/datasets/create/functions/filters/rename.py +16 -11
  32. anemoi/datasets/create/functions/filters/rotate_winds.py +3 -2
  33. anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py +3 -2
  34. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py +3 -2
  35. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py +2 -2
  36. anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py +2 -2
  37. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +3 -2
  38. anemoi/datasets/create/functions/filters/unrotate_winds.py +3 -2
  39. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +3 -2
  40. anemoi/datasets/create/functions/sources/__init__.py +2 -2
  41. anemoi/datasets/create/functions/sources/accumulations.py +10 -4
  42. anemoi/datasets/create/functions/sources/constants.py +3 -2
  43. anemoi/datasets/create/functions/sources/empty.py +3 -2
  44. anemoi/datasets/create/functions/sources/forcings.py +3 -2
  45. anemoi/datasets/create/functions/sources/grib.py +8 -2
  46. anemoi/datasets/create/functions/sources/hindcasts.py +3 -2
  47. anemoi/datasets/create/functions/sources/mars.py +97 -17
  48. anemoi/datasets/create/functions/sources/netcdf.py +3 -2
  49. anemoi/datasets/create/functions/sources/opendap.py +2 -2
  50. anemoi/datasets/create/functions/sources/recentre.py +3 -2
  51. anemoi/datasets/create/functions/sources/source.py +3 -2
  52. anemoi/datasets/create/functions/sources/tendencies.py +3 -2
  53. anemoi/datasets/create/functions/sources/xarray/__init__.py +8 -3
  54. anemoi/datasets/create/functions/sources/xarray/coordinates.py +3 -2
  55. anemoi/datasets/create/functions/sources/xarray/field.py +6 -5
  56. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +12 -4
  57. anemoi/datasets/create/functions/sources/xarray/flavour.py +2 -2
  58. anemoi/datasets/create/functions/sources/xarray/grid.py +2 -2
  59. anemoi/datasets/create/functions/sources/xarray/metadata.py +3 -2
  60. anemoi/datasets/create/functions/sources/xarray/time.py +2 -2
  61. anemoi/datasets/create/functions/sources/xarray/variable.py +6 -9
  62. anemoi/datasets/create/functions/sources/xarray_kerchunk.py +2 -2
  63. anemoi/datasets/create/functions/sources/xarray_zarr.py +2 -2
  64. anemoi/datasets/create/functions/sources/zenodo.py +2 -2
  65. anemoi/datasets/create/input/__init__.py +3 -17
  66. anemoi/datasets/create/input/action.py +3 -8
  67. anemoi/datasets/create/input/concat.py +3 -2
  68. anemoi/datasets/create/input/context.py +3 -8
  69. anemoi/datasets/create/input/data_sources.py +3 -9
  70. anemoi/datasets/create/input/empty.py +3 -9
  71. anemoi/datasets/create/input/filter.py +3 -9
  72. anemoi/datasets/create/input/function.py +3 -9
  73. anemoi/datasets/create/input/join.py +3 -2
  74. anemoi/datasets/create/input/misc.py +3 -8
  75. anemoi/datasets/create/input/pipe.py +9 -3
  76. anemoi/datasets/create/input/repeated_dates.py +14 -8
  77. anemoi/datasets/create/input/result.py +154 -12
  78. anemoi/datasets/create/input/step.py +4 -9
  79. anemoi/datasets/create/input/template.py +3 -2
  80. anemoi/datasets/create/input/trace.py +3 -2
  81. anemoi/datasets/create/patch.py +9 -1
  82. anemoi/datasets/create/persistent.py +3 -2
  83. anemoi/datasets/create/size.py +3 -2
  84. anemoi/datasets/create/statistics/__init__.py +3 -2
  85. anemoi/datasets/create/statistics/summary.py +3 -2
  86. anemoi/datasets/create/utils.py +15 -2
  87. anemoi/datasets/create/writer.py +3 -2
  88. anemoi/datasets/create/zarr.py +3 -2
  89. anemoi/datasets/data/__init__.py +27 -1
  90. anemoi/datasets/data/concat.py +5 -1
  91. anemoi/datasets/data/dataset.py +216 -37
  92. anemoi/datasets/data/debug.py +4 -1
  93. anemoi/datasets/data/ensemble.py +4 -1
  94. anemoi/datasets/data/fill_missing.py +165 -0
  95. anemoi/datasets/data/forwards.py +23 -1
  96. anemoi/datasets/data/grids.py +236 -58
  97. anemoi/datasets/data/indexing.py +4 -1
  98. anemoi/datasets/data/interpolate.py +4 -1
  99. anemoi/datasets/data/join.py +12 -9
  100. anemoi/datasets/data/masked.py +36 -10
  101. anemoi/datasets/data/merge.py +180 -0
  102. anemoi/datasets/data/misc.py +18 -3
  103. anemoi/datasets/data/missing.py +4 -1
  104. anemoi/datasets/data/rescale.py +4 -1
  105. anemoi/datasets/data/select.py +4 -1
  106. anemoi/datasets/data/statistics.py +4 -1
  107. anemoi/datasets/data/stores.py +66 -3
  108. anemoi/datasets/data/subset.py +6 -1
  109. anemoi/datasets/data/unchecked.py +4 -1
  110. anemoi/datasets/data/xy.py +20 -5
  111. anemoi/datasets/dates/__init__.py +9 -7
  112. anemoi/datasets/dates/groups.py +4 -2
  113. anemoi/datasets/grids.py +86 -2
  114. anemoi/datasets/testing.py +3 -2
  115. anemoi/datasets/utils/__init__.py +8 -0
  116. anemoi/datasets/utils/fields.py +2 -2
  117. {anemoi_datasets-0.5.7.dist-info → anemoi_datasets-0.5.11.dist-info}/METADATA +11 -29
  118. anemoi_datasets-0.5.11.dist-info/RECORD +123 -0
  119. {anemoi_datasets-0.5.7.dist-info → anemoi_datasets-0.5.11.dist-info}/WHEEL +1 -1
  120. anemoi/datasets/fields.py +0 -66
  121. anemoi_datasets-0.5.7.dist-info/RECORD +0 -122
  122. {anemoi_datasets-0.5.7.dist-info → anemoi_datasets-0.5.11.dist-info}/LICENSE +0 -0
  123. {anemoi_datasets-0.5.7.dist-info → anemoi_datasets-0.5.11.dist-info}/entry_points.txt +0 -0
  124. {anemoi_datasets-0.5.7.dist-info → anemoi_datasets-0.5.11.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
1
- # (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
2
3
  # This software is licensed under the terms of the Apache Licence Version 2.0
3
4
  # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
4
6
  # In applying this licence, ECMWF does not waive the privileges and immunities
5
7
  # granted to it by virtue of its status as an intergovernmental organisation
6
8
  # nor does it submit to any jurisdiction.
@@ -25,7 +27,31 @@ class MissingDateError(Exception):
25
27
  pass
26
28
 
27
29
 
30
+ def _convert(x):
31
+
32
+ if isinstance(x, list):
33
+ return [_convert(a) for a in x]
34
+
35
+ if isinstance(x, tuple):
36
+ return tuple(_convert(a) for a in x)
37
+
38
+ if isinstance(x, dict):
39
+ return {k: _convert(v) for k, v in x.items()}
40
+
41
+ if x.__class__.__name__ in ("DictConfig", "ListConfig"):
42
+ from omegaconf import OmegaConf
43
+
44
+ return OmegaConf.to_container(x, resolve=True)
45
+
46
+ return x
47
+
48
+
28
49
  def open_dataset(*args, **kwargs):
50
+
51
+ # That will get rid of OmegaConf objects
52
+
53
+ args, kwargs = _convert(args), _convert(kwargs)
54
+
29
55
  ds = _open_dataset(*args, **kwargs)
30
56
  ds = ds.mutate()
31
57
  ds.arguments = {"args": args, "kwargs": kwargs}
@@ -1,10 +1,13 @@
1
- # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
2
3
  # This software is licensed under the terms of the Apache Licence Version 2.0
3
4
  # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
4
6
  # In applying this licence, ECMWF does not waive the privileges and immunities
5
7
  # granted to it by virtue of its status as an intergovernmental organisation
6
8
  # nor does it submit to any jurisdiction.
7
9
 
10
+
8
11
  import logging
9
12
  from functools import cached_property
10
13
 
@@ -148,6 +151,7 @@ def concat_factory(args, kwargs):
148
151
 
149
152
  datasets = kwargs.pop("concat")
150
153
  fill_missing_gaps = kwargs.pop("fill_missing_gaps", False)
154
+
151
155
  assert isinstance(datasets, (list, tuple))
152
156
  assert len(args) == 0
153
157
 
@@ -1,14 +1,16 @@
1
- # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
2
3
  # This software is licensed under the terms of the Apache Licence Version 2.0
3
4
  # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
4
6
  # In applying this licence, ECMWF does not waive the privileges and immunities
5
7
  # granted to it by virtue of its status as an intergovernmental organisation
6
8
  # nor does it submit to any jurisdiction.
7
9
 
10
+
8
11
  import datetime
9
12
  import json
10
13
  import logging
11
- import os
12
14
  import pprint
13
15
  import warnings
14
16
  from functools import cached_property
@@ -20,14 +22,38 @@ from anemoi.utils.dates import frequency_to_timedelta
20
22
  LOG = logging.getLogger(__name__)
21
23
 
22
24
 
25
+ def _tidy(v):
26
+ if isinstance(v, (list, tuple, set)):
27
+ return [_tidy(i) for i in v]
28
+ if isinstance(v, dict):
29
+ return {k: _tidy(v) for k, v in v.items()}
30
+ if isinstance(v, datetime.datetime):
31
+ return v.isoformat()
32
+ if isinstance(v, datetime.date):
33
+ return v.isoformat()
34
+ if isinstance(v, datetime.timedelta):
35
+ return frequency_to_string(v)
36
+
37
+ if isinstance(v, Dataset):
38
+ # That can happen in the `arguments`
39
+ # if a dataset is passed as an argument
40
+ return repr(v)
41
+
42
+ if isinstance(v, slice):
43
+ return (v.start, v.stop, v.step)
44
+
45
+ return v
46
+
47
+
23
48
  class Dataset:
24
49
  arguments = {}
50
+ _name = None
25
51
 
26
52
  def mutate(self) -> "Dataset":
27
- """
28
- Give an opportunity to a subclass to return a new Dataset
53
+ """Give an opportunity to a subclass to return a new Dataset
29
54
  object of a different class, if needed.
30
55
  """
56
+
31
57
  return self
32
58
 
33
59
  def swap_with_parent(self, parent):
@@ -38,9 +64,32 @@ class Dataset:
38
64
  return len(self)
39
65
 
40
66
  def _subset(self, **kwargs):
67
+
68
+ if not kwargs:
69
+ return self.mutate()
70
+
71
+ name = kwargs.pop("name", None)
72
+ result = self.__subset(**kwargs)
73
+ result._name = name
74
+
75
+ return result
76
+
77
+ @property
78
+ def name(self):
79
+ return self._name
80
+
81
+ def __subset(self, **kwargs):
41
82
  if not kwargs:
42
83
  return self.mutate()
43
84
 
85
+ # This one must be first
86
+ if "fill_missing_dates" in kwargs:
87
+ from .fill_missing import fill_missing_dates_factory
88
+
89
+ fill_missing_dates = kwargs.pop("fill_missing_dates")
90
+ ds = fill_missing_dates_factory(self, fill_missing_dates, kwargs)
91
+ return ds._subset(**kwargs).mutate()
92
+
44
93
  if "start" in kwargs or "end" in kwargs:
45
94
  start = kwargs.pop("start", None)
46
95
  end = kwargs.pop("end", None)
@@ -64,12 +113,6 @@ class Dataset:
64
113
  .mutate()
65
114
  )
66
115
 
67
- if "interpolate_frequency" in kwargs:
68
- from .interpolate import InterpolateFrequency
69
-
70
- interpolate_frequency = kwargs.pop("interpolate_frequency")
71
- return InterpolateFrequency(self, interpolate_frequency)._subset(**kwargs).mutate()
72
-
73
116
  if "select" in kwargs:
74
117
  from .select import Select
75
118
 
@@ -121,11 +164,11 @@ class Dataset:
121
164
  bbox = kwargs.pop("area")
122
165
  return Cropping(self, bbox)._subset(**kwargs).mutate()
123
166
 
124
- if "missing_dates" in kwargs:
167
+ if "set_missing_dates" in kwargs:
125
168
  from .missing import MissingDates
126
169
 
127
- missing_dates = kwargs.pop("missing_dates")
128
- return MissingDates(self, missing_dates)._subset(**kwargs).mutate()
170
+ set_missing_dates = kwargs.pop("set_missing_dates")
171
+ return MissingDates(self, set_missing_dates)._subset(**kwargs).mutate()
129
172
 
130
173
  if "skip_missing_dates" in kwargs:
131
174
  from .missing import SkipMissingDates
@@ -139,6 +182,12 @@ class Dataset:
139
182
  if skip_missing_dates:
140
183
  return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate()
141
184
 
185
+ if "interpolate_frequency" in kwargs:
186
+ from .interpolate import InterpolateFrequency
187
+
188
+ interpolate_frequency = kwargs.pop("interpolate_frequency")
189
+ return InterpolateFrequency(self, interpolate_frequency)._subset(**kwargs).mutate()
190
+
142
191
  # Keep last
143
192
  if "shuffle" in kwargs:
144
193
  from .subset import Subset
@@ -222,41 +271,53 @@ class Dataset:
222
271
  shape.pop(drop_axis)
223
272
  return tuple(shape)
224
273
 
274
+ @property
275
+ def typed_variables(self):
276
+ from anemoi.transform.variables import Variable
277
+
278
+ constants = self.constant_fields
279
+
280
+ result = {}
281
+ for k, v in self.variables_metadata.items():
282
+
283
+ # TODO: Once all datasets are updated, we can remove this
284
+ v = v.copy()
285
+ if k in constants:
286
+ v["constant_in_time"] = True
287
+
288
+ if "is_constant_in_time" in v:
289
+ del v["is_constant_in_time"]
290
+
291
+ result[k] = Variable.from_dict(k, v)
292
+
293
+ return result
294
+
295
+ def _input_sources(self):
296
+ sources = []
297
+ self.collect_input_sources(sources)
298
+ return sources
299
+
225
300
  def metadata(self):
226
301
  import anemoi
227
302
 
228
- def tidy(v):
229
- if isinstance(v, (list, tuple, set)):
230
- return [tidy(i) for i in v]
231
- if isinstance(v, dict):
232
- return {k: tidy(v) for k, v in v.items()}
233
- if isinstance(v, str) and v.startswith("/"):
234
- return os.path.basename(v)
235
- if isinstance(v, datetime.datetime):
236
- return v.isoformat()
237
- if isinstance(v, datetime.date):
238
- return v.isoformat()
239
- if isinstance(v, datetime.timedelta):
240
- return frequency_to_string(v)
241
-
242
- if isinstance(v, Dataset):
243
- # That can happen in the `arguments`
244
- # if a dataset is passed as an argument
245
- return repr(v)
246
-
247
- if isinstance(v, slice):
248
- return (v.start, v.stop, v.step)
249
-
250
- return v
303
+ _, source_to_arrays = self._supporting_arrays_and_sources()
304
+
305
+ sources = []
306
+ for i, source in enumerate(self._input_sources()):
307
+ source_metadata = source.dataset_metadata().copy()
308
+ source_metadata["supporting_arrays"] = source_to_arrays[id(source)]
309
+ sources.append(source_metadata)
251
310
 
252
311
  md = dict(
253
312
  version=anemoi.datasets.__version__,
254
313
  arguments=self.arguments,
255
314
  **self.dataset_metadata(),
315
+ sources=sources,
316
+ supporting_arrays=source_to_arrays[id(self)],
256
317
  )
257
318
 
258
319
  try:
259
- return json.loads(json.dumps(tidy(md)))
320
+ return json.loads(json.dumps(_tidy(md)))
260
321
  except Exception:
261
322
  LOG.exception("Failed to serialize metadata")
262
323
  pprint.pprint(md)
@@ -276,11 +337,72 @@ class Dataset:
276
337
  specific=self.metadata_specific(),
277
338
  frequency=self.frequency,
278
339
  variables=self.variables,
340
+ variables_metadata=self.variables_metadata,
279
341
  shape=self.shape,
342
+ dtype=str(self.dtype),
280
343
  start_date=self.start_date.astype(str),
281
344
  end_date=self.end_date.astype(str),
345
+ name=self.name,
282
346
  )
283
347
 
348
+ def _supporting_arrays(self, *path):
349
+
350
+ import numpy as np
351
+
352
+ def _path(path, name):
353
+ return "/".join(str(_) for _ in [*path, name])
354
+
355
+ result = {
356
+ _path(path, "latitudes"): self.latitudes,
357
+ _path(path, "longitudes"): self.longitudes,
358
+ }
359
+ collected = []
360
+
361
+ self.collect_supporting_arrays(collected, *path)
362
+
363
+ for path, name, array in collected:
364
+ assert isinstance(path, tuple) and isinstance(name, str)
365
+ assert isinstance(array, np.ndarray)
366
+
367
+ name = _path(path, name)
368
+
369
+ if name in result:
370
+ raise ValueError(f"Duplicate key {name}")
371
+
372
+ result[name] = array
373
+
374
+ return result
375
+
376
+ def supporting_arrays(self):
377
+ """Arrays to be saved in the checkpoints"""
378
+ arrays, _ = self._supporting_arrays_and_sources()
379
+ return arrays
380
+
381
+ def _supporting_arrays_and_sources(self):
382
+
383
+ source_to_arrays = {}
384
+
385
+ # Top levels arrays
386
+ result = self._supporting_arrays()
387
+ source_to_arrays[id(self)] = sorted(result.keys())
388
+
389
+ # Arrays from the input sources
390
+ for i, source in enumerate(self._input_sources()):
391
+ name = source.name if source.name is not None else f"source{i}"
392
+ src_arrays = source._supporting_arrays(name)
393
+ source_to_arrays[id(source)] = sorted(src_arrays.keys())
394
+
395
+ for k in src_arrays:
396
+ assert k not in result
397
+
398
+ result.update(src_arrays)
399
+
400
+ return result, source_to_arrays
401
+
402
+ def collect_supporting_arrays(self, collected, *path):
403
+ # Override this method to add more arrays
404
+ pass
405
+
284
406
  def metadata_specific(self, **kwargs):
285
407
  action = self.__class__.__name__.lower()
286
408
  # assert isinstance(self.frequency, datetime.timedelta), (self.frequency, self, action)
@@ -318,3 +440,60 @@ class Dataset:
318
440
 
319
441
  def get_dataset_names(self, names):
320
442
  raise NotImplementedError(self)
443
+
444
+ def computed_constant_fields(self):
445
+ # Call `constant_fields` instead of `computed_constant_fields`
446
+ try:
447
+ # If the tendencies are computed, we can use them
448
+ return sorted(self._compute_constant_fields_from_statistics())
449
+ except KeyError:
450
+ # This can happen if the tendencies are not computed
451
+ pass
452
+
453
+ return sorted(self._compute_constant_fields_from_a_few_samples())
454
+
455
+ def _compute_constant_fields_from_a_few_samples(self):
456
+
457
+ import numpy as np
458
+
459
+ # Otherwise, we need to compute them
460
+ dates = self.dates
461
+ indices = set(range(len(dates)))
462
+ indices -= self.missing
463
+
464
+ sample_count = min(4, len(indices))
465
+ count = len(indices)
466
+
467
+ p = slice(0, count, count // (sample_count - 1))
468
+ samples = list(range(*p.indices(count)))
469
+
470
+ samples.append(count - 1) # Add last
471
+ samples = sorted(set(samples))
472
+ indices = list(indices)
473
+ samples = [indices[i] for i in samples]
474
+
475
+ assert set(samples) <= set(indices) # Make sure we have the samples
476
+
477
+ first = None
478
+ constants = [True] * len(self.variables)
479
+
480
+ first = self[samples.pop(0)]
481
+
482
+ for sample in samples:
483
+ row = self[sample]
484
+ for i, (a, b) in enumerate(zip(row, first)):
485
+ if np.any(a != b):
486
+ constants[i] = False
487
+
488
+ return [v for i, v in enumerate(self.variables) if constants[i]]
489
+
490
+ def _compute_constant_fields_from_statistics(self):
491
+ result = []
492
+
493
+ t = self.statistics_tendencies()
494
+
495
+ for i, v in enumerate(self.variables):
496
+ if t["mean"][i] == 0 and t["stdev"][i] == 0:
497
+ result.append(v)
498
+
499
+ return result
@@ -1,10 +1,13 @@
1
- # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
2
3
  # This software is licensed under the terms of the Apache Licence Version 2.0
3
4
  # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
4
6
  # In applying this licence, ECMWF does not waive the privileges and immunities
5
7
  # granted to it by virtue of its status as an intergovernmental organisation
6
8
  # nor does it submit to any jurisdiction.
7
9
 
10
+
8
11
  import logging
9
12
  import os
10
13
  import textwrap
@@ -1,10 +1,13 @@
1
- # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
2
3
  # This software is licensed under the terms of the Apache Licence Version 2.0
3
4
  # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
4
6
  # In applying this licence, ECMWF does not waive the privileges and immunities
5
7
  # granted to it by virtue of its status as an intergovernmental organisation
6
8
  # nor does it submit to any jurisdiction.
7
9
 
10
+
8
11
  import logging
9
12
 
10
13
  from .debug import Node
@@ -0,0 +1,165 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ import logging
12
+
13
+ import numpy as np
14
+
15
+ from anemoi.datasets.data import MissingDateError
16
+
17
+ from .debug import Node
18
+ from .debug import debug_indexing
19
+ from .forwards import Forwards
20
+ from .indexing import apply_index_to_slices_changes
21
+ from .indexing import expand_list_indexing
22
+ from .indexing import index_to_slices
23
+ from .indexing import update_tuple
24
+
25
+ LOG = logging.getLogger(__name__)
26
+
27
+
28
+ class MissingDatesFill(Forwards):
29
+ def __init__(self, dataset):
30
+ super().__init__(dataset)
31
+ self._missing = set(dataset.missing)
32
+ self._warnings = set()
33
+
34
+ @debug_indexing
35
+ @expand_list_indexing
36
+ def _get_tuple(self, index):
37
+ index, changes = index_to_slices(index, self.shape)
38
+ index, previous = update_tuple(index, 0, slice(None))
39
+ result = self._get_slice(previous)
40
+ return apply_index_to_slices_changes(result[index], changes)
41
+
42
+ def _get_slice(self, s):
43
+ return np.stack([self[i] for i in range(*s.indices(self._len))])
44
+
45
+ @property
46
+ def missing(self):
47
+ return set()
48
+
49
+ @debug_indexing
50
+ def __getitem__(self, n):
51
+
52
+ try:
53
+ return self.forward[n]
54
+ except MissingDateError:
55
+ pass
56
+
57
+ if isinstance(n, tuple):
58
+ return self._get_tuple(n)
59
+
60
+ if isinstance(n, slice):
61
+ return self._get_slice(n)
62
+
63
+ if n < 0:
64
+ n += self._len
65
+
66
+ a = None
67
+ i = n
68
+ while a is None and i >= 0:
69
+ if i in self._missing:
70
+ i -= 1
71
+ else:
72
+ a = i
73
+
74
+ len = self._len
75
+ b = None
76
+ i = n
77
+ while b is None and n < len:
78
+ if i in self._missing:
79
+ i += 1
80
+ else:
81
+ b = i
82
+
83
+ return self._fill_missing(n, a, b)
84
+
85
+
86
+ class MissingDatesClosest(MissingDatesFill):
87
+
88
+ def __init__(self, dataset, closest):
89
+ super().__init__(dataset)
90
+ self.closest = closest
91
+ self._closest = {}
92
+
93
+ def _fill_missing(self, n, a, b):
94
+
95
+ if n not in self._warnings:
96
+ LOG.warning(f"Missing date at index {n} ({self.dates[n]})")
97
+ if abs(n - a) == abs(b - n):
98
+ if self.closest == "up":
99
+ u = b
100
+ else:
101
+ u = a
102
+ else:
103
+ if abs(n - a) < abs(b - n):
104
+ u = a
105
+ else:
106
+ u = b
107
+ LOG.warning(f"Using closest date {u} ({self.dates[u]})")
108
+
109
+ self._closest[n] = u
110
+ self._warnings.add(n)
111
+
112
+ return self.forward[self._closest[n]]
113
+
114
+ def subclass_metadata_specific(self):
115
+ return {"closest": self.closest}
116
+
117
+ def tree(self):
118
+ return Node(self, [self.forward.tree()], closest=self.closest)
119
+
120
+
121
+ class MissingDatesInterpolate(MissingDatesFill):
122
+ def __init__(self, dataset):
123
+ super().__init__(dataset)
124
+ self._alpha = {}
125
+
126
+ def _fill_missing(self, n, a, b):
127
+ if n not in self._warnings:
128
+ LOG.warning(f"Missing date at index {n} ({self.dates[n]})")
129
+
130
+ if a is None or b is None:
131
+ raise MissingDateError(
132
+ f"Cannot interpolate at index {n} ({self.dates[n]}). Are the first or last date missing?"
133
+ )
134
+
135
+ assert a < n < b, (a, n, b)
136
+
137
+ alpha = (n - a) / (b - a)
138
+ assert 0 < alpha < 1, alpha
139
+
140
+ LOG.warning(f"Interpolating between index {a} ({self.dates[a]}) and {b} ({self.dates[b]})")
141
+ LOG.warning(f"Interpolation {1 - alpha:g} * ({self.dates[a]}) + {alpha:g} * ({self.dates[b]})")
142
+
143
+ self._alpha[n] = alpha
144
+
145
+ self._warnings.add(n)
146
+
147
+ alpha = self._alpha[n]
148
+ return self.forward[a] * (1 - alpha) + self.forward[b] * alpha
149
+
150
+ def subclass_metadata_specific(self):
151
+ return {}
152
+
153
+ def tree(self):
154
+ return Node(self, [self.forward.tree()])
155
+
156
+
157
+ def fill_missing_dates_factory(dataset, method, kwargs):
158
+ if method == "closest":
159
+ closest = kwargs.get("closest", "up")
160
+ return MissingDatesClosest(dataset, closest=closest)
161
+
162
+ if method == "interpolate":
163
+ return MissingDatesInterpolate(dataset)
164
+
165
+ raise ValueError(f"Invalid `fill_missing_dates` method '{method}'")
@@ -1,11 +1,15 @@
1
- # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
2
3
  # This software is licensed under the terms of the Apache Licence Version 2.0
3
4
  # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
4
6
  # In applying this licence, ECMWF does not waive the privileges and immunities
5
7
  # granted to it by virtue of its status as an intergovernmental organisation
6
8
  # nor does it submit to any jurisdiction.
7
9
 
10
+
8
11
  import logging
12
+ import warnings
9
13
  from functools import cached_property
10
14
 
11
15
  import numpy as np
@@ -31,6 +35,12 @@ class Forwards(Dataset):
31
35
  def __getitem__(self, n):
32
36
  return self.forward[n]
33
37
 
38
+ @property
39
+ def name(self):
40
+ if self._name is not None:
41
+ return self._name
42
+ return self.forward.name
43
+
34
44
  @property
35
45
  def dates(self):
36
46
  return self.forward.dates
@@ -99,6 +109,12 @@ class Forwards(Dataset):
99
109
  **kwargs,
100
110
  )
101
111
 
112
+ def collect_supporting_arrays(self, collected, *path):
113
+ self.forward.collect_supporting_arrays(collected, *path)
114
+
115
+ def collect_input_sources(self, collected):
116
+ self.forward.collect_input_sources(collected)
117
+
102
118
  def source(self, index):
103
119
  return self.forward.source(index)
104
120
 
@@ -194,6 +210,12 @@ class Combined(Forwards):
194
210
  **kwargs,
195
211
  )
196
212
 
213
+ def collect_supporting_arrays(self, collected, *path):
214
+ warnings.warn(f"The behaviour of {self.__class__.__name__}.collect_supporting_arrays() is not well defined")
215
+ for i, d in enumerate(self.datasets):
216
+ name = d.name if d.name is not None else i
217
+ d.collect_supporting_arrays(collected, *path, name)
218
+
197
219
  @property
198
220
  def missing(self):
199
221
  raise NotImplementedError("missing() not implemented for Combined")