anemoi-datasets 0.5.16__py3-none-any.whl → 0.5.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/__init__.py +4 -1
- anemoi/datasets/__main__.py +12 -2
- anemoi/datasets/_version.py +9 -4
- anemoi/datasets/commands/cleanup.py +17 -2
- anemoi/datasets/commands/compare.py +18 -2
- anemoi/datasets/commands/copy.py +196 -14
- anemoi/datasets/commands/create.py +50 -7
- anemoi/datasets/commands/finalise-additions.py +17 -2
- anemoi/datasets/commands/finalise.py +17 -2
- anemoi/datasets/commands/init-additions.py +17 -2
- anemoi/datasets/commands/init.py +16 -2
- anemoi/datasets/commands/inspect.py +283 -62
- anemoi/datasets/commands/load-additions.py +16 -2
- anemoi/datasets/commands/load.py +16 -2
- anemoi/datasets/commands/patch.py +17 -2
- anemoi/datasets/commands/publish.py +17 -2
- anemoi/datasets/commands/scan.py +31 -3
- anemoi/datasets/compute/recentre.py +47 -11
- anemoi/datasets/create/__init__.py +612 -85
- anemoi/datasets/create/check.py +142 -20
- anemoi/datasets/create/chunks.py +64 -4
- anemoi/datasets/create/config.py +185 -21
- anemoi/datasets/create/filter.py +50 -0
- anemoi/datasets/create/filters/__init__.py +33 -0
- anemoi/datasets/create/filters/empty.py +37 -0
- anemoi/datasets/create/filters/legacy.py +93 -0
- anemoi/datasets/create/filters/noop.py +37 -0
- anemoi/datasets/create/filters/orog_to_z.py +58 -0
- anemoi/datasets/create/{functions/filters → filters}/pressure_level_relative_humidity_to_specific_humidity.py +33 -10
- anemoi/datasets/create/{functions/filters → filters}/pressure_level_specific_humidity_to_relative_humidity.py +32 -8
- anemoi/datasets/create/filters/rename.py +205 -0
- anemoi/datasets/create/{functions/filters → filters}/rotate_winds.py +43 -28
- anemoi/datasets/create/{functions/filters → filters}/single_level_dewpoint_to_relative_humidity.py +32 -9
- anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_dewpoint.py +33 -9
- anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_specific_humidity.py +55 -7
- anemoi/datasets/create/{functions/filters → filters}/single_level_specific_humidity_to_relative_humidity.py +98 -37
- anemoi/datasets/create/filters/speeddir_to_uv.py +95 -0
- anemoi/datasets/create/{functions/filters → filters}/sum.py +24 -27
- anemoi/datasets/create/filters/transform.py +53 -0
- anemoi/datasets/create/{functions/filters → filters}/unrotate_winds.py +27 -18
- anemoi/datasets/create/filters/uv_to_speeddir.py +94 -0
- anemoi/datasets/create/{functions/filters → filters}/wz_to_w.py +51 -33
- anemoi/datasets/create/input/__init__.py +76 -5
- anemoi/datasets/create/input/action.py +149 -13
- anemoi/datasets/create/input/concat.py +81 -10
- anemoi/datasets/create/input/context.py +39 -4
- anemoi/datasets/create/input/data_sources.py +72 -6
- anemoi/datasets/create/input/empty.py +21 -3
- anemoi/datasets/create/input/filter.py +60 -12
- anemoi/datasets/create/input/function.py +154 -37
- anemoi/datasets/create/input/join.py +86 -14
- anemoi/datasets/create/input/misc.py +67 -17
- anemoi/datasets/create/input/pipe.py +33 -6
- anemoi/datasets/create/input/repeated_dates.py +189 -41
- anemoi/datasets/create/input/result.py +202 -87
- anemoi/datasets/create/input/step.py +119 -22
- anemoi/datasets/create/input/template.py +100 -13
- anemoi/datasets/create/input/trace.py +62 -7
- anemoi/datasets/create/patch.py +52 -4
- anemoi/datasets/create/persistent.py +134 -17
- anemoi/datasets/create/size.py +15 -1
- anemoi/datasets/create/source.py +51 -0
- anemoi/datasets/create/sources/__init__.py +36 -0
- anemoi/datasets/create/{functions/sources → sources}/accumulations.py +296 -30
- anemoi/datasets/create/{functions/sources → sources}/constants.py +27 -2
- anemoi/datasets/create/{functions/sources → sources}/eccc_fstd.py +7 -3
- anemoi/datasets/create/sources/empty.py +37 -0
- anemoi/datasets/create/{functions/sources → sources}/forcings.py +25 -1
- anemoi/datasets/create/sources/grib.py +297 -0
- anemoi/datasets/create/{functions/sources → sources}/hindcasts.py +38 -4
- anemoi/datasets/create/sources/legacy.py +93 -0
- anemoi/datasets/create/{functions/sources → sources}/mars.py +168 -20
- anemoi/datasets/create/sources/netcdf.py +42 -0
- anemoi/datasets/create/sources/opendap.py +43 -0
- anemoi/datasets/create/{functions/sources/__init__.py → sources/patterns.py} +35 -4
- anemoi/datasets/create/sources/recentre.py +150 -0
- anemoi/datasets/create/{functions/sources → sources}/source.py +27 -5
- anemoi/datasets/create/{functions/sources → sources}/tendencies.py +64 -7
- anemoi/datasets/create/sources/xarray.py +92 -0
- anemoi/datasets/create/sources/xarray_kerchunk.py +36 -0
- anemoi/datasets/create/sources/xarray_support/README.md +1 -0
- anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/__init__.py +109 -8
- anemoi/datasets/create/sources/xarray_support/coordinates.py +442 -0
- anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/field.py +94 -16
- anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/fieldlist.py +90 -25
- anemoi/datasets/create/sources/xarray_support/flavour.py +1036 -0
- anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/grid.py +92 -31
- anemoi/datasets/create/sources/xarray_support/metadata.py +395 -0
- anemoi/datasets/create/sources/xarray_support/patch.py +91 -0
- anemoi/datasets/create/sources/xarray_support/time.py +391 -0
- anemoi/datasets/create/sources/xarray_support/variable.py +331 -0
- anemoi/datasets/create/sources/xarray_zarr.py +41 -0
- anemoi/datasets/create/{functions/sources → sources}/zenodo.py +34 -5
- anemoi/datasets/create/statistics/__init__.py +233 -44
- anemoi/datasets/create/statistics/summary.py +52 -6
- anemoi/datasets/create/testing.py +76 -0
- anemoi/datasets/create/{functions/filters/noop.py → typing.py} +6 -3
- anemoi/datasets/create/utils.py +97 -6
- anemoi/datasets/create/writer.py +26 -4
- anemoi/datasets/create/zarr.py +170 -23
- anemoi/datasets/data/__init__.py +51 -4
- anemoi/datasets/data/complement.py +191 -40
- anemoi/datasets/data/concat.py +141 -16
- anemoi/datasets/data/dataset.py +552 -61
- anemoi/datasets/data/debug.py +197 -26
- anemoi/datasets/data/ensemble.py +93 -8
- anemoi/datasets/data/fill_missing.py +165 -18
- anemoi/datasets/data/forwards.py +428 -56
- anemoi/datasets/data/grids.py +323 -97
- anemoi/datasets/data/indexing.py +112 -19
- anemoi/datasets/data/interpolate.py +92 -12
- anemoi/datasets/data/join.py +158 -19
- anemoi/datasets/data/masked.py +129 -15
- anemoi/datasets/data/merge.py +137 -23
- anemoi/datasets/data/misc.py +172 -16
- anemoi/datasets/data/missing.py +233 -29
- anemoi/datasets/data/rescale.py +111 -10
- anemoi/datasets/data/select.py +168 -26
- anemoi/datasets/data/statistics.py +67 -6
- anemoi/datasets/data/stores.py +149 -64
- anemoi/datasets/data/subset.py +159 -25
- anemoi/datasets/data/unchecked.py +168 -57
- anemoi/datasets/data/xy.py +168 -25
- anemoi/datasets/dates/__init__.py +191 -16
- anemoi/datasets/dates/groups.py +189 -47
- anemoi/datasets/grids.py +270 -31
- anemoi/datasets/testing.py +28 -1
- {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/METADATA +9 -6
- anemoi_datasets-0.5.17.dist-info/RECORD +137 -0
- {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/WHEEL +1 -1
- anemoi/datasets/create/functions/__init__.py +0 -66
- anemoi/datasets/create/functions/filters/__init__.py +0 -9
- anemoi/datasets/create/functions/filters/empty.py +0 -17
- anemoi/datasets/create/functions/filters/orog_to_z.py +0 -58
- anemoi/datasets/create/functions/filters/rename.py +0 -79
- anemoi/datasets/create/functions/filters/speeddir_to_uv.py +0 -78
- anemoi/datasets/create/functions/filters/uv_to_speeddir.py +0 -56
- anemoi/datasets/create/functions/sources/empty.py +0 -15
- anemoi/datasets/create/functions/sources/grib.py +0 -150
- anemoi/datasets/create/functions/sources/netcdf.py +0 -15
- anemoi/datasets/create/functions/sources/opendap.py +0 -15
- anemoi/datasets/create/functions/sources/recentre.py +0 -60
- anemoi/datasets/create/functions/sources/xarray/coordinates.py +0 -255
- anemoi/datasets/create/functions/sources/xarray/flavour.py +0 -472
- anemoi/datasets/create/functions/sources/xarray/metadata.py +0 -148
- anemoi/datasets/create/functions/sources/xarray/patch.py +0 -44
- anemoi/datasets/create/functions/sources/xarray/time.py +0 -177
- anemoi/datasets/create/functions/sources/xarray/variable.py +0 -188
- anemoi/datasets/create/functions/sources/xarray_kerchunk.py +0 -42
- anemoi/datasets/create/functions/sources/xarray_zarr.py +0 -15
- anemoi/datasets/utils/fields.py +0 -47
- anemoi_datasets-0.5.16.dist-info/RECORD +0 -129
- {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info/licenses}/LICENSE +0 -0
- {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/top_level.txt +0 -0
anemoi/datasets/data/grids.py
CHANGED
|
@@ -10,10 +10,20 @@
|
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
12
|
from functools import cached_property
|
|
13
|
+
from typing import Any
|
|
14
|
+
from typing import Dict
|
|
15
|
+
from typing import List
|
|
16
|
+
from typing import Optional
|
|
17
|
+
from typing import Tuple
|
|
13
18
|
|
|
14
19
|
import numpy as np
|
|
20
|
+
from numpy.typing import NDArray
|
|
15
21
|
from scipy.spatial import cKDTree
|
|
16
22
|
|
|
23
|
+
from .dataset import Dataset
|
|
24
|
+
from .dataset import FullIndex
|
|
25
|
+
from .dataset import Shape
|
|
26
|
+
from .dataset import TupleIndex
|
|
17
27
|
from .debug import Node
|
|
18
28
|
from .debug import debug_indexing
|
|
19
29
|
from .forwards import Combined
|
|
@@ -30,12 +40,33 @@ LOG = logging.getLogger(__name__)
|
|
|
30
40
|
|
|
31
41
|
|
|
32
42
|
class Concat(Combined):
|
|
33
|
-
|
|
43
|
+
"""A class to represent concatenated datasets."""
|
|
44
|
+
|
|
45
|
+
def __len__(self) -> int:
|
|
46
|
+
"""Returns the total length of the concatenated datasets.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
int
|
|
51
|
+
Total length of the concatenated datasets.
|
|
52
|
+
"""
|
|
34
53
|
return sum(len(i) for i in self.datasets)
|
|
35
54
|
|
|
36
55
|
@debug_indexing
|
|
37
56
|
@expand_list_indexing
|
|
38
|
-
def _get_tuple(self, index):
|
|
57
|
+
def _get_tuple(self, index: TupleIndex) -> NDArray[Any]:
|
|
58
|
+
"""Retrieves a tuple of data from the concatenated datasets based on the given index.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
index : TupleIndex
|
|
63
|
+
Index specifying the data to retrieve.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
NDArray[Any]
|
|
68
|
+
Concatenated data array from the specified index.
|
|
69
|
+
"""
|
|
39
70
|
index, changes = index_to_slices(index, self.shape)
|
|
40
71
|
# print(index, changes)
|
|
41
72
|
lengths = [d.shape[0] for d in self.datasets]
|
|
@@ -46,7 +77,19 @@ class Concat(Combined):
|
|
|
46
77
|
return apply_index_to_slices_changes(result, changes)
|
|
47
78
|
|
|
48
79
|
@debug_indexing
|
|
49
|
-
def __getitem__(self, n):
|
|
80
|
+
def __getitem__(self, n: FullIndex) -> NDArray[Any]:
|
|
81
|
+
"""Retrieves data from the concatenated datasets based on the given index.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
n : FullIndex
|
|
86
|
+
Index specifying the data to retrieve.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
NDArray[Any]
|
|
91
|
+
Data array from the concatenated datasets based on the index.
|
|
92
|
+
"""
|
|
50
93
|
if isinstance(n, tuple):
|
|
51
94
|
return self._get_tuple(n)
|
|
52
95
|
|
|
@@ -61,7 +104,19 @@ class Concat(Combined):
|
|
|
61
104
|
return self.datasets[k][n]
|
|
62
105
|
|
|
63
106
|
@debug_indexing
|
|
64
|
-
def _get_slice(self, s):
|
|
107
|
+
def _get_slice(self, s: slice) -> NDArray[Any]:
|
|
108
|
+
"""Retrieves a slice of data from the concatenated datasets.
|
|
109
|
+
|
|
110
|
+
Parameters
|
|
111
|
+
----------
|
|
112
|
+
s : slice
|
|
113
|
+
Slice object specifying the range of data to retrieve.
|
|
114
|
+
|
|
115
|
+
Returns
|
|
116
|
+
-------
|
|
117
|
+
NDArray[Any]
|
|
118
|
+
Concatenated data array from the specified slice.
|
|
119
|
+
"""
|
|
65
120
|
result = []
|
|
66
121
|
|
|
67
122
|
lengths = [d.shape[0] for d in self.datasets]
|
|
@@ -71,50 +126,134 @@ class Concat(Combined):
|
|
|
71
126
|
|
|
72
127
|
return np.concatenate(result)
|
|
73
128
|
|
|
74
|
-
def check_compatibility(self, d1, d2):
|
|
129
|
+
def check_compatibility(self, d1: Dataset, d2: Dataset) -> None:
|
|
130
|
+
"""Check the compatibility of two datasets for concatenation.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
d1 : Dataset
|
|
135
|
+
The first dataset.
|
|
136
|
+
d2 : Dataset
|
|
137
|
+
The second dataset.
|
|
138
|
+
"""
|
|
75
139
|
super().check_compatibility(d1, d2)
|
|
76
140
|
self.check_same_sub_shapes(d1, d2, drop_axis=0)
|
|
77
141
|
|
|
78
|
-
def check_same_lengths(self, d1, d2):
|
|
142
|
+
def check_same_lengths(self, d1: Dataset, d2: Dataset) -> None:
|
|
143
|
+
"""Check if the lengths of two datasets are the same.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
d1 : Dataset
|
|
148
|
+
The first dataset.
|
|
149
|
+
d2 : Dataset
|
|
150
|
+
The second dataset.
|
|
151
|
+
"""
|
|
79
152
|
# Turned off because we are concatenating along the first axis
|
|
80
153
|
pass
|
|
81
154
|
|
|
82
|
-
def check_same_dates(self, d1, d2):
|
|
155
|
+
def check_same_dates(self, d1: Dataset, d2: Dataset) -> None:
|
|
156
|
+
"""Check if the dates of two datasets are the same.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
d1 : Dataset
|
|
161
|
+
The first dataset.
|
|
162
|
+
d2 : Dataset
|
|
163
|
+
The second dataset.
|
|
164
|
+
"""
|
|
83
165
|
# Turned off because we are concatenating along the dates axis
|
|
84
166
|
pass
|
|
85
167
|
|
|
86
168
|
@property
|
|
87
|
-
def dates(self):
|
|
169
|
+
def dates(self) -> NDArray[np.datetime64]:
|
|
170
|
+
"""Returns the concatenated dates of all datasets."""
|
|
88
171
|
return np.concatenate([d.dates for d in self.datasets])
|
|
89
172
|
|
|
90
173
|
@property
|
|
91
|
-
def shape(self):
|
|
174
|
+
def shape(self) -> Shape:
|
|
175
|
+
"""Returns the shape of the concatenated datasets."""
|
|
92
176
|
return (len(self),) + self.datasets[0].shape[1:]
|
|
93
177
|
|
|
94
|
-
def tree(self):
|
|
178
|
+
def tree(self) -> Node:
|
|
179
|
+
"""Generates a hierarchical tree structure for the concatenated datasets.
|
|
180
|
+
|
|
181
|
+
Returns
|
|
182
|
+
-------
|
|
183
|
+
Node
|
|
184
|
+
A Node object representing the concatenated datasets.
|
|
185
|
+
"""
|
|
95
186
|
return Node(self, [d.tree() for d in self.datasets])
|
|
96
187
|
|
|
97
188
|
|
|
98
189
|
class GridsBase(GivenAxis):
|
|
99
|
-
|
|
190
|
+
"""A base class for handling grids in datasets."""
|
|
191
|
+
|
|
192
|
+
def __init__(self, datasets: List[Any], axis: int) -> None:
|
|
193
|
+
"""Initializes a GridsBase object.
|
|
194
|
+
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
datasets : List[Any]
|
|
198
|
+
List of datasets.
|
|
199
|
+
axis : int
|
|
200
|
+
Axis along which to combine the datasets.
|
|
201
|
+
"""
|
|
100
202
|
super().__init__(datasets, axis)
|
|
101
203
|
# Shape: (dates, variables, ensemble, 1d-values)
|
|
102
204
|
assert len(datasets[0].shape) == 4, "Grids must be 1D for now"
|
|
103
205
|
|
|
104
|
-
def check_same_grid(self, d1, d2):
|
|
206
|
+
def check_same_grid(self, d1: Dataset, d2: Dataset) -> None:
|
|
207
|
+
"""Check if the grids of two datasets are the same.
|
|
208
|
+
|
|
209
|
+
Parameters
|
|
210
|
+
----------
|
|
211
|
+
d1 : Dataset
|
|
212
|
+
The first dataset.
|
|
213
|
+
d2 : Dataset
|
|
214
|
+
The second dataset.
|
|
215
|
+
"""
|
|
105
216
|
# We don't check the grid, because we want to be able to combine
|
|
106
217
|
pass
|
|
107
218
|
|
|
108
|
-
def check_same_resolution(self, d1, d2):
|
|
219
|
+
def check_same_resolution(self, d1: Dataset, d2: Dataset) -> None:
|
|
220
|
+
"""Check if the resolutions of two datasets are the same.
|
|
221
|
+
|
|
222
|
+
Parameters
|
|
223
|
+
----------
|
|
224
|
+
d1 : Dataset
|
|
225
|
+
The first dataset.
|
|
226
|
+
d2 : Dataset
|
|
227
|
+
The second dataset.
|
|
228
|
+
"""
|
|
109
229
|
# We don't check the resolution, because we want to be able to combine
|
|
110
230
|
pass
|
|
111
231
|
|
|
112
|
-
def metadata_specific(self):
|
|
232
|
+
def metadata_specific(self, **kwargs: Any) -> Dict[str, Any]:
|
|
233
|
+
"""Returns metadata specific to the GridsBase object.
|
|
234
|
+
|
|
235
|
+
Parameters
|
|
236
|
+
----------
|
|
237
|
+
kwargs : Any
|
|
238
|
+
Additional keyword arguments.
|
|
239
|
+
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
Dict[str, Any]
|
|
243
|
+
Metadata specific to the GridsBase object.
|
|
244
|
+
"""
|
|
113
245
|
return super().metadata_specific(
|
|
114
246
|
multi_grids=True,
|
|
115
247
|
)
|
|
116
248
|
|
|
117
|
-
def collect_input_sources(self, collected):
|
|
249
|
+
def collect_input_sources(self, collected: List[Any]) -> None:
|
|
250
|
+
"""Collects input sources from the datasets.
|
|
251
|
+
|
|
252
|
+
Parameters
|
|
253
|
+
----------
|
|
254
|
+
collected : List[Any]
|
|
255
|
+
List to which the input sources are appended.
|
|
256
|
+
"""
|
|
118
257
|
# We assume that,because they have different grids, they have different input sources
|
|
119
258
|
for d in self.datasets:
|
|
120
259
|
collected.append(d)
|
|
@@ -122,42 +261,75 @@ class GridsBase(GivenAxis):
|
|
|
122
261
|
|
|
123
262
|
|
|
124
263
|
class Grids(GridsBase):
|
|
264
|
+
"""A class to represent combined grids from multiple datasets."""
|
|
265
|
+
|
|
125
266
|
# TODO: select the statistics of the most global grid?
|
|
126
267
|
@property
|
|
127
|
-
def latitudes(self):
|
|
268
|
+
def latitudes(self) -> NDArray[Any]:
|
|
269
|
+
"""Returns the concatenated latitudes of all datasets."""
|
|
128
270
|
return np.concatenate([d.latitudes for d in self.datasets])
|
|
129
271
|
|
|
130
272
|
@property
|
|
131
|
-
def longitudes(self):
|
|
273
|
+
def longitudes(self) -> NDArray[Any]:
|
|
274
|
+
"""Returns the concatenated longitudes of all datasets."""
|
|
132
275
|
return np.concatenate([d.longitudes for d in self.datasets])
|
|
133
276
|
|
|
134
277
|
@property
|
|
135
|
-
def grids(self):
|
|
278
|
+
def grids(self) -> Tuple[Any, ...]:
|
|
279
|
+
"""Returns the grids of all datasets."""
|
|
136
280
|
result = []
|
|
137
281
|
for d in self.datasets:
|
|
138
282
|
result.extend(d.grids)
|
|
139
283
|
return tuple(result)
|
|
140
284
|
|
|
141
|
-
def tree(self):
|
|
285
|
+
def tree(self) -> Node:
|
|
286
|
+
"""Generates a hierarchical tree structure for the Grids object.
|
|
287
|
+
|
|
288
|
+
Returns
|
|
289
|
+
-------
|
|
290
|
+
Node
|
|
291
|
+
A Node object representing the Grids object.
|
|
292
|
+
"""
|
|
142
293
|
return Node(self, [d.tree() for d in self.datasets], mode="concat")
|
|
143
294
|
|
|
295
|
+
def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
|
|
296
|
+
"""Get the metadata specific to the forwards subclass.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
Dict[str, Any]: The metadata specific to the forwards subclass.
|
|
300
|
+
"""
|
|
301
|
+
return {}
|
|
302
|
+
|
|
144
303
|
|
|
145
304
|
class Cutout(GridsBase):
|
|
146
|
-
|
|
305
|
+
"""A class to handle hierarchical management of Limited Area Models (LAMs) and a global dataset."""
|
|
306
|
+
|
|
307
|
+
def __init__(
|
|
308
|
+
self,
|
|
309
|
+
datasets: List[Any],
|
|
310
|
+
axis: int = 3,
|
|
311
|
+
cropping_distance: float = 2.0,
|
|
312
|
+
neighbours: int = 5,
|
|
313
|
+
min_distance_km: Optional[float] = None,
|
|
314
|
+
plot: Optional[bool] = None,
|
|
315
|
+
) -> None:
|
|
147
316
|
"""Initializes a Cutout object for hierarchical management of Limited Area
|
|
148
317
|
Models (LAMs) and a global dataset, handling overlapping regions.
|
|
149
318
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
319
|
+
Parameters
|
|
320
|
+
----------
|
|
321
|
+
datasets : list
|
|
322
|
+
List of LAM and global datasets.
|
|
323
|
+
axis : int
|
|
324
|
+
Concatenation axis, must be set to 3.
|
|
325
|
+
cropping_distance : float
|
|
326
|
+
Distance threshold in degrees for cropping cutouts.
|
|
327
|
+
neighbours : int
|
|
328
|
+
Number of neighboring points to consider when constructing masks.
|
|
329
|
+
min_distance_km : float, optional
|
|
330
|
+
Minimum distance threshold in km between grid points.
|
|
331
|
+
plot : bool, optional
|
|
332
|
+
Flag to enable or disable visualization plots.
|
|
161
333
|
"""
|
|
162
334
|
super().__init__(datasets, axis)
|
|
163
335
|
assert len(datasets) >= 2, "CutoutGrids requires at least two datasets"
|
|
@@ -179,14 +351,13 @@ class Cutout(GridsBase):
|
|
|
179
351
|
# Initialize cumulative masks
|
|
180
352
|
self._initialize_masks()
|
|
181
353
|
|
|
182
|
-
def _initialize_masks(self):
|
|
183
|
-
"""
|
|
184
|
-
overlapping regions with previous LAMs and creating a global mask for
|
|
185
|
-
the global dataset.
|
|
354
|
+
def _initialize_masks(self) -> None:
|
|
355
|
+
"""Generate hierarchical masks for each LAM dataset by excluding overlapping regions with previous LAMs and creating a global mask for the global dataset.
|
|
186
356
|
|
|
187
|
-
Raises
|
|
188
|
-
|
|
189
|
-
|
|
357
|
+
Raises
|
|
358
|
+
------
|
|
359
|
+
ValueError
|
|
360
|
+
If the global mask dimension does not match the global dataset grid points.
|
|
190
361
|
"""
|
|
191
362
|
from anemoi.datasets.grids import cutout_mask
|
|
192
363
|
|
|
@@ -236,21 +407,33 @@ class Cutout(GridsBase):
|
|
|
236
407
|
lam_current_mask[~lam_overlap_mask] = False
|
|
237
408
|
self.masks.append(lam_current_mask)
|
|
238
409
|
|
|
239
|
-
def has_overlap(
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
410
|
+
def has_overlap(
|
|
411
|
+
self,
|
|
412
|
+
lats1: NDArray[Any],
|
|
413
|
+
lons1: NDArray[Any],
|
|
414
|
+
lats2: NDArray[Any],
|
|
415
|
+
lons2: NDArray[Any],
|
|
416
|
+
distance_threshold: float = 1.0,
|
|
417
|
+
) -> bool:
|
|
418
|
+
"""Check for overlapping points between two sets of latitudes and longitudes within a specified distance threshold.
|
|
419
|
+
|
|
420
|
+
Parameters
|
|
421
|
+
----------
|
|
422
|
+
lats1 : NDArray[Any]
|
|
423
|
+
Latitude array for the first dataset.
|
|
424
|
+
lons1 : NDArray[Any]
|
|
425
|
+
Longitude array for the first dataset.
|
|
426
|
+
lats2 : NDArray[Any]
|
|
427
|
+
Latitude array for the second dataset.
|
|
428
|
+
lons2 : NDArray[Any]
|
|
429
|
+
Longitude array for the second dataset.
|
|
430
|
+
distance_threshold : float
|
|
431
|
+
Distance in degrees to consider as overlapping.
|
|
432
|
+
|
|
433
|
+
Returns
|
|
434
|
+
-------
|
|
435
|
+
bool
|
|
436
|
+
True if any points overlap within the distance threshold, otherwise False.
|
|
254
437
|
"""
|
|
255
438
|
# Create KDTree for the first set of points
|
|
256
439
|
tree = cKDTree(np.vstack((lats1, lons1)).T)
|
|
@@ -261,31 +444,35 @@ class Cutout(GridsBase):
|
|
|
261
444
|
# Check if any distance is less than the specified threshold
|
|
262
445
|
return np.any(distances < distance_threshold)
|
|
263
446
|
|
|
264
|
-
def __getitem__(self, index):
|
|
265
|
-
"""
|
|
266
|
-
given index.
|
|
447
|
+
def __getitem__(self, index: FullIndex) -> NDArray[Any]:
|
|
448
|
+
"""Retrieve data from the masked LAMs and global dataset based on the given index.
|
|
267
449
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
450
|
+
Parameters
|
|
451
|
+
----------
|
|
452
|
+
index : FullIndex
|
|
453
|
+
Index specifying the data to retrieve.
|
|
271
454
|
|
|
272
|
-
Returns
|
|
273
|
-
|
|
455
|
+
Returns
|
|
456
|
+
-------
|
|
457
|
+
NDArray[Any]
|
|
458
|
+
Data array from the masked datasets based on the index.
|
|
274
459
|
"""
|
|
275
460
|
if isinstance(index, (int, slice)):
|
|
276
461
|
index = (index, slice(None), slice(None), slice(None))
|
|
277
462
|
return self._get_tuple(index)
|
|
278
463
|
|
|
279
|
-
def _get_tuple(self, index):
|
|
280
|
-
"""Helper method that applies masks and retrieves data from each dataset
|
|
281
|
-
according to the specified index.
|
|
464
|
+
def _get_tuple(self, index: TupleIndex) -> NDArray[Any]:
|
|
465
|
+
"""Helper method that applies masks and retrieves data from each dataset according to the specified index.
|
|
282
466
|
|
|
283
|
-
|
|
284
|
-
|
|
467
|
+
Parameters
|
|
468
|
+
----------
|
|
469
|
+
index : TupleIndex
|
|
470
|
+
Index specifying slices to retrieve data.
|
|
285
471
|
|
|
286
|
-
Returns
|
|
287
|
-
|
|
288
|
-
|
|
472
|
+
Returns
|
|
473
|
+
-------
|
|
474
|
+
NDArray[Any]
|
|
475
|
+
Concatenated data array from all datasets based on the index.
|
|
289
476
|
"""
|
|
290
477
|
index, changes = index_to_slices(index, self.shape)
|
|
291
478
|
# Select data from each LAM
|
|
@@ -300,13 +487,15 @@ class Cutout(GridsBase):
|
|
|
300
487
|
|
|
301
488
|
return apply_index_to_slices_changes(result, changes)
|
|
302
489
|
|
|
303
|
-
def collect_supporting_arrays(self, collected, *path):
|
|
304
|
-
"""
|
|
305
|
-
dataset.
|
|
490
|
+
def collect_supporting_arrays(self, collected: List[Any], *path: Any) -> None:
|
|
491
|
+
"""Collect supporting arrays, including masks for each LAM and the global dataset.
|
|
306
492
|
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
493
|
+
Parameters
|
|
494
|
+
----------
|
|
495
|
+
collected : List[Any]
|
|
496
|
+
List to which the supporting arrays are appended.
|
|
497
|
+
*path : Any
|
|
498
|
+
Variable length argument list specifying the paths for the masks.
|
|
310
499
|
"""
|
|
311
500
|
# Append masks for each LAM
|
|
312
501
|
for i, (lam, mask) in enumerate(zip(self.lams, self.masks)):
|
|
@@ -316,41 +505,41 @@ class Cutout(GridsBase):
|
|
|
316
505
|
collected.append((path + ("global",), "cutout_mask", self.global_mask))
|
|
317
506
|
|
|
318
507
|
@cached_property
|
|
319
|
-
def shape(self):
|
|
508
|
+
def shape(self) -> Shape:
|
|
320
509
|
"""Returns the shape of the Cutout, accounting for retained grid points
|
|
321
510
|
across all LAMs and the global dataset.
|
|
322
|
-
|
|
323
|
-
Returns:
|
|
324
|
-
tuple: Shape of the concatenated masked datasets.
|
|
325
511
|
"""
|
|
326
512
|
shapes = [np.sum(mask) for mask in self.masks]
|
|
327
513
|
global_shape = np.sum(self.global_mask)
|
|
328
514
|
total_shape = sum(shapes) + global_shape
|
|
329
515
|
return tuple(self.lams[0].shape[:-1] + (int(total_shape),))
|
|
330
516
|
|
|
331
|
-
def check_same_resolution(self, d1, d2):
|
|
517
|
+
def check_same_resolution(self, d1: Dataset, d2: Dataset) -> None:
|
|
518
|
+
"""Checks if the resolutions of two datasets are the same.
|
|
519
|
+
|
|
520
|
+
Parameters
|
|
521
|
+
----------
|
|
522
|
+
d1 : Dataset
|
|
523
|
+
The first dataset.
|
|
524
|
+
d2 : Dataset
|
|
525
|
+
The second dataset.
|
|
526
|
+
"""
|
|
332
527
|
# Turned off because we are combining different resolutions
|
|
333
528
|
pass
|
|
334
529
|
|
|
335
530
|
@property
|
|
336
|
-
def grids(self):
|
|
531
|
+
def grids(self) -> TupleIndex:
|
|
337
532
|
"""Returns the number of grid points for each LAM and the global dataset
|
|
338
533
|
after applying masks.
|
|
339
|
-
|
|
340
|
-
Returns:
|
|
341
|
-
tuple: Count of retained grid points for each dataset.
|
|
342
534
|
"""
|
|
343
535
|
grids = [np.sum(mask) for mask in self.masks]
|
|
344
536
|
grids.append(np.sum(self.global_mask))
|
|
345
537
|
return tuple(grids)
|
|
346
538
|
|
|
347
539
|
@property
|
|
348
|
-
def latitudes(self):
|
|
540
|
+
def latitudes(self) -> NDArray[Any]:
|
|
349
541
|
"""Returns the concatenated latitudes of each LAM and the global dataset
|
|
350
542
|
after applying masks.
|
|
351
|
-
|
|
352
|
-
Returns:
|
|
353
|
-
np.ndarray: Concatenated latitude array for the masked datasets.
|
|
354
543
|
"""
|
|
355
544
|
lam_latitudes = np.concatenate([lam.latitudes[mask] for lam, mask in zip(self.lams, self.masks)])
|
|
356
545
|
|
|
@@ -362,12 +551,9 @@ class Cutout(GridsBase):
|
|
|
362
551
|
return latitudes
|
|
363
552
|
|
|
364
553
|
@property
|
|
365
|
-
def longitudes(self):
|
|
554
|
+
def longitudes(self) -> NDArray[Any]:
|
|
366
555
|
"""Returns the concatenated longitudes of each LAM and the global dataset
|
|
367
556
|
after applying masks.
|
|
368
|
-
|
|
369
|
-
Returns:
|
|
370
|
-
np.ndarray: Concatenated longitude array for the masked datasets.
|
|
371
557
|
"""
|
|
372
558
|
lam_longitudes = np.concatenate([lam.longitudes[mask] for lam, mask in zip(self.lams, self.masks)])
|
|
373
559
|
|
|
@@ -378,19 +564,45 @@ class Cutout(GridsBase):
|
|
|
378
564
|
longitudes = np.concatenate([lam_longitudes, self.globe.longitudes[self.global_mask]])
|
|
379
565
|
return longitudes
|
|
380
566
|
|
|
381
|
-
def tree(self):
|
|
567
|
+
def tree(self) -> Node:
|
|
382
568
|
"""Generates a hierarchical tree structure for the `Cutout` instance and
|
|
383
569
|
its associated datasets.
|
|
384
570
|
|
|
385
|
-
Returns
|
|
386
|
-
|
|
571
|
+
Returns
|
|
572
|
+
-------
|
|
573
|
+
Node
|
|
574
|
+
A `Node` object representing the `Cutout` instance as the root
|
|
387
575
|
node, with each dataset in `self.datasets` represented as a child
|
|
388
576
|
node.
|
|
389
577
|
"""
|
|
390
578
|
return Node(self, [d.tree() for d in self.datasets])
|
|
391
579
|
|
|
580
|
+
def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
|
|
581
|
+
"""Returns metadata specific to the Cutout object.
|
|
392
582
|
|
|
393
|
-
|
|
583
|
+
Returns
|
|
584
|
+
-------
|
|
585
|
+
Dict[str, Any]
|
|
586
|
+
Metadata specific to the Cutout object.
|
|
587
|
+
"""
|
|
588
|
+
return {}
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
def grids_factory(args: Tuple[Any, ...], kwargs: dict) -> Dataset:
|
|
592
|
+
"""Factory function to create a Grids object.
|
|
593
|
+
|
|
594
|
+
Parameters
|
|
595
|
+
----------
|
|
596
|
+
args : Tuple[Any, ...]
|
|
597
|
+
Positional arguments.
|
|
598
|
+
kwargs : dict
|
|
599
|
+
Keyword arguments.
|
|
600
|
+
|
|
601
|
+
Returns
|
|
602
|
+
-------
|
|
603
|
+
Dataset
|
|
604
|
+
A Grids object.
|
|
605
|
+
"""
|
|
394
606
|
if "ensemble" in kwargs:
|
|
395
607
|
raise NotImplementedError("Cannot use both 'ensemble' and 'grids'")
|
|
396
608
|
|
|
@@ -406,7 +618,21 @@ def grids_factory(args, kwargs):
|
|
|
406
618
|
return Grids(datasets, axis=axis)._subset(**kwargs)
|
|
407
619
|
|
|
408
620
|
|
|
409
|
-
def cutout_factory(args, kwargs):
|
|
621
|
+
def cutout_factory(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dataset:
|
|
622
|
+
"""Factory function to create a Cutout object.
|
|
623
|
+
|
|
624
|
+
Parameters
|
|
625
|
+
----------
|
|
626
|
+
args : Tuple[Any, ...]
|
|
627
|
+
Positional arguments.
|
|
628
|
+
kwargs : Dict[str, Any]
|
|
629
|
+
Keyword arguments.
|
|
630
|
+
|
|
631
|
+
Returns
|
|
632
|
+
-------
|
|
633
|
+
Dataset
|
|
634
|
+
A Cutout object.
|
|
635
|
+
"""
|
|
410
636
|
if "ensemble" in kwargs:
|
|
411
637
|
raise NotImplementedError("Cannot use both 'ensemble' and 'cutout'")
|
|
412
638
|
|