anemoi-datasets 0.5.16__py3-none-any.whl → 0.5.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. anemoi/datasets/__init__.py +4 -1
  2. anemoi/datasets/__main__.py +12 -2
  3. anemoi/datasets/_version.py +9 -4
  4. anemoi/datasets/commands/cleanup.py +17 -2
  5. anemoi/datasets/commands/compare.py +18 -2
  6. anemoi/datasets/commands/copy.py +196 -14
  7. anemoi/datasets/commands/create.py +50 -7
  8. anemoi/datasets/commands/finalise-additions.py +17 -2
  9. anemoi/datasets/commands/finalise.py +17 -2
  10. anemoi/datasets/commands/init-additions.py +17 -2
  11. anemoi/datasets/commands/init.py +16 -2
  12. anemoi/datasets/commands/inspect.py +283 -62
  13. anemoi/datasets/commands/load-additions.py +16 -2
  14. anemoi/datasets/commands/load.py +16 -2
  15. anemoi/datasets/commands/patch.py +17 -2
  16. anemoi/datasets/commands/publish.py +17 -2
  17. anemoi/datasets/commands/scan.py +31 -3
  18. anemoi/datasets/compute/recentre.py +47 -11
  19. anemoi/datasets/create/__init__.py +612 -85
  20. anemoi/datasets/create/check.py +142 -20
  21. anemoi/datasets/create/chunks.py +64 -4
  22. anemoi/datasets/create/config.py +185 -21
  23. anemoi/datasets/create/filter.py +50 -0
  24. anemoi/datasets/create/filters/__init__.py +33 -0
  25. anemoi/datasets/create/filters/empty.py +37 -0
  26. anemoi/datasets/create/filters/legacy.py +93 -0
  27. anemoi/datasets/create/filters/noop.py +37 -0
  28. anemoi/datasets/create/filters/orog_to_z.py +58 -0
  29. anemoi/datasets/create/{functions/filters → filters}/pressure_level_relative_humidity_to_specific_humidity.py +33 -10
  30. anemoi/datasets/create/{functions/filters → filters}/pressure_level_specific_humidity_to_relative_humidity.py +32 -8
  31. anemoi/datasets/create/filters/rename.py +205 -0
  32. anemoi/datasets/create/{functions/filters → filters}/rotate_winds.py +43 -28
  33. anemoi/datasets/create/{functions/filters → filters}/single_level_dewpoint_to_relative_humidity.py +32 -9
  34. anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_dewpoint.py +33 -9
  35. anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_specific_humidity.py +55 -7
  36. anemoi/datasets/create/{functions/filters → filters}/single_level_specific_humidity_to_relative_humidity.py +98 -37
  37. anemoi/datasets/create/filters/speeddir_to_uv.py +95 -0
  38. anemoi/datasets/create/{functions/filters → filters}/sum.py +24 -27
  39. anemoi/datasets/create/filters/transform.py +53 -0
  40. anemoi/datasets/create/{functions/filters → filters}/unrotate_winds.py +27 -18
  41. anemoi/datasets/create/filters/uv_to_speeddir.py +94 -0
  42. anemoi/datasets/create/{functions/filters → filters}/wz_to_w.py +51 -33
  43. anemoi/datasets/create/input/__init__.py +76 -5
  44. anemoi/datasets/create/input/action.py +149 -13
  45. anemoi/datasets/create/input/concat.py +81 -10
  46. anemoi/datasets/create/input/context.py +39 -4
  47. anemoi/datasets/create/input/data_sources.py +72 -6
  48. anemoi/datasets/create/input/empty.py +21 -3
  49. anemoi/datasets/create/input/filter.py +60 -12
  50. anemoi/datasets/create/input/function.py +154 -37
  51. anemoi/datasets/create/input/join.py +86 -14
  52. anemoi/datasets/create/input/misc.py +67 -17
  53. anemoi/datasets/create/input/pipe.py +33 -6
  54. anemoi/datasets/create/input/repeated_dates.py +189 -41
  55. anemoi/datasets/create/input/result.py +202 -87
  56. anemoi/datasets/create/input/step.py +119 -22
  57. anemoi/datasets/create/input/template.py +100 -13
  58. anemoi/datasets/create/input/trace.py +62 -7
  59. anemoi/datasets/create/patch.py +52 -4
  60. anemoi/datasets/create/persistent.py +134 -17
  61. anemoi/datasets/create/size.py +15 -1
  62. anemoi/datasets/create/source.py +51 -0
  63. anemoi/datasets/create/sources/__init__.py +36 -0
  64. anemoi/datasets/create/{functions/sources → sources}/accumulations.py +296 -30
  65. anemoi/datasets/create/{functions/sources → sources}/constants.py +27 -2
  66. anemoi/datasets/create/{functions/sources → sources}/eccc_fstd.py +7 -3
  67. anemoi/datasets/create/sources/empty.py +37 -0
  68. anemoi/datasets/create/{functions/sources → sources}/forcings.py +25 -1
  69. anemoi/datasets/create/sources/grib.py +297 -0
  70. anemoi/datasets/create/{functions/sources → sources}/hindcasts.py +38 -4
  71. anemoi/datasets/create/sources/legacy.py +93 -0
  72. anemoi/datasets/create/{functions/sources → sources}/mars.py +168 -20
  73. anemoi/datasets/create/sources/netcdf.py +42 -0
  74. anemoi/datasets/create/sources/opendap.py +43 -0
  75. anemoi/datasets/create/{functions/sources/__init__.py → sources/patterns.py} +35 -4
  76. anemoi/datasets/create/sources/recentre.py +150 -0
  77. anemoi/datasets/create/{functions/sources → sources}/source.py +27 -5
  78. anemoi/datasets/create/{functions/sources → sources}/tendencies.py +64 -7
  79. anemoi/datasets/create/sources/xarray.py +92 -0
  80. anemoi/datasets/create/sources/xarray_kerchunk.py +36 -0
  81. anemoi/datasets/create/sources/xarray_support/README.md +1 -0
  82. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/__init__.py +109 -8
  83. anemoi/datasets/create/sources/xarray_support/coordinates.py +442 -0
  84. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/field.py +94 -16
  85. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/fieldlist.py +90 -25
  86. anemoi/datasets/create/sources/xarray_support/flavour.py +1036 -0
  87. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/grid.py +92 -31
  88. anemoi/datasets/create/sources/xarray_support/metadata.py +395 -0
  89. anemoi/datasets/create/sources/xarray_support/patch.py +91 -0
  90. anemoi/datasets/create/sources/xarray_support/time.py +391 -0
  91. anemoi/datasets/create/sources/xarray_support/variable.py +331 -0
  92. anemoi/datasets/create/sources/xarray_zarr.py +41 -0
  93. anemoi/datasets/create/{functions/sources → sources}/zenodo.py +34 -5
  94. anemoi/datasets/create/statistics/__init__.py +233 -44
  95. anemoi/datasets/create/statistics/summary.py +52 -6
  96. anemoi/datasets/create/testing.py +76 -0
  97. anemoi/datasets/create/{functions/filters/noop.py → typing.py} +6 -3
  98. anemoi/datasets/create/utils.py +97 -6
  99. anemoi/datasets/create/writer.py +26 -4
  100. anemoi/datasets/create/zarr.py +170 -23
  101. anemoi/datasets/data/__init__.py +51 -4
  102. anemoi/datasets/data/complement.py +191 -40
  103. anemoi/datasets/data/concat.py +141 -16
  104. anemoi/datasets/data/dataset.py +552 -61
  105. anemoi/datasets/data/debug.py +197 -26
  106. anemoi/datasets/data/ensemble.py +93 -8
  107. anemoi/datasets/data/fill_missing.py +165 -18
  108. anemoi/datasets/data/forwards.py +428 -56
  109. anemoi/datasets/data/grids.py +323 -97
  110. anemoi/datasets/data/indexing.py +112 -19
  111. anemoi/datasets/data/interpolate.py +92 -12
  112. anemoi/datasets/data/join.py +158 -19
  113. anemoi/datasets/data/masked.py +129 -15
  114. anemoi/datasets/data/merge.py +137 -23
  115. anemoi/datasets/data/misc.py +172 -16
  116. anemoi/datasets/data/missing.py +233 -29
  117. anemoi/datasets/data/rescale.py +111 -10
  118. anemoi/datasets/data/select.py +168 -26
  119. anemoi/datasets/data/statistics.py +67 -6
  120. anemoi/datasets/data/stores.py +149 -64
  121. anemoi/datasets/data/subset.py +159 -25
  122. anemoi/datasets/data/unchecked.py +168 -57
  123. anemoi/datasets/data/xy.py +168 -25
  124. anemoi/datasets/dates/__init__.py +191 -16
  125. anemoi/datasets/dates/groups.py +189 -47
  126. anemoi/datasets/grids.py +270 -31
  127. anemoi/datasets/testing.py +28 -1
  128. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/METADATA +9 -6
  129. anemoi_datasets-0.5.17.dist-info/RECORD +137 -0
  130. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/WHEEL +1 -1
  131. anemoi/datasets/create/functions/__init__.py +0 -66
  132. anemoi/datasets/create/functions/filters/__init__.py +0 -9
  133. anemoi/datasets/create/functions/filters/empty.py +0 -17
  134. anemoi/datasets/create/functions/filters/orog_to_z.py +0 -58
  135. anemoi/datasets/create/functions/filters/rename.py +0 -79
  136. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +0 -78
  137. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +0 -56
  138. anemoi/datasets/create/functions/sources/empty.py +0 -15
  139. anemoi/datasets/create/functions/sources/grib.py +0 -150
  140. anemoi/datasets/create/functions/sources/netcdf.py +0 -15
  141. anemoi/datasets/create/functions/sources/opendap.py +0 -15
  142. anemoi/datasets/create/functions/sources/recentre.py +0 -60
  143. anemoi/datasets/create/functions/sources/xarray/coordinates.py +0 -255
  144. anemoi/datasets/create/functions/sources/xarray/flavour.py +0 -472
  145. anemoi/datasets/create/functions/sources/xarray/metadata.py +0 -148
  146. anemoi/datasets/create/functions/sources/xarray/patch.py +0 -44
  147. anemoi/datasets/create/functions/sources/xarray/time.py +0 -177
  148. anemoi/datasets/create/functions/sources/xarray/variable.py +0 -188
  149. anemoi/datasets/create/functions/sources/xarray_kerchunk.py +0 -42
  150. anemoi/datasets/create/functions/sources/xarray_zarr.py +0 -15
  151. anemoi/datasets/utils/fields.py +0 -47
  152. anemoi_datasets-0.5.16.dist-info/RECORD +0 -129
  153. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/entry_points.txt +0 -0
  154. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info/licenses}/LICENSE +0 -0
  155. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/top_level.txt +0 -0
@@ -13,17 +13,59 @@ import json
13
13
  import logging
14
14
  import pprint
15
15
  import warnings
16
+ from abc import ABC
17
+ from abc import abstractmethod
16
18
  from functools import cached_property
17
19
 
20
+ try:
21
+ from types import EllipsisType
22
+ except ImportError:
23
+ # Python 3.9
24
+ EllipsisType = type(Ellipsis)
25
+ from typing import TYPE_CHECKING
26
+ from typing import Any
27
+ from typing import Dict
28
+ from typing import List
29
+ from typing import Optional
30
+ from typing import Sequence
31
+ from typing import Set
32
+ from typing import Sized
33
+ from typing import Tuple
34
+ from typing import Union
35
+
18
36
  import numpy as np
19
37
  from anemoi.utils.dates import frequency_to_seconds
20
38
  from anemoi.utils.dates import frequency_to_string
21
39
  from anemoi.utils.dates import frequency_to_timedelta
40
+ from numpy.typing import NDArray
41
+
42
+ from .debug import Node
43
+ from .debug import Source
44
+
45
+ if TYPE_CHECKING:
46
+ import matplotlib
22
47
 
23
48
  LOG = logging.getLogger(__name__)
24
49
 
25
50
 
26
- def _tidy(v):
51
+ Shape = Tuple[int, ...]
52
+ TupleIndex = Tuple[Union[int, slice, EllipsisType], ...]
53
+ FullIndex = Union[int, slice, TupleIndex]
54
+
55
+
56
+ def _tidy(v: Any) -> Any:
57
+ """Tidy up the input value.
58
+
59
+ Parameters
60
+ ----------
61
+ v : Any
62
+ The input value to tidy up.
63
+
64
+ Returns
65
+ -------
66
+ Any
67
+ The tidied value.
68
+ """
27
69
  if isinstance(v, (list, tuple, set)):
28
70
  return [_tidy(i) for i in v]
29
71
  if isinstance(v, dict):
@@ -49,26 +91,53 @@ def _tidy(v):
49
91
  return v
50
92
 
51
93
 
52
- class Dataset:
53
- arguments = {}
54
- _name = None
94
+ class Dataset(ABC, Sized):
95
+ arguments: Dict[str, Any] = {}
96
+ _name: Union[str, None] = None
55
97
 
56
98
  def mutate(self) -> "Dataset":
57
- """Give an opportunity to a subclass to return a new Dataset
58
- object of a different class, if needed.
59
- """
99
+ """Give an opportunity to a subclass to return a new Dataset object of a different class, if needed.
60
100
 
101
+ Returns
102
+ -------
103
+ Dataset
104
+ The mutated dataset.
105
+ """
61
106
  return self
62
107
 
63
- def swap_with_parent(self, parent):
108
+ def swap_with_parent(self, parent: "Dataset") -> "Dataset":
109
+ """Swap the current dataset with its parent dataset.
110
+
111
+ Parameters
112
+ ----------
113
+ parent : Dataset
114
+ The parent dataset.
115
+
116
+ Returns
117
+ -------
118
+ Dataset
119
+ The parent dataset.
120
+ """
64
121
  return parent
65
122
 
66
123
  @cached_property
67
- def _len(self):
124
+ def _len(self) -> int:
125
+ """Cache and return the length of the dataset."""
68
126
  return len(self)
69
127
 
70
- def _subset(self, **kwargs):
128
+ def _subset(self, **kwargs: Any) -> "Dataset":
129
+ """Create a subset of the dataset based on the provided keyword arguments.
130
+
131
+ Parameters
132
+ ----------
133
+ **kwargs : Any
134
+ Keyword arguments for creating the subset.
71
135
 
136
+ Returns
137
+ -------
138
+ Dataset
139
+ The subset of the dataset.
140
+ """
72
141
  if not kwargs:
73
142
  return self.mutate()
74
143
 
@@ -79,10 +148,23 @@ class Dataset:
79
148
  return result
80
149
 
81
150
  @property
82
- def name(self):
151
+ def name(self) -> Union[str, None]:
152
+ """Return the name of the dataset."""
83
153
  return self._name
84
154
 
85
- def __subset(self, **kwargs):
155
+ def __subset(self, **kwargs: Any) -> "Dataset":
156
+ """Internal method to create a subset of the dataset based on the provided keyword arguments.
157
+
158
+ Parameters
159
+ ----------
160
+ **kwargs : Any
161
+ Keyword arguments for creating the subset.
162
+
163
+ Returns
164
+ -------
165
+ Dataset
166
+ The subset of the dataset.
167
+ """
86
168
  if not kwargs:
87
169
  return self.mutate()
88
170
 
@@ -213,8 +295,19 @@ class Dataset:
213
295
 
214
296
  raise NotImplementedError("Unsupported arguments: " + ", ".join(kwargs))
215
297
 
216
- def _frequency_to_indices(self, frequency):
298
+ def _frequency_to_indices(self, frequency: str) -> list[int]:
299
+ """Convert a frequency string to a list of indices.
217
300
 
301
+ Parameters
302
+ ----------
303
+ frequency : str
304
+ The frequency string.
305
+
306
+ Returns
307
+ -------
308
+ list of int
309
+ The list of indices.
310
+ """
218
311
  requested_frequency = frequency_to_seconds(frequency)
219
312
  dataset_frequency = frequency_to_seconds(self.frequency)
220
313
  assert requested_frequency % dataset_frequency == 0
@@ -223,12 +316,35 @@ class Dataset:
223
316
 
224
317
  return range(0, len(self), step)
225
318
 
226
- def _shuffle_indices(self):
227
- import numpy as np
319
+ def _shuffle_indices(self) -> NDArray[Any]:
320
+ """Return a shuffled array of indices.
228
321
 
322
+ Returns
323
+ -------
324
+ numpy.ndarray
325
+ The shuffled array of indices.
326
+ """
229
327
  return np.random.permutation(len(self))
230
328
 
231
- def _dates_to_indices(self, start, end):
329
+ def _dates_to_indices(
330
+ self,
331
+ start: Union[None, str, datetime.datetime],
332
+ end: Union[None, str, datetime.datetime],
333
+ ) -> List[int]:
334
+ """Convert date range to a list of indices.
335
+
336
+ Parameters
337
+ ----------
338
+ start : None, str, or datetime.datetime
339
+ The start date.
340
+ end : None, str, or datetime.datetime
341
+ The end date.
342
+
343
+ Returns
344
+ -------
345
+ list of int
346
+ The list of indices.
347
+ """
232
348
  from .misc import as_first_date
233
349
  from .misc import as_last_date
234
350
 
@@ -239,7 +355,19 @@ class Dataset:
239
355
 
240
356
  return [i for i, date in enumerate(self.dates) if start <= date <= end]
241
357
 
242
- def _select_to_columns(self, vars):
358
+ def _select_to_columns(self, vars: Union[str, List[str], Tuple[str], set]) -> List[int]:
359
+ """Convert variable names to a list of column indices.
360
+
361
+ Parameters
362
+ ----------
363
+ vars : str, list of str, tuple of str, or set
364
+ The variable names.
365
+
366
+ Returns
367
+ -------
368
+ list of int
369
+ The list of column indices.
370
+ """
243
371
  if isinstance(vars, set):
244
372
  # We keep the order of the variables as they are in the zarr file
245
373
  nvars = [v for v in self.name_to_index if v in vars]
@@ -251,7 +379,19 @@ class Dataset:
251
379
 
252
380
  return [self.name_to_index[v] for v in vars]
253
381
 
254
- def _drop_to_columns(self, vars):
382
+ def _drop_to_columns(self, vars: Union[str, Sequence[str]]) -> List[int]:
383
+ """Convert variable names to a list of column indices to drop.
384
+
385
+ Parameters
386
+ ----------
387
+ vars : str, list of str, tuple of str, or set
388
+ The variable names.
389
+
390
+ Returns
391
+ -------
392
+ list of int
393
+ The list of column indices to drop.
394
+ """
255
395
  if not isinstance(vars, (list, tuple, set)):
256
396
  vars = [vars]
257
397
 
@@ -260,7 +400,19 @@ class Dataset:
260
400
 
261
401
  return sorted([v for k, v in self.name_to_index.items() if k not in vars])
262
402
 
263
- def _reorder_to_columns(self, vars):
403
+ def _reorder_to_columns(self, vars: Union[str, List[str], Tuple[str], Dict[str, int]]) -> List[int]:
404
+ """Convert variable names to a list of reordered column indices.
405
+
406
+ Parameters
407
+ ----------
408
+ vars : str, list of str, tuple of str, or dict of str to int
409
+ The variable names.
410
+
411
+ Returns
412
+ -------
413
+ list of int
414
+ The list of reordered column indices.
415
+ """
264
416
  if isinstance(vars, str) and vars == "sort":
265
417
  # Sorting the variables alphabetically.
266
418
  # This is cruical for pre-training then transfer learning in combination with
@@ -280,20 +432,55 @@ class Dataset:
280
432
 
281
433
  return indices
282
434
 
283
- def dates_interval_to_indices(self, start, end):
435
+ def dates_interval_to_indices(
436
+ self, start: Union[None, str, datetime.datetime], end: Union[None, str, datetime.datetime]
437
+ ) -> List[int]:
438
+ """Convert date interval to a list of indices.
439
+
440
+ Parameters
441
+ ----------
442
+ start : None, str, or datetime.datetime
443
+ The start date.
444
+ end : None, str, or datetime.datetime
445
+ The end date.
446
+
447
+ Returns
448
+ -------
449
+ list of int
450
+ The list of indices.
451
+ """
284
452
  return self._dates_to_indices(start, end)
285
453
 
286
- def provenance(self):
454
+ def provenance(self) -> Dict[str, Any]:
455
+ """Return the provenance information of the dataset.
456
+
457
+ Returns
458
+ -------
459
+ dict
460
+ The provenance information.
461
+ """
287
462
  return {}
288
463
 
289
- def sub_shape(self, drop_axis):
290
- shape = self.shape
291
- shape = list(shape)
464
+ def sub_shape(self, drop_axis: int) -> TupleIndex:
465
+ """Return the shape of the dataset with one axis dropped.
466
+
467
+ Parameters
468
+ ----------
469
+ drop_axis : int
470
+ The axis to drop.
471
+
472
+ Returns
473
+ -------
474
+ tuple
475
+ The shape with one axis dropped.
476
+ """
477
+ shape = list(self.shape)
292
478
  shape.pop(drop_axis)
293
479
  return tuple(shape)
294
480
 
295
481
  @property
296
- def typed_variables(self):
482
+ def typed_variables(self) -> Dict[str, Any]:
483
+ """Return the variables with their types."""
297
484
  from anemoi.transform.variables import Variable
298
485
 
299
486
  constants = self.constant_fields
@@ -313,12 +500,26 @@ class Dataset:
313
500
 
314
501
  return result
315
502
 
316
- def _input_sources(self):
503
+ def _input_sources(self) -> List[Any]:
504
+ """Return the input sources of the dataset.
505
+
506
+ Returns
507
+ -------
508
+ list
509
+ The input sources.
510
+ """
317
511
  sources = []
318
512
  self.collect_input_sources(sources)
319
513
  return sources
320
514
 
321
- def metadata(self):
515
+ def metadata(self) -> Dict[str, Any]:
516
+ """Return the metadata of the dataset.
517
+
518
+ Returns
519
+ -------
520
+ dict
521
+ The metadata.
522
+ """
322
523
  import anemoi
323
524
 
324
525
  _, source_to_arrays = self._supporting_arrays_and_sources()
@@ -346,14 +547,23 @@ class Dataset:
346
547
  raise
347
548
 
348
549
  @property
349
- def start_date(self):
550
+ def start_date(self) -> np.datetime64:
551
+ """Return the start date of the dataset."""
350
552
  return self.dates[0]
351
553
 
352
554
  @property
353
- def end_date(self):
555
+ def end_date(self) -> np.datetime64:
556
+ """Return the end date of the dataset."""
354
557
  return self.dates[-1]
355
558
 
356
- def dataset_metadata(self):
559
+ def dataset_metadata(self) -> Dict[str, Any]:
560
+ """Return the metadata of the dataset.
561
+
562
+ Returns
563
+ -------
564
+ dict
565
+ The metadata.
566
+ """
357
567
  return dict(
358
568
  specific=self.metadata_specific(),
359
569
  frequency=self.frequency,
@@ -366,11 +576,21 @@ class Dataset:
366
576
  name=self.name,
367
577
  )
368
578
 
369
- def _supporting_arrays(self, *path):
579
+ def _supporting_arrays(self, *path: str) -> Dict[str, NDArray[Any]]:
580
+ """Return the supporting arrays of the dataset.
370
581
 
371
- import numpy as np
582
+ Parameters
583
+ ----------
584
+ *path : str
585
+ The path components.
372
586
 
373
- def _path(path, name):
587
+ Returns
588
+ -------
589
+ dict
590
+ The supporting arrays.
591
+ """
592
+
593
+ def _path(path, name: str) -> str:
374
594
  return "/".join(str(_) for _ in [*path, name])
375
595
 
376
596
  result = {
@@ -394,13 +614,25 @@ class Dataset:
394
614
 
395
615
  return result
396
616
 
397
- def supporting_arrays(self):
398
- """Arrays to be saved in the checkpoints"""
617
+ def supporting_arrays(self) -> Dict[str, NDArray[Any]]:
618
+ """Return the supporting arrays to be saved in the checkpoints.
619
+
620
+ Returns
621
+ -------
622
+ dict
623
+ The supporting arrays.
624
+ """
399
625
  arrays, _ = self._supporting_arrays_and_sources()
400
626
  return arrays
401
627
 
402
- def _supporting_arrays_and_sources(self):
628
+ def _supporting_arrays_and_sources(self) -> Tuple[Dict[str, NDArray], Dict[int, List[str]]]:
629
+ """Return the supporting arrays and their sources.
403
630
 
631
+ Returns
632
+ -------
633
+ tuple
634
+ The supporting arrays and their sources.
635
+ """
404
636
  source_to_arrays = {}
405
637
 
406
638
  # Top levels arrays
@@ -420,11 +652,32 @@ class Dataset:
420
652
 
421
653
  return result, source_to_arrays
422
654
 
423
- def collect_supporting_arrays(self, collected, *path):
655
+ def collect_supporting_arrays(self, collected: List[Tuple[Tuple[str, ...], str, NDArray[Any]]], *path: str) -> None:
656
+ """Collect supporting arrays.
657
+
658
+ Parameters
659
+ ----------
660
+ collected : list of tuple
661
+ The collected supporting arrays.
662
+ *path : str
663
+ The path components.
664
+ """
424
665
  # Override this method to add more arrays
425
666
  pass
426
667
 
427
- def metadata_specific(self, **kwargs):
668
+ def metadata_specific(self, **kwargs: Any) -> Dict[str, Any]:
669
+ """Return specific metadata of the dataset.
670
+
671
+ Parameters
672
+ ----------
673
+ **kwargs : Any
674
+ Additional keyword arguments.
675
+
676
+ Returns
677
+ -------
678
+ dict
679
+ The specific metadata.
680
+ """
428
681
  action = self.__class__.__name__.lower()
429
682
  # assert isinstance(self.frequency, datetime.timedelta), (self.frequency, self, action)
430
683
  return dict(
@@ -437,33 +690,53 @@ class Dataset:
437
690
  **kwargs,
438
691
  )
439
692
 
440
- def __repr__(self):
693
+ def __repr__(self) -> str:
694
+ """Return the string representation of the dataset.
695
+
696
+ Returns
697
+ -------
698
+ str
699
+ The string representation.
700
+ """
441
701
  return self.__class__.__name__ + "()"
442
702
 
443
703
  @property
444
- def grids(self):
704
+ def grids(self) -> TupleIndex:
705
+ """Return the grid shape of the dataset."""
445
706
  return (self.shape[-1],)
446
707
 
447
- def _check(ds):
448
- common = Dataset.__dict__.keys() & ds.__class__.__dict__.keys()
449
- overriden = [m for m in common if Dataset.__dict__[m] is not ds.__class__.__dict__[m]]
708
+ def _check(self) -> None:
709
+ """Check for overridden private methods in the dataset."""
710
+ common = Dataset.__dict__.keys() & self.__class__.__dict__.keys()
711
+ overriden = [m for m in common if Dataset.__dict__[m] is not self.__class__.__dict__[m]]
450
712
 
451
713
  for n in overriden:
452
- if n.startswith("_") and not n.startswith("__"):
453
- warnings.warn(f"Private method {n} is overriden in {ds.__class__.__name__}")
714
+ if n.startswith("_") and not n.startswith("__") and n not in ("_abc_impl",):
715
+ warnings.warn(f"Private method {n} is overriden in {self.__class__.__name__}")
716
+
717
+ def _repr_html_(self) -> str:
718
+ """Return the HTML representation of the dataset.
454
719
 
455
- def _repr_html_(self):
720
+ Returns
721
+ -------
722
+ str
723
+ The HTML representation.
724
+ """
456
725
  return self.tree().html()
457
726
 
458
727
  @property
459
- def label(self):
728
+ def label(self) -> str:
729
+ """Return the label of the dataset."""
460
730
  return self.__class__.__name__.lower()
461
731
 
462
- def get_dataset_names(self, names):
463
- raise NotImplementedError(self)
732
+ def computed_constant_fields(self) -> List[str]:
733
+ """Return the computed constant fields of the dataset.
464
734
 
465
- def computed_constant_fields(self):
466
- # Call `constant_fields` instead of `computed_constant_fields`
735
+ Returns
736
+ -------
737
+ list of str
738
+ The computed constant fields.
739
+ """
467
740
  try:
468
741
  # If the tendencies are computed, we can use them
469
742
  return sorted(self._compute_constant_fields_from_statistics())
@@ -473,8 +746,14 @@ class Dataset:
473
746
 
474
747
  return sorted(self._compute_constant_fields_from_a_few_samples())
475
748
 
476
- def _compute_constant_fields_from_a_few_samples(self):
749
+ def _compute_constant_fields_from_a_few_samples(self) -> List[str]:
750
+ """Compute constant fields from a few samples.
477
751
 
752
+ Returns
753
+ -------
754
+ list of str
755
+ The computed constant fields.
756
+ """
478
757
  import numpy as np
479
758
 
480
759
  # Otherwise, we need to compute them
@@ -508,7 +787,14 @@ class Dataset:
508
787
 
509
788
  return [v for i, v in enumerate(self.variables) if constants[i]]
510
789
 
511
- def _compute_constant_fields_from_statistics(self):
790
+ def _compute_constant_fields_from_statistics(self) -> List[str]:
791
+ """Compute constant fields from statistics.
792
+
793
+ Returns
794
+ -------
795
+ list of str
796
+ The computed constant fields.
797
+ """
512
798
  result = []
513
799
 
514
800
  t = self.statistics_tendencies()
@@ -519,7 +805,13 @@ class Dataset:
519
805
 
520
806
  return result
521
807
 
522
- def plot(self, date, variable, member=0, **kwargs):
808
+ def plot(
809
+ self,
810
+ date: Union[int, datetime.datetime, np.datetime64, str],
811
+ variable: Union[int, str],
812
+ member: int = 0,
813
+ **kwargs: Any,
814
+ ) -> "matplotlib.pyplot.Axes":
523
815
  """For debugging purposes, plot a field.
524
816
 
525
817
  Parameters
@@ -530,17 +822,42 @@ class Dataset:
530
822
  The variable to plot.
531
823
  member : int, optional
532
824
  The ensemble member to plot.
825
+ **kwargs : Any
826
+ Additional arguments to pass to matplotlib.pyplot.tricontourf.
827
+
828
+ Returns
829
+ -------
830
+ matplotlib.pyplot.Axes
831
+ The plot axes.
832
+ """
833
+ from anemoi.utils.devtools import plot_values
533
834
 
534
- **kwargs:
535
- Additional arguments to pass to matplotlib.pyplot.tricontourf
835
+ values = self[self.to_index(date, variable, member)]
836
+
837
+ return plot_values(values, self.latitudes, self.longitudes, **kwargs)
536
838
 
839
+ def to_index(
840
+ self,
841
+ date: Union[int, datetime.datetime, np.datetime64, str],
842
+ variable: Union[int, str],
843
+ member: int = 0,
844
+ ) -> Tuple[int, int, int]:
845
+ """Convert date, variable, and member to indices.
846
+
847
+ Parameters
848
+ ----------
849
+ date : int or datetime.datetime or numpy.datetime64 or str
850
+ The date.
851
+ variable : int or str
852
+ The variable.
853
+ member : int, optional
854
+ The ensemble member.
537
855
 
538
856
  Returns
539
857
  -------
540
- matplotlib.pyplot.Axes
858
+ tuple of int
859
+ The indices.
541
860
  """
542
-
543
- from anemoi.utils.devtools import plot_values
544
861
  from earthkit.data.utils.dates import to_datetime
545
862
 
546
863
  if not isinstance(date, int):
@@ -554,6 +871,8 @@ class Dataset:
554
871
  else:
555
872
  date_index = date
556
873
 
874
+ date_index = int(date_index) # because np.int64 is not instance of int
875
+
557
876
  if isinstance(variable, int):
558
877
  variable_index = variable
559
878
  else:
@@ -562,6 +881,178 @@ class Dataset:
562
881
 
563
882
  variable_index = self.name_to_index[variable]
564
883
 
565
- values = self[date_index, variable_index, member]
884
+ return (date_index, variable_index, member)
566
885
 
567
- return plot_values(values, self.latitudes, self.longitudes, **kwargs)
886
+ @abstractmethod
887
+ def __getitem__(self, n: FullIndex) -> NDArray[Any]:
888
+ """Get the item at the specified index.
889
+
890
+ Parameters
891
+ ----------
892
+ n : FullIndex
893
+ Index to retrieve.
894
+
895
+ Returns
896
+ -------
897
+ NDArray[Any]
898
+ Retrieved item.
899
+ """
900
+
901
+ @abstractmethod
902
+ def __len__(self) -> int:
903
+ """Return the length of the dataset.
904
+
905
+ Returns
906
+ -------
907
+ int
908
+ The length of the dataset.
909
+ """
910
+
911
+ @property
912
+ @abstractmethod
913
+ def variables(self) -> List[str]:
914
+ """Return the list of variables in the dataset."""
915
+ pass
916
+
917
+ @property
918
+ @abstractmethod
919
+ def frequency(self) -> datetime.timedelta:
920
+ """Return the frequency of the dataset."""
921
+ pass
922
+
923
+ @property
924
+ @abstractmethod
925
+ def dates(self) -> NDArray[np.datetime64]:
926
+ """Return the dates in the dataset."""
927
+ pass
928
+
929
+ @property
930
+ @abstractmethod
931
+ def resolution(self) -> str:
932
+ """Return the resolution of the dataset."""
933
+ pass
934
+
935
+ @property
936
+ @abstractmethod
937
+ def name_to_index(self) -> Dict[str, int]:
938
+ """Return the mapping of variable names to indices."""
939
+ pass
940
+
941
+ @property
942
+ @abstractmethod
943
+ def shape(self) -> Shape:
944
+ """Return the shape of the dataset."""
945
+ pass
946
+
947
+ @property
948
+ @abstractmethod
949
+ def field_shape(self) -> Shape:
950
+ """Return the shape of the fields in the dataset."""
951
+ pass
952
+
953
+ @property
954
+ @abstractmethod
955
+ def dtype(self) -> np.dtype:
956
+ """Return the data type of the dataset."""
957
+ pass
958
+
959
+ @property
960
+ @abstractmethod
961
+ def latitudes(self) -> NDArray[Any]:
962
+ """Return the latitudes in the dataset."""
963
+ pass
964
+
965
+ @property
966
+ @abstractmethod
967
+ def longitudes(self) -> NDArray[Any]:
968
+ """Return the longitudes in the dataset."""
969
+ pass
970
+
971
+ @property
972
+ @abstractmethod
973
+ def variables_metadata(self) -> Dict[str, Any]:
974
+ """Return the metadata of the variables in the dataset."""
975
+ pass
976
+
977
+ @abstractmethod
978
+ @cached_property
979
+ def missing(self) -> Set[int]:
980
+ """Return the set of missing indices in the dataset."""
981
+ pass
982
+
983
+ @abstractmethod
984
+ @cached_property
985
+ def constant_fields(self) -> List[str]:
986
+ """Return the list of constant fields in the dataset."""
987
+ pass
988
+
989
+ @abstractmethod
990
+ @cached_property
991
+ def statistics(self) -> Dict[str, NDArray[Any]]:
992
+ """Return the statistics of the dataset."""
993
+ pass
994
+
995
+ @abstractmethod
996
+ def statistics_tendencies(self, delta: Optional[datetime.timedelta] = None) -> Dict[str, NDArray[Any]]:
997
+ """Return the tendencies of the statistics in the dataset.
998
+
999
+ Parameters
1000
+ ----------
1001
+ delta : datetime.timedelta, optional
1002
+ The time delta for computing tendencies.
1003
+
1004
+ Returns
1005
+ -------
1006
+ dict
1007
+ The tendencies.
1008
+ """
1009
+ pass
1010
+
1011
+ @abstractmethod
1012
+ def source(self, index: int) -> Source:
1013
+ """Return the source of the dataset at the specified index.
1014
+
1015
+ Parameters
1016
+ ----------
1017
+ index : int
1018
+ The index.
1019
+
1020
+ Returns
1021
+ -------
1022
+ Source
1023
+ The source.
1024
+ """
1025
+ pass
1026
+
1027
+ @abstractmethod
1028
+ def tree(self) -> Node:
1029
+ """Return the tree representation of the dataset.
1030
+
1031
+ Returns
1032
+ -------
1033
+ Node
1034
+ The tree representation.
1035
+ """
1036
+ pass
1037
+
1038
+ @abstractmethod
1039
+ def collect_input_sources(self, sources: List[Any]) -> None:
1040
+ """Collect the input sources of the dataset.
1041
+
1042
+ Parameters
1043
+ ----------
1044
+ sources : list
1045
+ The input sources.
1046
+ """
1047
+ pass
1048
+
1049
+ @abstractmethod
1050
+ def get_dataset_names(self, names: Set[str]) -> None:
1051
+ """Get the names of the datasets.
1052
+
1053
+ Parameters
1054
+ ----------
1055
+ names : set of str
1056
+ The dataset names.
1057
+ """
1058
+ pass