anemoi-datasets 0.5.26__py3-none-any.whl → 0.5.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. anemoi/datasets/__init__.py +1 -2
  2. anemoi/datasets/_version.py +16 -3
  3. anemoi/datasets/commands/check.py +1 -1
  4. anemoi/datasets/commands/copy.py +1 -2
  5. anemoi/datasets/commands/create.py +1 -1
  6. anemoi/datasets/commands/inspect.py +27 -35
  7. anemoi/datasets/commands/recipe/__init__.py +93 -0
  8. anemoi/datasets/commands/recipe/format.py +55 -0
  9. anemoi/datasets/commands/recipe/migrate.py +555 -0
  10. anemoi/datasets/commands/validate.py +59 -0
  11. anemoi/datasets/compute/recentre.py +3 -6
  12. anemoi/datasets/create/__init__.py +64 -26
  13. anemoi/datasets/create/check.py +10 -12
  14. anemoi/datasets/create/chunks.py +1 -2
  15. anemoi/datasets/create/config.py +5 -6
  16. anemoi/datasets/create/input/__init__.py +44 -65
  17. anemoi/datasets/create/input/action.py +296 -238
  18. anemoi/datasets/create/input/context/__init__.py +71 -0
  19. anemoi/datasets/create/input/context/field.py +54 -0
  20. anemoi/datasets/create/input/data_sources.py +7 -9
  21. anemoi/datasets/create/input/misc.py +2 -75
  22. anemoi/datasets/create/input/repeated_dates.py +11 -130
  23. anemoi/datasets/{utils → create/input/result}/__init__.py +10 -1
  24. anemoi/datasets/create/input/{result.py → result/field.py} +36 -120
  25. anemoi/datasets/create/input/trace.py +1 -1
  26. anemoi/datasets/create/patch.py +1 -2
  27. anemoi/datasets/create/persistent.py +3 -5
  28. anemoi/datasets/create/size.py +1 -3
  29. anemoi/datasets/create/sources/accumulations.py +120 -145
  30. anemoi/datasets/create/sources/accumulations2.py +20 -53
  31. anemoi/datasets/create/sources/anemoi_dataset.py +46 -42
  32. anemoi/datasets/create/sources/constants.py +39 -40
  33. anemoi/datasets/create/sources/empty.py +22 -19
  34. anemoi/datasets/create/sources/fdb.py +133 -0
  35. anemoi/datasets/create/sources/forcings.py +29 -29
  36. anemoi/datasets/create/sources/grib.py +94 -78
  37. anemoi/datasets/create/sources/grib_index.py +57 -55
  38. anemoi/datasets/create/sources/hindcasts.py +57 -59
  39. anemoi/datasets/create/sources/legacy.py +10 -62
  40. anemoi/datasets/create/sources/mars.py +121 -149
  41. anemoi/datasets/create/sources/netcdf.py +28 -25
  42. anemoi/datasets/create/sources/opendap.py +28 -26
  43. anemoi/datasets/create/sources/patterns.py +4 -6
  44. anemoi/datasets/create/sources/recentre.py +46 -48
  45. anemoi/datasets/create/sources/repeated_dates.py +44 -0
  46. anemoi/datasets/create/sources/source.py +26 -51
  47. anemoi/datasets/create/sources/tendencies.py +68 -98
  48. anemoi/datasets/create/sources/xarray.py +4 -6
  49. anemoi/datasets/create/sources/xarray_support/__init__.py +40 -36
  50. anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -12
  51. anemoi/datasets/create/sources/xarray_support/field.py +20 -16
  52. anemoi/datasets/create/sources/xarray_support/fieldlist.py +11 -15
  53. anemoi/datasets/create/sources/xarray_support/flavour.py +42 -42
  54. anemoi/datasets/create/sources/xarray_support/grid.py +15 -9
  55. anemoi/datasets/create/sources/xarray_support/metadata.py +19 -128
  56. anemoi/datasets/create/sources/xarray_support/patch.py +4 -6
  57. anemoi/datasets/create/sources/xarray_support/time.py +10 -13
  58. anemoi/datasets/create/sources/xarray_support/variable.py +21 -21
  59. anemoi/datasets/create/sources/xarray_zarr.py +28 -25
  60. anemoi/datasets/create/sources/zenodo.py +43 -41
  61. anemoi/datasets/create/statistics/__init__.py +3 -6
  62. anemoi/datasets/create/testing.py +4 -0
  63. anemoi/datasets/create/typing.py +1 -2
  64. anemoi/datasets/create/utils.py +0 -43
  65. anemoi/datasets/create/zarr.py +7 -2
  66. anemoi/datasets/data/__init__.py +15 -6
  67. anemoi/datasets/data/complement.py +7 -12
  68. anemoi/datasets/data/concat.py +5 -8
  69. anemoi/datasets/data/dataset.py +48 -47
  70. anemoi/datasets/data/debug.py +7 -9
  71. anemoi/datasets/data/ensemble.py +4 -6
  72. anemoi/datasets/data/fill_missing.py +7 -10
  73. anemoi/datasets/data/forwards.py +22 -26
  74. anemoi/datasets/data/grids.py +12 -168
  75. anemoi/datasets/data/indexing.py +9 -12
  76. anemoi/datasets/data/interpolate.py +7 -15
  77. anemoi/datasets/data/join.py +8 -12
  78. anemoi/datasets/data/masked.py +6 -11
  79. anemoi/datasets/data/merge.py +5 -9
  80. anemoi/datasets/data/misc.py +41 -45
  81. anemoi/datasets/data/missing.py +11 -16
  82. anemoi/datasets/data/observations/__init__.py +8 -14
  83. anemoi/datasets/data/padded.py +3 -5
  84. anemoi/datasets/data/records/backends/__init__.py +2 -2
  85. anemoi/datasets/data/rescale.py +5 -12
  86. anemoi/datasets/data/rolling_average.py +141 -0
  87. anemoi/datasets/data/select.py +13 -16
  88. anemoi/datasets/data/statistics.py +4 -7
  89. anemoi/datasets/data/stores.py +22 -29
  90. anemoi/datasets/data/subset.py +8 -11
  91. anemoi/datasets/data/unchecked.py +7 -11
  92. anemoi/datasets/data/xy.py +25 -21
  93. anemoi/datasets/dates/__init__.py +15 -18
  94. anemoi/datasets/dates/groups.py +7 -10
  95. anemoi/datasets/dumper.py +76 -0
  96. anemoi/datasets/grids.py +4 -185
  97. anemoi/datasets/schemas/recipe.json +131 -0
  98. anemoi/datasets/testing.py +93 -7
  99. anemoi/datasets/validate.py +598 -0
  100. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/METADATA +7 -4
  101. anemoi_datasets-0.5.28.dist-info/RECORD +134 -0
  102. anemoi/datasets/create/filter.py +0 -48
  103. anemoi/datasets/create/input/concat.py +0 -164
  104. anemoi/datasets/create/input/context.py +0 -89
  105. anemoi/datasets/create/input/empty.py +0 -54
  106. anemoi/datasets/create/input/filter.py +0 -118
  107. anemoi/datasets/create/input/function.py +0 -233
  108. anemoi/datasets/create/input/join.py +0 -130
  109. anemoi/datasets/create/input/pipe.py +0 -66
  110. anemoi/datasets/create/input/step.py +0 -177
  111. anemoi/datasets/create/input/template.py +0 -162
  112. anemoi_datasets-0.5.26.dist-info/RECORD +0 -131
  113. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/WHEEL +0 -0
  114. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/entry_points.txt +0 -0
  115. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/licenses/LICENSE +0 -0
  116. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,598 @@
1
+ # (C) Copyright 2025- Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ import logging
12
+ import math
13
+ from collections import defaultdict
14
+
15
+ import numpy as np
16
+
17
+ from anemoi.datasets.data.dataset import Dataset
18
+ from anemoi.datasets.testing import default_test_indexing
19
+
20
+ LOG = logging.getLogger(__name__)
21
+ # List of methods called during training. To update the list, run training with ANEMOI_DATASETS_TRACE=1
22
+
23
+ TRAINING_METHODS = [
24
+ "__getitem__",
25
+ "__len__",
26
+ "latitudes",
27
+ "longitudes",
28
+ "metadata", # Accessed when checkpointing
29
+ "missing",
30
+ "name_to_index",
31
+ "shape",
32
+ "statistics",
33
+ "supporting_arrays", # Accessed when checkpointing
34
+ "variables",
35
+ ]
36
+
37
+ EXTRA_TRAINING_METHODS = [
38
+ "statistics_tendencies",
39
+ ]
40
+
41
+ DEBUGGING_METHODS = [
42
+ "plot",
43
+ "to_index",
44
+ "tree",
45
+ "source",
46
+ ]
47
+
48
+ PUBLIC_METADATA_METHODS = [
49
+ "arguments",
50
+ "dtype",
51
+ "end_date",
52
+ "resolution",
53
+ "start_date",
54
+ "field_shape",
55
+ "frequency",
56
+ "dates",
57
+ "typed_variables",
58
+ "variables_metadata",
59
+ ]
60
+
61
+ PRIVATE_METADATA_METHODS = [
62
+ "computed_constant_fields",
63
+ "constant_fields",
64
+ "dataset_metadata",
65
+ "label",
66
+ "metadata_specific",
67
+ "provenance",
68
+ ]
69
+
70
+ INTERNAL_METHODS = [
71
+ "mutate",
72
+ "swap_with_parent",
73
+ "dates_interval_to_indices",
74
+ ]
75
+
76
+ EXPERIMENTAL_METHODS = [
77
+ "get_dataset_names",
78
+ "name",
79
+ "grids",
80
+ ]
81
+
82
+ OTHER_METHODS = [
83
+ "collect_input_sources",
84
+ "collect_supporting_arrays",
85
+ "sub_shape",
86
+ ]
87
+
88
+
89
+ METHODS_CATEGORIES = {k: v for k, v in list(globals().items()) if k.endswith("_METHODS")}
90
+
91
+
92
+ METHODS = set(sum(METHODS_CATEGORIES.values(), []))
93
+
94
+
95
+ KWARGS = {
96
+ "__len__": {},
97
+ "__getitem__": {"index": 0},
98
+ "get_dataset_names": {"names": set()},
99
+ "metadata": {},
100
+ "metadata_specific": {},
101
+ "mutate": {},
102
+ "plot": {"date": 0, "variable": 0},
103
+ "provenance": {},
104
+ "source": {"index": 0},
105
+ "statistics_tendencies": {},
106
+ "sub_shape": {},
107
+ "supporting_arrays": {},
108
+ "swap_with_parent": {},
109
+ "to_index": {"date": 0, "variable": 0},
110
+ "tree": {},
111
+ }
112
+
113
+
114
+ class Unknown:
115
+ emoji = "❓"
116
+
117
+
118
+ class Success:
119
+ emoji = "✅"
120
+ success = True
121
+
122
+ def __repr__(self):
123
+ return "Success"
124
+
125
+
126
+ class Error:
127
+ success = False
128
+
129
+ def __init__(self, message):
130
+ self.message = message
131
+
132
+ def __repr__(self):
133
+ return str(self.message) or repr(self.message) or "Error"
134
+
135
+
136
+ class Failure(Error):
137
+ emoji = "💥"
138
+
139
+
140
+ class Internal(Error):
141
+ emoji = "💣"
142
+
143
+
144
+ class Invalid(Error):
145
+ emoji = "❌"
146
+
147
+
148
+ class Report:
149
+
150
+ def __init__(self):
151
+ self.report = {}
152
+ self.methods = {}
153
+ self.warnings = defaultdict(list)
154
+
155
+ def method(self, name, method):
156
+ self.methods[name] = method
157
+
158
+ def success(self, name):
159
+ self.report[name] = Success()
160
+
161
+ def failure(self, name, message):
162
+ self.report[name] = Failure(message)
163
+
164
+ def internal(self, name, message):
165
+ self.report[name] = Internal(message)
166
+
167
+ def invalid(self, name, exception):
168
+ self.report[name] = Invalid(exception)
169
+
170
+ def warning(self, name, message):
171
+ self.warnings[name].append(message)
172
+
173
+ def summary(self, detailed=False):
174
+
175
+ maxlen = max(len(name) for name in self.report.keys())
176
+
177
+ for name, methods in METHODS_CATEGORIES.items():
178
+ print()
179
+ print(f"{name.title().replace('_', ' ')}:")
180
+ print("-" * (len(name) + 1))
181
+ print()
182
+
183
+ for method in methods:
184
+ r = self.report.get(method, Unknown())
185
+ msg = repr(r)
186
+ if not msg.endswith("."):
187
+ msg += "."
188
+ print(f"{r.emoji} {method.ljust(maxlen)}: {msg}")
189
+
190
+ for w in self.warnings.get(method, []):
191
+ print(" " * (maxlen + 4), "⚠️", w)
192
+
193
+ if r.success:
194
+ continue
195
+
196
+ if not detailed:
197
+ continue
198
+
199
+ if method not in self.methods:
200
+ continue
201
+
202
+ proc = self.methods[method]
203
+
204
+ doc = proc.__doc__
205
+ if doc:
206
+ width = 80
207
+ indent = maxlen + 4
208
+ doc = "\n".join(["=" * width, "", doc, "=" * width])
209
+ indented_doc = "\n".join(" " * indent + line for line in doc.splitlines())
210
+ print()
211
+ print(indented_doc)
212
+ print()
213
+ print()
214
+
215
+ print()
216
+
217
+
218
+ def _no_validate(report, dataset, name, result):
219
+ report.warning(name, f"Validation for {name} not implemented. Result: {type(result)}")
220
+
221
+
222
+ def validate_variables(report, dataset, name, result):
223
+ """Validate the variables of the dataset."""
224
+
225
+ if not isinstance(result, (list, tuple)):
226
+ raise ValueError(f"Result is not a list or tuple {type(result)}")
227
+
228
+ if len(result) != dataset.shape[1]:
229
+ raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[1]}")
230
+
231
+ for value in result:
232
+ if not isinstance(value, str):
233
+ raise ValueError(f"`{value}` is not a string")
234
+
235
+
236
+ def validate_latitudes(report, dataset, name, result):
237
+ """Validate the latitudes of the dataset."""
238
+
239
+ if not isinstance(result, np.ndarray):
240
+ raise ValueError(f"Result is not a np.ndarray {type(result)}")
241
+
242
+ if len(result) != dataset.shape[3]:
243
+ raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[3]}")
244
+
245
+ if not np.all(np.isfinite(result)):
246
+ raise ValueError("Result contains non-finite values")
247
+
248
+ if np.isnan(result).any():
249
+ report.invalid(name, ValueError("Result contains NaN values"))
250
+ return
251
+
252
+ if not np.all((result >= -90) & (result <= 90)):
253
+ raise ValueError("Result contains values outside the range [-90, 90]")
254
+
255
+ if np.all((result >= -np.pi) & (result <= np.pi)):
256
+ report.warning(name, "All latitudes are in the range [-π, π]. Are they in radians?")
257
+
258
+
259
+ def validate_longitudes(report, dataset, name, result):
260
+ """Validate the longitudes of the dataset."""
261
+
262
+ if not isinstance(result, np.ndarray):
263
+ raise ValueError(f"Result is not a np.ndarray {type(result)}")
264
+
265
+ if len(result) != dataset.shape[3]:
266
+ raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[2]}")
267
+
268
+ if not np.all(np.isfinite(result)):
269
+ raise ValueError("Result contains non-finite values")
270
+
271
+ if np.isnan(result).any():
272
+ report.invalid(name, ValueError("Result contains NaN values"))
273
+ return
274
+
275
+ if not np.all((result >= -180) & (result <= 360)):
276
+ raise ValueError("Result contains values outside the range [-180, 360]")
277
+
278
+ if np.all((result >= -np.pi) & (result <= 2 * np.pi)):
279
+ report.warning(name, "All longitudes are in the range [-π, 2π]. Are they in radians?")
280
+
281
+
282
+ def validate_statistics(report, dataset, name, result):
283
+ """Validate the statistics of the dataset."""
284
+
285
+ if not isinstance(result, dict):
286
+ raise ValueError(f"Result is not a dict {type(result)}")
287
+
288
+ for key in ["mean", "stdev", "minimum", "maximum"]:
289
+
290
+ if key not in result:
291
+ raise ValueError(f"Result does not contain `{key}`")
292
+
293
+ if not isinstance(result[key], np.ndarray):
294
+ raise ValueError(f"Result[{key}] is not a np.ndarray {type(result[key])}")
295
+
296
+ if len(result[key].shape) != 1:
297
+ raise ValueError(f"Result[{key}] has wrong shape: {len(result[key].shape)} != 1")
298
+
299
+ if result[key].shape[0] != len(dataset.variables):
300
+ raise ValueError(f"Result[{key}] has wrong length: {result[key].shape[0]} != {len(dataset.variables)}")
301
+
302
+ if not np.all(np.isfinite(result[key])):
303
+ raise ValueError(f"Result[{key}] contains non-finite values")
304
+
305
+ if np.isnan(result[key]).any():
306
+ report.invalid(name, ValueError(f"Result[{key}] contains NaN values"))
307
+
308
+
309
+ def validate_shape(report, dataset, name, result):
310
+ """Validate the shape of the dataset."""
311
+
312
+ if not isinstance(result, tuple):
313
+ raise ValueError(f"Result is not a tuple {type(result)}")
314
+
315
+ if len(result) != 4:
316
+ raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.shape)}")
317
+
318
+ if result[0] != len(dataset):
319
+ raise ValueError(f"Result[0] has wrong length: {result[0]} != {len(dataset)}")
320
+
321
+ if result[1] != len(dataset.variables):
322
+ raise ValueError(f"Result[1] has wrong length: {result[1]} != {len(dataset.variables)}")
323
+
324
+ if result[2] != 1: # We ignore ensemble dimension for now
325
+ pass
326
+
327
+ if result[3] != len(dataset.latitudes):
328
+ raise ValueError(f"Result[3] has wrong length: {result[3]} != {len(dataset.latitudes)}")
329
+
330
+
331
+ def validate_supporting_arrays(report, dataset, name, result):
332
+ """Validate the supporting arrays of the dataset."""
333
+
334
+ if not isinstance(result, dict):
335
+ raise ValueError(f"Result is not a dict {type(result)}")
336
+
337
+ if "latitudes" not in result:
338
+ raise ValueError("Result does not contain `latitudes`")
339
+
340
+ if "longitudes" not in result:
341
+ raise ValueError("Result does not contain `longitudes`")
342
+
343
+ if not isinstance(result["latitudes"], np.ndarray):
344
+ raise ValueError(f"Result[latitudes] is not a np.ndarray {type(result['latitudes'])}")
345
+
346
+ if not isinstance(result["longitudes"], np.ndarray):
347
+ raise ValueError(f"Result[longitudes] is not a np.ndarray {type(result['longitudes'])}")
348
+
349
+ if np.any(result["latitudes"] != dataset.latitudes):
350
+ raise ValueError("Result[latitudes] does not match dataset.latitudes")
351
+
352
+ if np.any(result["longitudes"] != dataset.longitudes):
353
+ raise ValueError("Result[longitudes] does not match dataset.longitudes")
354
+
355
+
356
+ def validate_dates(report, dataset, name, result):
357
+ """Validate the dates of the dataset."""
358
+
359
+ if not isinstance(result, np.ndarray):
360
+ raise ValueError(f"Result is not a np.ndarray {type(result)}")
361
+
362
+ if len(result.shape) != 1:
363
+ raise ValueError(f"Result has wrong shape: {len(result.shape)} != 1")
364
+
365
+ if result.shape[0] != len(dataset.dates):
366
+ raise ValueError(f"Result has wrong length: {result.shape[0]} != {len(dataset.dates)}")
367
+
368
+ if not np.issubdtype(result.dtype, np.datetime64):
369
+ raise ValueError(f"Result is not a datetime64 array {result.dtype}")
370
+
371
+ if len(result) != len(dataset.dates):
372
+ raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.dates)}")
373
+
374
+ if not np.all(np.isfinite(result)):
375
+ raise ValueError("Result contains non-finite values")
376
+
377
+ if np.isnan(result).any():
378
+ report.invalid(name, ValueError("Result contains NaN values"))
379
+ return
380
+
381
+ for d1, d2 in zip(result[:-1], result[1:]):
382
+ if d1 >= d2:
383
+ raise ValueError(f"Result contains non-increasing dates: {d1} >= {d2}")
384
+
385
+ frequency = np.diff(result)
386
+ if not np.all(frequency == frequency[0]):
387
+ raise ValueError("Result contains non-constant frequency")
388
+
389
+
390
+ def validate_metadata(report, dataset, name, result):
391
+ """Validate the metadata of the dataset."""
392
+
393
+ if not isinstance(result, dict):
394
+ raise ValueError(f"Result is not a dict {type(result)}")
395
+
396
+
397
+ def validate_missing(report, dataset, name, result):
398
+ """Validate the missing values of the dataset."""
399
+
400
+ if not isinstance(result, set):
401
+ raise ValueError(f"Result is not a set {type(result)}")
402
+
403
+ if not all(isinstance(item, int) for item in result):
404
+ raise ValueError("Result contains non-integer values")
405
+
406
+ if len(result) > 0:
407
+ if min(result) < 0:
408
+ raise ValueError("Result contains negative values")
409
+
410
+ if max(result) >= len(dataset):
411
+ raise ValueError(f"Result contains values greater than {len(dataset)}")
412
+
413
+
414
+ def validate_name_to_index(report, dataset, name, result):
415
+ """Validate the name to index mapping of the dataset."""
416
+
417
+ if not isinstance(result, dict):
418
+ raise ValueError(f"Result is not a dict {type(result)}")
419
+
420
+ for key in dataset.variables:
421
+ if key not in result:
422
+ raise ValueError(f"Result does not contain `{key}`")
423
+
424
+ if not isinstance(result[key], int):
425
+ raise ValueError(f"Result[{key}] is not an int {type(result[key])}")
426
+
427
+ if result[key] < 0 or result[key] >= len(dataset.variables):
428
+ raise ValueError(f"Result[{key}] is out of bounds: {result[key]}")
429
+
430
+ index_to_name = {v: k for k, v in result.items()}
431
+ for i in range(len(dataset.variables)):
432
+ if i not in index_to_name:
433
+ raise ValueError(f"Result does not contain index `{i}`")
434
+
435
+ if not isinstance(index_to_name[i], str):
436
+ raise ValueError(f"Result[{i}] is not a string {type(index_to_name[i])}")
437
+
438
+ if index_to_name[i] != dataset.variables[i]:
439
+ raise ValueError(
440
+ f"Result[{i}] does not match dataset.variables[{i}]: {index_to_name[i]} != {dataset.variables[i]}"
441
+ )
442
+
443
+
444
+ def validate___getitem__(report, dataset, name, result):
445
+ """Validate the __getitem__ method of the dataset."""
446
+
447
+ if not isinstance(result, np.ndarray):
448
+ raise ValueError(f"Result is not a np.ndarray {type(result)}")
449
+
450
+ if result.shape != dataset.shape[1:]:
451
+ raise ValueError(f"Result has wrong shape: {result.shape} != {dataset.shape[1:]}")
452
+
453
+
454
+ def validate___len__(report, dataset, name, result):
455
+ """Validate the __len__ method of the dataset."""
456
+
457
+ if not isinstance(result, int):
458
+ raise ValueError(f"Result is not an int {type(result)}")
459
+
460
+ if result != dataset.shape[0]:
461
+ raise ValueError(f"Result has wrong length: {result} != {len(dataset)}")
462
+
463
+ if result != len(dataset.dates):
464
+ raise ValueError(f"Result has wrong length: {result} != {len(dataset.dates)}")
465
+
466
+
467
+ def validate_start_date(report, dataset, name, result):
468
+ """Validate the start date of the dataset."""
469
+
470
+ if not isinstance(result, np.datetime64):
471
+ raise ValueError(f"Result is not a datetime64 {type(result)}")
472
+
473
+ if result != dataset.dates[0]:
474
+ raise ValueError(f"Result has wrong start date: {result} != {dataset.dates[0]}")
475
+
476
+
477
+ def validate_end_date(report, dataset, name, result):
478
+ """Validate the end date of the dataset."""
479
+
480
+ if not isinstance(result, np.datetime64):
481
+ raise ValueError(f"Result is not a datetime64 {type(result)}")
482
+
483
+ if result != dataset.dates[-1]:
484
+ raise ValueError(f"Result has wrong end date: {result} != {dataset.dates[-1]}")
485
+
486
+
487
+ def validate_field_shape(report, dataset, name, result):
488
+ """Validate the field shape of the dataset."""
489
+
490
+ if not isinstance(result, tuple):
491
+ raise ValueError(f"Result is not a tuple {type(result)}")
492
+
493
+ if math.prod(result) != dataset.shape[-1]:
494
+ raise ValueError(f"Result has wrong shape: {result} != {dataset.shape[-1]}")
495
+
496
+
497
+ def validate(report, dataset, name, kwargs=None):
498
+
499
+ try:
500
+
501
+ validate_fn = globals().get(f"validate_{name}", _no_validate)
502
+
503
+ # Check if the method is still in the Dataset class
504
+ try:
505
+ report.method(name, getattr(Dataset, name))
506
+ except AttributeError:
507
+ report.internal(name, "Attribute not found in Dataset class. Please update the list of methods.")
508
+ return
509
+
510
+ # Check if the method is supported by the dataset instance
511
+ try:
512
+ result = getattr(dataset, name)
513
+ except AttributeError as e:
514
+ report.failure(name, e)
515
+ return
516
+
517
+ # Check if the method is callable
518
+ if callable(result):
519
+ if kwargs is None:
520
+ report.internal(
521
+ name, f"`{name}` is a callable method, not an attribute. Please update KWARGS accordingly."
522
+ )
523
+ return
524
+ else:
525
+ if kwargs is not None:
526
+ report.internal(name, f"`{name}` is not callable. Please remove entry from KWARGS.")
527
+ return
528
+
529
+ if kwargs is not None:
530
+ result = result(**kwargs)
531
+
532
+ if isinstance(result, np.ndarray) and np.isnan(result).any():
533
+ report.invalid(name, ValueError("Result contains NaN values"))
534
+ return
535
+
536
+ try:
537
+ validate_fn(report, dataset, name, result)
538
+ except Exception as e:
539
+ report.invalid(name, e)
540
+ return
541
+
542
+ report.success(name)
543
+
544
+ except Exception as e:
545
+ report.failure(name, e)
546
+
547
+
548
+ def validate_dtype(report, dataset, name, result):
549
+ """Validate the dtype of the dataset."""
550
+
551
+ if not isinstance(result, np.dtype):
552
+ raise ValueError(f"Result is not a np.dtype {type(result)}")
553
+
554
+
555
+ def validate_dataset(dataset, costly_checks=False, detailed=False):
556
+ """Validate the dataset."""
557
+
558
+ report = Report()
559
+
560
+ if costly_checks:
561
+ # This check is expensive as it loads the entire dataset into memory
562
+ # so we make it optional
563
+ default_test_indexing(dataset)
564
+
565
+ for i, x in enumerate(dataset):
566
+ y = dataset[i]
567
+ assert (x == y).all(), f"Dataset indexing failed at index {i}: {x} != {y}"
568
+
569
+ for name in METHODS:
570
+ validate(report, dataset, name, kwargs=KWARGS.get(name))
571
+
572
+ report.summary(detailed=detailed)
573
+
574
+
575
+ if __name__ == "__main__":
576
+ methods = METHODS_CATEGORIES.copy()
577
+ methods.pop("OTHER_METHODS")
578
+
579
+ o = set(OTHER_METHODS)
580
+ overlap = False
581
+ for m in methods:
582
+ if set(methods[m]).intersection(set(OTHER_METHODS)):
583
+ print(
584
+ f"WARNING: {m} contains methods from OTHER_METHODS: {set(methods[m]).intersection(set(OTHER_METHODS))}"
585
+ )
586
+ o = o - set(methods[m])
587
+ overlap = True
588
+
589
+ for m in methods:
590
+ for n in methods:
591
+ if n is not m:
592
+ if set(methods[m]).intersection(set(methods[n])):
593
+ print(
594
+ f"WARNING: {m} and {n} have methods in common: {set(methods[m]).intersection(set(methods[n]))}"
595
+ )
596
+
597
+ if overlap:
598
+ print(sorted(o))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: anemoi-datasets
3
- Version: 0.5.26
3
+ Version: 0.5.28
4
4
  Summary: A package to hold various functions to support training of ML models on ECMWF data.
5
5
  Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
6
6
  License: Apache License
@@ -216,21 +216,23 @@ Classifier: Intended Audience :: Developers
216
216
  Classifier: License :: OSI Approved :: Apache Software License
217
217
  Classifier: Operating System :: OS Independent
218
218
  Classifier: Programming Language :: Python :: 3 :: Only
219
- Classifier: Programming Language :: Python :: 3.9
220
219
  Classifier: Programming Language :: Python :: 3.10
221
220
  Classifier: Programming Language :: Python :: 3.11
222
221
  Classifier: Programming Language :: Python :: 3.12
223
222
  Classifier: Programming Language :: Python :: 3.13
224
223
  Classifier: Programming Language :: Python :: Implementation :: CPython
225
224
  Classifier: Programming Language :: Python :: Implementation :: PyPy
226
- Requires-Python: >=3.9
225
+ Requires-Python: >=3.10
227
226
  License-File: LICENSE
228
227
  Requires-Dist: anemoi-transform>=0.1.10
229
- Requires-Dist: anemoi-utils[provenance]>=0.4.26
228
+ Requires-Dist: anemoi-utils[provenance]>=0.4.38
230
229
  Requires-Dist: cfunits
230
+ Requires-Dist: glom
231
+ Requires-Dist: jsonschema
231
232
  Requires-Dist: numcodecs<0.16
232
233
  Requires-Dist: numpy
233
234
  Requires-Dist: pyyaml
235
+ Requires-Dist: ruamel-yaml
234
236
  Requires-Dist: semantic-version
235
237
  Requires-Dist: tqdm
236
238
  Requires-Dist: zarr<=2.18.4
@@ -262,6 +264,7 @@ Requires-Dist: requests; extra == "remote"
262
264
  Provides-Extra: tests
263
265
  Requires-Dist: anemoi-datasets[xarray]; extra == "tests"
264
266
  Requires-Dist: pytest; extra == "tests"
267
+ Requires-Dist: pytest-skip-slow; extra == "tests"
265
268
  Requires-Dist: pytest-xdist; extra == "tests"
266
269
  Provides-Extra: xarray
267
270
  Requires-Dist: adlfs; extra == "xarray"