anemoi-datasets 0.5.26__py3-none-any.whl → 0.5.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/__init__.py +1 -2
- anemoi/datasets/_version.py +16 -3
- anemoi/datasets/commands/check.py +1 -1
- anemoi/datasets/commands/copy.py +1 -2
- anemoi/datasets/commands/create.py +1 -1
- anemoi/datasets/commands/inspect.py +27 -35
- anemoi/datasets/commands/recipe/__init__.py +93 -0
- anemoi/datasets/commands/recipe/format.py +55 -0
- anemoi/datasets/commands/recipe/migrate.py +555 -0
- anemoi/datasets/commands/validate.py +59 -0
- anemoi/datasets/compute/recentre.py +3 -6
- anemoi/datasets/create/__init__.py +64 -26
- anemoi/datasets/create/check.py +10 -12
- anemoi/datasets/create/chunks.py +1 -2
- anemoi/datasets/create/config.py +5 -6
- anemoi/datasets/create/input/__init__.py +44 -65
- anemoi/datasets/create/input/action.py +296 -238
- anemoi/datasets/create/input/context/__init__.py +71 -0
- anemoi/datasets/create/input/context/field.py +54 -0
- anemoi/datasets/create/input/data_sources.py +7 -9
- anemoi/datasets/create/input/misc.py +2 -75
- anemoi/datasets/create/input/repeated_dates.py +11 -130
- anemoi/datasets/{utils → create/input/result}/__init__.py +10 -1
- anemoi/datasets/create/input/{result.py → result/field.py} +36 -120
- anemoi/datasets/create/input/trace.py +1 -1
- anemoi/datasets/create/patch.py +1 -2
- anemoi/datasets/create/persistent.py +3 -5
- anemoi/datasets/create/size.py +1 -3
- anemoi/datasets/create/sources/accumulations.py +120 -145
- anemoi/datasets/create/sources/accumulations2.py +20 -53
- anemoi/datasets/create/sources/anemoi_dataset.py +46 -42
- anemoi/datasets/create/sources/constants.py +39 -40
- anemoi/datasets/create/sources/empty.py +22 -19
- anemoi/datasets/create/sources/fdb.py +133 -0
- anemoi/datasets/create/sources/forcings.py +29 -29
- anemoi/datasets/create/sources/grib.py +94 -78
- anemoi/datasets/create/sources/grib_index.py +57 -55
- anemoi/datasets/create/sources/hindcasts.py +57 -59
- anemoi/datasets/create/sources/legacy.py +10 -62
- anemoi/datasets/create/sources/mars.py +121 -149
- anemoi/datasets/create/sources/netcdf.py +28 -25
- anemoi/datasets/create/sources/opendap.py +28 -26
- anemoi/datasets/create/sources/patterns.py +4 -6
- anemoi/datasets/create/sources/recentre.py +46 -48
- anemoi/datasets/create/sources/repeated_dates.py +44 -0
- anemoi/datasets/create/sources/source.py +26 -51
- anemoi/datasets/create/sources/tendencies.py +68 -98
- anemoi/datasets/create/sources/xarray.py +4 -6
- anemoi/datasets/create/sources/xarray_support/__init__.py +40 -36
- anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -12
- anemoi/datasets/create/sources/xarray_support/field.py +20 -16
- anemoi/datasets/create/sources/xarray_support/fieldlist.py +11 -15
- anemoi/datasets/create/sources/xarray_support/flavour.py +42 -42
- anemoi/datasets/create/sources/xarray_support/grid.py +15 -9
- anemoi/datasets/create/sources/xarray_support/metadata.py +19 -128
- anemoi/datasets/create/sources/xarray_support/patch.py +4 -6
- anemoi/datasets/create/sources/xarray_support/time.py +10 -13
- anemoi/datasets/create/sources/xarray_support/variable.py +21 -21
- anemoi/datasets/create/sources/xarray_zarr.py +28 -25
- anemoi/datasets/create/sources/zenodo.py +43 -41
- anemoi/datasets/create/statistics/__init__.py +3 -6
- anemoi/datasets/create/testing.py +4 -0
- anemoi/datasets/create/typing.py +1 -2
- anemoi/datasets/create/utils.py +0 -43
- anemoi/datasets/create/zarr.py +7 -2
- anemoi/datasets/data/__init__.py +15 -6
- anemoi/datasets/data/complement.py +7 -12
- anemoi/datasets/data/concat.py +5 -8
- anemoi/datasets/data/dataset.py +48 -47
- anemoi/datasets/data/debug.py +7 -9
- anemoi/datasets/data/ensemble.py +4 -6
- anemoi/datasets/data/fill_missing.py +7 -10
- anemoi/datasets/data/forwards.py +22 -26
- anemoi/datasets/data/grids.py +12 -168
- anemoi/datasets/data/indexing.py +9 -12
- anemoi/datasets/data/interpolate.py +7 -15
- anemoi/datasets/data/join.py +8 -12
- anemoi/datasets/data/masked.py +6 -11
- anemoi/datasets/data/merge.py +5 -9
- anemoi/datasets/data/misc.py +41 -45
- anemoi/datasets/data/missing.py +11 -16
- anemoi/datasets/data/observations/__init__.py +8 -14
- anemoi/datasets/data/padded.py +3 -5
- anemoi/datasets/data/records/backends/__init__.py +2 -2
- anemoi/datasets/data/rescale.py +5 -12
- anemoi/datasets/data/rolling_average.py +141 -0
- anemoi/datasets/data/select.py +13 -16
- anemoi/datasets/data/statistics.py +4 -7
- anemoi/datasets/data/stores.py +22 -29
- anemoi/datasets/data/subset.py +8 -11
- anemoi/datasets/data/unchecked.py +7 -11
- anemoi/datasets/data/xy.py +25 -21
- anemoi/datasets/dates/__init__.py +15 -18
- anemoi/datasets/dates/groups.py +7 -10
- anemoi/datasets/dumper.py +76 -0
- anemoi/datasets/grids.py +4 -185
- anemoi/datasets/schemas/recipe.json +131 -0
- anemoi/datasets/testing.py +93 -7
- anemoi/datasets/validate.py +598 -0
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/METADATA +7 -4
- anemoi_datasets-0.5.28.dist-info/RECORD +134 -0
- anemoi/datasets/create/filter.py +0 -48
- anemoi/datasets/create/input/concat.py +0 -164
- anemoi/datasets/create/input/context.py +0 -89
- anemoi/datasets/create/input/empty.py +0 -54
- anemoi/datasets/create/input/filter.py +0 -118
- anemoi/datasets/create/input/function.py +0 -233
- anemoi/datasets/create/input/join.py +0 -130
- anemoi/datasets/create/input/pipe.py +0 -66
- anemoi/datasets/create/input/step.py +0 -177
- anemoi/datasets/create/input/template.py +0 -162
- anemoi_datasets-0.5.26.dist-info/RECORD +0 -131
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/WHEEL +0 -0
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/licenses/LICENSE +0 -0
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,598 @@
|
|
|
1
|
+
# (C) Copyright 2025- Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import math
|
|
13
|
+
from collections import defaultdict
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
from anemoi.datasets.data.dataset import Dataset
|
|
18
|
+
from anemoi.datasets.testing import default_test_indexing
|
|
19
|
+
|
|
20
|
+
LOG = logging.getLogger(__name__)
|
|
21
|
+
# List of methods called during training. To update the list, run training with ANEMOI_DATASETS_TRACE=1
|
|
22
|
+
|
|
23
|
+
TRAINING_METHODS = [
|
|
24
|
+
"__getitem__",
|
|
25
|
+
"__len__",
|
|
26
|
+
"latitudes",
|
|
27
|
+
"longitudes",
|
|
28
|
+
"metadata", # Accessed when checkpointing
|
|
29
|
+
"missing",
|
|
30
|
+
"name_to_index",
|
|
31
|
+
"shape",
|
|
32
|
+
"statistics",
|
|
33
|
+
"supporting_arrays", # Accessed when checkpointing
|
|
34
|
+
"variables",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
EXTRA_TRAINING_METHODS = [
|
|
38
|
+
"statistics_tendencies",
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
DEBUGGING_METHODS = [
|
|
42
|
+
"plot",
|
|
43
|
+
"to_index",
|
|
44
|
+
"tree",
|
|
45
|
+
"source",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
PUBLIC_METADATA_METHODS = [
|
|
49
|
+
"arguments",
|
|
50
|
+
"dtype",
|
|
51
|
+
"end_date",
|
|
52
|
+
"resolution",
|
|
53
|
+
"start_date",
|
|
54
|
+
"field_shape",
|
|
55
|
+
"frequency",
|
|
56
|
+
"dates",
|
|
57
|
+
"typed_variables",
|
|
58
|
+
"variables_metadata",
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
PRIVATE_METADATA_METHODS = [
|
|
62
|
+
"computed_constant_fields",
|
|
63
|
+
"constant_fields",
|
|
64
|
+
"dataset_metadata",
|
|
65
|
+
"label",
|
|
66
|
+
"metadata_specific",
|
|
67
|
+
"provenance",
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
INTERNAL_METHODS = [
|
|
71
|
+
"mutate",
|
|
72
|
+
"swap_with_parent",
|
|
73
|
+
"dates_interval_to_indices",
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
EXPERIMENTAL_METHODS = [
|
|
77
|
+
"get_dataset_names",
|
|
78
|
+
"name",
|
|
79
|
+
"grids",
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
OTHER_METHODS = [
|
|
83
|
+
"collect_input_sources",
|
|
84
|
+
"collect_supporting_arrays",
|
|
85
|
+
"sub_shape",
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
METHODS_CATEGORIES = {k: v for k, v in list(globals().items()) if k.endswith("_METHODS")}
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
METHODS = set(sum(METHODS_CATEGORIES.values(), []))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
KWARGS = {
|
|
96
|
+
"__len__": {},
|
|
97
|
+
"__getitem__": {"index": 0},
|
|
98
|
+
"get_dataset_names": {"names": set()},
|
|
99
|
+
"metadata": {},
|
|
100
|
+
"metadata_specific": {},
|
|
101
|
+
"mutate": {},
|
|
102
|
+
"plot": {"date": 0, "variable": 0},
|
|
103
|
+
"provenance": {},
|
|
104
|
+
"source": {"index": 0},
|
|
105
|
+
"statistics_tendencies": {},
|
|
106
|
+
"sub_shape": {},
|
|
107
|
+
"supporting_arrays": {},
|
|
108
|
+
"swap_with_parent": {},
|
|
109
|
+
"to_index": {"date": 0, "variable": 0},
|
|
110
|
+
"tree": {},
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class Unknown:
|
|
115
|
+
emoji = "❓"
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class Success:
|
|
119
|
+
emoji = "✅"
|
|
120
|
+
success = True
|
|
121
|
+
|
|
122
|
+
def __repr__(self):
|
|
123
|
+
return "Success"
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class Error:
|
|
127
|
+
success = False
|
|
128
|
+
|
|
129
|
+
def __init__(self, message):
|
|
130
|
+
self.message = message
|
|
131
|
+
|
|
132
|
+
def __repr__(self):
|
|
133
|
+
return str(self.message) or repr(self.message) or "Error"
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class Failure(Error):
|
|
137
|
+
emoji = "💥"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class Internal(Error):
|
|
141
|
+
emoji = "💣"
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class Invalid(Error):
|
|
145
|
+
emoji = "❌"
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class Report:
|
|
149
|
+
|
|
150
|
+
def __init__(self):
|
|
151
|
+
self.report = {}
|
|
152
|
+
self.methods = {}
|
|
153
|
+
self.warnings = defaultdict(list)
|
|
154
|
+
|
|
155
|
+
def method(self, name, method):
|
|
156
|
+
self.methods[name] = method
|
|
157
|
+
|
|
158
|
+
def success(self, name):
|
|
159
|
+
self.report[name] = Success()
|
|
160
|
+
|
|
161
|
+
def failure(self, name, message):
|
|
162
|
+
self.report[name] = Failure(message)
|
|
163
|
+
|
|
164
|
+
def internal(self, name, message):
|
|
165
|
+
self.report[name] = Internal(message)
|
|
166
|
+
|
|
167
|
+
def invalid(self, name, exception):
|
|
168
|
+
self.report[name] = Invalid(exception)
|
|
169
|
+
|
|
170
|
+
def warning(self, name, message):
|
|
171
|
+
self.warnings[name].append(message)
|
|
172
|
+
|
|
173
|
+
def summary(self, detailed=False):
|
|
174
|
+
|
|
175
|
+
maxlen = max(len(name) for name in self.report.keys())
|
|
176
|
+
|
|
177
|
+
for name, methods in METHODS_CATEGORIES.items():
|
|
178
|
+
print()
|
|
179
|
+
print(f"{name.title().replace('_', ' ')}:")
|
|
180
|
+
print("-" * (len(name) + 1))
|
|
181
|
+
print()
|
|
182
|
+
|
|
183
|
+
for method in methods:
|
|
184
|
+
r = self.report.get(method, Unknown())
|
|
185
|
+
msg = repr(r)
|
|
186
|
+
if not msg.endswith("."):
|
|
187
|
+
msg += "."
|
|
188
|
+
print(f"{r.emoji} {method.ljust(maxlen)}: {msg}")
|
|
189
|
+
|
|
190
|
+
for w in self.warnings.get(method, []):
|
|
191
|
+
print(" " * (maxlen + 4), "⚠️", w)
|
|
192
|
+
|
|
193
|
+
if r.success:
|
|
194
|
+
continue
|
|
195
|
+
|
|
196
|
+
if not detailed:
|
|
197
|
+
continue
|
|
198
|
+
|
|
199
|
+
if method not in self.methods:
|
|
200
|
+
continue
|
|
201
|
+
|
|
202
|
+
proc = self.methods[method]
|
|
203
|
+
|
|
204
|
+
doc = proc.__doc__
|
|
205
|
+
if doc:
|
|
206
|
+
width = 80
|
|
207
|
+
indent = maxlen + 4
|
|
208
|
+
doc = "\n".join(["=" * width, "", doc, "=" * width])
|
|
209
|
+
indented_doc = "\n".join(" " * indent + line for line in doc.splitlines())
|
|
210
|
+
print()
|
|
211
|
+
print(indented_doc)
|
|
212
|
+
print()
|
|
213
|
+
print()
|
|
214
|
+
|
|
215
|
+
print()
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _no_validate(report, dataset, name, result):
|
|
219
|
+
report.warning(name, f"Validation for {name} not implemented. Result: {type(result)}")
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def validate_variables(report, dataset, name, result):
|
|
223
|
+
"""Validate the variables of the dataset."""
|
|
224
|
+
|
|
225
|
+
if not isinstance(result, (list, tuple)):
|
|
226
|
+
raise ValueError(f"Result is not a list or tuple {type(result)}")
|
|
227
|
+
|
|
228
|
+
if len(result) != dataset.shape[1]:
|
|
229
|
+
raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[1]}")
|
|
230
|
+
|
|
231
|
+
for value in result:
|
|
232
|
+
if not isinstance(value, str):
|
|
233
|
+
raise ValueError(f"`{value}` is not a string")
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def validate_latitudes(report, dataset, name, result):
|
|
237
|
+
"""Validate the latitudes of the dataset."""
|
|
238
|
+
|
|
239
|
+
if not isinstance(result, np.ndarray):
|
|
240
|
+
raise ValueError(f"Result is not a np.ndarray {type(result)}")
|
|
241
|
+
|
|
242
|
+
if len(result) != dataset.shape[3]:
|
|
243
|
+
raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[3]}")
|
|
244
|
+
|
|
245
|
+
if not np.all(np.isfinite(result)):
|
|
246
|
+
raise ValueError("Result contains non-finite values")
|
|
247
|
+
|
|
248
|
+
if np.isnan(result).any():
|
|
249
|
+
report.invalid(name, ValueError("Result contains NaN values"))
|
|
250
|
+
return
|
|
251
|
+
|
|
252
|
+
if not np.all((result >= -90) & (result <= 90)):
|
|
253
|
+
raise ValueError("Result contains values outside the range [-90, 90]")
|
|
254
|
+
|
|
255
|
+
if np.all((result >= -np.pi) & (result <= np.pi)):
|
|
256
|
+
report.warning(name, "All latitudes are in the range [-π, π]. Are they in radians?")
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def validate_longitudes(report, dataset, name, result):
|
|
260
|
+
"""Validate the longitudes of the dataset."""
|
|
261
|
+
|
|
262
|
+
if not isinstance(result, np.ndarray):
|
|
263
|
+
raise ValueError(f"Result is not a np.ndarray {type(result)}")
|
|
264
|
+
|
|
265
|
+
if len(result) != dataset.shape[3]:
|
|
266
|
+
raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[2]}")
|
|
267
|
+
|
|
268
|
+
if not np.all(np.isfinite(result)):
|
|
269
|
+
raise ValueError("Result contains non-finite values")
|
|
270
|
+
|
|
271
|
+
if np.isnan(result).any():
|
|
272
|
+
report.invalid(name, ValueError("Result contains NaN values"))
|
|
273
|
+
return
|
|
274
|
+
|
|
275
|
+
if not np.all((result >= -180) & (result <= 360)):
|
|
276
|
+
raise ValueError("Result contains values outside the range [-180, 360]")
|
|
277
|
+
|
|
278
|
+
if np.all((result >= -np.pi) & (result <= 2 * np.pi)):
|
|
279
|
+
report.warning(name, "All longitudes are in the range [-π, 2π]. Are they in radians?")
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def validate_statistics(report, dataset, name, result):
|
|
283
|
+
"""Validate the statistics of the dataset."""
|
|
284
|
+
|
|
285
|
+
if not isinstance(result, dict):
|
|
286
|
+
raise ValueError(f"Result is not a dict {type(result)}")
|
|
287
|
+
|
|
288
|
+
for key in ["mean", "stdev", "minimum", "maximum"]:
|
|
289
|
+
|
|
290
|
+
if key not in result:
|
|
291
|
+
raise ValueError(f"Result does not contain `{key}`")
|
|
292
|
+
|
|
293
|
+
if not isinstance(result[key], np.ndarray):
|
|
294
|
+
raise ValueError(f"Result[{key}] is not a np.ndarray {type(result[key])}")
|
|
295
|
+
|
|
296
|
+
if len(result[key].shape) != 1:
|
|
297
|
+
raise ValueError(f"Result[{key}] has wrong shape: {len(result[key].shape)} != 1")
|
|
298
|
+
|
|
299
|
+
if result[key].shape[0] != len(dataset.variables):
|
|
300
|
+
raise ValueError(f"Result[{key}] has wrong length: {result[key].shape[0]} != {len(dataset.variables)}")
|
|
301
|
+
|
|
302
|
+
if not np.all(np.isfinite(result[key])):
|
|
303
|
+
raise ValueError(f"Result[{key}] contains non-finite values")
|
|
304
|
+
|
|
305
|
+
if np.isnan(result[key]).any():
|
|
306
|
+
report.invalid(name, ValueError(f"Result[{key}] contains NaN values"))
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def validate_shape(report, dataset, name, result):
|
|
310
|
+
"""Validate the shape of the dataset."""
|
|
311
|
+
|
|
312
|
+
if not isinstance(result, tuple):
|
|
313
|
+
raise ValueError(f"Result is not a tuple {type(result)}")
|
|
314
|
+
|
|
315
|
+
if len(result) != 4:
|
|
316
|
+
raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.shape)}")
|
|
317
|
+
|
|
318
|
+
if result[0] != len(dataset):
|
|
319
|
+
raise ValueError(f"Result[0] has wrong length: {result[0]} != {len(dataset)}")
|
|
320
|
+
|
|
321
|
+
if result[1] != len(dataset.variables):
|
|
322
|
+
raise ValueError(f"Result[1] has wrong length: {result[1]} != {len(dataset.variables)}")
|
|
323
|
+
|
|
324
|
+
if result[2] != 1: # We ignore ensemble dimension for now
|
|
325
|
+
pass
|
|
326
|
+
|
|
327
|
+
if result[3] != len(dataset.latitudes):
|
|
328
|
+
raise ValueError(f"Result[3] has wrong length: {result[3]} != {len(dataset.latitudes)}")
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def validate_supporting_arrays(report, dataset, name, result):
|
|
332
|
+
"""Validate the supporting arrays of the dataset."""
|
|
333
|
+
|
|
334
|
+
if not isinstance(result, dict):
|
|
335
|
+
raise ValueError(f"Result is not a dict {type(result)}")
|
|
336
|
+
|
|
337
|
+
if "latitudes" not in result:
|
|
338
|
+
raise ValueError("Result does not contain `latitudes`")
|
|
339
|
+
|
|
340
|
+
if "longitudes" not in result:
|
|
341
|
+
raise ValueError("Result does not contain `longitudes`")
|
|
342
|
+
|
|
343
|
+
if not isinstance(result["latitudes"], np.ndarray):
|
|
344
|
+
raise ValueError(f"Result[latitudes] is not a np.ndarray {type(result['latitudes'])}")
|
|
345
|
+
|
|
346
|
+
if not isinstance(result["longitudes"], np.ndarray):
|
|
347
|
+
raise ValueError(f"Result[longitudes] is not a np.ndarray {type(result['longitudes'])}")
|
|
348
|
+
|
|
349
|
+
if np.any(result["latitudes"] != dataset.latitudes):
|
|
350
|
+
raise ValueError("Result[latitudes] does not match dataset.latitudes")
|
|
351
|
+
|
|
352
|
+
if np.any(result["longitudes"] != dataset.longitudes):
|
|
353
|
+
raise ValueError("Result[longitudes] does not match dataset.longitudes")
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def validate_dates(report, dataset, name, result):
|
|
357
|
+
"""Validate the dates of the dataset."""
|
|
358
|
+
|
|
359
|
+
if not isinstance(result, np.ndarray):
|
|
360
|
+
raise ValueError(f"Result is not a np.ndarray {type(result)}")
|
|
361
|
+
|
|
362
|
+
if len(result.shape) != 1:
|
|
363
|
+
raise ValueError(f"Result has wrong shape: {len(result.shape)} != 1")
|
|
364
|
+
|
|
365
|
+
if result.shape[0] != len(dataset.dates):
|
|
366
|
+
raise ValueError(f"Result has wrong length: {result.shape[0]} != {len(dataset.dates)}")
|
|
367
|
+
|
|
368
|
+
if not np.issubdtype(result.dtype, np.datetime64):
|
|
369
|
+
raise ValueError(f"Result is not a datetime64 array {result.dtype}")
|
|
370
|
+
|
|
371
|
+
if len(result) != len(dataset.dates):
|
|
372
|
+
raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.dates)}")
|
|
373
|
+
|
|
374
|
+
if not np.all(np.isfinite(result)):
|
|
375
|
+
raise ValueError("Result contains non-finite values")
|
|
376
|
+
|
|
377
|
+
if np.isnan(result).any():
|
|
378
|
+
report.invalid(name, ValueError("Result contains NaN values"))
|
|
379
|
+
return
|
|
380
|
+
|
|
381
|
+
for d1, d2 in zip(result[:-1], result[1:]):
|
|
382
|
+
if d1 >= d2:
|
|
383
|
+
raise ValueError(f"Result contains non-increasing dates: {d1} >= {d2}")
|
|
384
|
+
|
|
385
|
+
frequency = np.diff(result)
|
|
386
|
+
if not np.all(frequency == frequency[0]):
|
|
387
|
+
raise ValueError("Result contains non-constant frequency")
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def validate_metadata(report, dataset, name, result):
|
|
391
|
+
"""Validate the metadata of the dataset."""
|
|
392
|
+
|
|
393
|
+
if not isinstance(result, dict):
|
|
394
|
+
raise ValueError(f"Result is not a dict {type(result)}")
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def validate_missing(report, dataset, name, result):
|
|
398
|
+
"""Validate the missing values of the dataset."""
|
|
399
|
+
|
|
400
|
+
if not isinstance(result, set):
|
|
401
|
+
raise ValueError(f"Result is not a set {type(result)}")
|
|
402
|
+
|
|
403
|
+
if not all(isinstance(item, int) for item in result):
|
|
404
|
+
raise ValueError("Result contains non-integer values")
|
|
405
|
+
|
|
406
|
+
if len(result) > 0:
|
|
407
|
+
if min(result) < 0:
|
|
408
|
+
raise ValueError("Result contains negative values")
|
|
409
|
+
|
|
410
|
+
if max(result) >= len(dataset):
|
|
411
|
+
raise ValueError(f"Result contains values greater than {len(dataset)}")
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def validate_name_to_index(report, dataset, name, result):
|
|
415
|
+
"""Validate the name to index mapping of the dataset."""
|
|
416
|
+
|
|
417
|
+
if not isinstance(result, dict):
|
|
418
|
+
raise ValueError(f"Result is not a dict {type(result)}")
|
|
419
|
+
|
|
420
|
+
for key in dataset.variables:
|
|
421
|
+
if key not in result:
|
|
422
|
+
raise ValueError(f"Result does not contain `{key}`")
|
|
423
|
+
|
|
424
|
+
if not isinstance(result[key], int):
|
|
425
|
+
raise ValueError(f"Result[{key}] is not an int {type(result[key])}")
|
|
426
|
+
|
|
427
|
+
if result[key] < 0 or result[key] >= len(dataset.variables):
|
|
428
|
+
raise ValueError(f"Result[{key}] is out of bounds: {result[key]}")
|
|
429
|
+
|
|
430
|
+
index_to_name = {v: k for k, v in result.items()}
|
|
431
|
+
for i in range(len(dataset.variables)):
|
|
432
|
+
if i not in index_to_name:
|
|
433
|
+
raise ValueError(f"Result does not contain index `{i}`")
|
|
434
|
+
|
|
435
|
+
if not isinstance(index_to_name[i], str):
|
|
436
|
+
raise ValueError(f"Result[{i}] is not a string {type(index_to_name[i])}")
|
|
437
|
+
|
|
438
|
+
if index_to_name[i] != dataset.variables[i]:
|
|
439
|
+
raise ValueError(
|
|
440
|
+
f"Result[{i}] does not match dataset.variables[{i}]: {index_to_name[i]} != {dataset.variables[i]}"
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def validate___getitem__(report, dataset, name, result):
|
|
445
|
+
"""Validate the __getitem__ method of the dataset."""
|
|
446
|
+
|
|
447
|
+
if not isinstance(result, np.ndarray):
|
|
448
|
+
raise ValueError(f"Result is not a np.ndarray {type(result)}")
|
|
449
|
+
|
|
450
|
+
if result.shape != dataset.shape[1:]:
|
|
451
|
+
raise ValueError(f"Result has wrong shape: {result.shape} != {dataset.shape[1:]}")
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def validate___len__(report, dataset, name, result):
|
|
455
|
+
"""Validate the __len__ method of the dataset."""
|
|
456
|
+
|
|
457
|
+
if not isinstance(result, int):
|
|
458
|
+
raise ValueError(f"Result is not an int {type(result)}")
|
|
459
|
+
|
|
460
|
+
if result != dataset.shape[0]:
|
|
461
|
+
raise ValueError(f"Result has wrong length: {result} != {len(dataset)}")
|
|
462
|
+
|
|
463
|
+
if result != len(dataset.dates):
|
|
464
|
+
raise ValueError(f"Result has wrong length: {result} != {len(dataset.dates)}")
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def validate_start_date(report, dataset, name, result):
|
|
468
|
+
"""Validate the start date of the dataset."""
|
|
469
|
+
|
|
470
|
+
if not isinstance(result, np.datetime64):
|
|
471
|
+
raise ValueError(f"Result is not a datetime64 {type(result)}")
|
|
472
|
+
|
|
473
|
+
if result != dataset.dates[0]:
|
|
474
|
+
raise ValueError(f"Result has wrong start date: {result} != {dataset.dates[0]}")
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def validate_end_date(report, dataset, name, result):
|
|
478
|
+
"""Validate the end date of the dataset."""
|
|
479
|
+
|
|
480
|
+
if not isinstance(result, np.datetime64):
|
|
481
|
+
raise ValueError(f"Result is not a datetime64 {type(result)}")
|
|
482
|
+
|
|
483
|
+
if result != dataset.dates[-1]:
|
|
484
|
+
raise ValueError(f"Result has wrong end date: {result} != {dataset.dates[-1]}")
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def validate_field_shape(report, dataset, name, result):
|
|
488
|
+
"""Validate the field shape of the dataset."""
|
|
489
|
+
|
|
490
|
+
if not isinstance(result, tuple):
|
|
491
|
+
raise ValueError(f"Result is not a tuple {type(result)}")
|
|
492
|
+
|
|
493
|
+
if math.prod(result) != dataset.shape[-1]:
|
|
494
|
+
raise ValueError(f"Result has wrong shape: {result} != {dataset.shape[-1]}")
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def validate(report, dataset, name, kwargs=None):
|
|
498
|
+
|
|
499
|
+
try:
|
|
500
|
+
|
|
501
|
+
validate_fn = globals().get(f"validate_{name}", _no_validate)
|
|
502
|
+
|
|
503
|
+
# Check if the method is still in the Dataset class
|
|
504
|
+
try:
|
|
505
|
+
report.method(name, getattr(Dataset, name))
|
|
506
|
+
except AttributeError:
|
|
507
|
+
report.internal(name, "Attribute not found in Dataset class. Please update the list of methods.")
|
|
508
|
+
return
|
|
509
|
+
|
|
510
|
+
# Check if the method is supported by the dataset instance
|
|
511
|
+
try:
|
|
512
|
+
result = getattr(dataset, name)
|
|
513
|
+
except AttributeError as e:
|
|
514
|
+
report.failure(name, e)
|
|
515
|
+
return
|
|
516
|
+
|
|
517
|
+
# Check if the method is callable
|
|
518
|
+
if callable(result):
|
|
519
|
+
if kwargs is None:
|
|
520
|
+
report.internal(
|
|
521
|
+
name, f"`{name}` is a callable method, not an attribute. Please update KWARGS accordingly."
|
|
522
|
+
)
|
|
523
|
+
return
|
|
524
|
+
else:
|
|
525
|
+
if kwargs is not None:
|
|
526
|
+
report.internal(name, f"`{name}` is not callable. Please remove entry from KWARGS.")
|
|
527
|
+
return
|
|
528
|
+
|
|
529
|
+
if kwargs is not None:
|
|
530
|
+
result = result(**kwargs)
|
|
531
|
+
|
|
532
|
+
if isinstance(result, np.ndarray) and np.isnan(result).any():
|
|
533
|
+
report.invalid(name, ValueError("Result contains NaN values"))
|
|
534
|
+
return
|
|
535
|
+
|
|
536
|
+
try:
|
|
537
|
+
validate_fn(report, dataset, name, result)
|
|
538
|
+
except Exception as e:
|
|
539
|
+
report.invalid(name, e)
|
|
540
|
+
return
|
|
541
|
+
|
|
542
|
+
report.success(name)
|
|
543
|
+
|
|
544
|
+
except Exception as e:
|
|
545
|
+
report.failure(name, e)
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
def validate_dtype(report, dataset, name, result):
|
|
549
|
+
"""Validate the dtype of the dataset."""
|
|
550
|
+
|
|
551
|
+
if not isinstance(result, np.dtype):
|
|
552
|
+
raise ValueError(f"Result is not a np.dtype {type(result)}")
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
def validate_dataset(dataset, costly_checks=False, detailed=False):
|
|
556
|
+
"""Validate the dataset."""
|
|
557
|
+
|
|
558
|
+
report = Report()
|
|
559
|
+
|
|
560
|
+
if costly_checks:
|
|
561
|
+
# This check is expensive as it loads the entire dataset into memory
|
|
562
|
+
# so we make it optional
|
|
563
|
+
default_test_indexing(dataset)
|
|
564
|
+
|
|
565
|
+
for i, x in enumerate(dataset):
|
|
566
|
+
y = dataset[i]
|
|
567
|
+
assert (x == y).all(), f"Dataset indexing failed at index {i}: {x} != {y}"
|
|
568
|
+
|
|
569
|
+
for name in METHODS:
|
|
570
|
+
validate(report, dataset, name, kwargs=KWARGS.get(name))
|
|
571
|
+
|
|
572
|
+
report.summary(detailed=detailed)
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
if __name__ == "__main__":
|
|
576
|
+
methods = METHODS_CATEGORIES.copy()
|
|
577
|
+
methods.pop("OTHER_METHODS")
|
|
578
|
+
|
|
579
|
+
o = set(OTHER_METHODS)
|
|
580
|
+
overlap = False
|
|
581
|
+
for m in methods:
|
|
582
|
+
if set(methods[m]).intersection(set(OTHER_METHODS)):
|
|
583
|
+
print(
|
|
584
|
+
f"WARNING: {m} contains methods from OTHER_METHODS: {set(methods[m]).intersection(set(OTHER_METHODS))}"
|
|
585
|
+
)
|
|
586
|
+
o = o - set(methods[m])
|
|
587
|
+
overlap = True
|
|
588
|
+
|
|
589
|
+
for m in methods:
|
|
590
|
+
for n in methods:
|
|
591
|
+
if n is not m:
|
|
592
|
+
if set(methods[m]).intersection(set(methods[n])):
|
|
593
|
+
print(
|
|
594
|
+
f"WARNING: {m} and {n} have methods in common: {set(methods[m]).intersection(set(methods[n]))}"
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
if overlap:
|
|
598
|
+
print(sorted(o))
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: anemoi-datasets
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.28
|
|
4
4
|
Summary: A package to hold various functions to support training of ML models on ECMWF data.
|
|
5
5
|
Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
|
|
6
6
|
License: Apache License
|
|
@@ -216,21 +216,23 @@ Classifier: Intended Audience :: Developers
|
|
|
216
216
|
Classifier: License :: OSI Approved :: Apache Software License
|
|
217
217
|
Classifier: Operating System :: OS Independent
|
|
218
218
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
219
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
220
219
|
Classifier: Programming Language :: Python :: 3.10
|
|
221
220
|
Classifier: Programming Language :: Python :: 3.11
|
|
222
221
|
Classifier: Programming Language :: Python :: 3.12
|
|
223
222
|
Classifier: Programming Language :: Python :: 3.13
|
|
224
223
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
225
224
|
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
226
|
-
Requires-Python: >=3.
|
|
225
|
+
Requires-Python: >=3.10
|
|
227
226
|
License-File: LICENSE
|
|
228
227
|
Requires-Dist: anemoi-transform>=0.1.10
|
|
229
|
-
Requires-Dist: anemoi-utils[provenance]>=0.4.
|
|
228
|
+
Requires-Dist: anemoi-utils[provenance]>=0.4.38
|
|
230
229
|
Requires-Dist: cfunits
|
|
230
|
+
Requires-Dist: glom
|
|
231
|
+
Requires-Dist: jsonschema
|
|
231
232
|
Requires-Dist: numcodecs<0.16
|
|
232
233
|
Requires-Dist: numpy
|
|
233
234
|
Requires-Dist: pyyaml
|
|
235
|
+
Requires-Dist: ruamel-yaml
|
|
234
236
|
Requires-Dist: semantic-version
|
|
235
237
|
Requires-Dist: tqdm
|
|
236
238
|
Requires-Dist: zarr<=2.18.4
|
|
@@ -262,6 +264,7 @@ Requires-Dist: requests; extra == "remote"
|
|
|
262
264
|
Provides-Extra: tests
|
|
263
265
|
Requires-Dist: anemoi-datasets[xarray]; extra == "tests"
|
|
264
266
|
Requires-Dist: pytest; extra == "tests"
|
|
267
|
+
Requires-Dist: pytest-skip-slow; extra == "tests"
|
|
265
268
|
Requires-Dist: pytest-xdist; extra == "tests"
|
|
266
269
|
Provides-Extra: xarray
|
|
267
270
|
Requires-Dist: adlfs; extra == "xarray"
|