anemoi-datasets 0.5.27__py3-none-any.whl → 0.5.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/recipe/__init__.py +93 -0
- anemoi/datasets/commands/recipe/format.py +55 -0
- anemoi/datasets/commands/recipe/migrate.py +555 -0
- anemoi/datasets/create/__init__.py +46 -13
- anemoi/datasets/create/config.py +52 -53
- anemoi/datasets/create/input/__init__.py +43 -63
- anemoi/datasets/create/input/action.py +296 -236
- anemoi/datasets/create/input/context/__init__.py +71 -0
- anemoi/datasets/create/input/context/field.py +54 -0
- anemoi/datasets/create/input/data_sources.py +2 -1
- anemoi/datasets/create/input/misc.py +0 -71
- anemoi/datasets/create/input/repeated_dates.py +0 -114
- anemoi/datasets/create/input/result/__init__.py +17 -0
- anemoi/datasets/create/input/{result.py → result/field.py} +10 -92
- anemoi/datasets/create/sources/accumulate.py +517 -0
- anemoi/datasets/create/sources/accumulate_utils/__init__.py +8 -0
- anemoi/datasets/create/sources/accumulate_utils/covering_intervals.py +221 -0
- anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py +149 -0
- anemoi/datasets/create/sources/accumulate_utils/interval_generators.py +321 -0
- anemoi/datasets/create/sources/anemoi_dataset.py +46 -42
- anemoi/datasets/create/sources/constants.py +39 -38
- anemoi/datasets/create/sources/empty.py +26 -22
- anemoi/datasets/create/sources/forcings.py +29 -28
- anemoi/datasets/create/sources/grib.py +92 -72
- anemoi/datasets/create/sources/grib_index.py +102 -54
- anemoi/datasets/create/sources/hindcasts.py +56 -55
- anemoi/datasets/create/sources/legacy.py +10 -62
- anemoi/datasets/create/sources/mars.py +159 -154
- anemoi/datasets/create/sources/netcdf.py +28 -24
- anemoi/datasets/create/sources/opendap.py +28 -24
- anemoi/datasets/create/sources/recentre.py +42 -41
- anemoi/datasets/create/sources/repeated_dates.py +44 -0
- anemoi/datasets/create/sources/source.py +26 -48
- anemoi/datasets/create/sources/xarray_support/__init__.py +30 -24
- anemoi/datasets/create/sources/xarray_support/coordinates.py +1 -4
- anemoi/datasets/create/sources/xarray_support/field.py +4 -4
- anemoi/datasets/create/sources/xarray_support/flavour.py +2 -2
- anemoi/datasets/create/sources/xarray_support/patch.py +178 -5
- anemoi/datasets/create/sources/xarray_zarr.py +28 -24
- anemoi/datasets/create/sources/zenodo.py +43 -39
- anemoi/datasets/create/utils.py +0 -42
- anemoi/datasets/data/complement.py +26 -17
- anemoi/datasets/data/dataset.py +12 -0
- anemoi/datasets/data/grids.py +0 -152
- anemoi/datasets/data/masked.py +74 -13
- anemoi/datasets/data/missing.py +5 -0
- anemoi/datasets/data/rolling_average.py +141 -0
- anemoi/datasets/data/stores.py +7 -9
- anemoi/datasets/dates/__init__.py +2 -0
- anemoi/datasets/dumper.py +76 -0
- anemoi/datasets/grids.py +1 -178
- anemoi/datasets/schemas/recipe.json +131 -0
- {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/METADATA +9 -6
- {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/RECORD +59 -57
- {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/WHEEL +1 -1
- anemoi/datasets/create/filter.py +0 -47
- anemoi/datasets/create/input/concat.py +0 -161
- anemoi/datasets/create/input/context.py +0 -86
- anemoi/datasets/create/input/empty.py +0 -53
- anemoi/datasets/create/input/filter.py +0 -117
- anemoi/datasets/create/input/function.py +0 -232
- anemoi/datasets/create/input/join.py +0 -129
- anemoi/datasets/create/input/pipe.py +0 -66
- anemoi/datasets/create/input/step.py +0 -173
- anemoi/datasets/create/input/template.py +0 -161
- anemoi/datasets/create/sources/accumulations.py +0 -1062
- anemoi/datasets/create/sources/accumulations2.py +0 -647
- anemoi/datasets/create/sources/tendencies.py +0 -198
- {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/licenses/LICENSE +0 -0
- {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/top_level.txt +0 -0
anemoi/datasets/create/utils.py
CHANGED
|
@@ -96,48 +96,6 @@ def to_datetime(*args: Any, **kwargs: Any) -> datetime.datetime:
|
|
|
96
96
|
return to_datetime_(*args, **kwargs)
|
|
97
97
|
|
|
98
98
|
|
|
99
|
-
def make_list_int(value: str | list | tuple | int) -> list[int]:
|
|
100
|
-
"""Convert a string, list, tuple, or integer to a list of integers.
|
|
101
|
-
|
|
102
|
-
Parameters
|
|
103
|
-
----------
|
|
104
|
-
value : str or list or tuple or int
|
|
105
|
-
The value to convert.
|
|
106
|
-
|
|
107
|
-
Returns
|
|
108
|
-
-------
|
|
109
|
-
list[int]
|
|
110
|
-
A list of integers.
|
|
111
|
-
|
|
112
|
-
Raises
|
|
113
|
-
------
|
|
114
|
-
ValueError
|
|
115
|
-
If the value cannot be converted to a list of integers.
|
|
116
|
-
"""
|
|
117
|
-
# Convert a string like "1/2/3" or "1/to/3" or "1/to/10/by/2" to a list of integers.
|
|
118
|
-
# Moved to anemoi.utils.humanize
|
|
119
|
-
# replace with from anemoi.utils.humanize import make_list_int
|
|
120
|
-
# when anemoi-utils is released and pyproject.toml is updated
|
|
121
|
-
if isinstance(value, str):
|
|
122
|
-
if "/" not in value:
|
|
123
|
-
return [value]
|
|
124
|
-
bits = value.split("/")
|
|
125
|
-
if len(bits) == 3 and bits[1].lower() == "to":
|
|
126
|
-
value = list(range(int(bits[0]), int(bits[2]) + 1, 1))
|
|
127
|
-
|
|
128
|
-
elif len(bits) == 5 and bits[1].lower() == "to" and bits[3].lower() == "by":
|
|
129
|
-
value = list(range(int(bits[0]), int(bits[2]) + int(bits[4]), int(bits[4])))
|
|
130
|
-
|
|
131
|
-
if isinstance(value, list):
|
|
132
|
-
return value
|
|
133
|
-
if isinstance(value, tuple):
|
|
134
|
-
return value
|
|
135
|
-
if isinstance(value, int):
|
|
136
|
-
return [value]
|
|
137
|
-
|
|
138
|
-
raise ValueError(f"Cannot make list from {value}")
|
|
139
|
-
|
|
140
|
-
|
|
141
99
|
def normalize_and_check_dates(
|
|
142
100
|
dates: list[datetime.datetime],
|
|
143
101
|
start: datetime.datetime,
|
|
@@ -293,21 +293,29 @@ class ComplementNearest(Complement):
|
|
|
293
293
|
index, previous = update_tuple(index, variable_index, slice(None))
|
|
294
294
|
source_index = [self._source.name_to_index[x] for x in self.variables[previous]]
|
|
295
295
|
source_data = self._source[index[0], source_index, index[2], ...]
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
296
|
+
if any(self._nearest_grid_points >= source_data.shape[-1]):
|
|
297
|
+
target_shape = source_data.shape[:-1] + self._target.shape[-1:]
|
|
298
|
+
target_data = np.full(target_shape, np.nan, dtype=self._target.dtype)
|
|
299
|
+
cond = self._nearest_grid_points < source_data.shape[-1]
|
|
300
|
+
reachable = np.where(cond)[0]
|
|
301
|
+
nearest_reachable = self._nearest_grid_points[cond]
|
|
302
|
+
target_data[..., reachable] = source_data[..., nearest_reachable]
|
|
303
|
+
result = target_data[..., index[3]]
|
|
304
|
+
else:
|
|
305
|
+
target_data = source_data[..., self._nearest_grid_points]
|
|
306
|
+
epsilon = 1e-8 # prevent division by zero
|
|
307
|
+
weights = 1.0 / (self._distances + epsilon)
|
|
308
|
+
weights = weights.astype(target_data.dtype)
|
|
309
|
+
weights /= weights.sum(axis=1, keepdims=True) # normalize
|
|
310
|
+
|
|
311
|
+
# Reshape weights to broadcast correctly
|
|
312
|
+
# Add leading singleton dimensions so it matches target_data shape
|
|
313
|
+
while weights.ndim < target_data.ndim:
|
|
314
|
+
weights = np.expand_dims(weights, axis=0)
|
|
315
|
+
|
|
316
|
+
# Compute weighted average along the last dimension
|
|
317
|
+
final_point = np.sum(target_data * weights, axis=-1)
|
|
318
|
+
result = final_point[..., index[3]]
|
|
311
319
|
|
|
312
320
|
return apply_index_to_slices_changes(result, changes)
|
|
313
321
|
|
|
@@ -353,8 +361,9 @@ def complement_factory(args: tuple, kwargs: dict) -> Dataset:
|
|
|
353
361
|
}[interpolation]
|
|
354
362
|
|
|
355
363
|
if interpolation == "nearest":
|
|
356
|
-
k = kwargs.pop("k",
|
|
357
|
-
|
|
364
|
+
k = kwargs.pop("k", 1)
|
|
365
|
+
max_distance = kwargs.pop("max_distance", None)
|
|
366
|
+
complement = Class(target=target, source=source, k=k, max_distance=max_distance)._subset(**kwargs)
|
|
358
367
|
|
|
359
368
|
else:
|
|
360
369
|
complement = Class(target=target, source=source)._subset(**kwargs)
|
anemoi/datasets/data/dataset.py
CHANGED
|
@@ -245,6 +245,12 @@ class Dataset(ABC, Sized):
|
|
|
245
245
|
|
|
246
246
|
return Statistics(self, open_dataset(statistics))._subset(**kwargs).mutate()
|
|
247
247
|
|
|
248
|
+
if "mask" in kwargs:
|
|
249
|
+
from .masked import Masking
|
|
250
|
+
|
|
251
|
+
mask_file = kwargs.pop("mask")
|
|
252
|
+
return Masking(self, mask_file)._subset(**kwargs).mutate()
|
|
253
|
+
|
|
248
254
|
# Note: trim_edge should go before thinning
|
|
249
255
|
if "trim_edge" in kwargs:
|
|
250
256
|
from .masked import TrimEdge
|
|
@@ -293,6 +299,12 @@ class Dataset(ABC, Sized):
|
|
|
293
299
|
if skip_missing_dates:
|
|
294
300
|
return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate()
|
|
295
301
|
|
|
302
|
+
if "rolling_average" in kwargs:
|
|
303
|
+
from .rolling_average import RollingAverage
|
|
304
|
+
|
|
305
|
+
rolling_average = kwargs.pop("rolling_average")
|
|
306
|
+
return RollingAverage(self, rolling_average)._subset(**kwargs).mutate()
|
|
307
|
+
|
|
296
308
|
if "interpolate_frequency" in kwargs:
|
|
297
309
|
from .interpolate import InterpolateFrequency
|
|
298
310
|
|
anemoi/datasets/data/grids.py
CHANGED
|
@@ -21,167 +21,15 @@ from .dataset import FullIndex
|
|
|
21
21
|
from .dataset import Shape
|
|
22
22
|
from .dataset import TupleIndex
|
|
23
23
|
from .debug import Node
|
|
24
|
-
from .debug import debug_indexing
|
|
25
|
-
from .forwards import Combined
|
|
26
24
|
from .forwards import GivenAxis
|
|
27
25
|
from .indexing import apply_index_to_slices_changes
|
|
28
|
-
from .indexing import expand_list_indexing
|
|
29
26
|
from .indexing import index_to_slices
|
|
30
|
-
from .indexing import length_to_slices
|
|
31
|
-
from .indexing import update_tuple
|
|
32
27
|
from .misc import _auto_adjust
|
|
33
28
|
from .misc import _open
|
|
34
29
|
|
|
35
30
|
LOG = logging.getLogger(__name__)
|
|
36
31
|
|
|
37
32
|
|
|
38
|
-
class Concat(Combined):
|
|
39
|
-
"""A class to represent concatenated datasets."""
|
|
40
|
-
|
|
41
|
-
def __len__(self) -> int:
|
|
42
|
-
"""Returns the total length of the concatenated datasets.
|
|
43
|
-
|
|
44
|
-
Returns
|
|
45
|
-
-------
|
|
46
|
-
int
|
|
47
|
-
Total length of the concatenated datasets.
|
|
48
|
-
"""
|
|
49
|
-
return sum(len(i) for i in self.datasets)
|
|
50
|
-
|
|
51
|
-
@debug_indexing
|
|
52
|
-
@expand_list_indexing
|
|
53
|
-
def _get_tuple(self, index: TupleIndex) -> NDArray[Any]:
|
|
54
|
-
"""Retrieves a tuple of data from the concatenated datasets based on the given index.
|
|
55
|
-
|
|
56
|
-
Parameters
|
|
57
|
-
----------
|
|
58
|
-
index : TupleIndex
|
|
59
|
-
Index specifying the data to retrieve.
|
|
60
|
-
|
|
61
|
-
Returns
|
|
62
|
-
-------
|
|
63
|
-
NDArray[Any]
|
|
64
|
-
Concatenated data array from the specified index.
|
|
65
|
-
"""
|
|
66
|
-
index, changes = index_to_slices(index, self.shape)
|
|
67
|
-
# print(index, changes)
|
|
68
|
-
lengths = [d.shape[0] for d in self.datasets]
|
|
69
|
-
slices = length_to_slices(index[0], lengths)
|
|
70
|
-
# print("slies", slices)
|
|
71
|
-
result = [d[update_tuple(index, 0, i)[0]] for (d, i) in zip(self.datasets, slices) if i is not None]
|
|
72
|
-
result = np.concatenate(result, axis=0)
|
|
73
|
-
return apply_index_to_slices_changes(result, changes)
|
|
74
|
-
|
|
75
|
-
@debug_indexing
|
|
76
|
-
def __getitem__(self, n: FullIndex) -> NDArray[Any]:
|
|
77
|
-
"""Retrieves data from the concatenated datasets based on the given index.
|
|
78
|
-
|
|
79
|
-
Parameters
|
|
80
|
-
----------
|
|
81
|
-
n : FullIndex
|
|
82
|
-
Index specifying the data to retrieve.
|
|
83
|
-
|
|
84
|
-
Returns
|
|
85
|
-
-------
|
|
86
|
-
NDArray[Any]
|
|
87
|
-
Data array from the concatenated datasets based on the index.
|
|
88
|
-
"""
|
|
89
|
-
if isinstance(n, tuple):
|
|
90
|
-
return self._get_tuple(n)
|
|
91
|
-
|
|
92
|
-
if isinstance(n, slice):
|
|
93
|
-
return self._get_slice(n)
|
|
94
|
-
|
|
95
|
-
# TODO: optimize
|
|
96
|
-
k = 0
|
|
97
|
-
while n >= self.datasets[k]._len:
|
|
98
|
-
n -= self.datasets[k]._len
|
|
99
|
-
k += 1
|
|
100
|
-
return self.datasets[k][n]
|
|
101
|
-
|
|
102
|
-
@debug_indexing
|
|
103
|
-
def _get_slice(self, s: slice) -> NDArray[Any]:
|
|
104
|
-
"""Retrieves a slice of data from the concatenated datasets.
|
|
105
|
-
|
|
106
|
-
Parameters
|
|
107
|
-
----------
|
|
108
|
-
s : slice
|
|
109
|
-
Slice object specifying the range of data to retrieve.
|
|
110
|
-
|
|
111
|
-
Returns
|
|
112
|
-
-------
|
|
113
|
-
NDArray[Any]
|
|
114
|
-
Concatenated data array from the specified slice.
|
|
115
|
-
"""
|
|
116
|
-
result = []
|
|
117
|
-
|
|
118
|
-
lengths = [d.shape[0] for d in self.datasets]
|
|
119
|
-
slices = length_to_slices(s, lengths)
|
|
120
|
-
|
|
121
|
-
result = [d[i] for (d, i) in zip(self.datasets, slices) if i is not None]
|
|
122
|
-
|
|
123
|
-
return np.concatenate(result)
|
|
124
|
-
|
|
125
|
-
def check_compatibility(self, d1: Dataset, d2: Dataset) -> None:
|
|
126
|
-
"""Check the compatibility of two datasets for concatenation.
|
|
127
|
-
|
|
128
|
-
Parameters
|
|
129
|
-
----------
|
|
130
|
-
d1 : Dataset
|
|
131
|
-
The first dataset.
|
|
132
|
-
d2 : Dataset
|
|
133
|
-
The second dataset.
|
|
134
|
-
"""
|
|
135
|
-
super().check_compatibility(d1, d2)
|
|
136
|
-
self.check_same_sub_shapes(d1, d2, drop_axis=0)
|
|
137
|
-
|
|
138
|
-
def check_same_lengths(self, d1: Dataset, d2: Dataset) -> None:
|
|
139
|
-
"""Check if the lengths of two datasets are the same.
|
|
140
|
-
|
|
141
|
-
Parameters
|
|
142
|
-
----------
|
|
143
|
-
d1 : Dataset
|
|
144
|
-
The first dataset.
|
|
145
|
-
d2 : Dataset
|
|
146
|
-
The second dataset.
|
|
147
|
-
"""
|
|
148
|
-
# Turned off because we are concatenating along the first axis
|
|
149
|
-
pass
|
|
150
|
-
|
|
151
|
-
def check_same_dates(self, d1: Dataset, d2: Dataset) -> None:
|
|
152
|
-
"""Check if the dates of two datasets are the same.
|
|
153
|
-
|
|
154
|
-
Parameters
|
|
155
|
-
----------
|
|
156
|
-
d1 : Dataset
|
|
157
|
-
The first dataset.
|
|
158
|
-
d2 : Dataset
|
|
159
|
-
The second dataset.
|
|
160
|
-
"""
|
|
161
|
-
# Turned off because we are concatenating along the dates axis
|
|
162
|
-
pass
|
|
163
|
-
|
|
164
|
-
@property
|
|
165
|
-
def dates(self) -> NDArray[np.datetime64]:
|
|
166
|
-
"""Returns the concatenated dates of all datasets."""
|
|
167
|
-
return np.concatenate([d.dates for d in self.datasets])
|
|
168
|
-
|
|
169
|
-
@property
|
|
170
|
-
def shape(self) -> Shape:
|
|
171
|
-
"""Returns the shape of the concatenated datasets."""
|
|
172
|
-
return (len(self),) + self.datasets[0].shape[1:]
|
|
173
|
-
|
|
174
|
-
def tree(self) -> Node:
|
|
175
|
-
"""Generates a hierarchical tree structure for the concatenated datasets.
|
|
176
|
-
|
|
177
|
-
Returns
|
|
178
|
-
-------
|
|
179
|
-
Node
|
|
180
|
-
A Node object representing the concatenated datasets.
|
|
181
|
-
"""
|
|
182
|
-
return Node(self, [d.tree() for d in self.datasets])
|
|
183
|
-
|
|
184
|
-
|
|
185
33
|
class GridsBase(GivenAxis):
|
|
186
34
|
"""A base class for handling grids in datasets."""
|
|
187
35
|
|
anemoi/datasets/data/masked.py
CHANGED
|
@@ -10,6 +10,7 @@
|
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
12
|
from functools import cached_property
|
|
13
|
+
from pathlib import Path
|
|
13
14
|
from typing import Any
|
|
14
15
|
|
|
15
16
|
import numpy as np
|
|
@@ -66,6 +67,12 @@ class Masked(Forwards):
|
|
|
66
67
|
"""Get the masked longitudes."""
|
|
67
68
|
return self.forward.longitudes[self.mask]
|
|
68
69
|
|
|
70
|
+
@property
|
|
71
|
+
def grids(self) -> TupleIndex:
|
|
72
|
+
"""Returns the number of grid points after masking"""
|
|
73
|
+
grids = np.sum(self.mask)
|
|
74
|
+
return (grids,)
|
|
75
|
+
|
|
69
76
|
@debug_indexing
|
|
70
77
|
def __getitem__(self, index: FullIndex) -> NDArray[Any]:
|
|
71
78
|
"""Get the masked data at the specified index.
|
|
@@ -150,19 +157,9 @@ class Thinning(Masked):
|
|
|
150
157
|
if len(shape) != 2:
|
|
151
158
|
raise ValueError("Thinning only works latitude/longitude fields")
|
|
152
159
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
latitudes = forward_latitudes.reshape(shape)
|
|
158
|
-
longitudes = forward_longitudes.reshape(shape)
|
|
159
|
-
latitudes = latitudes[::thinning, ::thinning].flatten()
|
|
160
|
-
longitudes = longitudes[::thinning, ::thinning].flatten()
|
|
161
|
-
|
|
162
|
-
# TODO: This is not very efficient
|
|
163
|
-
|
|
164
|
-
mask = [lat in latitudes and lon in longitudes for lat, lon in zip(forward_latitudes, forward_longitudes)]
|
|
165
|
-
mask = np.array(mask, dtype=bool)
|
|
160
|
+
mask = np.full(shape, False, dtype=bool)
|
|
161
|
+
mask[::thinning, ::thinning] = True
|
|
162
|
+
mask = mask.flatten()
|
|
166
163
|
else:
|
|
167
164
|
mask = None
|
|
168
165
|
|
|
@@ -200,6 +197,70 @@ class Thinning(Masked):
|
|
|
200
197
|
"""
|
|
201
198
|
return dict(thinning=self.thinning, method=self.method)
|
|
202
199
|
|
|
200
|
+
@property
|
|
201
|
+
def field_shape(self) -> Shape:
|
|
202
|
+
"""Returns the field shape of the dataset."""
|
|
203
|
+
if self.thinning is None:
|
|
204
|
+
return self.forward.field_shape
|
|
205
|
+
x, y = self.forward.field_shape
|
|
206
|
+
x = (x + self.thinning - 1) // self.thinning
|
|
207
|
+
y = (y + self.thinning - 1) // self.thinning
|
|
208
|
+
return x, y
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class Masking(Masked):
|
|
212
|
+
"""A class that applies a precomputed boolean mask from a .npy file."""
|
|
213
|
+
|
|
214
|
+
def __init__(self, forward: Dataset, mask_file: str) -> None:
|
|
215
|
+
"""Initialize the Masking class.
|
|
216
|
+
|
|
217
|
+
Parameters
|
|
218
|
+
----------
|
|
219
|
+
forward : Dataset
|
|
220
|
+
The dataset to be masked.
|
|
221
|
+
mask_file : str
|
|
222
|
+
Path to a .npy file containing a boolean mask of same shape as fields.
|
|
223
|
+
"""
|
|
224
|
+
self.mask_file = mask_file
|
|
225
|
+
|
|
226
|
+
# Check path
|
|
227
|
+
if not Path(self.mask_file).exists():
|
|
228
|
+
raise FileNotFoundError(f"Mask file not found: {self.mask_file}")
|
|
229
|
+
# Load mask
|
|
230
|
+
try:
|
|
231
|
+
mask = np.load(self.mask_file)
|
|
232
|
+
except Exception as e:
|
|
233
|
+
raise ValueError(f"Could not load data from {mask_file}: {e}")
|
|
234
|
+
|
|
235
|
+
if mask.dtype != bool:
|
|
236
|
+
raise ValueError(f"Mask file {mask_file} does not contain boolean values.")
|
|
237
|
+
if mask.shape != forward.field_shape:
|
|
238
|
+
raise ValueError(f"Mask length {mask.shape} does not match field size {forward.field_shape}.")
|
|
239
|
+
if sum(mask) == 0:
|
|
240
|
+
LOG.warning(f"Mask in {mask_file} eliminates all points in field.")
|
|
241
|
+
|
|
242
|
+
super().__init__(forward, mask)
|
|
243
|
+
|
|
244
|
+
def tree(self) -> Node:
|
|
245
|
+
"""Get the tree representation of the dataset.
|
|
246
|
+
|
|
247
|
+
Returns
|
|
248
|
+
-------
|
|
249
|
+
Node
|
|
250
|
+
The tree representation of the dataset.
|
|
251
|
+
"""
|
|
252
|
+
return Node(self, [self.forward.tree()], mask_file=self.mask_file)
|
|
253
|
+
|
|
254
|
+
def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
|
|
255
|
+
"""Get the metadata specific to the Masking subclass.
|
|
256
|
+
|
|
257
|
+
Returns
|
|
258
|
+
-------
|
|
259
|
+
Dict[str, Any]
|
|
260
|
+
The metadata specific to the Masking subclass.
|
|
261
|
+
"""
|
|
262
|
+
return dict(mask_file=self.mask_file)
|
|
263
|
+
|
|
203
264
|
|
|
204
265
|
class Cropping(Masked):
|
|
205
266
|
"""A class to represent a cropped dataset."""
|
anemoi/datasets/data/missing.py
CHANGED
|
@@ -440,3 +440,8 @@ class MissingDataset(Forwards):
|
|
|
440
440
|
Metadata specific to the subclass.
|
|
441
441
|
"""
|
|
442
442
|
return {"start": self.start, "end": self.end}
|
|
443
|
+
|
|
444
|
+
@property
|
|
445
|
+
def shape(self) -> tuple[int, ...]:
|
|
446
|
+
"""Return the shape of the dataset."""
|
|
447
|
+
return (len(self),) + self.forward.shape[1:]
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# (C) Copyright 2025 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from functools import cached_property
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from numpy.typing import NDArray
|
|
17
|
+
|
|
18
|
+
from anemoi.datasets.data.indexing import expand_list_indexing
|
|
19
|
+
|
|
20
|
+
from .dataset import Dataset
|
|
21
|
+
from .dataset import FullIndex
|
|
22
|
+
from .debug import Node
|
|
23
|
+
from .debug import debug_indexing
|
|
24
|
+
from .forwards import Forwards
|
|
25
|
+
|
|
26
|
+
LOG = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class RollingAverage(Forwards):
|
|
30
|
+
"""A class to represent a dataset with interpolated frequency."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, dataset: Dataset, window: str | tuple[int, int, str]) -> None:
|
|
33
|
+
"""Initialize the RollingAverage class.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
dataset : Dataset
|
|
38
|
+
The dataset to be averaged with a rolling window.
|
|
39
|
+
window : (int, int, str)
|
|
40
|
+
The rolling average window (start, end, 'freq').
|
|
41
|
+
'freq' means the window is in number of time steps in the dataset.
|
|
42
|
+
Both start and end are inclusive, i.e. window = (-2, 2, 'freq') means a window of 5 time steps.
|
|
43
|
+
For now, only 'freq' is supported, in the future other units may be supported.
|
|
44
|
+
Windows such as "[-2h, +2h]" are not supported yet.
|
|
45
|
+
"""
|
|
46
|
+
super().__init__(dataset)
|
|
47
|
+
if not (isinstance(window, (list, tuple)) and len(window) == 3):
|
|
48
|
+
raise ValueError(f"Window must be (int, int, str), got {window}")
|
|
49
|
+
if not isinstance(window[0], int) or not isinstance(window[1], int) or not isinstance(window[2], str):
|
|
50
|
+
raise ValueError(f"Window must be (int, int, str), got {window}")
|
|
51
|
+
if window[2] not in ["freq", "frequency"]:
|
|
52
|
+
raise NotImplementedError(f"Window must be (int, int, 'freq'), got {window}")
|
|
53
|
+
|
|
54
|
+
# window = (0, 0, 'freq') means no change
|
|
55
|
+
self.i_start = -window[0]
|
|
56
|
+
self.i_end = window[1] + 1
|
|
57
|
+
if self.i_start <= 0:
|
|
58
|
+
raise ValueError(f"Window start must be negative, got {window}")
|
|
59
|
+
if self.i_end <= 0:
|
|
60
|
+
raise ValueError(f"Window end must be positive, got {window}")
|
|
61
|
+
|
|
62
|
+
self.window_str = f"-{self.i_start}-to-{self.i_end}"
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def shape(self):
|
|
66
|
+
shape = list(self.forward.shape)
|
|
67
|
+
shape[0] = len(self)
|
|
68
|
+
return tuple(shape)
|
|
69
|
+
|
|
70
|
+
@debug_indexing
|
|
71
|
+
@expand_list_indexing
|
|
72
|
+
def __getitem__(self, n: FullIndex) -> NDArray[Any]:
|
|
73
|
+
def f(array):
|
|
74
|
+
return np.nanmean(array, axis=0)
|
|
75
|
+
|
|
76
|
+
if isinstance(n, slice):
|
|
77
|
+
n = (n,)
|
|
78
|
+
|
|
79
|
+
if isinstance(n, tuple):
|
|
80
|
+
first = n[0]
|
|
81
|
+
if len(n) > 1:
|
|
82
|
+
rest = n[1:]
|
|
83
|
+
else:
|
|
84
|
+
rest = ()
|
|
85
|
+
|
|
86
|
+
if isinstance(first, int):
|
|
87
|
+
slice_ = slice(first, first + self.i_start + self.i_end)
|
|
88
|
+
data = self.forward[(slice_,) + rest]
|
|
89
|
+
return f(data)
|
|
90
|
+
|
|
91
|
+
if isinstance(first, slice):
|
|
92
|
+
first = list(range(first.start or 0, first.stop or len(self), first.step or 1))
|
|
93
|
+
|
|
94
|
+
if isinstance(first, (list, tuple)):
|
|
95
|
+
first = [i if i >= 0 else len(self) + i for i in first]
|
|
96
|
+
if any(i >= len(self) for i in first):
|
|
97
|
+
raise IndexError(f"Index out of range: {first}")
|
|
98
|
+
slices = [slice(i, i + self.i_start + self.i_end) for i in first]
|
|
99
|
+
data = [self.forward[(slice_,) + rest] for slice_ in slices]
|
|
100
|
+
res = [f(d) for d in data]
|
|
101
|
+
return np.array(res)
|
|
102
|
+
|
|
103
|
+
assert False, f"Expected int, slice, list or tuple as first element of tuple, got {type(first)}"
|
|
104
|
+
|
|
105
|
+
assert isinstance(n, int), f"Expected int, slice, tuple, got {type(n)}"
|
|
106
|
+
|
|
107
|
+
if n < 0:
|
|
108
|
+
n = len(self) + n
|
|
109
|
+
if n >= len(self):
|
|
110
|
+
raise IndexError(f"Index out of range: {n}")
|
|
111
|
+
|
|
112
|
+
slice_ = slice(n, n + self.i_start + self.i_end)
|
|
113
|
+
data = self.forward[slice_]
|
|
114
|
+
return f(data)
|
|
115
|
+
|
|
116
|
+
def __len__(self) -> int:
|
|
117
|
+
return len(self.forward) - (self.i_end + self.i_start - 1)
|
|
118
|
+
|
|
119
|
+
@cached_property
|
|
120
|
+
def dates(self) -> NDArray[np.datetime64]:
|
|
121
|
+
"""Get the interpolated dates."""
|
|
122
|
+
dates = self.forward.dates
|
|
123
|
+
return dates[self.i_start : len(dates) - self.i_end + 1]
|
|
124
|
+
|
|
125
|
+
def tree(self) -> Node:
|
|
126
|
+
return Node(self, [self.forward.tree()], window=self.window_str)
|
|
127
|
+
|
|
128
|
+
@cached_property
|
|
129
|
+
def missing(self) -> set[int]:
|
|
130
|
+
"""Get the missing data indices."""
|
|
131
|
+
result = []
|
|
132
|
+
|
|
133
|
+
for i in self.forward.missing:
|
|
134
|
+
for j in range(0, self.i_end + self.i_start):
|
|
135
|
+
result.append(i + j)
|
|
136
|
+
|
|
137
|
+
result = {x for x in result if x < self._len}
|
|
138
|
+
return result
|
|
139
|
+
|
|
140
|
+
def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
|
|
141
|
+
return {}
|
anemoi/datasets/data/stores.py
CHANGED
|
@@ -85,22 +85,20 @@ class S3Store(ReadOnlyStore):
|
|
|
85
85
|
options using the anemoi configs.
|
|
86
86
|
"""
|
|
87
87
|
|
|
88
|
-
def __init__(self, url: str
|
|
89
|
-
"""Initialize the S3Store with a URL
|
|
90
|
-
from anemoi.utils.remote.s3 import s3_client
|
|
88
|
+
def __init__(self, url: str) -> None:
|
|
89
|
+
"""Initialize the S3Store with a URL."""
|
|
91
90
|
|
|
92
|
-
|
|
93
|
-
self.s3 = s3_client(self.bucket, region=region)
|
|
91
|
+
self.url = url
|
|
94
92
|
|
|
95
93
|
def __getitem__(self, key: str) -> bytes:
|
|
96
94
|
"""Retrieve an item from the store."""
|
|
95
|
+
from anemoi.utils.remote.s3 import get_object
|
|
96
|
+
|
|
97
97
|
try:
|
|
98
|
-
|
|
99
|
-
except
|
|
98
|
+
return get_object(os.path.join(self.url, key))
|
|
99
|
+
except FileNotFoundError:
|
|
100
100
|
raise KeyError(key)
|
|
101
101
|
|
|
102
|
-
return response["Body"].read()
|
|
103
|
-
|
|
104
102
|
|
|
105
103
|
class DebugStore(ReadOnlyStore):
|
|
106
104
|
"""A store to debug the zarr loading."""
|
|
@@ -58,6 +58,8 @@ def extend(x: str | list[Any] | tuple[Any, ...]) -> Iterator[datetime.datetime]:
|
|
|
58
58
|
class DatesProvider:
|
|
59
59
|
"""Base class for date generation.
|
|
60
60
|
|
|
61
|
+
Examples
|
|
62
|
+
--------
|
|
61
63
|
>>> DatesProvider.from_config(**{"start": "2023-01-01 00:00", "end": "2023-01-02 00:00", "frequency": "1d"}).values
|
|
62
64
|
[datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 2, 0, 0)]
|
|
63
65
|
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# (C) Copyright 2025 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
import datetime
|
|
11
|
+
import io
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
import ruamel.yaml
|
|
15
|
+
|
|
16
|
+
LOG = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def represent_date(dumper, data):
|
|
20
|
+
|
|
21
|
+
if isinstance(data, datetime.datetime):
|
|
22
|
+
if data.tzinfo is None:
|
|
23
|
+
data = data.replace(tzinfo=datetime.timezone.utc)
|
|
24
|
+
data = data.astimezone(datetime.timezone.utc)
|
|
25
|
+
iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z"
|
|
26
|
+
else:
|
|
27
|
+
iso_str = data.isoformat()
|
|
28
|
+
|
|
29
|
+
return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# --- Represent multiline strings with | style ---
|
|
33
|
+
def represent_multiline_str(dumper, data):
|
|
34
|
+
if "\n" in data:
|
|
35
|
+
return dumper.represent_scalar("tag:yaml.org,2002:str", data.strip(), style="|")
|
|
36
|
+
return dumper.represent_scalar("tag:yaml.org,2002:str", data)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# --- Represent short lists inline (flow style) ---
|
|
40
|
+
def represent_inline_list(dumper, data):
|
|
41
|
+
|
|
42
|
+
if not all(isinstance(i, (str, int, float, bool, type(None))) for i in data):
|
|
43
|
+
return dumper.represent_sequence("tag:yaml.org,2002:seq", data)
|
|
44
|
+
|
|
45
|
+
return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def yaml_dump(obj, order=None, stream=None, **kwargs):
|
|
49
|
+
|
|
50
|
+
if order:
|
|
51
|
+
|
|
52
|
+
def _ordering(k):
|
|
53
|
+
return order.index(k) if k in order else len(order)
|
|
54
|
+
|
|
55
|
+
obj = {k: v for k, v in sorted(obj.items(), key=lambda item: _ordering(item[0]))}
|
|
56
|
+
|
|
57
|
+
yaml = ruamel.yaml.YAML()
|
|
58
|
+
yaml.width = 120 # wrap long flow sequences
|
|
59
|
+
|
|
60
|
+
yaml.Representer.add_representer(datetime.date, represent_date)
|
|
61
|
+
yaml.Representer.add_representer(datetime.datetime, represent_date)
|
|
62
|
+
yaml.Representer.add_representer(str, represent_multiline_str)
|
|
63
|
+
yaml.Representer.add_representer(list, represent_inline_list)
|
|
64
|
+
|
|
65
|
+
data = ruamel.yaml.comments.CommentedMap()
|
|
66
|
+
for i, (k, v) in enumerate(obj.items()):
|
|
67
|
+
data[k] = v
|
|
68
|
+
if i > 0:
|
|
69
|
+
data.yaml_set_comment_before_after_key(key=k, before="\n")
|
|
70
|
+
|
|
71
|
+
if stream:
|
|
72
|
+
yaml.dump(data, stream=stream, **kwargs)
|
|
73
|
+
|
|
74
|
+
stream = io.StringIO()
|
|
75
|
+
yaml.dump(data, stream=stream, **kwargs)
|
|
76
|
+
return stream.getvalue()
|