anemoi-datasets 0.5.16__py3-none-any.whl → 0.5.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. anemoi/datasets/__init__.py +4 -1
  2. anemoi/datasets/__main__.py +12 -2
  3. anemoi/datasets/_version.py +9 -4
  4. anemoi/datasets/commands/cleanup.py +17 -2
  5. anemoi/datasets/commands/compare.py +18 -2
  6. anemoi/datasets/commands/copy.py +196 -14
  7. anemoi/datasets/commands/create.py +50 -7
  8. anemoi/datasets/commands/finalise-additions.py +17 -2
  9. anemoi/datasets/commands/finalise.py +17 -2
  10. anemoi/datasets/commands/init-additions.py +17 -2
  11. anemoi/datasets/commands/init.py +16 -2
  12. anemoi/datasets/commands/inspect.py +283 -62
  13. anemoi/datasets/commands/load-additions.py +16 -2
  14. anemoi/datasets/commands/load.py +16 -2
  15. anemoi/datasets/commands/patch.py +17 -2
  16. anemoi/datasets/commands/publish.py +17 -2
  17. anemoi/datasets/commands/scan.py +31 -3
  18. anemoi/datasets/compute/recentre.py +47 -11
  19. anemoi/datasets/create/__init__.py +612 -85
  20. anemoi/datasets/create/check.py +142 -20
  21. anemoi/datasets/create/chunks.py +64 -4
  22. anemoi/datasets/create/config.py +185 -21
  23. anemoi/datasets/create/filter.py +50 -0
  24. anemoi/datasets/create/filters/__init__.py +33 -0
  25. anemoi/datasets/create/filters/empty.py +37 -0
  26. anemoi/datasets/create/filters/legacy.py +93 -0
  27. anemoi/datasets/create/filters/noop.py +37 -0
  28. anemoi/datasets/create/filters/orog_to_z.py +58 -0
  29. anemoi/datasets/create/{functions/filters → filters}/pressure_level_relative_humidity_to_specific_humidity.py +33 -10
  30. anemoi/datasets/create/{functions/filters → filters}/pressure_level_specific_humidity_to_relative_humidity.py +32 -8
  31. anemoi/datasets/create/filters/rename.py +205 -0
  32. anemoi/datasets/create/{functions/filters → filters}/rotate_winds.py +43 -28
  33. anemoi/datasets/create/{functions/filters → filters}/single_level_dewpoint_to_relative_humidity.py +32 -9
  34. anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_dewpoint.py +33 -9
  35. anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_specific_humidity.py +55 -7
  36. anemoi/datasets/create/{functions/filters → filters}/single_level_specific_humidity_to_relative_humidity.py +98 -37
  37. anemoi/datasets/create/filters/speeddir_to_uv.py +95 -0
  38. anemoi/datasets/create/{functions/filters → filters}/sum.py +24 -27
  39. anemoi/datasets/create/filters/transform.py +53 -0
  40. anemoi/datasets/create/{functions/filters → filters}/unrotate_winds.py +27 -18
  41. anemoi/datasets/create/filters/uv_to_speeddir.py +94 -0
  42. anemoi/datasets/create/{functions/filters → filters}/wz_to_w.py +51 -33
  43. anemoi/datasets/create/input/__init__.py +76 -5
  44. anemoi/datasets/create/input/action.py +149 -13
  45. anemoi/datasets/create/input/concat.py +81 -10
  46. anemoi/datasets/create/input/context.py +39 -4
  47. anemoi/datasets/create/input/data_sources.py +72 -6
  48. anemoi/datasets/create/input/empty.py +21 -3
  49. anemoi/datasets/create/input/filter.py +60 -12
  50. anemoi/datasets/create/input/function.py +154 -37
  51. anemoi/datasets/create/input/join.py +86 -14
  52. anemoi/datasets/create/input/misc.py +67 -17
  53. anemoi/datasets/create/input/pipe.py +33 -6
  54. anemoi/datasets/create/input/repeated_dates.py +189 -41
  55. anemoi/datasets/create/input/result.py +202 -87
  56. anemoi/datasets/create/input/step.py +119 -22
  57. anemoi/datasets/create/input/template.py +100 -13
  58. anemoi/datasets/create/input/trace.py +62 -7
  59. anemoi/datasets/create/patch.py +52 -4
  60. anemoi/datasets/create/persistent.py +134 -17
  61. anemoi/datasets/create/size.py +15 -1
  62. anemoi/datasets/create/source.py +51 -0
  63. anemoi/datasets/create/sources/__init__.py +36 -0
  64. anemoi/datasets/create/{functions/sources → sources}/accumulations.py +296 -30
  65. anemoi/datasets/create/{functions/sources → sources}/constants.py +27 -2
  66. anemoi/datasets/create/{functions/sources → sources}/eccc_fstd.py +7 -3
  67. anemoi/datasets/create/sources/empty.py +37 -0
  68. anemoi/datasets/create/{functions/sources → sources}/forcings.py +25 -1
  69. anemoi/datasets/create/sources/grib.py +297 -0
  70. anemoi/datasets/create/{functions/sources → sources}/hindcasts.py +38 -4
  71. anemoi/datasets/create/sources/legacy.py +93 -0
  72. anemoi/datasets/create/{functions/sources → sources}/mars.py +168 -20
  73. anemoi/datasets/create/sources/netcdf.py +42 -0
  74. anemoi/datasets/create/sources/opendap.py +43 -0
  75. anemoi/datasets/create/{functions/sources/__init__.py → sources/patterns.py} +35 -4
  76. anemoi/datasets/create/sources/recentre.py +150 -0
  77. anemoi/datasets/create/{functions/sources → sources}/source.py +27 -5
  78. anemoi/datasets/create/{functions/sources → sources}/tendencies.py +64 -7
  79. anemoi/datasets/create/sources/xarray.py +92 -0
  80. anemoi/datasets/create/sources/xarray_kerchunk.py +36 -0
  81. anemoi/datasets/create/sources/xarray_support/README.md +1 -0
  82. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/__init__.py +109 -8
  83. anemoi/datasets/create/sources/xarray_support/coordinates.py +442 -0
  84. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/field.py +94 -16
  85. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/fieldlist.py +90 -25
  86. anemoi/datasets/create/sources/xarray_support/flavour.py +1036 -0
  87. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/grid.py +92 -31
  88. anemoi/datasets/create/sources/xarray_support/metadata.py +395 -0
  89. anemoi/datasets/create/sources/xarray_support/patch.py +91 -0
  90. anemoi/datasets/create/sources/xarray_support/time.py +391 -0
  91. anemoi/datasets/create/sources/xarray_support/variable.py +331 -0
  92. anemoi/datasets/create/sources/xarray_zarr.py +41 -0
  93. anemoi/datasets/create/{functions/sources → sources}/zenodo.py +34 -5
  94. anemoi/datasets/create/statistics/__init__.py +233 -44
  95. anemoi/datasets/create/statistics/summary.py +52 -6
  96. anemoi/datasets/create/testing.py +76 -0
  97. anemoi/datasets/create/{functions/filters/noop.py → typing.py} +6 -3
  98. anemoi/datasets/create/utils.py +97 -6
  99. anemoi/datasets/create/writer.py +26 -4
  100. anemoi/datasets/create/zarr.py +170 -23
  101. anemoi/datasets/data/__init__.py +51 -4
  102. anemoi/datasets/data/complement.py +191 -40
  103. anemoi/datasets/data/concat.py +141 -16
  104. anemoi/datasets/data/dataset.py +558 -62
  105. anemoi/datasets/data/debug.py +197 -26
  106. anemoi/datasets/data/ensemble.py +93 -8
  107. anemoi/datasets/data/fill_missing.py +165 -18
  108. anemoi/datasets/data/forwards.py +428 -56
  109. anemoi/datasets/data/grids.py +323 -97
  110. anemoi/datasets/data/indexing.py +112 -19
  111. anemoi/datasets/data/interpolate.py +92 -12
  112. anemoi/datasets/data/join.py +158 -19
  113. anemoi/datasets/data/masked.py +129 -15
  114. anemoi/datasets/data/merge.py +137 -23
  115. anemoi/datasets/data/misc.py +172 -16
  116. anemoi/datasets/data/missing.py +233 -29
  117. anemoi/datasets/data/rescale.py +111 -10
  118. anemoi/datasets/data/select.py +168 -26
  119. anemoi/datasets/data/statistics.py +67 -6
  120. anemoi/datasets/data/stores.py +149 -64
  121. anemoi/datasets/data/subset.py +159 -25
  122. anemoi/datasets/data/unchecked.py +168 -57
  123. anemoi/datasets/data/xy.py +168 -25
  124. anemoi/datasets/dates/__init__.py +191 -16
  125. anemoi/datasets/dates/groups.py +189 -47
  126. anemoi/datasets/grids.py +270 -31
  127. anemoi/datasets/testing.py +28 -1
  128. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.18.dist-info}/METADATA +9 -6
  129. anemoi_datasets-0.5.18.dist-info/RECORD +137 -0
  130. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.18.dist-info}/WHEEL +1 -1
  131. anemoi/datasets/create/functions/__init__.py +0 -66
  132. anemoi/datasets/create/functions/filters/__init__.py +0 -9
  133. anemoi/datasets/create/functions/filters/empty.py +0 -17
  134. anemoi/datasets/create/functions/filters/orog_to_z.py +0 -58
  135. anemoi/datasets/create/functions/filters/rename.py +0 -79
  136. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +0 -78
  137. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +0 -56
  138. anemoi/datasets/create/functions/sources/empty.py +0 -15
  139. anemoi/datasets/create/functions/sources/grib.py +0 -150
  140. anemoi/datasets/create/functions/sources/netcdf.py +0 -15
  141. anemoi/datasets/create/functions/sources/opendap.py +0 -15
  142. anemoi/datasets/create/functions/sources/recentre.py +0 -60
  143. anemoi/datasets/create/functions/sources/xarray/coordinates.py +0 -255
  144. anemoi/datasets/create/functions/sources/xarray/flavour.py +0 -472
  145. anemoi/datasets/create/functions/sources/xarray/metadata.py +0 -148
  146. anemoi/datasets/create/functions/sources/xarray/patch.py +0 -44
  147. anemoi/datasets/create/functions/sources/xarray/time.py +0 -177
  148. anemoi/datasets/create/functions/sources/xarray/variable.py +0 -188
  149. anemoi/datasets/create/functions/sources/xarray_kerchunk.py +0 -42
  150. anemoi/datasets/create/functions/sources/xarray_zarr.py +0 -15
  151. anemoi/datasets/utils/fields.py +0 -47
  152. anemoi_datasets-0.5.16.dist-info/RECORD +0 -129
  153. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.18.dist-info}/entry_points.txt +0 -0
  154. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.18.dist-info/licenses}/LICENSE +0 -0
  155. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.18.dist-info}/top_level.txt +0 -0
@@ -13,17 +13,59 @@ import json
13
13
  import logging
14
14
  import pprint
15
15
  import warnings
16
+ from abc import ABC
17
+ from abc import abstractmethod
16
18
  from functools import cached_property
17
19
 
20
+ try:
21
+ from types import EllipsisType
22
+ except ImportError:
23
+ # Python 3.9
24
+ EllipsisType = type(Ellipsis)
25
+ from typing import TYPE_CHECKING
26
+ from typing import Any
27
+ from typing import Dict
28
+ from typing import List
29
+ from typing import Optional
30
+ from typing import Sequence
31
+ from typing import Set
32
+ from typing import Sized
33
+ from typing import Tuple
34
+ from typing import Union
35
+
18
36
  import numpy as np
19
37
  from anemoi.utils.dates import frequency_to_seconds
20
38
  from anemoi.utils.dates import frequency_to_string
21
39
  from anemoi.utils.dates import frequency_to_timedelta
40
+ from numpy.typing import NDArray
41
+
42
+ from .debug import Node
43
+ from .debug import Source
44
+
45
+ if TYPE_CHECKING:
46
+ import matplotlib
22
47
 
23
48
  LOG = logging.getLogger(__name__)
24
49
 
25
50
 
26
- def _tidy(v):
51
+ Shape = Tuple[int, ...]
52
+ TupleIndex = Tuple[Union[int, slice, EllipsisType], ...]
53
+ FullIndex = Union[int, slice, TupleIndex]
54
+
55
+
56
+ def _tidy(v: Any) -> Any:
57
+ """Tidy up the input value.
58
+
59
+ Parameters
60
+ ----------
61
+ v : Any
62
+ The input value to tidy up.
63
+
64
+ Returns
65
+ -------
66
+ Any
67
+ The tidied value.
68
+ """
27
69
  if isinstance(v, (list, tuple, set)):
28
70
  return [_tidy(i) for i in v]
29
71
  if isinstance(v, dict):
@@ -49,26 +91,53 @@ def _tidy(v):
49
91
  return v
50
92
 
51
93
 
52
- class Dataset:
53
- arguments = {}
54
- _name = None
94
+ class Dataset(ABC, Sized):
95
+ arguments: Dict[str, Any] = {}
96
+ _name: Union[str, None] = None
55
97
 
56
98
  def mutate(self) -> "Dataset":
57
- """Give an opportunity to a subclass to return a new Dataset
58
- object of a different class, if needed.
59
- """
99
+ """Give an opportunity to a subclass to return a new Dataset object of a different class, if needed.
60
100
 
101
+ Returns
102
+ -------
103
+ Dataset
104
+ The mutated dataset.
105
+ """
61
106
  return self
62
107
 
63
- def swap_with_parent(self, parent):
108
+ def swap_with_parent(self, parent: "Dataset") -> "Dataset":
109
+ """Swap the current dataset with its parent dataset.
110
+
111
+ Parameters
112
+ ----------
113
+ parent : Dataset
114
+ The parent dataset.
115
+
116
+ Returns
117
+ -------
118
+ Dataset
119
+ The parent dataset.
120
+ """
64
121
  return parent
65
122
 
66
123
  @cached_property
67
- def _len(self):
124
+ def _len(self) -> int:
125
+ """Cache and return the length of the dataset."""
68
126
  return len(self)
69
127
 
70
- def _subset(self, **kwargs):
128
+ def _subset(self, **kwargs: Any) -> "Dataset":
129
+ """Create a subset of the dataset based on the provided keyword arguments.
130
+
131
+ Parameters
132
+ ----------
133
+ **kwargs : Any
134
+ Keyword arguments for creating the subset.
71
135
 
136
+ Returns
137
+ -------
138
+ Dataset
139
+ The subset of the dataset.
140
+ """
72
141
  if not kwargs:
73
142
  return self.mutate()
74
143
 
@@ -79,10 +148,23 @@ class Dataset:
79
148
  return result
80
149
 
81
150
  @property
82
- def name(self):
151
+ def name(self) -> Union[str, None]:
152
+ """Return the name of the dataset."""
83
153
  return self._name
84
154
 
85
- def __subset(self, **kwargs):
155
+ def __subset(self, **kwargs: Any) -> "Dataset":
156
+ """Internal method to create a subset of the dataset based on the provided keyword arguments.
157
+
158
+ Parameters
159
+ ----------
160
+ **kwargs : Any
161
+ Keyword arguments for creating the subset.
162
+
163
+ Returns
164
+ -------
165
+ Dataset
166
+ The subset of the dataset.
167
+ """
86
168
  if not kwargs:
87
169
  return self.mutate()
88
170
 
@@ -213,22 +295,61 @@ class Dataset:
213
295
 
214
296
  raise NotImplementedError("Unsupported arguments: " + ", ".join(kwargs))
215
297
 
216
- def _frequency_to_indices(self, frequency):
298
+ def _frequency_to_indices(self, frequency: str) -> list[int]:
299
+ """Convert a frequency string to a list of indices.
217
300
 
301
+ Parameters
302
+ ----------
303
+ frequency : str
304
+ The frequency string.
305
+
306
+ Returns
307
+ -------
308
+ list of int
309
+ The list of indices.
310
+ """
218
311
  requested_frequency = frequency_to_seconds(frequency)
219
312
  dataset_frequency = frequency_to_seconds(self.frequency)
220
- assert requested_frequency % dataset_frequency == 0
313
+
314
+ if requested_frequency % dataset_frequency != 0:
315
+ raise ValueError(
316
+ f"Requested frequency {frequency} is not a multiple of the dataset frequency {self.frequency}. Did you mean to use `interpolate_frequency`?"
317
+ )
318
+
221
319
  # Question: where do we start? first date, or first date that is a multiple of the frequency?
222
320
  step = requested_frequency // dataset_frequency
223
321
 
224
322
  return range(0, len(self), step)
225
323
 
226
- def _shuffle_indices(self):
227
- import numpy as np
324
+ def _shuffle_indices(self) -> NDArray[Any]:
325
+ """Return a shuffled array of indices.
228
326
 
327
+ Returns
328
+ -------
329
+ numpy.ndarray
330
+ The shuffled array of indices.
331
+ """
229
332
  return np.random.permutation(len(self))
230
333
 
231
- def _dates_to_indices(self, start, end):
334
+ def _dates_to_indices(
335
+ self,
336
+ start: Union[None, str, datetime.datetime],
337
+ end: Union[None, str, datetime.datetime],
338
+ ) -> List[int]:
339
+ """Convert date range to a list of indices.
340
+
341
+ Parameters
342
+ ----------
343
+ start : None, str, or datetime.datetime
344
+ The start date.
345
+ end : None, str, or datetime.datetime
346
+ The end date.
347
+
348
+ Returns
349
+ -------
350
+ list of int
351
+ The list of indices.
352
+ """
232
353
  from .misc import as_first_date
233
354
  from .misc import as_last_date
234
355
 
@@ -239,7 +360,19 @@ class Dataset:
239
360
 
240
361
  return [i for i, date in enumerate(self.dates) if start <= date <= end]
241
362
 
242
- def _select_to_columns(self, vars):
363
+ def _select_to_columns(self, vars: Union[str, List[str], Tuple[str], set]) -> List[int]:
364
+ """Convert variable names to a list of column indices.
365
+
366
+ Parameters
367
+ ----------
368
+ vars : str, list of str, tuple of str, or set
369
+ The variable names.
370
+
371
+ Returns
372
+ -------
373
+ list of int
374
+ The list of column indices.
375
+ """
243
376
  if isinstance(vars, set):
244
377
  # We keep the order of the variables as they are in the zarr file
245
378
  nvars = [v for v in self.name_to_index if v in vars]
@@ -251,7 +384,19 @@ class Dataset:
251
384
 
252
385
  return [self.name_to_index[v] for v in vars]
253
386
 
254
- def _drop_to_columns(self, vars):
387
+ def _drop_to_columns(self, vars: Union[str, Sequence[str]]) -> List[int]:
388
+ """Convert variable names to a list of column indices to drop.
389
+
390
+ Parameters
391
+ ----------
392
+ vars : str, list of str, tuple of str, or set
393
+ The variable names.
394
+
395
+ Returns
396
+ -------
397
+ list of int
398
+ The list of column indices to drop.
399
+ """
255
400
  if not isinstance(vars, (list, tuple, set)):
256
401
  vars = [vars]
257
402
 
@@ -260,7 +405,19 @@ class Dataset:
260
405
 
261
406
  return sorted([v for k, v in self.name_to_index.items() if k not in vars])
262
407
 
263
- def _reorder_to_columns(self, vars):
408
+ def _reorder_to_columns(self, vars: Union[str, List[str], Tuple[str], Dict[str, int]]) -> List[int]:
409
+ """Convert variable names to a list of reordered column indices.
410
+
411
+ Parameters
412
+ ----------
413
+ vars : str, list of str, tuple of str, or dict of str to int
414
+ The variable names.
415
+
416
+ Returns
417
+ -------
418
+ list of int
419
+ The list of reordered column indices.
420
+ """
264
421
  if isinstance(vars, str) and vars == "sort":
265
422
  # Sorting the variables alphabetically.
266
423
  # This is cruical for pre-training then transfer learning in combination with
@@ -280,20 +437,55 @@ class Dataset:
280
437
 
281
438
  return indices
282
439
 
283
- def dates_interval_to_indices(self, start, end):
440
+ def dates_interval_to_indices(
441
+ self, start: Union[None, str, datetime.datetime], end: Union[None, str, datetime.datetime]
442
+ ) -> List[int]:
443
+ """Convert date interval to a list of indices.
444
+
445
+ Parameters
446
+ ----------
447
+ start : None, str, or datetime.datetime
448
+ The start date.
449
+ end : None, str, or datetime.datetime
450
+ The end date.
451
+
452
+ Returns
453
+ -------
454
+ list of int
455
+ The list of indices.
456
+ """
284
457
  return self._dates_to_indices(start, end)
285
458
 
286
- def provenance(self):
459
+ def provenance(self) -> Dict[str, Any]:
460
+ """Return the provenance information of the dataset.
461
+
462
+ Returns
463
+ -------
464
+ dict
465
+ The provenance information.
466
+ """
287
467
  return {}
288
468
 
289
- def sub_shape(self, drop_axis):
290
- shape = self.shape
291
- shape = list(shape)
469
+ def sub_shape(self, drop_axis: int) -> TupleIndex:
470
+ """Return the shape of the dataset with one axis dropped.
471
+
472
+ Parameters
473
+ ----------
474
+ drop_axis : int
475
+ The axis to drop.
476
+
477
+ Returns
478
+ -------
479
+ tuple
480
+ The shape with one axis dropped.
481
+ """
482
+ shape = list(self.shape)
292
483
  shape.pop(drop_axis)
293
484
  return tuple(shape)
294
485
 
295
486
  @property
296
- def typed_variables(self):
487
+ def typed_variables(self) -> Dict[str, Any]:
488
+ """Return the variables with their types."""
297
489
  from anemoi.transform.variables import Variable
298
490
 
299
491
  constants = self.constant_fields
@@ -313,12 +505,26 @@ class Dataset:
313
505
 
314
506
  return result
315
507
 
316
- def _input_sources(self):
508
+ def _input_sources(self) -> List[Any]:
509
+ """Return the input sources of the dataset.
510
+
511
+ Returns
512
+ -------
513
+ list
514
+ The input sources.
515
+ """
317
516
  sources = []
318
517
  self.collect_input_sources(sources)
319
518
  return sources
320
519
 
321
- def metadata(self):
520
+ def metadata(self) -> Dict[str, Any]:
521
+ """Return the metadata of the dataset.
522
+
523
+ Returns
524
+ -------
525
+ dict
526
+ The metadata.
527
+ """
322
528
  import anemoi
323
529
 
324
530
  _, source_to_arrays = self._supporting_arrays_and_sources()
@@ -346,14 +552,23 @@ class Dataset:
346
552
  raise
347
553
 
348
554
  @property
349
- def start_date(self):
555
+ def start_date(self) -> np.datetime64:
556
+ """Return the start date of the dataset."""
350
557
  return self.dates[0]
351
558
 
352
559
  @property
353
- def end_date(self):
560
+ def end_date(self) -> np.datetime64:
561
+ """Return the end date of the dataset."""
354
562
  return self.dates[-1]
355
563
 
356
- def dataset_metadata(self):
564
+ def dataset_metadata(self) -> Dict[str, Any]:
565
+ """Return the metadata of the dataset.
566
+
567
+ Returns
568
+ -------
569
+ dict
570
+ The metadata.
571
+ """
357
572
  return dict(
358
573
  specific=self.metadata_specific(),
359
574
  frequency=self.frequency,
@@ -366,11 +581,21 @@ class Dataset:
366
581
  name=self.name,
367
582
  )
368
583
 
369
- def _supporting_arrays(self, *path):
584
+ def _supporting_arrays(self, *path: str) -> Dict[str, NDArray[Any]]:
585
+ """Return the supporting arrays of the dataset.
370
586
 
371
- import numpy as np
587
+ Parameters
588
+ ----------
589
+ *path : str
590
+ The path components.
591
+
592
+ Returns
593
+ -------
594
+ dict
595
+ The supporting arrays.
596
+ """
372
597
 
373
- def _path(path, name):
598
+ def _path(path, name: str) -> str:
374
599
  return "/".join(str(_) for _ in [*path, name])
375
600
 
376
601
  result = {
@@ -394,13 +619,25 @@ class Dataset:
394
619
 
395
620
  return result
396
621
 
397
- def supporting_arrays(self):
398
- """Arrays to be saved in the checkpoints"""
622
+ def supporting_arrays(self) -> Dict[str, NDArray[Any]]:
623
+ """Return the supporting arrays to be saved in the checkpoints.
624
+
625
+ Returns
626
+ -------
627
+ dict
628
+ The supporting arrays.
629
+ """
399
630
  arrays, _ = self._supporting_arrays_and_sources()
400
631
  return arrays
401
632
 
402
- def _supporting_arrays_and_sources(self):
633
+ def _supporting_arrays_and_sources(self) -> Tuple[Dict[str, NDArray], Dict[int, List[str]]]:
634
+ """Return the supporting arrays and their sources.
403
635
 
636
+ Returns
637
+ -------
638
+ tuple
639
+ The supporting arrays and their sources.
640
+ """
404
641
  source_to_arrays = {}
405
642
 
406
643
  # Top levels arrays
@@ -420,11 +657,32 @@ class Dataset:
420
657
 
421
658
  return result, source_to_arrays
422
659
 
423
- def collect_supporting_arrays(self, collected, *path):
660
+ def collect_supporting_arrays(self, collected: List[Tuple[Tuple[str, ...], str, NDArray[Any]]], *path: str) -> None:
661
+ """Collect supporting arrays.
662
+
663
+ Parameters
664
+ ----------
665
+ collected : list of tuple
666
+ The collected supporting arrays.
667
+ *path : str
668
+ The path components.
669
+ """
424
670
  # Override this method to add more arrays
425
671
  pass
426
672
 
427
- def metadata_specific(self, **kwargs):
673
+ def metadata_specific(self, **kwargs: Any) -> Dict[str, Any]:
674
+ """Return specific metadata of the dataset.
675
+
676
+ Parameters
677
+ ----------
678
+ **kwargs : Any
679
+ Additional keyword arguments.
680
+
681
+ Returns
682
+ -------
683
+ dict
684
+ The specific metadata.
685
+ """
428
686
  action = self.__class__.__name__.lower()
429
687
  # assert isinstance(self.frequency, datetime.timedelta), (self.frequency, self, action)
430
688
  return dict(
@@ -437,33 +695,53 @@ class Dataset:
437
695
  **kwargs,
438
696
  )
439
697
 
440
- def __repr__(self):
698
+ def __repr__(self) -> str:
699
+ """Return the string representation of the dataset.
700
+
701
+ Returns
702
+ -------
703
+ str
704
+ The string representation.
705
+ """
441
706
  return self.__class__.__name__ + "()"
442
707
 
443
708
  @property
444
- def grids(self):
709
+ def grids(self) -> TupleIndex:
710
+ """Return the grid shape of the dataset."""
445
711
  return (self.shape[-1],)
446
712
 
447
- def _check(ds):
448
- common = Dataset.__dict__.keys() & ds.__class__.__dict__.keys()
449
- overriden = [m for m in common if Dataset.__dict__[m] is not ds.__class__.__dict__[m]]
713
+ def _check(self) -> None:
714
+ """Check for overridden private methods in the dataset."""
715
+ common = Dataset.__dict__.keys() & self.__class__.__dict__.keys()
716
+ overriden = [m for m in common if Dataset.__dict__[m] is not self.__class__.__dict__[m]]
450
717
 
451
718
  for n in overriden:
452
- if n.startswith("_") and not n.startswith("__"):
453
- warnings.warn(f"Private method {n} is overriden in {ds.__class__.__name__}")
719
+ if n.startswith("_") and not n.startswith("__") and n not in ("_abc_impl",):
720
+ warnings.warn(f"Private method {n} is overriden in {self.__class__.__name__}")
721
+
722
+ def _repr_html_(self) -> str:
723
+ """Return the HTML representation of the dataset.
454
724
 
455
- def _repr_html_(self):
725
+ Returns
726
+ -------
727
+ str
728
+ The HTML representation.
729
+ """
456
730
  return self.tree().html()
457
731
 
458
732
  @property
459
- def label(self):
733
+ def label(self) -> str:
734
+ """Return the label of the dataset."""
460
735
  return self.__class__.__name__.lower()
461
736
 
462
- def get_dataset_names(self, names):
463
- raise NotImplementedError(self)
737
+ def computed_constant_fields(self) -> List[str]:
738
+ """Return the computed constant fields of the dataset.
464
739
 
465
- def computed_constant_fields(self):
466
- # Call `constant_fields` instead of `computed_constant_fields`
740
+ Returns
741
+ -------
742
+ list of str
743
+ The computed constant fields.
744
+ """
467
745
  try:
468
746
  # If the tendencies are computed, we can use them
469
747
  return sorted(self._compute_constant_fields_from_statistics())
@@ -473,8 +751,14 @@ class Dataset:
473
751
 
474
752
  return sorted(self._compute_constant_fields_from_a_few_samples())
475
753
 
476
- def _compute_constant_fields_from_a_few_samples(self):
754
+ def _compute_constant_fields_from_a_few_samples(self) -> List[str]:
755
+ """Compute constant fields from a few samples.
477
756
 
757
+ Returns
758
+ -------
759
+ list of str
760
+ The computed constant fields.
761
+ """
478
762
  import numpy as np
479
763
 
480
764
  # Otherwise, we need to compute them
@@ -508,7 +792,14 @@ class Dataset:
508
792
 
509
793
  return [v for i, v in enumerate(self.variables) if constants[i]]
510
794
 
511
- def _compute_constant_fields_from_statistics(self):
795
+ def _compute_constant_fields_from_statistics(self) -> List[str]:
796
+ """Compute constant fields from statistics.
797
+
798
+ Returns
799
+ -------
800
+ list of str
801
+ The computed constant fields.
802
+ """
512
803
  result = []
513
804
 
514
805
  t = self.statistics_tendencies()
@@ -519,7 +810,13 @@ class Dataset:
519
810
 
520
811
  return result
521
812
 
522
- def plot(self, date, variable, member=0, **kwargs):
813
+ def plot(
814
+ self,
815
+ date: Union[int, datetime.datetime, np.datetime64, str],
816
+ variable: Union[int, str],
817
+ member: int = 0,
818
+ **kwargs: Any,
819
+ ) -> "matplotlib.pyplot.Axes":
523
820
  """For debugging purposes, plot a field.
524
821
 
525
822
  Parameters
@@ -530,17 +827,42 @@ class Dataset:
530
827
  The variable to plot.
531
828
  member : int, optional
532
829
  The ensemble member to plot.
830
+ **kwargs : Any
831
+ Additional arguments to pass to matplotlib.pyplot.tricontourf.
832
+
833
+ Returns
834
+ -------
835
+ matplotlib.pyplot.Axes
836
+ The plot axes.
837
+ """
838
+ from anemoi.utils.devtools import plot_values
533
839
 
534
- **kwargs:
535
- Additional arguments to pass to matplotlib.pyplot.tricontourf
840
+ values = self[self.to_index(date, variable, member)]
536
841
 
842
+ return plot_values(values, self.latitudes, self.longitudes, **kwargs)
843
+
844
+ def to_index(
845
+ self,
846
+ date: Union[int, datetime.datetime, np.datetime64, str],
847
+ variable: Union[int, str],
848
+ member: int = 0,
849
+ ) -> Tuple[int, int, int]:
850
+ """Convert date, variable, and member to indices.
851
+
852
+ Parameters
853
+ ----------
854
+ date : int or datetime.datetime or numpy.datetime64 or str
855
+ The date.
856
+ variable : int or str
857
+ The variable.
858
+ member : int, optional
859
+ The ensemble member.
537
860
 
538
861
  Returns
539
862
  -------
540
- matplotlib.pyplot.Axes
863
+ tuple of int
864
+ The indices.
541
865
  """
542
-
543
- from anemoi.utils.devtools import plot_values
544
866
  from earthkit.data.utils.dates import to_datetime
545
867
 
546
868
  if not isinstance(date, int):
@@ -554,6 +876,8 @@ class Dataset:
554
876
  else:
555
877
  date_index = date
556
878
 
879
+ date_index = int(date_index) # because np.int64 is not instance of int
880
+
557
881
  if isinstance(variable, int):
558
882
  variable_index = variable
559
883
  else:
@@ -562,6 +886,178 @@ class Dataset:
562
886
 
563
887
  variable_index = self.name_to_index[variable]
564
888
 
565
- values = self[date_index, variable_index, member]
889
+ return (date_index, variable_index, member)
566
890
 
567
- return plot_values(values, self.latitudes, self.longitudes, **kwargs)
891
+ @abstractmethod
892
+ def __getitem__(self, n: FullIndex) -> NDArray[Any]:
893
+ """Get the item at the specified index.
894
+
895
+ Parameters
896
+ ----------
897
+ n : FullIndex
898
+ Index to retrieve.
899
+
900
+ Returns
901
+ -------
902
+ NDArray[Any]
903
+ Retrieved item.
904
+ """
905
+
906
+ @abstractmethod
907
+ def __len__(self) -> int:
908
+ """Return the length of the dataset.
909
+
910
+ Returns
911
+ -------
912
+ int
913
+ The length of the dataset.
914
+ """
915
+
916
+ @property
917
+ @abstractmethod
918
+ def variables(self) -> List[str]:
919
+ """Return the list of variables in the dataset."""
920
+ pass
921
+
922
+ @property
923
+ @abstractmethod
924
+ def frequency(self) -> datetime.timedelta:
925
+ """Return the frequency of the dataset."""
926
+ pass
927
+
928
+ @property
929
+ @abstractmethod
930
+ def dates(self) -> NDArray[np.datetime64]:
931
+ """Return the dates in the dataset."""
932
+ pass
933
+
934
+ @property
935
+ @abstractmethod
936
+ def resolution(self) -> str:
937
+ """Return the resolution of the dataset."""
938
+ pass
939
+
940
+ @property
941
+ @abstractmethod
942
+ def name_to_index(self) -> Dict[str, int]:
943
+ """Return the mapping of variable names to indices."""
944
+ pass
945
+
946
+ @property
947
+ @abstractmethod
948
+ def shape(self) -> Shape:
949
+ """Return the shape of the dataset."""
950
+ pass
951
+
952
+ @property
953
+ @abstractmethod
954
+ def field_shape(self) -> Shape:
955
+ """Return the shape of the fields in the dataset."""
956
+ pass
957
+
958
+ @property
959
+ @abstractmethod
960
+ def dtype(self) -> np.dtype:
961
+ """Return the data type of the dataset."""
962
+ pass
963
+
964
+ @property
965
+ @abstractmethod
966
+ def latitudes(self) -> NDArray[Any]:
967
+ """Return the latitudes in the dataset."""
968
+ pass
969
+
970
+ @property
971
+ @abstractmethod
972
+ def longitudes(self) -> NDArray[Any]:
973
+ """Return the longitudes in the dataset."""
974
+ pass
975
+
976
+ @property
977
+ @abstractmethod
978
+ def variables_metadata(self) -> Dict[str, Any]:
979
+ """Return the metadata of the variables in the dataset."""
980
+ pass
981
+
982
+ @abstractmethod
983
+ @cached_property
984
+ def missing(self) -> Set[int]:
985
+ """Return the set of missing indices in the dataset."""
986
+ pass
987
+
988
+ @abstractmethod
989
+ @cached_property
990
+ def constant_fields(self) -> List[str]:
991
+ """Return the list of constant fields in the dataset."""
992
+ pass
993
+
994
+ @abstractmethod
995
+ @cached_property
996
+ def statistics(self) -> Dict[str, NDArray[Any]]:
997
+ """Return the statistics of the dataset."""
998
+ pass
999
+
1000
+ @abstractmethod
1001
+ def statistics_tendencies(self, delta: Optional[datetime.timedelta] = None) -> Dict[str, NDArray[Any]]:
1002
+ """Return the tendencies of the statistics in the dataset.
1003
+
1004
+ Parameters
1005
+ ----------
1006
+ delta : datetime.timedelta, optional
1007
+ The time delta for computing tendencies.
1008
+
1009
+ Returns
1010
+ -------
1011
+ dict
1012
+ The tendencies.
1013
+ """
1014
+ pass
1015
+
1016
+ @abstractmethod
1017
+ def source(self, index: int) -> Source:
1018
+ """Return the source of the dataset at the specified index.
1019
+
1020
+ Parameters
1021
+ ----------
1022
+ index : int
1023
+ The index.
1024
+
1025
+ Returns
1026
+ -------
1027
+ Source
1028
+ The source.
1029
+ """
1030
+ pass
1031
+
1032
+ @abstractmethod
1033
+ def tree(self) -> Node:
1034
+ """Return the tree representation of the dataset.
1035
+
1036
+ Returns
1037
+ -------
1038
+ Node
1039
+ The tree representation.
1040
+ """
1041
+ pass
1042
+
1043
+ @abstractmethod
1044
+ def collect_input_sources(self, sources: List[Any]) -> None:
1045
+ """Collect the input sources of the dataset.
1046
+
1047
+ Parameters
1048
+ ----------
1049
+ sources : list
1050
+ The input sources.
1051
+ """
1052
+ pass
1053
+
1054
+ @abstractmethod
1055
+ def get_dataset_names(self, names: Set[str]) -> None:
1056
+ """Get the names of the datasets.
1057
+
1058
+ Parameters
1059
+ ----------
1060
+ names : set of str
1061
+ The dataset names.
1062
+ """
1063
+ pass