anemoi-datasets 0.5.16__py3-none-any.whl → 0.5.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. anemoi/datasets/__init__.py +4 -1
  2. anemoi/datasets/__main__.py +12 -2
  3. anemoi/datasets/_version.py +9 -4
  4. anemoi/datasets/commands/cleanup.py +17 -2
  5. anemoi/datasets/commands/compare.py +18 -2
  6. anemoi/datasets/commands/copy.py +196 -14
  7. anemoi/datasets/commands/create.py +50 -7
  8. anemoi/datasets/commands/finalise-additions.py +17 -2
  9. anemoi/datasets/commands/finalise.py +17 -2
  10. anemoi/datasets/commands/init-additions.py +17 -2
  11. anemoi/datasets/commands/init.py +16 -2
  12. anemoi/datasets/commands/inspect.py +283 -62
  13. anemoi/datasets/commands/load-additions.py +16 -2
  14. anemoi/datasets/commands/load.py +16 -2
  15. anemoi/datasets/commands/patch.py +17 -2
  16. anemoi/datasets/commands/publish.py +17 -2
  17. anemoi/datasets/commands/scan.py +31 -3
  18. anemoi/datasets/compute/recentre.py +47 -11
  19. anemoi/datasets/create/__init__.py +612 -85
  20. anemoi/datasets/create/check.py +142 -20
  21. anemoi/datasets/create/chunks.py +64 -4
  22. anemoi/datasets/create/config.py +185 -21
  23. anemoi/datasets/create/filter.py +50 -0
  24. anemoi/datasets/create/filters/__init__.py +33 -0
  25. anemoi/datasets/create/filters/empty.py +37 -0
  26. anemoi/datasets/create/filters/legacy.py +93 -0
  27. anemoi/datasets/create/filters/noop.py +37 -0
  28. anemoi/datasets/create/filters/orog_to_z.py +58 -0
  29. anemoi/datasets/create/{functions/filters → filters}/pressure_level_relative_humidity_to_specific_humidity.py +33 -10
  30. anemoi/datasets/create/{functions/filters → filters}/pressure_level_specific_humidity_to_relative_humidity.py +32 -8
  31. anemoi/datasets/create/filters/rename.py +205 -0
  32. anemoi/datasets/create/{functions/filters → filters}/rotate_winds.py +43 -28
  33. anemoi/datasets/create/{functions/filters → filters}/single_level_dewpoint_to_relative_humidity.py +32 -9
  34. anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_dewpoint.py +33 -9
  35. anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_specific_humidity.py +55 -7
  36. anemoi/datasets/create/{functions/filters → filters}/single_level_specific_humidity_to_relative_humidity.py +98 -37
  37. anemoi/datasets/create/filters/speeddir_to_uv.py +95 -0
  38. anemoi/datasets/create/{functions/filters → filters}/sum.py +24 -27
  39. anemoi/datasets/create/filters/transform.py +53 -0
  40. anemoi/datasets/create/{functions/filters → filters}/unrotate_winds.py +27 -18
  41. anemoi/datasets/create/filters/uv_to_speeddir.py +94 -0
  42. anemoi/datasets/create/{functions/filters → filters}/wz_to_w.py +51 -33
  43. anemoi/datasets/create/input/__init__.py +76 -5
  44. anemoi/datasets/create/input/action.py +149 -13
  45. anemoi/datasets/create/input/concat.py +81 -10
  46. anemoi/datasets/create/input/context.py +39 -4
  47. anemoi/datasets/create/input/data_sources.py +72 -6
  48. anemoi/datasets/create/input/empty.py +21 -3
  49. anemoi/datasets/create/input/filter.py +60 -12
  50. anemoi/datasets/create/input/function.py +154 -37
  51. anemoi/datasets/create/input/join.py +86 -14
  52. anemoi/datasets/create/input/misc.py +67 -17
  53. anemoi/datasets/create/input/pipe.py +33 -6
  54. anemoi/datasets/create/input/repeated_dates.py +189 -41
  55. anemoi/datasets/create/input/result.py +202 -87
  56. anemoi/datasets/create/input/step.py +119 -22
  57. anemoi/datasets/create/input/template.py +100 -13
  58. anemoi/datasets/create/input/trace.py +62 -7
  59. anemoi/datasets/create/patch.py +52 -4
  60. anemoi/datasets/create/persistent.py +134 -17
  61. anemoi/datasets/create/size.py +15 -1
  62. anemoi/datasets/create/source.py +51 -0
  63. anemoi/datasets/create/sources/__init__.py +36 -0
  64. anemoi/datasets/create/{functions/sources → sources}/accumulations.py +296 -30
  65. anemoi/datasets/create/{functions/sources → sources}/constants.py +27 -2
  66. anemoi/datasets/create/{functions/sources → sources}/eccc_fstd.py +7 -3
  67. anemoi/datasets/create/sources/empty.py +37 -0
  68. anemoi/datasets/create/{functions/sources → sources}/forcings.py +25 -1
  69. anemoi/datasets/create/sources/grib.py +297 -0
  70. anemoi/datasets/create/{functions/sources → sources}/hindcasts.py +38 -4
  71. anemoi/datasets/create/sources/legacy.py +93 -0
  72. anemoi/datasets/create/{functions/sources → sources}/mars.py +168 -20
  73. anemoi/datasets/create/sources/netcdf.py +42 -0
  74. anemoi/datasets/create/sources/opendap.py +43 -0
  75. anemoi/datasets/create/{functions/sources/__init__.py → sources/patterns.py} +35 -4
  76. anemoi/datasets/create/sources/recentre.py +150 -0
  77. anemoi/datasets/create/{functions/sources → sources}/source.py +27 -5
  78. anemoi/datasets/create/{functions/sources → sources}/tendencies.py +64 -7
  79. anemoi/datasets/create/sources/xarray.py +92 -0
  80. anemoi/datasets/create/sources/xarray_kerchunk.py +36 -0
  81. anemoi/datasets/create/sources/xarray_support/README.md +1 -0
  82. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/__init__.py +109 -8
  83. anemoi/datasets/create/sources/xarray_support/coordinates.py +442 -0
  84. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/field.py +94 -16
  85. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/fieldlist.py +90 -25
  86. anemoi/datasets/create/sources/xarray_support/flavour.py +1036 -0
  87. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/grid.py +92 -31
  88. anemoi/datasets/create/sources/xarray_support/metadata.py +395 -0
  89. anemoi/datasets/create/sources/xarray_support/patch.py +91 -0
  90. anemoi/datasets/create/sources/xarray_support/time.py +391 -0
  91. anemoi/datasets/create/sources/xarray_support/variable.py +331 -0
  92. anemoi/datasets/create/sources/xarray_zarr.py +41 -0
  93. anemoi/datasets/create/{functions/sources → sources}/zenodo.py +34 -5
  94. anemoi/datasets/create/statistics/__init__.py +233 -44
  95. anemoi/datasets/create/statistics/summary.py +52 -6
  96. anemoi/datasets/create/testing.py +76 -0
  97. anemoi/datasets/create/{functions/filters/noop.py → typing.py} +6 -3
  98. anemoi/datasets/create/utils.py +97 -6
  99. anemoi/datasets/create/writer.py +26 -4
  100. anemoi/datasets/create/zarr.py +170 -23
  101. anemoi/datasets/data/__init__.py +51 -4
  102. anemoi/datasets/data/complement.py +191 -40
  103. anemoi/datasets/data/concat.py +141 -16
  104. anemoi/datasets/data/dataset.py +552 -61
  105. anemoi/datasets/data/debug.py +197 -26
  106. anemoi/datasets/data/ensemble.py +93 -8
  107. anemoi/datasets/data/fill_missing.py +165 -18
  108. anemoi/datasets/data/forwards.py +428 -56
  109. anemoi/datasets/data/grids.py +323 -97
  110. anemoi/datasets/data/indexing.py +112 -19
  111. anemoi/datasets/data/interpolate.py +92 -12
  112. anemoi/datasets/data/join.py +158 -19
  113. anemoi/datasets/data/masked.py +129 -15
  114. anemoi/datasets/data/merge.py +137 -23
  115. anemoi/datasets/data/misc.py +172 -16
  116. anemoi/datasets/data/missing.py +233 -29
  117. anemoi/datasets/data/rescale.py +111 -10
  118. anemoi/datasets/data/select.py +168 -26
  119. anemoi/datasets/data/statistics.py +67 -6
  120. anemoi/datasets/data/stores.py +149 -64
  121. anemoi/datasets/data/subset.py +159 -25
  122. anemoi/datasets/data/unchecked.py +168 -57
  123. anemoi/datasets/data/xy.py +168 -25
  124. anemoi/datasets/dates/__init__.py +191 -16
  125. anemoi/datasets/dates/groups.py +189 -47
  126. anemoi/datasets/grids.py +270 -31
  127. anemoi/datasets/testing.py +28 -1
  128. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/METADATA +9 -6
  129. anemoi_datasets-0.5.17.dist-info/RECORD +137 -0
  130. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/WHEEL +1 -1
  131. anemoi/datasets/create/functions/__init__.py +0 -66
  132. anemoi/datasets/create/functions/filters/__init__.py +0 -9
  133. anemoi/datasets/create/functions/filters/empty.py +0 -17
  134. anemoi/datasets/create/functions/filters/orog_to_z.py +0 -58
  135. anemoi/datasets/create/functions/filters/rename.py +0 -79
  136. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +0 -78
  137. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +0 -56
  138. anemoi/datasets/create/functions/sources/empty.py +0 -15
  139. anemoi/datasets/create/functions/sources/grib.py +0 -150
  140. anemoi/datasets/create/functions/sources/netcdf.py +0 -15
  141. anemoi/datasets/create/functions/sources/opendap.py +0 -15
  142. anemoi/datasets/create/functions/sources/recentre.py +0 -60
  143. anemoi/datasets/create/functions/sources/xarray/coordinates.py +0 -255
  144. anemoi/datasets/create/functions/sources/xarray/flavour.py +0 -472
  145. anemoi/datasets/create/functions/sources/xarray/metadata.py +0 -148
  146. anemoi/datasets/create/functions/sources/xarray/patch.py +0 -44
  147. anemoi/datasets/create/functions/sources/xarray/time.py +0 -177
  148. anemoi/datasets/create/functions/sources/xarray/variable.py +0 -188
  149. anemoi/datasets/create/functions/sources/xarray_kerchunk.py +0 -42
  150. anemoi/datasets/create/functions/sources/xarray_zarr.py +0 -15
  151. anemoi/datasets/utils/fields.py +0 -47
  152. anemoi_datasets-0.5.16.dist-info/RECORD +0 -129
  153. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/entry_points.txt +0 -0
  154. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info/licenses}/LICENSE +0 -0
  155. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,20 @@
10
10
 
11
11
  import logging
12
12
  from functools import cached_property
13
+ from typing import Any
14
+ from typing import Dict
15
+ from typing import List
16
+ from typing import Optional
17
+ from typing import Tuple
13
18
 
14
19
  import numpy as np
20
+ from numpy.typing import NDArray
15
21
  from scipy.spatial import cKDTree
16
22
 
23
+ from .dataset import Dataset
24
+ from .dataset import FullIndex
25
+ from .dataset import Shape
26
+ from .dataset import TupleIndex
17
27
  from .debug import Node
18
28
  from .debug import debug_indexing
19
29
  from .forwards import Combined
@@ -30,12 +40,33 @@ LOG = logging.getLogger(__name__)
30
40
 
31
41
 
32
42
  class Concat(Combined):
33
- def __len__(self):
43
+ """A class to represent concatenated datasets."""
44
+
45
+ def __len__(self) -> int:
46
+ """Returns the total length of the concatenated datasets.
47
+
48
+ Returns
49
+ -------
50
+ int
51
+ Total length of the concatenated datasets.
52
+ """
34
53
  return sum(len(i) for i in self.datasets)
35
54
 
36
55
  @debug_indexing
37
56
  @expand_list_indexing
38
- def _get_tuple(self, index):
57
+ def _get_tuple(self, index: TupleIndex) -> NDArray[Any]:
58
+ """Retrieves a tuple of data from the concatenated datasets based on the given index.
59
+
60
+ Parameters
61
+ ----------
62
+ index : TupleIndex
63
+ Index specifying the data to retrieve.
64
+
65
+ Returns
66
+ -------
67
+ NDArray[Any]
68
+ Concatenated data array from the specified index.
69
+ """
39
70
  index, changes = index_to_slices(index, self.shape)
40
71
  # print(index, changes)
41
72
  lengths = [d.shape[0] for d in self.datasets]
@@ -46,7 +77,19 @@ class Concat(Combined):
46
77
  return apply_index_to_slices_changes(result, changes)
47
78
 
48
79
  @debug_indexing
49
- def __getitem__(self, n):
80
+ def __getitem__(self, n: FullIndex) -> NDArray[Any]:
81
+ """Retrieves data from the concatenated datasets based on the given index.
82
+
83
+ Parameters
84
+ ----------
85
+ n : FullIndex
86
+ Index specifying the data to retrieve.
87
+
88
+ Returns
89
+ -------
90
+ NDArray[Any]
91
+ Data array from the concatenated datasets based on the index.
92
+ """
50
93
  if isinstance(n, tuple):
51
94
  return self._get_tuple(n)
52
95
 
@@ -61,7 +104,19 @@ class Concat(Combined):
61
104
  return self.datasets[k][n]
62
105
 
63
106
  @debug_indexing
64
- def _get_slice(self, s):
107
+ def _get_slice(self, s: slice) -> NDArray[Any]:
108
+ """Retrieves a slice of data from the concatenated datasets.
109
+
110
+ Parameters
111
+ ----------
112
+ s : slice
113
+ Slice object specifying the range of data to retrieve.
114
+
115
+ Returns
116
+ -------
117
+ NDArray[Any]
118
+ Concatenated data array from the specified slice.
119
+ """
65
120
  result = []
66
121
 
67
122
  lengths = [d.shape[0] for d in self.datasets]
@@ -71,50 +126,134 @@ class Concat(Combined):
71
126
 
72
127
  return np.concatenate(result)
73
128
 
74
- def check_compatibility(self, d1, d2):
129
+ def check_compatibility(self, d1: Dataset, d2: Dataset) -> None:
130
+ """Check the compatibility of two datasets for concatenation.
131
+
132
+ Parameters
133
+ ----------
134
+ d1 : Dataset
135
+ The first dataset.
136
+ d2 : Dataset
137
+ The second dataset.
138
+ """
75
139
  super().check_compatibility(d1, d2)
76
140
  self.check_same_sub_shapes(d1, d2, drop_axis=0)
77
141
 
78
- def check_same_lengths(self, d1, d2):
142
+ def check_same_lengths(self, d1: Dataset, d2: Dataset) -> None:
143
+ """Check if the lengths of two datasets are the same.
144
+
145
+ Parameters
146
+ ----------
147
+ d1 : Dataset
148
+ The first dataset.
149
+ d2 : Dataset
150
+ The second dataset.
151
+ """
79
152
  # Turned off because we are concatenating along the first axis
80
153
  pass
81
154
 
82
- def check_same_dates(self, d1, d2):
155
+ def check_same_dates(self, d1: Dataset, d2: Dataset) -> None:
156
+ """Check if the dates of two datasets are the same.
157
+
158
+ Parameters
159
+ ----------
160
+ d1 : Dataset
161
+ The first dataset.
162
+ d2 : Dataset
163
+ The second dataset.
164
+ """
83
165
  # Turned off because we are concatenating along the dates axis
84
166
  pass
85
167
 
86
168
  @property
87
- def dates(self):
169
+ def dates(self) -> NDArray[np.datetime64]:
170
+ """Returns the concatenated dates of all datasets."""
88
171
  return np.concatenate([d.dates for d in self.datasets])
89
172
 
90
173
  @property
91
- def shape(self):
174
+ def shape(self) -> Shape:
175
+ """Returns the shape of the concatenated datasets."""
92
176
  return (len(self),) + self.datasets[0].shape[1:]
93
177
 
94
- def tree(self):
178
+ def tree(self) -> Node:
179
+ """Generates a hierarchical tree structure for the concatenated datasets.
180
+
181
+ Returns
182
+ -------
183
+ Node
184
+ A Node object representing the concatenated datasets.
185
+ """
95
186
  return Node(self, [d.tree() for d in self.datasets])
96
187
 
97
188
 
98
189
  class GridsBase(GivenAxis):
99
- def __init__(self, datasets, axis):
190
+ """A base class for handling grids in datasets."""
191
+
192
+ def __init__(self, datasets: List[Any], axis: int) -> None:
193
+ """Initializes a GridsBase object.
194
+
195
+ Parameters
196
+ ----------
197
+ datasets : List[Any]
198
+ List of datasets.
199
+ axis : int
200
+ Axis along which to combine the datasets.
201
+ """
100
202
  super().__init__(datasets, axis)
101
203
  # Shape: (dates, variables, ensemble, 1d-values)
102
204
  assert len(datasets[0].shape) == 4, "Grids must be 1D for now"
103
205
 
104
- def check_same_grid(self, d1, d2):
206
+ def check_same_grid(self, d1: Dataset, d2: Dataset) -> None:
207
+ """Check if the grids of two datasets are the same.
208
+
209
+ Parameters
210
+ ----------
211
+ d1 : Dataset
212
+ The first dataset.
213
+ d2 : Dataset
214
+ The second dataset.
215
+ """
105
216
  # We don't check the grid, because we want to be able to combine
106
217
  pass
107
218
 
108
- def check_same_resolution(self, d1, d2):
219
+ def check_same_resolution(self, d1: Dataset, d2: Dataset) -> None:
220
+ """Check if the resolutions of two datasets are the same.
221
+
222
+ Parameters
223
+ ----------
224
+ d1 : Dataset
225
+ The first dataset.
226
+ d2 : Dataset
227
+ The second dataset.
228
+ """
109
229
  # We don't check the resolution, because we want to be able to combine
110
230
  pass
111
231
 
112
- def metadata_specific(self):
232
+ def metadata_specific(self, **kwargs: Any) -> Dict[str, Any]:
233
+ """Returns metadata specific to the GridsBase object.
234
+
235
+ Parameters
236
+ ----------
237
+ kwargs : Any
238
+ Additional keyword arguments.
239
+
240
+ Returns
241
+ -------
242
+ Dict[str, Any]
243
+ Metadata specific to the GridsBase object.
244
+ """
113
245
  return super().metadata_specific(
114
246
  multi_grids=True,
115
247
  )
116
248
 
117
- def collect_input_sources(self, collected):
249
+ def collect_input_sources(self, collected: List[Any]) -> None:
250
+ """Collects input sources from the datasets.
251
+
252
+ Parameters
253
+ ----------
254
+ collected : List[Any]
255
+ List to which the input sources are appended.
256
+ """
118
257
  # We assume that,because they have different grids, they have different input sources
119
258
  for d in self.datasets:
120
259
  collected.append(d)
@@ -122,42 +261,75 @@ class GridsBase(GivenAxis):
122
261
 
123
262
 
124
263
  class Grids(GridsBase):
264
+ """A class to represent combined grids from multiple datasets."""
265
+
125
266
  # TODO: select the statistics of the most global grid?
126
267
  @property
127
- def latitudes(self):
268
+ def latitudes(self) -> NDArray[Any]:
269
+ """Returns the concatenated latitudes of all datasets."""
128
270
  return np.concatenate([d.latitudes for d in self.datasets])
129
271
 
130
272
  @property
131
- def longitudes(self):
273
+ def longitudes(self) -> NDArray[Any]:
274
+ """Returns the concatenated longitudes of all datasets."""
132
275
  return np.concatenate([d.longitudes for d in self.datasets])
133
276
 
134
277
  @property
135
- def grids(self):
278
+ def grids(self) -> Tuple[Any, ...]:
279
+ """Returns the grids of all datasets."""
136
280
  result = []
137
281
  for d in self.datasets:
138
282
  result.extend(d.grids)
139
283
  return tuple(result)
140
284
 
141
- def tree(self):
285
+ def tree(self) -> Node:
286
+ """Generates a hierarchical tree structure for the Grids object.
287
+
288
+ Returns
289
+ -------
290
+ Node
291
+ A Node object representing the Grids object.
292
+ """
142
293
  return Node(self, [d.tree() for d in self.datasets], mode="concat")
143
294
 
295
+ def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
296
+ """Get the metadata specific to the forwards subclass.
297
+
298
+ Returns:
299
+ Dict[str, Any]: The metadata specific to the forwards subclass.
300
+ """
301
+ return {}
302
+
144
303
 
145
304
  class Cutout(GridsBase):
146
- def __init__(self, datasets, axis=3, cropping_distance=2.0, neighbours=5, min_distance_km=None, plot=None):
305
+ """A class to handle hierarchical management of Limited Area Models (LAMs) and a global dataset."""
306
+
307
+ def __init__(
308
+ self,
309
+ datasets: List[Any],
310
+ axis: int = 3,
311
+ cropping_distance: float = 2.0,
312
+ neighbours: int = 5,
313
+ min_distance_km: Optional[float] = None,
314
+ plot: Optional[bool] = None,
315
+ ) -> None:
147
316
  """Initializes a Cutout object for hierarchical management of Limited Area
148
317
  Models (LAMs) and a global dataset, handling overlapping regions.
149
318
 
150
- Args:
151
- datasets (list): List of LAM and global datasets.
152
- axis (int): Concatenation axis, must be set to 3.
153
- cropping_distance (float): Distance threshold in degrees for
154
- cropping cutouts.
155
- neighbours (int): Number of neighboring points to consider when
156
- constructing masks.
157
- min_distance_km (float, optional): Minimum distance threshold in km
158
- between grid points.
159
- plot (bool, optional): Flag to enable or disable visualization
160
- plots.
319
+ Parameters
320
+ ----------
321
+ datasets : list
322
+ List of LAM and global datasets.
323
+ axis : int
324
+ Concatenation axis, must be set to 3.
325
+ cropping_distance : float
326
+ Distance threshold in degrees for cropping cutouts.
327
+ neighbours : int
328
+ Number of neighboring points to consider when constructing masks.
329
+ min_distance_km : float, optional
330
+ Minimum distance threshold in km between grid points.
331
+ plot : bool, optional
332
+ Flag to enable or disable visualization plots.
161
333
  """
162
334
  super().__init__(datasets, axis)
163
335
  assert len(datasets) >= 2, "CutoutGrids requires at least two datasets"
@@ -179,14 +351,13 @@ class Cutout(GridsBase):
179
351
  # Initialize cumulative masks
180
352
  self._initialize_masks()
181
353
 
182
- def _initialize_masks(self):
183
- """Generates hierarchical masks for each LAM dataset by excluding
184
- overlapping regions with previous LAMs and creating a global mask for
185
- the global dataset.
354
+ def _initialize_masks(self) -> None:
355
+ """Generate hierarchical masks for each LAM dataset by excluding overlapping regions with previous LAMs and creating a global mask for the global dataset.
186
356
 
187
- Raises:
188
- ValueError: If the global mask dimension does not match the global
189
- dataset grid points.
357
+ Raises
358
+ ------
359
+ ValueError
360
+ If the global mask dimension does not match the global dataset grid points.
190
361
  """
191
362
  from anemoi.datasets.grids import cutout_mask
192
363
 
@@ -236,21 +407,33 @@ class Cutout(GridsBase):
236
407
  lam_current_mask[~lam_overlap_mask] = False
237
408
  self.masks.append(lam_current_mask)
238
409
 
239
- def has_overlap(self, lats1, lons1, lats2, lons2, distance_threshold=1.0):
240
- """Checks for overlapping points between two sets of latitudes and
241
- longitudes within a specified distance threshold.
242
-
243
- Args:
244
- lats1, lons1 (np.ndarray): Latitude and longitude arrays for the
245
- first dataset.
246
- lats2, lons2 (np.ndarray): Latitude and longitude arrays for the
247
- second dataset.
248
- distance_threshold (float): Distance in degrees to consider as
249
- overlapping.
250
-
251
- Returns:
252
- bool: True if any points overlap within the distance threshold,
253
- otherwise False.
410
+ def has_overlap(
411
+ self,
412
+ lats1: NDArray[Any],
413
+ lons1: NDArray[Any],
414
+ lats2: NDArray[Any],
415
+ lons2: NDArray[Any],
416
+ distance_threshold: float = 1.0,
417
+ ) -> bool:
418
+ """Check for overlapping points between two sets of latitudes and longitudes within a specified distance threshold.
419
+
420
+ Parameters
421
+ ----------
422
+ lats1 : NDArray[Any]
423
+ Latitude array for the first dataset.
424
+ lons1 : NDArray[Any]
425
+ Longitude array for the first dataset.
426
+ lats2 : NDArray[Any]
427
+ Latitude array for the second dataset.
428
+ lons2 : NDArray[Any]
429
+ Longitude array for the second dataset.
430
+ distance_threshold : float
431
+ Distance in degrees to consider as overlapping.
432
+
433
+ Returns
434
+ -------
435
+ bool
436
+ True if any points overlap within the distance threshold, otherwise False.
254
437
  """
255
438
  # Create KDTree for the first set of points
256
439
  tree = cKDTree(np.vstack((lats1, lons1)).T)
@@ -261,31 +444,35 @@ class Cutout(GridsBase):
261
444
  # Check if any distance is less than the specified threshold
262
445
  return np.any(distances < distance_threshold)
263
446
 
264
- def __getitem__(self, index):
265
- """Retrieves data from the masked LAMs and global dataset based on the
266
- given index.
447
+ def __getitem__(self, index: FullIndex) -> NDArray[Any]:
448
+ """Retrieve data from the masked LAMs and global dataset based on the given index.
267
449
 
268
- Args:
269
- index (int or slice or tuple): Index specifying the data to
270
- retrieve.
450
+ Parameters
451
+ ----------
452
+ index : FullIndex
453
+ Index specifying the data to retrieve.
271
454
 
272
- Returns:
273
- np.ndarray: Data array from the masked datasets based on the index.
455
+ Returns
456
+ -------
457
+ NDArray[Any]
458
+ Data array from the masked datasets based on the index.
274
459
  """
275
460
  if isinstance(index, (int, slice)):
276
461
  index = (index, slice(None), slice(None), slice(None))
277
462
  return self._get_tuple(index)
278
463
 
279
- def _get_tuple(self, index):
280
- """Helper method that applies masks and retrieves data from each dataset
281
- according to the specified index.
464
+ def _get_tuple(self, index: TupleIndex) -> NDArray[Any]:
465
+ """Helper method that applies masks and retrieves data from each dataset according to the specified index.
282
466
 
283
- Args:
284
- index (tuple): Index specifying slices to retrieve data.
467
+ Parameters
468
+ ----------
469
+ index : TupleIndex
470
+ Index specifying slices to retrieve data.
285
471
 
286
- Returns:
287
- np.ndarray: Concatenated data array from all datasets based on the
288
- index.
472
+ Returns
473
+ -------
474
+ NDArray[Any]
475
+ Concatenated data array from all datasets based on the index.
289
476
  """
290
477
  index, changes = index_to_slices(index, self.shape)
291
478
  # Select data from each LAM
@@ -300,13 +487,15 @@ class Cutout(GridsBase):
300
487
 
301
488
  return apply_index_to_slices_changes(result, changes)
302
489
 
303
- def collect_supporting_arrays(self, collected, *path):
304
- """Collects supporting arrays, including masks for each LAM and the global
305
- dataset.
490
+ def collect_supporting_arrays(self, collected: List[Any], *path: Any) -> None:
491
+ """Collect supporting arrays, including masks for each LAM and the global dataset.
306
492
 
307
- Args:
308
- collected (list): List to which the supporting arrays are appended.
309
- *path: Variable length argument list specifying the paths for the masks.
493
+ Parameters
494
+ ----------
495
+ collected : List[Any]
496
+ List to which the supporting arrays are appended.
497
+ *path : Any
498
+ Variable length argument list specifying the paths for the masks.
310
499
  """
311
500
  # Append masks for each LAM
312
501
  for i, (lam, mask) in enumerate(zip(self.lams, self.masks)):
@@ -316,41 +505,41 @@ class Cutout(GridsBase):
316
505
  collected.append((path + ("global",), "cutout_mask", self.global_mask))
317
506
 
318
507
  @cached_property
319
- def shape(self):
508
+ def shape(self) -> Shape:
320
509
  """Returns the shape of the Cutout, accounting for retained grid points
321
510
  across all LAMs and the global dataset.
322
-
323
- Returns:
324
- tuple: Shape of the concatenated masked datasets.
325
511
  """
326
512
  shapes = [np.sum(mask) for mask in self.masks]
327
513
  global_shape = np.sum(self.global_mask)
328
514
  total_shape = sum(shapes) + global_shape
329
515
  return tuple(self.lams[0].shape[:-1] + (int(total_shape),))
330
516
 
331
- def check_same_resolution(self, d1, d2):
517
+ def check_same_resolution(self, d1: Dataset, d2: Dataset) -> None:
518
+ """Checks if the resolutions of two datasets are the same.
519
+
520
+ Parameters
521
+ ----------
522
+ d1 : Dataset
523
+ The first dataset.
524
+ d2 : Dataset
525
+ The second dataset.
526
+ """
332
527
  # Turned off because we are combining different resolutions
333
528
  pass
334
529
 
335
530
  @property
336
- def grids(self):
531
+ def grids(self) -> TupleIndex:
337
532
  """Returns the number of grid points for each LAM and the global dataset
338
533
  after applying masks.
339
-
340
- Returns:
341
- tuple: Count of retained grid points for each dataset.
342
534
  """
343
535
  grids = [np.sum(mask) for mask in self.masks]
344
536
  grids.append(np.sum(self.global_mask))
345
537
  return tuple(grids)
346
538
 
347
539
  @property
348
- def latitudes(self):
540
+ def latitudes(self) -> NDArray[Any]:
349
541
  """Returns the concatenated latitudes of each LAM and the global dataset
350
542
  after applying masks.
351
-
352
- Returns:
353
- np.ndarray: Concatenated latitude array for the masked datasets.
354
543
  """
355
544
  lam_latitudes = np.concatenate([lam.latitudes[mask] for lam, mask in zip(self.lams, self.masks)])
356
545
 
@@ -362,12 +551,9 @@ class Cutout(GridsBase):
362
551
  return latitudes
363
552
 
364
553
  @property
365
- def longitudes(self):
554
+ def longitudes(self) -> NDArray[Any]:
366
555
  """Returns the concatenated longitudes of each LAM and the global dataset
367
556
  after applying masks.
368
-
369
- Returns:
370
- np.ndarray: Concatenated longitude array for the masked datasets.
371
557
  """
372
558
  lam_longitudes = np.concatenate([lam.longitudes[mask] for lam, mask in zip(self.lams, self.masks)])
373
559
 
@@ -378,19 +564,45 @@ class Cutout(GridsBase):
378
564
  longitudes = np.concatenate([lam_longitudes, self.globe.longitudes[self.global_mask]])
379
565
  return longitudes
380
566
 
381
- def tree(self):
567
+ def tree(self) -> Node:
382
568
  """Generates a hierarchical tree structure for the `Cutout` instance and
383
569
  its associated datasets.
384
570
 
385
- Returns:
386
- Node: A `Node` object representing the `Cutout` instance as the root
571
+ Returns
572
+ -------
573
+ Node
574
+ A `Node` object representing the `Cutout` instance as the root
387
575
  node, with each dataset in `self.datasets` represented as a child
388
576
  node.
389
577
  """
390
578
  return Node(self, [d.tree() for d in self.datasets])
391
579
 
580
+ def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
581
+ """Returns metadata specific to the Cutout object.
392
582
 
393
- def grids_factory(args, kwargs):
583
+ Returns
584
+ -------
585
+ Dict[str, Any]
586
+ Metadata specific to the Cutout object.
587
+ """
588
+ return {}
589
+
590
+
591
+ def grids_factory(args: Tuple[Any, ...], kwargs: dict) -> Dataset:
592
+ """Factory function to create a Grids object.
593
+
594
+ Parameters
595
+ ----------
596
+ args : Tuple[Any, ...]
597
+ Positional arguments.
598
+ kwargs : dict
599
+ Keyword arguments.
600
+
601
+ Returns
602
+ -------
603
+ Dataset
604
+ A Grids object.
605
+ """
394
606
  if "ensemble" in kwargs:
395
607
  raise NotImplementedError("Cannot use both 'ensemble' and 'grids'")
396
608
 
@@ -406,7 +618,21 @@ def grids_factory(args, kwargs):
406
618
  return Grids(datasets, axis=axis)._subset(**kwargs)
407
619
 
408
620
 
409
- def cutout_factory(args, kwargs):
621
+ def cutout_factory(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dataset:
622
+ """Factory function to create a Cutout object.
623
+
624
+ Parameters
625
+ ----------
626
+ args : Tuple[Any, ...]
627
+ Positional arguments.
628
+ kwargs : Dict[str, Any]
629
+ Keyword arguments.
630
+
631
+ Returns
632
+ -------
633
+ Dataset
634
+ A Cutout object.
635
+ """
410
636
  if "ensemble" in kwargs:
411
637
  raise NotImplementedError("Cannot use both 'ensemble' and 'cutout'")
412
638