anemoi-datasets 0.5.26__py3-none-any.whl → 0.5.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/__init__.py +1 -2
- anemoi/datasets/_version.py +16 -3
- anemoi/datasets/commands/check.py +1 -1
- anemoi/datasets/commands/copy.py +1 -2
- anemoi/datasets/commands/create.py +1 -1
- anemoi/datasets/commands/inspect.py +27 -35
- anemoi/datasets/commands/recipe/__init__.py +93 -0
- anemoi/datasets/commands/recipe/format.py +55 -0
- anemoi/datasets/commands/recipe/migrate.py +555 -0
- anemoi/datasets/commands/validate.py +59 -0
- anemoi/datasets/compute/recentre.py +3 -6
- anemoi/datasets/create/__init__.py +64 -26
- anemoi/datasets/create/check.py +10 -12
- anemoi/datasets/create/chunks.py +1 -2
- anemoi/datasets/create/config.py +5 -6
- anemoi/datasets/create/input/__init__.py +44 -65
- anemoi/datasets/create/input/action.py +296 -238
- anemoi/datasets/create/input/context/__init__.py +71 -0
- anemoi/datasets/create/input/context/field.py +54 -0
- anemoi/datasets/create/input/data_sources.py +7 -9
- anemoi/datasets/create/input/misc.py +2 -75
- anemoi/datasets/create/input/repeated_dates.py +11 -130
- anemoi/datasets/{utils → create/input/result}/__init__.py +10 -1
- anemoi/datasets/create/input/{result.py → result/field.py} +36 -120
- anemoi/datasets/create/input/trace.py +1 -1
- anemoi/datasets/create/patch.py +1 -2
- anemoi/datasets/create/persistent.py +3 -5
- anemoi/datasets/create/size.py +1 -3
- anemoi/datasets/create/sources/accumulations.py +120 -145
- anemoi/datasets/create/sources/accumulations2.py +20 -53
- anemoi/datasets/create/sources/anemoi_dataset.py +46 -42
- anemoi/datasets/create/sources/constants.py +39 -40
- anemoi/datasets/create/sources/empty.py +22 -19
- anemoi/datasets/create/sources/fdb.py +133 -0
- anemoi/datasets/create/sources/forcings.py +29 -29
- anemoi/datasets/create/sources/grib.py +94 -78
- anemoi/datasets/create/sources/grib_index.py +57 -55
- anemoi/datasets/create/sources/hindcasts.py +57 -59
- anemoi/datasets/create/sources/legacy.py +10 -62
- anemoi/datasets/create/sources/mars.py +121 -149
- anemoi/datasets/create/sources/netcdf.py +28 -25
- anemoi/datasets/create/sources/opendap.py +28 -26
- anemoi/datasets/create/sources/patterns.py +4 -6
- anemoi/datasets/create/sources/recentre.py +46 -48
- anemoi/datasets/create/sources/repeated_dates.py +44 -0
- anemoi/datasets/create/sources/source.py +26 -51
- anemoi/datasets/create/sources/tendencies.py +68 -98
- anemoi/datasets/create/sources/xarray.py +4 -6
- anemoi/datasets/create/sources/xarray_support/__init__.py +40 -36
- anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -12
- anemoi/datasets/create/sources/xarray_support/field.py +20 -16
- anemoi/datasets/create/sources/xarray_support/fieldlist.py +11 -15
- anemoi/datasets/create/sources/xarray_support/flavour.py +42 -42
- anemoi/datasets/create/sources/xarray_support/grid.py +15 -9
- anemoi/datasets/create/sources/xarray_support/metadata.py +19 -128
- anemoi/datasets/create/sources/xarray_support/patch.py +4 -6
- anemoi/datasets/create/sources/xarray_support/time.py +10 -13
- anemoi/datasets/create/sources/xarray_support/variable.py +21 -21
- anemoi/datasets/create/sources/xarray_zarr.py +28 -25
- anemoi/datasets/create/sources/zenodo.py +43 -41
- anemoi/datasets/create/statistics/__init__.py +3 -6
- anemoi/datasets/create/testing.py +4 -0
- anemoi/datasets/create/typing.py +1 -2
- anemoi/datasets/create/utils.py +0 -43
- anemoi/datasets/create/zarr.py +7 -2
- anemoi/datasets/data/__init__.py +15 -6
- anemoi/datasets/data/complement.py +7 -12
- anemoi/datasets/data/concat.py +5 -8
- anemoi/datasets/data/dataset.py +48 -47
- anemoi/datasets/data/debug.py +7 -9
- anemoi/datasets/data/ensemble.py +4 -6
- anemoi/datasets/data/fill_missing.py +7 -10
- anemoi/datasets/data/forwards.py +22 -26
- anemoi/datasets/data/grids.py +12 -168
- anemoi/datasets/data/indexing.py +9 -12
- anemoi/datasets/data/interpolate.py +7 -15
- anemoi/datasets/data/join.py +8 -12
- anemoi/datasets/data/masked.py +6 -11
- anemoi/datasets/data/merge.py +5 -9
- anemoi/datasets/data/misc.py +41 -45
- anemoi/datasets/data/missing.py +11 -16
- anemoi/datasets/data/observations/__init__.py +8 -14
- anemoi/datasets/data/padded.py +3 -5
- anemoi/datasets/data/records/backends/__init__.py +2 -2
- anemoi/datasets/data/rescale.py +5 -12
- anemoi/datasets/data/rolling_average.py +141 -0
- anemoi/datasets/data/select.py +13 -16
- anemoi/datasets/data/statistics.py +4 -7
- anemoi/datasets/data/stores.py +22 -29
- anemoi/datasets/data/subset.py +8 -11
- anemoi/datasets/data/unchecked.py +7 -11
- anemoi/datasets/data/xy.py +25 -21
- anemoi/datasets/dates/__init__.py +15 -18
- anemoi/datasets/dates/groups.py +7 -10
- anemoi/datasets/dumper.py +76 -0
- anemoi/datasets/grids.py +4 -185
- anemoi/datasets/schemas/recipe.json +131 -0
- anemoi/datasets/testing.py +93 -7
- anemoi/datasets/validate.py +598 -0
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/METADATA +7 -4
- anemoi_datasets-0.5.28.dist-info/RECORD +134 -0
- anemoi/datasets/create/filter.py +0 -48
- anemoi/datasets/create/input/concat.py +0 -164
- anemoi/datasets/create/input/context.py +0 -89
- anemoi/datasets/create/input/empty.py +0 -54
- anemoi/datasets/create/input/filter.py +0 -118
- anemoi/datasets/create/input/function.py +0 -233
- anemoi/datasets/create/input/join.py +0 -130
- anemoi/datasets/create/input/pipe.py +0 -66
- anemoi/datasets/create/input/step.py +0 -177
- anemoi/datasets/create/input/template.py +0 -162
- anemoi_datasets-0.5.26.dist-info/RECORD +0 -131
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/WHEEL +0 -0
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/licenses/LICENSE +0 -0
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/top_level.txt +0 -0
anemoi/datasets/data/debug.py
CHANGED
|
@@ -11,12 +11,10 @@
|
|
|
11
11
|
import logging
|
|
12
12
|
import os
|
|
13
13
|
import textwrap
|
|
14
|
+
from collections.abc import Callable
|
|
14
15
|
from functools import wraps
|
|
15
16
|
from typing import TYPE_CHECKING
|
|
16
17
|
from typing import Any
|
|
17
|
-
from typing import Callable
|
|
18
|
-
from typing import List
|
|
19
|
-
from typing import Optional
|
|
20
18
|
|
|
21
19
|
from anemoi.utils.text import Tree
|
|
22
20
|
from numpy.typing import NDArray
|
|
@@ -56,7 +54,7 @@ def css(name: str) -> str:
|
|
|
56
54
|
class Node:
|
|
57
55
|
"""A class to represent a node in a dataset tree."""
|
|
58
56
|
|
|
59
|
-
def __init__(self, dataset: "Dataset", kids:
|
|
57
|
+
def __init__(self, dataset: "Dataset", kids: list[Any], **kwargs: Any) -> None:
|
|
60
58
|
"""Initializes a Node object.
|
|
61
59
|
|
|
62
60
|
Parameters
|
|
@@ -72,7 +70,7 @@ class Node:
|
|
|
72
70
|
self.kids = kids
|
|
73
71
|
self.kwargs = kwargs
|
|
74
72
|
|
|
75
|
-
def _put(self, indent: int, result:
|
|
73
|
+
def _put(self, indent: int, result: list[str]) -> None:
|
|
76
74
|
"""Helper method to add the node representation to the result list.
|
|
77
75
|
|
|
78
76
|
Parameters
|
|
@@ -103,11 +101,11 @@ class Node:
|
|
|
103
101
|
str
|
|
104
102
|
String representation of the node.
|
|
105
103
|
"""
|
|
106
|
-
result:
|
|
104
|
+
result: list[str] = []
|
|
107
105
|
self._put(0, result)
|
|
108
106
|
return "\n".join(result)
|
|
109
107
|
|
|
110
|
-
def graph(self, digraph:
|
|
108
|
+
def graph(self, digraph: list[str], nodes: dict) -> None:
|
|
111
109
|
"""Generates a graph representation of the node.
|
|
112
110
|
|
|
113
111
|
Parameters
|
|
@@ -170,7 +168,7 @@ class Node:
|
|
|
170
168
|
digraph.append("}")
|
|
171
169
|
return "\n".join(digraph)
|
|
172
170
|
|
|
173
|
-
def _html(self, indent: str, rows:
|
|
171
|
+
def _html(self, indent: str, rows: list[list[str]]) -> None:
|
|
174
172
|
"""Helper method to add the node representation to the HTML rows.
|
|
175
173
|
|
|
176
174
|
Parameters
|
|
@@ -273,7 +271,7 @@ class Node:
|
|
|
273
271
|
class Source:
|
|
274
272
|
"""A class used to follow the provenance of a data point."""
|
|
275
273
|
|
|
276
|
-
def __init__(self, dataset: Any, index: int, source:
|
|
274
|
+
def __init__(self, dataset: Any, index: int, source: Any | None = None, info: Any | None = None) -> None:
|
|
277
275
|
"""Initializes a Source object.
|
|
278
276
|
|
|
279
277
|
Parameters
|
anemoi/datasets/data/ensemble.py
CHANGED
|
@@ -10,8 +10,6 @@
|
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
12
|
from typing import Any
|
|
13
|
-
from typing import Dict
|
|
14
|
-
from typing import Tuple
|
|
15
13
|
|
|
16
14
|
import numpy as np
|
|
17
15
|
from numpy.typing import NDArray
|
|
@@ -105,7 +103,7 @@ class Number(Forwards):
|
|
|
105
103
|
"""
|
|
106
104
|
return Node(self, [self.forward.tree()], numbers=[n + 1 for n in self.members])
|
|
107
105
|
|
|
108
|
-
def metadata_specific(self, **kwargs: Any) ->
|
|
106
|
+
def metadata_specific(self, **kwargs: Any) -> dict[str, Any]:
|
|
109
107
|
"""Returns metadata specific to the Number object.
|
|
110
108
|
|
|
111
109
|
Parameters
|
|
@@ -122,7 +120,7 @@ class Number(Forwards):
|
|
|
122
120
|
"numbers": [n + 1 for n in self.members],
|
|
123
121
|
}
|
|
124
122
|
|
|
125
|
-
def forwards_subclass_metadata_specific(self) ->
|
|
123
|
+
def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
|
|
126
124
|
"""Returns metadata specific to the Number object."""
|
|
127
125
|
return {}
|
|
128
126
|
|
|
@@ -140,7 +138,7 @@ class Ensemble(GivenAxis):
|
|
|
140
138
|
"""
|
|
141
139
|
return Node(self, [d.tree() for d in self.datasets])
|
|
142
140
|
|
|
143
|
-
def forwards_subclass_metadata_specific(self) ->
|
|
141
|
+
def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
|
|
144
142
|
"""Get the metadata specific to the forwards subclass.
|
|
145
143
|
|
|
146
144
|
Returns:
|
|
@@ -149,7 +147,7 @@ class Ensemble(GivenAxis):
|
|
|
149
147
|
return {}
|
|
150
148
|
|
|
151
149
|
|
|
152
|
-
def ensemble_factory(args:
|
|
150
|
+
def ensemble_factory(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Ensemble:
|
|
153
151
|
"""Factory function to create an Ensemble object.
|
|
154
152
|
|
|
155
153
|
Parameters
|
|
@@ -10,9 +10,6 @@
|
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
12
|
from typing import Any
|
|
13
|
-
from typing import Dict
|
|
14
|
-
from typing import Optional
|
|
15
|
-
from typing import Set
|
|
16
13
|
|
|
17
14
|
import numpy as np
|
|
18
15
|
from numpy.typing import NDArray
|
|
@@ -46,7 +43,7 @@ class MissingDatesFill(Forwards):
|
|
|
46
43
|
"""
|
|
47
44
|
super().__init__(dataset)
|
|
48
45
|
self._missing = set(dataset.missing)
|
|
49
|
-
self._warnings:
|
|
46
|
+
self._warnings: set[int] = set()
|
|
50
47
|
|
|
51
48
|
@debug_indexing
|
|
52
49
|
@expand_list_indexing
|
|
@@ -84,7 +81,7 @@ class MissingDatesFill(Forwards):
|
|
|
84
81
|
return np.stack([self[i] for i in range(*s.indices(self._len))])
|
|
85
82
|
|
|
86
83
|
@property
|
|
87
|
-
def missing(self) ->
|
|
84
|
+
def missing(self) -> set[int]:
|
|
88
85
|
"""Get the set of missing dates."""
|
|
89
86
|
return set()
|
|
90
87
|
|
|
@@ -153,7 +150,7 @@ class MissingDatesClosest(MissingDatesFill):
|
|
|
153
150
|
self.closest = closest
|
|
154
151
|
self._closest = {}
|
|
155
152
|
|
|
156
|
-
def _fill_missing(self, n: int, a:
|
|
153
|
+
def _fill_missing(self, n: int, a: int | None, b: int | None) -> NDArray[Any]:
|
|
157
154
|
"""Fill the missing date at the given index.
|
|
158
155
|
|
|
159
156
|
Parameters
|
|
@@ -189,7 +186,7 @@ class MissingDatesClosest(MissingDatesFill):
|
|
|
189
186
|
|
|
190
187
|
return self.forward[self._closest[n]]
|
|
191
188
|
|
|
192
|
-
def forwards_subclass_metadata_specific(self) ->
|
|
189
|
+
def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
|
|
193
190
|
"""Get metadata specific to the subclass.
|
|
194
191
|
|
|
195
192
|
Returns
|
|
@@ -224,7 +221,7 @@ class MissingDatesInterpolate(MissingDatesFill):
|
|
|
224
221
|
super().__init__(dataset)
|
|
225
222
|
self._alpha = {}
|
|
226
223
|
|
|
227
|
-
def _fill_missing(self, n: int, a:
|
|
224
|
+
def _fill_missing(self, n: int, a: int | None, b: int | None) -> NDArray[Any]:
|
|
228
225
|
"""Fill the missing date at the given index using interpolation.
|
|
229
226
|
|
|
230
227
|
Parameters
|
|
@@ -264,7 +261,7 @@ class MissingDatesInterpolate(MissingDatesFill):
|
|
|
264
261
|
alpha = self._alpha[n]
|
|
265
262
|
return self.forward[a] * (1 - alpha) + self.forward[b] * alpha
|
|
266
263
|
|
|
267
|
-
def forwards_subclass_metadata_specific(self) ->
|
|
264
|
+
def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
|
|
268
265
|
"""Get metadata specific to the subclass.
|
|
269
266
|
|
|
270
267
|
Returns
|
|
@@ -285,7 +282,7 @@ class MissingDatesInterpolate(MissingDatesFill):
|
|
|
285
282
|
return Node(self, [self.forward.tree()])
|
|
286
283
|
|
|
287
284
|
|
|
288
|
-
def fill_missing_dates_factory(dataset: Any, method: str, kwargs:
|
|
285
|
+
def fill_missing_dates_factory(dataset: Any, method: str, kwargs: dict[str, Any]) -> Dataset:
|
|
289
286
|
"""Factory function to create an instance of a class to fill missing dates.
|
|
290
287
|
|
|
291
288
|
Parameters
|
anemoi/datasets/data/forwards.py
CHANGED
|
@@ -14,10 +14,6 @@ import warnings
|
|
|
14
14
|
from abc import abstractmethod
|
|
15
15
|
from functools import cached_property
|
|
16
16
|
from typing import Any
|
|
17
|
-
from typing import Dict
|
|
18
|
-
from typing import List
|
|
19
|
-
from typing import Optional
|
|
20
|
-
from typing import Set
|
|
21
17
|
|
|
22
18
|
import numpy as np
|
|
23
19
|
from numpy.typing import NDArray
|
|
@@ -75,7 +71,7 @@ class Forwards(Dataset):
|
|
|
75
71
|
return self.forward[n]
|
|
76
72
|
|
|
77
73
|
@property
|
|
78
|
-
def name(self) ->
|
|
74
|
+
def name(self) -> str | None:
|
|
79
75
|
"""Returns the name of the forward dataset."""
|
|
80
76
|
if self._name is not None:
|
|
81
77
|
return self._name
|
|
@@ -112,26 +108,26 @@ class Forwards(Dataset):
|
|
|
112
108
|
return self.forward.longitudes
|
|
113
109
|
|
|
114
110
|
@property
|
|
115
|
-
def name_to_index(self) ->
|
|
111
|
+
def name_to_index(self) -> dict[str, int]:
|
|
116
112
|
"""Returns a dictionary mapping variable names to their indices."""
|
|
117
113
|
return self.forward.name_to_index
|
|
118
114
|
|
|
119
115
|
@property
|
|
120
|
-
def variables(self) ->
|
|
116
|
+
def variables(self) -> list[str]:
|
|
121
117
|
"""Returns the variables of the forward dataset."""
|
|
122
118
|
return self.forward.variables
|
|
123
119
|
|
|
124
120
|
@property
|
|
125
|
-
def variables_metadata(self) ->
|
|
121
|
+
def variables_metadata(self) -> dict[str, Any]:
|
|
126
122
|
"""Returns the metadata of the variables in the forward dataset."""
|
|
127
123
|
return self.forward.variables_metadata
|
|
128
124
|
|
|
129
125
|
@property
|
|
130
|
-
def statistics(self) ->
|
|
126
|
+
def statistics(self) -> dict[str, NDArray[Any]]:
|
|
131
127
|
"""Returns the statistics of the forward dataset."""
|
|
132
128
|
return self.forward.statistics
|
|
133
129
|
|
|
134
|
-
def statistics_tendencies(self, delta:
|
|
130
|
+
def statistics_tendencies(self, delta: datetime.timedelta | None = None) -> dict[str, NDArray[Any]]:
|
|
135
131
|
"""Returns the statistics tendencies of the forward dataset.
|
|
136
132
|
|
|
137
133
|
Parameters
|
|
@@ -159,7 +155,7 @@ class Forwards(Dataset):
|
|
|
159
155
|
return self.forward.dtype
|
|
160
156
|
|
|
161
157
|
@property
|
|
162
|
-
def missing(self) ->
|
|
158
|
+
def missing(self) -> set[int]:
|
|
163
159
|
"""Returns the missing data information of the forward dataset."""
|
|
164
160
|
return self.forward.missing
|
|
165
161
|
|
|
@@ -168,7 +164,7 @@ class Forwards(Dataset):
|
|
|
168
164
|
"""Returns the grids of the forward dataset."""
|
|
169
165
|
return self.forward.grids
|
|
170
166
|
|
|
171
|
-
def metadata_specific(self, **kwargs: Any) ->
|
|
167
|
+
def metadata_specific(self, **kwargs: Any) -> dict[str, Any]:
|
|
172
168
|
"""Returns metadata specific to the forward dataset.
|
|
173
169
|
|
|
174
170
|
Parameters
|
|
@@ -187,7 +183,7 @@ class Forwards(Dataset):
|
|
|
187
183
|
**kwargs,
|
|
188
184
|
)
|
|
189
185
|
|
|
190
|
-
def collect_supporting_arrays(self, collected:
|
|
186
|
+
def collect_supporting_arrays(self, collected: list[Any], *path: Any) -> None:
|
|
191
187
|
"""Collects supporting arrays from the forward dataset.
|
|
192
188
|
|
|
193
189
|
Parameters
|
|
@@ -199,7 +195,7 @@ class Forwards(Dataset):
|
|
|
199
195
|
"""
|
|
200
196
|
self.forward.collect_supporting_arrays(collected, *path)
|
|
201
197
|
|
|
202
|
-
def collect_input_sources(self, collected:
|
|
198
|
+
def collect_input_sources(self, collected: list[Any]) -> None:
|
|
203
199
|
"""Collects input sources from the forward dataset.
|
|
204
200
|
|
|
205
201
|
Parameters
|
|
@@ -225,11 +221,11 @@ class Forwards(Dataset):
|
|
|
225
221
|
return self.forward.source(index)
|
|
226
222
|
|
|
227
223
|
@abstractmethod
|
|
228
|
-
def forwards_subclass_metadata_specific(self) ->
|
|
224
|
+
def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
|
|
229
225
|
"""Returns metadata specific to the subclass."""
|
|
230
226
|
pass
|
|
231
227
|
|
|
232
|
-
def get_dataset_names(self, names:
|
|
228
|
+
def get_dataset_names(self, names: set[str]) -> None:
|
|
233
229
|
"""Collects the names of the datasets.
|
|
234
230
|
|
|
235
231
|
Parameters
|
|
@@ -240,7 +236,7 @@ class Forwards(Dataset):
|
|
|
240
236
|
self.forward.get_dataset_names(names)
|
|
241
237
|
|
|
242
238
|
@property
|
|
243
|
-
def constant_fields(self) ->
|
|
239
|
+
def constant_fields(self) -> list[str]:
|
|
244
240
|
"""Returns the constant fields of the forward dataset."""
|
|
245
241
|
return self.forward.constant_fields
|
|
246
242
|
|
|
@@ -248,7 +244,7 @@ class Forwards(Dataset):
|
|
|
248
244
|
class Combined(Forwards):
|
|
249
245
|
"""A class to combine multiple datasets into a single dataset."""
|
|
250
246
|
|
|
251
|
-
def __init__(self, datasets:
|
|
247
|
+
def __init__(self, datasets: list[Dataset]) -> None:
|
|
252
248
|
"""Initializes a Combined object.
|
|
253
249
|
|
|
254
250
|
Parameters
|
|
@@ -466,7 +462,7 @@ class Combined(Forwards):
|
|
|
466
462
|
self.check_same_variables(d1, d2)
|
|
467
463
|
self.check_same_dates(d1, d2)
|
|
468
464
|
|
|
469
|
-
def provenance(self) ->
|
|
465
|
+
def provenance(self) -> list[Any]:
|
|
470
466
|
"""Returns the provenance of the combined datasets.
|
|
471
467
|
|
|
472
468
|
Returns
|
|
@@ -487,7 +483,7 @@ class Combined(Forwards):
|
|
|
487
483
|
lst = ", ".join(repr(d) for d in self.datasets)
|
|
488
484
|
return f"{self.__class__.__name__}({lst})"
|
|
489
485
|
|
|
490
|
-
def metadata_specific(self, **kwargs: Any) ->
|
|
486
|
+
def metadata_specific(self, **kwargs: Any) -> dict[str, Any]:
|
|
491
487
|
"""Returns metadata specific to the combined datasets.
|
|
492
488
|
|
|
493
489
|
Parameters
|
|
@@ -508,7 +504,7 @@ class Combined(Forwards):
|
|
|
508
504
|
**kwargs,
|
|
509
505
|
)
|
|
510
506
|
|
|
511
|
-
def collect_supporting_arrays(self, collected:
|
|
507
|
+
def collect_supporting_arrays(self, collected: list[Any], *path: Any) -> None:
|
|
512
508
|
"""Collects supporting arrays from the combined datasets.
|
|
513
509
|
|
|
514
510
|
Parameters
|
|
@@ -524,7 +520,7 @@ class Combined(Forwards):
|
|
|
524
520
|
d.collect_supporting_arrays(collected, *path, name)
|
|
525
521
|
|
|
526
522
|
@property
|
|
527
|
-
def missing(self) ->
|
|
523
|
+
def missing(self) -> set[int]:
|
|
528
524
|
"""Returns the missing data information of the combined datasets.
|
|
529
525
|
|
|
530
526
|
Raises
|
|
@@ -534,7 +530,7 @@ class Combined(Forwards):
|
|
|
534
530
|
"""
|
|
535
531
|
raise NotImplementedError("missing() not implemented for Combined")
|
|
536
532
|
|
|
537
|
-
def get_dataset_names(self, names:
|
|
533
|
+
def get_dataset_names(self, names: set[str]) -> None:
|
|
538
534
|
"""Collects the names of the combined datasets.
|
|
539
535
|
|
|
540
536
|
Parameters
|
|
@@ -549,7 +545,7 @@ class Combined(Forwards):
|
|
|
549
545
|
class GivenAxis(Combined):
|
|
550
546
|
"""A class to combine datasets along a given axis."""
|
|
551
547
|
|
|
552
|
-
def __init__(self, datasets:
|
|
548
|
+
def __init__(self, datasets: list[Any], axis: int) -> None:
|
|
553
549
|
"""Initializes a GivenAxis object.
|
|
554
550
|
|
|
555
551
|
Parameters
|
|
@@ -656,10 +652,10 @@ class GivenAxis(Combined):
|
|
|
656
652
|
return np.concatenate([d[n] for d in self.datasets], axis=self.axis - 1)
|
|
657
653
|
|
|
658
654
|
@cached_property
|
|
659
|
-
def missing(self) ->
|
|
655
|
+
def missing(self) -> set[int]:
|
|
660
656
|
"""Returns the missing data information of the combined dataset along the given axis."""
|
|
661
657
|
offset = 0
|
|
662
|
-
result:
|
|
658
|
+
result: set[int] = set()
|
|
663
659
|
for d in self.datasets:
|
|
664
660
|
result.update(offset + m for m in d.missing)
|
|
665
661
|
if self.axis == 0: # Advance if axis is time
|
anemoi/datasets/data/grids.py
CHANGED
|
@@ -11,10 +11,6 @@
|
|
|
11
11
|
import logging
|
|
12
12
|
from functools import cached_property
|
|
13
13
|
from typing import Any
|
|
14
|
-
from typing import Dict
|
|
15
|
-
from typing import List
|
|
16
|
-
from typing import Optional
|
|
17
|
-
from typing import Tuple
|
|
18
14
|
|
|
19
15
|
import numpy as np
|
|
20
16
|
from numpy.typing import NDArray
|
|
@@ -25,171 +21,19 @@ from .dataset import FullIndex
|
|
|
25
21
|
from .dataset import Shape
|
|
26
22
|
from .dataset import TupleIndex
|
|
27
23
|
from .debug import Node
|
|
28
|
-
from .debug import debug_indexing
|
|
29
|
-
from .forwards import Combined
|
|
30
24
|
from .forwards import GivenAxis
|
|
31
25
|
from .indexing import apply_index_to_slices_changes
|
|
32
|
-
from .indexing import expand_list_indexing
|
|
33
26
|
from .indexing import index_to_slices
|
|
34
|
-
from .indexing import length_to_slices
|
|
35
|
-
from .indexing import update_tuple
|
|
36
27
|
from .misc import _auto_adjust
|
|
37
28
|
from .misc import _open
|
|
38
29
|
|
|
39
30
|
LOG = logging.getLogger(__name__)
|
|
40
31
|
|
|
41
32
|
|
|
42
|
-
class Concat(Combined):
|
|
43
|
-
"""A class to represent concatenated datasets."""
|
|
44
|
-
|
|
45
|
-
def __len__(self) -> int:
|
|
46
|
-
"""Returns the total length of the concatenated datasets.
|
|
47
|
-
|
|
48
|
-
Returns
|
|
49
|
-
-------
|
|
50
|
-
int
|
|
51
|
-
Total length of the concatenated datasets.
|
|
52
|
-
"""
|
|
53
|
-
return sum(len(i) for i in self.datasets)
|
|
54
|
-
|
|
55
|
-
@debug_indexing
|
|
56
|
-
@expand_list_indexing
|
|
57
|
-
def _get_tuple(self, index: TupleIndex) -> NDArray[Any]:
|
|
58
|
-
"""Retrieves a tuple of data from the concatenated datasets based on the given index.
|
|
59
|
-
|
|
60
|
-
Parameters
|
|
61
|
-
----------
|
|
62
|
-
index : TupleIndex
|
|
63
|
-
Index specifying the data to retrieve.
|
|
64
|
-
|
|
65
|
-
Returns
|
|
66
|
-
-------
|
|
67
|
-
NDArray[Any]
|
|
68
|
-
Concatenated data array from the specified index.
|
|
69
|
-
"""
|
|
70
|
-
index, changes = index_to_slices(index, self.shape)
|
|
71
|
-
# print(index, changes)
|
|
72
|
-
lengths = [d.shape[0] for d in self.datasets]
|
|
73
|
-
slices = length_to_slices(index[0], lengths)
|
|
74
|
-
# print("slies", slices)
|
|
75
|
-
result = [d[update_tuple(index, 0, i)[0]] for (d, i) in zip(self.datasets, slices) if i is not None]
|
|
76
|
-
result = np.concatenate(result, axis=0)
|
|
77
|
-
return apply_index_to_slices_changes(result, changes)
|
|
78
|
-
|
|
79
|
-
@debug_indexing
|
|
80
|
-
def __getitem__(self, n: FullIndex) -> NDArray[Any]:
|
|
81
|
-
"""Retrieves data from the concatenated datasets based on the given index.
|
|
82
|
-
|
|
83
|
-
Parameters
|
|
84
|
-
----------
|
|
85
|
-
n : FullIndex
|
|
86
|
-
Index specifying the data to retrieve.
|
|
87
|
-
|
|
88
|
-
Returns
|
|
89
|
-
-------
|
|
90
|
-
NDArray[Any]
|
|
91
|
-
Data array from the concatenated datasets based on the index.
|
|
92
|
-
"""
|
|
93
|
-
if isinstance(n, tuple):
|
|
94
|
-
return self._get_tuple(n)
|
|
95
|
-
|
|
96
|
-
if isinstance(n, slice):
|
|
97
|
-
return self._get_slice(n)
|
|
98
|
-
|
|
99
|
-
# TODO: optimize
|
|
100
|
-
k = 0
|
|
101
|
-
while n >= self.datasets[k]._len:
|
|
102
|
-
n -= self.datasets[k]._len
|
|
103
|
-
k += 1
|
|
104
|
-
return self.datasets[k][n]
|
|
105
|
-
|
|
106
|
-
@debug_indexing
|
|
107
|
-
def _get_slice(self, s: slice) -> NDArray[Any]:
|
|
108
|
-
"""Retrieves a slice of data from the concatenated datasets.
|
|
109
|
-
|
|
110
|
-
Parameters
|
|
111
|
-
----------
|
|
112
|
-
s : slice
|
|
113
|
-
Slice object specifying the range of data to retrieve.
|
|
114
|
-
|
|
115
|
-
Returns
|
|
116
|
-
-------
|
|
117
|
-
NDArray[Any]
|
|
118
|
-
Concatenated data array from the specified slice.
|
|
119
|
-
"""
|
|
120
|
-
result = []
|
|
121
|
-
|
|
122
|
-
lengths = [d.shape[0] for d in self.datasets]
|
|
123
|
-
slices = length_to_slices(s, lengths)
|
|
124
|
-
|
|
125
|
-
result = [d[i] for (d, i) in zip(self.datasets, slices) if i is not None]
|
|
126
|
-
|
|
127
|
-
return np.concatenate(result)
|
|
128
|
-
|
|
129
|
-
def check_compatibility(self, d1: Dataset, d2: Dataset) -> None:
|
|
130
|
-
"""Check the compatibility of two datasets for concatenation.
|
|
131
|
-
|
|
132
|
-
Parameters
|
|
133
|
-
----------
|
|
134
|
-
d1 : Dataset
|
|
135
|
-
The first dataset.
|
|
136
|
-
d2 : Dataset
|
|
137
|
-
The second dataset.
|
|
138
|
-
"""
|
|
139
|
-
super().check_compatibility(d1, d2)
|
|
140
|
-
self.check_same_sub_shapes(d1, d2, drop_axis=0)
|
|
141
|
-
|
|
142
|
-
def check_same_lengths(self, d1: Dataset, d2: Dataset) -> None:
|
|
143
|
-
"""Check if the lengths of two datasets are the same.
|
|
144
|
-
|
|
145
|
-
Parameters
|
|
146
|
-
----------
|
|
147
|
-
d1 : Dataset
|
|
148
|
-
The first dataset.
|
|
149
|
-
d2 : Dataset
|
|
150
|
-
The second dataset.
|
|
151
|
-
"""
|
|
152
|
-
# Turned off because we are concatenating along the first axis
|
|
153
|
-
pass
|
|
154
|
-
|
|
155
|
-
def check_same_dates(self, d1: Dataset, d2: Dataset) -> None:
|
|
156
|
-
"""Check if the dates of two datasets are the same.
|
|
157
|
-
|
|
158
|
-
Parameters
|
|
159
|
-
----------
|
|
160
|
-
d1 : Dataset
|
|
161
|
-
The first dataset.
|
|
162
|
-
d2 : Dataset
|
|
163
|
-
The second dataset.
|
|
164
|
-
"""
|
|
165
|
-
# Turned off because we are concatenating along the dates axis
|
|
166
|
-
pass
|
|
167
|
-
|
|
168
|
-
@property
|
|
169
|
-
def dates(self) -> NDArray[np.datetime64]:
|
|
170
|
-
"""Returns the concatenated dates of all datasets."""
|
|
171
|
-
return np.concatenate([d.dates for d in self.datasets])
|
|
172
|
-
|
|
173
|
-
@property
|
|
174
|
-
def shape(self) -> Shape:
|
|
175
|
-
"""Returns the shape of the concatenated datasets."""
|
|
176
|
-
return (len(self),) + self.datasets[0].shape[1:]
|
|
177
|
-
|
|
178
|
-
def tree(self) -> Node:
|
|
179
|
-
"""Generates a hierarchical tree structure for the concatenated datasets.
|
|
180
|
-
|
|
181
|
-
Returns
|
|
182
|
-
-------
|
|
183
|
-
Node
|
|
184
|
-
A Node object representing the concatenated datasets.
|
|
185
|
-
"""
|
|
186
|
-
return Node(self, [d.tree() for d in self.datasets])
|
|
187
|
-
|
|
188
|
-
|
|
189
33
|
class GridsBase(GivenAxis):
|
|
190
34
|
"""A base class for handling grids in datasets."""
|
|
191
35
|
|
|
192
|
-
def __init__(self, datasets:
|
|
36
|
+
def __init__(self, datasets: list[Any], axis: int) -> None:
|
|
193
37
|
"""Initializes a GridsBase object.
|
|
194
38
|
|
|
195
39
|
Parameters
|
|
@@ -229,7 +73,7 @@ class GridsBase(GivenAxis):
|
|
|
229
73
|
# We don't check the resolution, because we want to be able to combine
|
|
230
74
|
pass
|
|
231
75
|
|
|
232
|
-
def metadata_specific(self, **kwargs: Any) ->
|
|
76
|
+
def metadata_specific(self, **kwargs: Any) -> dict[str, Any]:
|
|
233
77
|
"""Returns metadata specific to the GridsBase object.
|
|
234
78
|
|
|
235
79
|
Parameters
|
|
@@ -246,7 +90,7 @@ class GridsBase(GivenAxis):
|
|
|
246
90
|
multi_grids=True,
|
|
247
91
|
)
|
|
248
92
|
|
|
249
|
-
def collect_input_sources(self, collected:
|
|
93
|
+
def collect_input_sources(self, collected: list[Any]) -> None:
|
|
250
94
|
"""Collects input sources from the datasets.
|
|
251
95
|
|
|
252
96
|
Parameters
|
|
@@ -275,7 +119,7 @@ class Grids(GridsBase):
|
|
|
275
119
|
return np.concatenate([d.longitudes for d in self.datasets])
|
|
276
120
|
|
|
277
121
|
@property
|
|
278
|
-
def grids(self) ->
|
|
122
|
+
def grids(self) -> tuple[Any, ...]:
|
|
279
123
|
"""Returns the grids of all datasets."""
|
|
280
124
|
result = []
|
|
281
125
|
for d in self.datasets:
|
|
@@ -292,7 +136,7 @@ class Grids(GridsBase):
|
|
|
292
136
|
"""
|
|
293
137
|
return Node(self, [d.tree() for d in self.datasets], mode="concat")
|
|
294
138
|
|
|
295
|
-
def forwards_subclass_metadata_specific(self) ->
|
|
139
|
+
def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
|
|
296
140
|
"""Get the metadata specific to the forwards subclass.
|
|
297
141
|
|
|
298
142
|
Returns:
|
|
@@ -306,12 +150,12 @@ class Cutout(GridsBase):
|
|
|
306
150
|
|
|
307
151
|
def __init__(
|
|
308
152
|
self,
|
|
309
|
-
datasets:
|
|
153
|
+
datasets: list[Any],
|
|
310
154
|
axis: int = 3,
|
|
311
155
|
cropping_distance: float = 2.0,
|
|
312
156
|
neighbours: int = 5,
|
|
313
|
-
min_distance_km:
|
|
314
|
-
plot:
|
|
157
|
+
min_distance_km: float | None = None,
|
|
158
|
+
plot: bool | None = None,
|
|
315
159
|
) -> None:
|
|
316
160
|
"""Initializes a Cutout object for hierarchical management of Limited Area
|
|
317
161
|
Models (LAMs) and a global dataset, handling overlapping regions.
|
|
@@ -487,7 +331,7 @@ class Cutout(GridsBase):
|
|
|
487
331
|
|
|
488
332
|
return apply_index_to_slices_changes(result, changes)
|
|
489
333
|
|
|
490
|
-
def collect_supporting_arrays(self, collected:
|
|
334
|
+
def collect_supporting_arrays(self, collected: list[Any], *path: Any) -> None:
|
|
491
335
|
"""Collect supporting arrays, including masks for each LAM and the global dataset.
|
|
492
336
|
|
|
493
337
|
Parameters
|
|
@@ -577,7 +421,7 @@ class Cutout(GridsBase):
|
|
|
577
421
|
"""
|
|
578
422
|
return Node(self, [d.tree() for d in self.datasets])
|
|
579
423
|
|
|
580
|
-
def forwards_subclass_metadata_specific(self) ->
|
|
424
|
+
def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
|
|
581
425
|
"""Returns metadata specific to the Cutout object.
|
|
582
426
|
|
|
583
427
|
Returns
|
|
@@ -588,7 +432,7 @@ class Cutout(GridsBase):
|
|
|
588
432
|
return {}
|
|
589
433
|
|
|
590
434
|
|
|
591
|
-
def grids_factory(args:
|
|
435
|
+
def grids_factory(args: tuple[Any, ...], kwargs: dict) -> Dataset:
|
|
592
436
|
"""Factory function to create a Grids object.
|
|
593
437
|
|
|
594
438
|
Parameters
|
|
@@ -618,7 +462,7 @@ def grids_factory(args: Tuple[Any, ...], kwargs: dict) -> Dataset:
|
|
|
618
462
|
return Grids(datasets, axis=axis)._subset(**kwargs)
|
|
619
463
|
|
|
620
464
|
|
|
621
|
-
def cutout_factory(args:
|
|
465
|
+
def cutout_factory(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Dataset:
|
|
622
466
|
"""Factory function to create a Cutout object.
|
|
623
467
|
|
|
624
468
|
Parameters
|