anemoi-datasets 0.5.15__py3-none-any.whl → 0.5.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. anemoi/datasets/__init__.py +4 -1
  2. anemoi/datasets/__main__.py +12 -2
  3. anemoi/datasets/_version.py +9 -4
  4. anemoi/datasets/commands/cleanup.py +17 -2
  5. anemoi/datasets/commands/compare.py +18 -2
  6. anemoi/datasets/commands/copy.py +196 -14
  7. anemoi/datasets/commands/create.py +50 -7
  8. anemoi/datasets/commands/finalise-additions.py +17 -2
  9. anemoi/datasets/commands/finalise.py +17 -2
  10. anemoi/datasets/commands/init-additions.py +17 -2
  11. anemoi/datasets/commands/init.py +16 -2
  12. anemoi/datasets/commands/inspect.py +283 -62
  13. anemoi/datasets/commands/load-additions.py +16 -2
  14. anemoi/datasets/commands/load.py +16 -2
  15. anemoi/datasets/commands/patch.py +17 -2
  16. anemoi/datasets/commands/publish.py +17 -2
  17. anemoi/datasets/commands/scan.py +31 -3
  18. anemoi/datasets/compute/recentre.py +47 -11
  19. anemoi/datasets/create/__init__.py +612 -85
  20. anemoi/datasets/create/check.py +142 -20
  21. anemoi/datasets/create/chunks.py +64 -4
  22. anemoi/datasets/create/config.py +185 -21
  23. anemoi/datasets/create/filter.py +50 -0
  24. anemoi/datasets/create/filters/__init__.py +33 -0
  25. anemoi/datasets/create/filters/empty.py +37 -0
  26. anemoi/datasets/create/filters/legacy.py +93 -0
  27. anemoi/datasets/create/filters/noop.py +37 -0
  28. anemoi/datasets/create/filters/orog_to_z.py +58 -0
  29. anemoi/datasets/create/{functions/filters → filters}/pressure_level_relative_humidity_to_specific_humidity.py +33 -10
  30. anemoi/datasets/create/{functions/filters → filters}/pressure_level_specific_humidity_to_relative_humidity.py +32 -8
  31. anemoi/datasets/create/filters/rename.py +205 -0
  32. anemoi/datasets/create/{functions/filters → filters}/rotate_winds.py +43 -28
  33. anemoi/datasets/create/{functions/filters → filters}/single_level_dewpoint_to_relative_humidity.py +32 -9
  34. anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_dewpoint.py +33 -9
  35. anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_specific_humidity.py +55 -7
  36. anemoi/datasets/create/{functions/filters → filters}/single_level_specific_humidity_to_relative_humidity.py +98 -37
  37. anemoi/datasets/create/filters/speeddir_to_uv.py +95 -0
  38. anemoi/datasets/create/{functions/filters → filters}/sum.py +24 -27
  39. anemoi/datasets/create/filters/transform.py +53 -0
  40. anemoi/datasets/create/{functions/filters → filters}/unrotate_winds.py +27 -18
  41. anemoi/datasets/create/filters/uv_to_speeddir.py +94 -0
  42. anemoi/datasets/create/{functions/filters → filters}/wz_to_w.py +51 -33
  43. anemoi/datasets/create/input/__init__.py +76 -5
  44. anemoi/datasets/create/input/action.py +149 -13
  45. anemoi/datasets/create/input/concat.py +81 -10
  46. anemoi/datasets/create/input/context.py +39 -4
  47. anemoi/datasets/create/input/data_sources.py +72 -6
  48. anemoi/datasets/create/input/empty.py +21 -3
  49. anemoi/datasets/create/input/filter.py +60 -12
  50. anemoi/datasets/create/input/function.py +154 -37
  51. anemoi/datasets/create/input/join.py +86 -14
  52. anemoi/datasets/create/input/misc.py +67 -17
  53. anemoi/datasets/create/input/pipe.py +33 -6
  54. anemoi/datasets/create/input/repeated_dates.py +189 -41
  55. anemoi/datasets/create/input/result.py +202 -87
  56. anemoi/datasets/create/input/step.py +119 -22
  57. anemoi/datasets/create/input/template.py +100 -13
  58. anemoi/datasets/create/input/trace.py +62 -7
  59. anemoi/datasets/create/patch.py +52 -4
  60. anemoi/datasets/create/persistent.py +134 -17
  61. anemoi/datasets/create/size.py +15 -1
  62. anemoi/datasets/create/source.py +51 -0
  63. anemoi/datasets/create/sources/__init__.py +36 -0
  64. anemoi/datasets/create/{functions/sources → sources}/accumulations.py +296 -30
  65. anemoi/datasets/create/{functions/sources → sources}/constants.py +27 -2
  66. anemoi/datasets/create/{functions/sources → sources}/eccc_fstd.py +7 -3
  67. anemoi/datasets/create/sources/empty.py +37 -0
  68. anemoi/datasets/create/{functions/sources → sources}/forcings.py +25 -1
  69. anemoi/datasets/create/sources/grib.py +297 -0
  70. anemoi/datasets/create/{functions/sources → sources}/hindcasts.py +38 -4
  71. anemoi/datasets/create/sources/legacy.py +93 -0
  72. anemoi/datasets/create/{functions/sources → sources}/mars.py +168 -20
  73. anemoi/datasets/create/sources/netcdf.py +42 -0
  74. anemoi/datasets/create/sources/opendap.py +43 -0
  75. anemoi/datasets/create/{functions/sources/__init__.py → sources/patterns.py} +35 -4
  76. anemoi/datasets/create/sources/recentre.py +150 -0
  77. anemoi/datasets/create/{functions/sources → sources}/source.py +27 -5
  78. anemoi/datasets/create/{functions/sources → sources}/tendencies.py +64 -7
  79. anemoi/datasets/create/sources/xarray.py +92 -0
  80. anemoi/datasets/create/sources/xarray_kerchunk.py +36 -0
  81. anemoi/datasets/create/sources/xarray_support/README.md +1 -0
  82. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/__init__.py +109 -8
  83. anemoi/datasets/create/sources/xarray_support/coordinates.py +442 -0
  84. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/field.py +94 -16
  85. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/fieldlist.py +90 -25
  86. anemoi/datasets/create/sources/xarray_support/flavour.py +1036 -0
  87. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/grid.py +92 -31
  88. anemoi/datasets/create/sources/xarray_support/metadata.py +395 -0
  89. anemoi/datasets/create/sources/xarray_support/patch.py +91 -0
  90. anemoi/datasets/create/sources/xarray_support/time.py +391 -0
  91. anemoi/datasets/create/sources/xarray_support/variable.py +331 -0
  92. anemoi/datasets/create/sources/xarray_zarr.py +41 -0
  93. anemoi/datasets/create/{functions/sources → sources}/zenodo.py +34 -5
  94. anemoi/datasets/create/statistics/__init__.py +233 -44
  95. anemoi/datasets/create/statistics/summary.py +52 -6
  96. anemoi/datasets/create/testing.py +76 -0
  97. anemoi/datasets/create/{functions/filters/noop.py → typing.py} +6 -3
  98. anemoi/datasets/create/utils.py +97 -6
  99. anemoi/datasets/create/writer.py +26 -4
  100. anemoi/datasets/create/zarr.py +170 -23
  101. anemoi/datasets/data/__init__.py +51 -4
  102. anemoi/datasets/data/complement.py +191 -40
  103. anemoi/datasets/data/concat.py +141 -16
  104. anemoi/datasets/data/dataset.py +552 -61
  105. anemoi/datasets/data/debug.py +197 -26
  106. anemoi/datasets/data/ensemble.py +93 -8
  107. anemoi/datasets/data/fill_missing.py +165 -18
  108. anemoi/datasets/data/forwards.py +428 -56
  109. anemoi/datasets/data/grids.py +323 -97
  110. anemoi/datasets/data/indexing.py +112 -19
  111. anemoi/datasets/data/interpolate.py +92 -12
  112. anemoi/datasets/data/join.py +158 -19
  113. anemoi/datasets/data/masked.py +129 -15
  114. anemoi/datasets/data/merge.py +137 -23
  115. anemoi/datasets/data/misc.py +172 -16
  116. anemoi/datasets/data/missing.py +233 -29
  117. anemoi/datasets/data/rescale.py +111 -10
  118. anemoi/datasets/data/select.py +168 -26
  119. anemoi/datasets/data/statistics.py +67 -6
  120. anemoi/datasets/data/stores.py +149 -64
  121. anemoi/datasets/data/subset.py +159 -25
  122. anemoi/datasets/data/unchecked.py +168 -57
  123. anemoi/datasets/data/xy.py +168 -25
  124. anemoi/datasets/dates/__init__.py +191 -16
  125. anemoi/datasets/dates/groups.py +189 -47
  126. anemoi/datasets/grids.py +270 -31
  127. anemoi/datasets/testing.py +28 -1
  128. {anemoi_datasets-0.5.15.dist-info → anemoi_datasets-0.5.17.dist-info}/METADATA +10 -7
  129. anemoi_datasets-0.5.17.dist-info/RECORD +137 -0
  130. {anemoi_datasets-0.5.15.dist-info → anemoi_datasets-0.5.17.dist-info}/WHEEL +1 -1
  131. {anemoi_datasets-0.5.15.dist-info → anemoi_datasets-0.5.17.dist-info/licenses}/LICENSE +1 -1
  132. anemoi/datasets/create/functions/__init__.py +0 -66
  133. anemoi/datasets/create/functions/filters/__init__.py +0 -9
  134. anemoi/datasets/create/functions/filters/empty.py +0 -17
  135. anemoi/datasets/create/functions/filters/orog_to_z.py +0 -58
  136. anemoi/datasets/create/functions/filters/rename.py +0 -79
  137. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +0 -78
  138. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +0 -56
  139. anemoi/datasets/create/functions/sources/empty.py +0 -15
  140. anemoi/datasets/create/functions/sources/grib.py +0 -150
  141. anemoi/datasets/create/functions/sources/netcdf.py +0 -15
  142. anemoi/datasets/create/functions/sources/opendap.py +0 -15
  143. anemoi/datasets/create/functions/sources/recentre.py +0 -60
  144. anemoi/datasets/create/functions/sources/xarray/coordinates.py +0 -255
  145. anemoi/datasets/create/functions/sources/xarray/flavour.py +0 -472
  146. anemoi/datasets/create/functions/sources/xarray/metadata.py +0 -148
  147. anemoi/datasets/create/functions/sources/xarray/patch.py +0 -44
  148. anemoi/datasets/create/functions/sources/xarray/time.py +0 -177
  149. anemoi/datasets/create/functions/sources/xarray/variable.py +0 -188
  150. anemoi/datasets/create/functions/sources/xarray_kerchunk.py +0 -42
  151. anemoi/datasets/create/functions/sources/xarray_zarr.py +0 -15
  152. anemoi/datasets/utils/fields.py +0 -47
  153. anemoi_datasets-0.5.15.dist-info/RECORD +0 -129
  154. {anemoi_datasets-0.5.15.dist-info → anemoi_datasets-0.5.17.dist-info}/entry_points.txt +0 -0
  155. {anemoi_datasets-0.5.15.dist-info → anemoi_datasets-0.5.17.dist-info}/top_level.txt +0 -0
@@ -9,13 +9,30 @@
9
9
 
10
10
 
11
11
  from functools import wraps
12
+ from typing import Any
13
+ from typing import Callable
14
+ from typing import List
15
+ from typing import Tuple
16
+ from typing import Union
12
17
 
13
18
  import numpy as np
19
+ from numpy.typing import NDArray
14
20
 
21
+ from .dataset import FullIndex
22
+ from .dataset import Shape
23
+ from .dataset import TupleIndex
15
24
 
16
- def _tuple_with_slices(t, shape):
17
- """Replace all integers in a tuple with slices, so we preserve the dimensionality."""
18
25
 
26
+ def _tuple_with_slices(t: TupleIndex, shape: Shape) -> Tuple[TupleIndex, Tuple[int, ...]]:
27
+ """Replace all integers in a tuple with slices, so we preserve the dimensionality.
28
+
29
+ Parameters:
30
+ t (TupleIndex): The tuple index to process.
31
+ shape (Shape): The shape of the array.
32
+
33
+ Returns:
34
+ Tuple[TupleIndex, Tuple[int, ...]]: A tuple containing the modified index and the changes.
35
+ """
19
36
  result = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in t)
20
37
  changes = tuple(j for (j, i) in enumerate(t) if isinstance(i, int))
21
38
  result = tuple(slice(*s.indices(shape[i])) for (i, s) in enumerate(result))
@@ -23,7 +40,16 @@ def _tuple_with_slices(t, shape):
23
40
  return result, changes
24
41
 
25
42
 
26
- def _extend_shape(index, shape):
43
+ def _extend_shape(index: TupleIndex, shape: Shape) -> TupleIndex:
44
+ """Extend the shape of the index to match the shape of the array.
45
+
46
+ Parameters:
47
+ index (TupleIndex): The index to extend.
48
+ shape (Shape): The shape of the array.
49
+
50
+ Returns:
51
+ TupleIndex: The extended index.
52
+ """
27
53
  if Ellipsis in index:
28
54
  if index.count(Ellipsis) > 1:
29
55
  raise IndexError("Only one Ellipsis is allowed")
@@ -40,7 +66,16 @@ def _extend_shape(index, shape):
40
66
  return index
41
67
 
42
68
 
43
- def _index_to_tuple(index, shape):
69
+ def _index_to_tuple(index: FullIndex, shape: Shape) -> TupleIndex:
70
+ """Convert an index to a tuple index.
71
+
72
+ Parameters:
73
+ index (FullIndex): The index to convert.
74
+ shape (Shape): The shape of the array.
75
+
76
+ Returns:
77
+ TupleIndex: The converted tuple index.
78
+ """
44
79
  if isinstance(index, int):
45
80
  return _extend_shape((index,), shape)
46
81
  if isinstance(index, slice):
@@ -52,12 +87,29 @@ def _index_to_tuple(index, shape):
52
87
  raise ValueError(f"Invalid index: {index}")
53
88
 
54
89
 
55
- def index_to_slices(index, shape):
56
- """Convert an index to a tuple of slices, with the same dimensionality as the shape."""
90
+ def index_to_slices(index: Union[int, slice, Tuple], shape: Shape) -> Tuple[TupleIndex, Tuple[int, ...]]:
91
+ """Convert an index to a tuple of slices, with the same dimensionality as the shape.
92
+
93
+ Parameters:
94
+ index (Union[int, slice, Tuple]): The index to convert.
95
+ shape (Shape): The shape of the array.
96
+
97
+ Returns:
98
+ Tuple[TupleIndex, Tuple[int, ...]]: A tuple containing the slices and the changes.
99
+ """
57
100
  return _tuple_with_slices(_index_to_tuple(index, shape), shape)
58
101
 
59
102
 
60
- def apply_index_to_slices_changes(result, changes):
103
+ def apply_index_to_slices_changes(result: NDArray[Any], changes: Tuple[int, ...]) -> NDArray[Any]:
104
+ """Apply changes to the result array based on the slices.
105
+
106
+ Parameters:
107
+ result (NDArray[Any]): The result array.
108
+ changes (Tuple[int, ...]): The changes to apply.
109
+
110
+ Returns:
111
+ NDArray[Any]: The modified result array.
112
+ """
61
113
  if changes:
62
114
  shape = result.shape
63
115
  for i in changes:
@@ -66,16 +118,33 @@ def apply_index_to_slices_changes(result, changes):
66
118
  return result
67
119
 
68
120
 
69
- def update_tuple(t, index, value):
70
- """Replace the elements of a tuple at the given index with a new value."""
121
+ def update_tuple(t: Tuple, index: int, value: Any) -> Tuple[Tuple, Any]:
122
+ """Replace the elements of a tuple at the given index with a new value.
123
+
124
+ Parameters:
125
+ tp (Tuple): The original tuple.
126
+ index (int): The index to update.
127
+ value (Any): The new value.
128
+
129
+ Returns:
130
+ Tuple[Tuple, Any]: The updated tuple and the previous value.
131
+ """
71
132
  t = list(t)
72
133
  prev = t[index]
73
134
  t[index] = value
74
135
  return tuple(t), prev
75
136
 
76
137
 
77
- def length_to_slices(index, lengths):
78
- """Convert an index to a list of slices, given the lengths of the dimensions."""
138
+ def length_to_slices(index: slice, lengths: List[int]) -> List[Union[slice, None]]:
139
+ """Convert an index to a list of slices, given the lengths of the dimensions.
140
+
141
+ Parameters:
142
+ index (slice): The index to convert.
143
+ lengths (List[int]): The lengths of the dimensions.
144
+
145
+ Returns:
146
+ List[Union[slice, None]]: A list of slices.
147
+ """
79
148
  total = sum(lengths)
80
149
  start, stop, step = index.indices(total)
81
150
 
@@ -105,8 +174,17 @@ def length_to_slices(index, lengths):
105
174
  return result
106
175
 
107
176
 
108
- def _as_tuples(index):
109
- def _(i):
177
+ def _as_tuples(index: Tuple) -> Tuple:
178
+ """Convert elements of the index to tuples if they are lists or arrays.
179
+
180
+ Parameters:
181
+ index (Tuple): The index to convert.
182
+
183
+ Returns:
184
+ Tuple: The converted index.
185
+ """
186
+
187
+ def _(i: Any) -> Any:
110
188
  if hasattr(i, "tolist"):
111
189
  # NumPy arrays, TensorFlow tensors, etc.
112
190
  i = i.tolist()
@@ -121,18 +199,27 @@ def _as_tuples(index):
121
199
  return tuple(_(i) for i in index)
122
200
 
123
201
 
124
- def expand_list_indexing(method):
125
- """Allows to use slices, lists, and tuples to select data from the dataset. Zarr does not support indexing with lists/arrays directly, so we need to implement it ourselves."""
202
+ def expand_list_indexing(method: Callable[..., NDArray[Any]]) -> Callable[..., NDArray[Any]]:
203
+ """Allows to use slices, lists, and tuples to select data from the dataset.
204
+ Zarr does not support indexing with lists/arrays directly,
205
+ so we need to implement it ourselves.
206
+
207
+ Parameters:
208
+ method (Callable[..., NDArray[Any]]): The method to wrap.
209
+
210
+ Returns:
211
+ Callable[..., NDArray[Any]]: The wrapped method.
212
+ """
126
213
 
127
214
  @wraps(method)
128
- def wrapper(self, index):
215
+ def wrapper(self: Any, index: FullIndex) -> NDArray[Any]:
129
216
  if not isinstance(index, tuple):
130
217
  return method(self, index)
131
218
 
132
219
  if not any(isinstance(i, (list, tuple)) for i in index):
133
220
  return method(self, index)
134
221
 
135
- which = []
222
+ which: List[int] = []
136
223
  for i, idx in enumerate(index):
137
224
  if isinstance(idx, (list, tuple)):
138
225
  which.append(i)
@@ -154,9 +241,15 @@ def expand_list_indexing(method):
154
241
  return wrapper
155
242
 
156
243
 
157
- def make_slice_or_index_from_list_or_tuple(indices):
158
- """Convert a list or tuple of indices to a slice or an index, if possible."""
244
+ def make_slice_or_index_from_list_or_tuple(indices: List[int]) -> Union[List[int], slice]:
245
+ """Convert a list or tuple of indices to a slice or an index, if possible.
246
+
247
+ Parameters:
248
+ indices (List[int]): The list or tuple of indices.
159
249
 
250
+ Returns:
251
+ Union[List[int], slice]: The slice or index.
252
+ """
160
253
  if len(indices) < 2:
161
254
  return indices
162
255
 
@@ -8,12 +8,21 @@
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
10
 
11
+ import datetime
11
12
  import logging
12
13
  from functools import cached_property
14
+ from typing import Any
15
+ from typing import Dict
16
+ from typing import Set
13
17
 
14
18
  import numpy as np
15
19
  from anemoi.utils.dates import frequency_to_timedelta
20
+ from numpy.typing import NDArray
16
21
 
22
+ from .dataset import Dataset
23
+ from .dataset import FullIndex
24
+ from .dataset import Shape
25
+ from .dataset import TupleIndex
17
26
  from .debug import Node
18
27
  from .debug import debug_indexing
19
28
  from .forwards import Forwards
@@ -26,8 +35,18 @@ LOG = logging.getLogger(__name__)
26
35
 
27
36
 
28
37
  class InterpolateFrequency(Forwards):
29
-
30
- def __init__(self, dataset, frequency):
38
+ """A class to represent a dataset with interpolated frequency."""
39
+
40
+ def __init__(self, dataset: Dataset, frequency: str) -> None:
41
+ """Initialize the InterpolateFrequency class.
42
+
43
+ Parameters
44
+ ----------
45
+ dataset : Dataset
46
+ The dataset to be interpolated.
47
+ frequency : str
48
+ The interpolation frequency.
49
+ """
31
50
  super().__init__(dataset)
32
51
  self._frequency = frequency_to_timedelta(frequency)
33
52
 
@@ -56,17 +75,53 @@ class InterpolateFrequency(Forwards):
56
75
 
57
76
  @debug_indexing
58
77
  @expand_list_indexing
59
- def _get_tuple(self, index):
78
+ def _get_tuple(self, index: TupleIndex) -> NDArray[Any]:
79
+ """Get the interpolated data for a tuple index.
80
+
81
+ Parameters
82
+ ----------
83
+ index : TupleIndex
84
+ The tuple index to retrieve data from.
85
+
86
+ Returns
87
+ -------
88
+ NDArray[Any]
89
+ The interpolated data for the tuple index.
90
+ """
60
91
  index, changes = index_to_slices(index, self.shape)
61
92
  index, previous = update_tuple(index, 0, slice(None))
62
93
  result = self._get_slice(previous)
63
94
  return apply_index_to_slices_changes(result[index], changes)
64
95
 
65
- def _get_slice(self, s):
96
+ def _get_slice(self, s: slice) -> NDArray[Any]:
97
+ """Get the interpolated data for a slice.
98
+
99
+ Parameters
100
+ ----------
101
+ s : slice
102
+ The slice to retrieve data from.
103
+
104
+ Returns
105
+ -------
106
+ NDArray[Any]
107
+ The interpolated data for the slice.
108
+ """
66
109
  return np.stack([self[i] for i in range(*s.indices(self._len))])
67
110
 
68
111
  @debug_indexing
69
- def __getitem__(self, n):
112
+ def __getitem__(self, n: FullIndex) -> NDArray[Any]:
113
+ """Get the interpolated data at the specified index.
114
+
115
+ Parameters
116
+ ----------
117
+ n : FullIndex
118
+ The index to retrieve data from.
119
+
120
+ Returns
121
+ -------
122
+ NDArray[Any]
123
+ The interpolated data at the specified index.
124
+ """
70
125
  if isinstance(n, tuple):
71
126
  return self._get_tuple(n)
72
127
 
@@ -92,15 +147,24 @@ class InterpolateFrequency(Forwards):
92
147
  assert 0 < alpha < 1, alpha
93
148
  return self.forward[i] * (1 - alpha) + self.forward[i + 1] * alpha
94
149
 
95
- def __len__(self):
150
+ def __len__(self) -> int:
151
+ """Get the length of the interpolated dataset.
152
+
153
+ Returns
154
+ -------
155
+ int
156
+ The length of the interpolated dataset.
157
+ """
96
158
  return (self.other_len - 1) * self.ratio + 1
97
159
 
98
160
  @property
99
- def frequency(self):
161
+ def frequency(self) -> datetime.timedelta:
162
+ """Get the interpolation frequency."""
100
163
  return self._frequency
101
164
 
102
165
  @cached_property
103
- def dates(self):
166
+ def dates(self) -> NDArray[np.datetime64]:
167
+ """Get the interpolated dates."""
104
168
  result = []
105
169
  deltas = [np.timedelta64(self.seconds * i, "s") for i in range(self.ratio)]
106
170
  for d in self.forward.dates[:-1]:
@@ -110,14 +174,23 @@ class InterpolateFrequency(Forwards):
110
174
  return np.array(result)
111
175
 
112
176
  @property
113
- def shape(self):
177
+ def shape(self) -> Shape:
178
+ """Get the shape of the interpolated dataset."""
114
179
  return (self._len,) + self.forward.shape[1:]
115
180
 
116
- def tree(self):
181
+ def tree(self) -> Node:
182
+ """Get the tree representation of the dataset.
183
+
184
+ Returns
185
+ -------
186
+ Node
187
+ The tree representation of the dataset.
188
+ """
117
189
  return Node(self, [self.forward.tree()], frequency=self.frequency)
118
190
 
119
191
  @cached_property
120
- def missing(self):
192
+ def missing(self) -> Set[int]:
193
+ """Get the missing data indices."""
121
194
  result = []
122
195
  j = 0
123
196
  for i in range(self.other_len):
@@ -130,7 +203,14 @@ class InterpolateFrequency(Forwards):
130
203
  result = set(x for x in result if x < self._len)
131
204
  return result
132
205
 
133
- def subclass_metadata_specific(self):
206
+ def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
207
+ """Get the metadata specific to the InterpolateFrequency subclass.
208
+
209
+ Returns
210
+ -------
211
+ Dict[str, Any]
212
+ The metadata specific to the InterpolateFrequency subclass.
213
+ """
134
214
  return {
135
215
  # "frequency": frequency_to_string(self._frequency),
136
216
  }
@@ -8,11 +8,22 @@
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
10
 
11
+ import datetime
11
12
  import logging
12
13
  from functools import cached_property
14
+ from typing import Any
15
+ from typing import Dict
16
+ from typing import List
17
+ from typing import Optional
18
+ from typing import Set
13
19
 
14
20
  import numpy as np
21
+ from numpy.typing import NDArray
15
22
 
23
+ from .dataset import Dataset
24
+ from .dataset import FullIndex
25
+ from .dataset import Shape
26
+ from .dataset import TupleIndex
16
27
  from .debug import Node
17
28
  from .debug import Source
18
29
  from .debug import debug_indexing
@@ -30,20 +41,57 @@ LOG = logging.getLogger(__name__)
30
41
  class Join(Combined):
31
42
  """Join the datasets along the variables axis."""
32
43
 
33
- def check_compatibility(self, d1, d2):
44
+ def check_compatibility(self, d1: Dataset, d2: Dataset) -> None:
45
+ """Check the compatibility of two datasets.
46
+
47
+ Parameters
48
+ ----------
49
+ d1 : Dataset
50
+ The first dataset.
51
+ d2 : Dataset
52
+ The second dataset.
53
+ """
34
54
  super().check_compatibility(d1, d2)
35
55
  self.check_same_sub_shapes(d1, d2, drop_axis=1)
36
56
 
37
- def check_same_variables(self, d1, d2):
57
+ def check_same_variables(self, d1: Dataset, d2: Dataset) -> None:
58
+ """Check if the datasets have the same variables.
59
+
60
+ Parameters
61
+ ----------
62
+ d1 : Dataset
63
+ The first dataset.
64
+ d2 : Dataset
65
+ The second dataset.
66
+ """
38
67
  # Turned off because we are joining along the variables axis
39
68
  pass
40
69
 
41
- def __len__(self):
70
+ def __len__(self) -> int:
71
+ """Get the length of the joined dataset.
72
+
73
+ Returns
74
+ -------
75
+ int
76
+ The length of the joined dataset.
77
+ """
42
78
  return len(self.datasets[0])
43
79
 
44
80
  @debug_indexing
45
81
  @expand_list_indexing
46
- def _get_tuple(self, index):
82
+ def _get_tuple(self, index: TupleIndex) -> NDArray[Any]:
83
+ """Get the data for a tuple index.
84
+
85
+ Parameters
86
+ ----------
87
+ index : TupleIndex
88
+ The tuple index to retrieve data from.
89
+
90
+ Returns
91
+ -------
92
+ NDArray[Any]
93
+ The data for the tuple index.
94
+ """
47
95
  index, changes = index_to_slices(index, self.shape)
48
96
  index, previous = update_tuple(index, 1, slice(None))
49
97
 
@@ -54,11 +102,35 @@ class Join(Combined):
54
102
  return apply_index_to_slices_changes(result[:, previous], changes)
55
103
 
56
104
  @debug_indexing
57
- def _get_slice(self, s):
105
+ def _get_slice(self, s: slice) -> NDArray[Any]:
106
+ """Get the data for a slice.
107
+
108
+ Parameters
109
+ ----------
110
+ s : slice
111
+ The slice to retrieve data from.
112
+
113
+ Returns
114
+ -------
115
+ NDArray[Any]
116
+ The data for the slice.
117
+ """
58
118
  return np.stack([self[i] for i in range(*s.indices(self._len))])
59
119
 
60
120
  @debug_indexing
61
- def __getitem__(self, n):
121
+ def __getitem__(self, n: FullIndex) -> NDArray[Any]:
122
+ """Get the data at the specified index.
123
+
124
+ Parameters
125
+ ----------
126
+ n : FullIndex
127
+ The index to retrieve data from.
128
+
129
+ Returns
130
+ -------
131
+ NDArray[Any]
132
+ The data at the specified index.
133
+ """
62
134
  if isinstance(n, tuple):
63
135
  return self._get_tuple(n)
64
136
 
@@ -68,11 +140,19 @@ class Join(Combined):
68
140
  return np.concatenate([d[n] for d in self.datasets])
69
141
 
70
142
  @cached_property
71
- def shape(self):
143
+ def shape(self) -> Shape:
144
+ """Get the shape of the joined dataset."""
72
145
  cols = sum(d.shape[1] for d in self.datasets)
73
146
  return (len(self), cols) + self.datasets[0].shape[2:]
74
147
 
75
- def _overlay(self):
148
+ def _overlay(self) -> Dataset:
149
+ """Overlay the datasets.
150
+
151
+ Returns
152
+ -------
153
+ Dataset
154
+ The overlaid dataset.
155
+ """
76
156
  indices = {}
77
157
  i = 0
78
158
  for d in self.datasets:
@@ -102,9 +182,10 @@ class Join(Combined):
102
182
  return Select(self, indices, {"overlay": variables})
103
183
 
104
184
  @cached_property
105
- def variables(self):
185
+ def variables(self) -> List[str]:
186
+ """Get the variables of the joined dataset."""
106
187
  seen = set()
107
- result = []
188
+ result: List[str] = []
108
189
  for d in reversed(self.datasets):
109
190
  for v in reversed(d.variables):
110
191
  while v in seen:
@@ -115,7 +196,8 @@ class Join(Combined):
115
196
  return result
116
197
 
117
198
  @property
118
- def variables_metadata(self):
199
+ def variables_metadata(self) -> Dict[str, Any]:
200
+ """Get the metadata of the variables."""
119
201
  result = {}
120
202
  variables = [v for v in self.variables if not (v.startswith("(") and v.endswith(")"))]
121
203
 
@@ -134,16 +216,30 @@ class Join(Combined):
134
216
  return result
135
217
 
136
218
  @cached_property
137
- def name_to_index(self):
219
+ def name_to_index(self) -> Dict[str, int]:
220
+ """Get the mapping of variable names to indices."""
138
221
  return {k: i for i, k in enumerate(self.variables)}
139
222
 
140
223
  @property
141
- def statistics(self):
224
+ def statistics(self) -> Dict[str, NDArray[Any]]:
225
+ """Get the statistics of the joined dataset."""
142
226
  return {
143
227
  k: np.concatenate([d.statistics[k] for d in self.datasets], axis=0) for k in self.datasets[0].statistics
144
228
  }
145
229
 
146
- def statistics_tendencies(self, delta=None):
230
+ def statistics_tendencies(self, delta: Optional[datetime.timedelta] = None) -> Dict[str, NDArray[Any]]:
231
+ """Get the statistics tendencies of the joined dataset.
232
+
233
+ Parameters
234
+ ----------
235
+ delta : Optional[datetime.timedelta]
236
+ The time delta for the tendencies.
237
+
238
+ Returns
239
+ -------
240
+ Dict[str, NDArray[Any]]
241
+ The statistics tendencies of the joined dataset.
242
+ """
147
243
  if delta is None:
148
244
  delta = self.frequency
149
245
  return {
@@ -151,7 +247,19 @@ class Join(Combined):
151
247
  for k in self.datasets[0].statistics_tendencies(delta)
152
248
  }
153
249
 
154
- def source(self, index):
250
+ def source(self, index: int) -> Source:
251
+ """Get the source of the data at the specified index.
252
+
253
+ Parameters
254
+ ----------
255
+ index : int
256
+ The index to retrieve the source from.
257
+
258
+ Returns
259
+ -------
260
+ Source
261
+ The source of the data at the specified index.
262
+ """
155
263
  i = index
156
264
  for dataset in self.datasets:
157
265
  if i < dataset.shape[1]:
@@ -160,18 +268,49 @@ class Join(Combined):
160
268
  assert False
161
269
 
162
270
  @cached_property
163
- def missing(self):
164
- result = set()
271
+ def missing(self) -> Set[int]:
272
+ """Get the missing data indices."""
273
+ result: Set[int] = set()
165
274
  for d in self.datasets:
166
275
  result = result | d.missing
167
276
  return result
168
277
 
169
- def tree(self):
278
+ def tree(self) -> Node:
279
+ """Get the tree representation of the dataset.
280
+
281
+ Returns
282
+ -------
283
+ Node
284
+ The tree representation of the dataset.
285
+ """
170
286
  return Node(self, [d.tree() for d in self.datasets])
171
287
 
288
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
289
+ """Get the metadata specific to the forwards subclass.
290
+
291
+ Returns
292
+ -------
293
+ dict[str, Any]
294
+ The metadata specific to the forwards subclass.
295
+ """
296
+ return {}
297
+
298
+
299
+ def join_factory(args: tuple, kwargs: dict) -> Dataset:
300
+ """Create a joined dataset.
172
301
 
173
- def join_factory(args, kwargs):
302
+ Parameters
303
+ ----------
304
+ args : tuple
305
+ The positional arguments.
306
+ kwargs : dict
307
+ The keyword arguments.
174
308
 
309
+ Returns
310
+ -------
311
+ Dataset
312
+ The joined dataset.
313
+ """
175
314
  datasets = kwargs.pop("join")
176
315
  assert isinstance(datasets, (list, tuple))
177
316
  assert len(args) == 0