anemoi-datasets 0.5.26__py3-none-any.whl → 0.5.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. anemoi/datasets/__init__.py +1 -2
  2. anemoi/datasets/_version.py +16 -3
  3. anemoi/datasets/commands/check.py +1 -1
  4. anemoi/datasets/commands/copy.py +1 -2
  5. anemoi/datasets/commands/create.py +1 -1
  6. anemoi/datasets/commands/inspect.py +27 -35
  7. anemoi/datasets/commands/validate.py +59 -0
  8. anemoi/datasets/compute/recentre.py +3 -6
  9. anemoi/datasets/create/__init__.py +22 -25
  10. anemoi/datasets/create/check.py +10 -12
  11. anemoi/datasets/create/chunks.py +1 -2
  12. anemoi/datasets/create/config.py +3 -6
  13. anemoi/datasets/create/filter.py +1 -2
  14. anemoi/datasets/create/input/__init__.py +1 -2
  15. anemoi/datasets/create/input/action.py +3 -5
  16. anemoi/datasets/create/input/concat.py +5 -8
  17. anemoi/datasets/create/input/context.py +3 -6
  18. anemoi/datasets/create/input/data_sources.py +5 -8
  19. anemoi/datasets/create/input/empty.py +1 -2
  20. anemoi/datasets/create/input/filter.py +2 -3
  21. anemoi/datasets/create/input/function.py +1 -2
  22. anemoi/datasets/create/input/join.py +4 -5
  23. anemoi/datasets/create/input/misc.py +4 -6
  24. anemoi/datasets/create/input/repeated_dates.py +13 -18
  25. anemoi/datasets/create/input/result.py +29 -33
  26. anemoi/datasets/create/input/step.py +4 -8
  27. anemoi/datasets/create/input/template.py +3 -4
  28. anemoi/datasets/create/input/trace.py +1 -1
  29. anemoi/datasets/create/patch.py +1 -2
  30. anemoi/datasets/create/persistent.py +3 -5
  31. anemoi/datasets/create/size.py +1 -3
  32. anemoi/datasets/create/sources/accumulations.py +47 -52
  33. anemoi/datasets/create/sources/accumulations2.py +4 -8
  34. anemoi/datasets/create/sources/constants.py +1 -3
  35. anemoi/datasets/create/sources/empty.py +1 -2
  36. anemoi/datasets/create/sources/fdb.py +133 -0
  37. anemoi/datasets/create/sources/forcings.py +1 -2
  38. anemoi/datasets/create/sources/grib.py +6 -10
  39. anemoi/datasets/create/sources/grib_index.py +13 -15
  40. anemoi/datasets/create/sources/hindcasts.py +2 -5
  41. anemoi/datasets/create/sources/legacy.py +1 -1
  42. anemoi/datasets/create/sources/mars.py +17 -21
  43. anemoi/datasets/create/sources/netcdf.py +1 -2
  44. anemoi/datasets/create/sources/opendap.py +1 -3
  45. anemoi/datasets/create/sources/patterns.py +4 -6
  46. anemoi/datasets/create/sources/recentre.py +8 -11
  47. anemoi/datasets/create/sources/source.py +3 -6
  48. anemoi/datasets/create/sources/tendencies.py +2 -5
  49. anemoi/datasets/create/sources/xarray.py +4 -6
  50. anemoi/datasets/create/sources/xarray_support/__init__.py +12 -13
  51. anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -12
  52. anemoi/datasets/create/sources/xarray_support/field.py +16 -12
  53. anemoi/datasets/create/sources/xarray_support/fieldlist.py +11 -15
  54. anemoi/datasets/create/sources/xarray_support/flavour.py +42 -42
  55. anemoi/datasets/create/sources/xarray_support/grid.py +15 -9
  56. anemoi/datasets/create/sources/xarray_support/metadata.py +19 -128
  57. anemoi/datasets/create/sources/xarray_support/patch.py +4 -6
  58. anemoi/datasets/create/sources/xarray_support/time.py +10 -13
  59. anemoi/datasets/create/sources/xarray_support/variable.py +21 -21
  60. anemoi/datasets/create/sources/xarray_zarr.py +1 -2
  61. anemoi/datasets/create/sources/zenodo.py +3 -5
  62. anemoi/datasets/create/statistics/__init__.py +3 -6
  63. anemoi/datasets/create/testing.py +4 -0
  64. anemoi/datasets/create/typing.py +1 -2
  65. anemoi/datasets/create/utils.py +1 -2
  66. anemoi/datasets/create/zarr.py +7 -2
  67. anemoi/datasets/data/__init__.py +15 -6
  68. anemoi/datasets/data/complement.py +7 -12
  69. anemoi/datasets/data/concat.py +5 -8
  70. anemoi/datasets/data/dataset.py +42 -47
  71. anemoi/datasets/data/debug.py +7 -9
  72. anemoi/datasets/data/ensemble.py +4 -6
  73. anemoi/datasets/data/fill_missing.py +7 -10
  74. anemoi/datasets/data/forwards.py +22 -26
  75. anemoi/datasets/data/grids.py +12 -16
  76. anemoi/datasets/data/indexing.py +9 -12
  77. anemoi/datasets/data/interpolate.py +7 -15
  78. anemoi/datasets/data/join.py +8 -12
  79. anemoi/datasets/data/masked.py +6 -11
  80. anemoi/datasets/data/merge.py +5 -9
  81. anemoi/datasets/data/misc.py +41 -45
  82. anemoi/datasets/data/missing.py +11 -16
  83. anemoi/datasets/data/observations/__init__.py +8 -14
  84. anemoi/datasets/data/padded.py +3 -5
  85. anemoi/datasets/data/records/backends/__init__.py +2 -2
  86. anemoi/datasets/data/rescale.py +5 -12
  87. anemoi/datasets/data/select.py +13 -16
  88. anemoi/datasets/data/statistics.py +4 -7
  89. anemoi/datasets/data/stores.py +16 -21
  90. anemoi/datasets/data/subset.py +8 -11
  91. anemoi/datasets/data/unchecked.py +7 -11
  92. anemoi/datasets/data/xy.py +25 -21
  93. anemoi/datasets/dates/__init__.py +13 -18
  94. anemoi/datasets/dates/groups.py +7 -10
  95. anemoi/datasets/grids.py +5 -9
  96. anemoi/datasets/testing.py +93 -7
  97. anemoi/datasets/validate.py +598 -0
  98. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.27.dist-info}/METADATA +4 -4
  99. anemoi_datasets-0.5.27.dist-info/RECORD +134 -0
  100. anemoi/datasets/utils/__init__.py +0 -8
  101. anemoi_datasets-0.5.26.dist-info/RECORD +0 -131
  102. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.27.dist-info}/WHEEL +0 -0
  103. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.27.dist-info}/entry_points.txt +0 -0
  104. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.27.dist-info}/licenses/LICENSE +0 -0
  105. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.27.dist-info}/top_level.txt +0 -0
@@ -12,19 +12,17 @@
12
12
 
13
13
  import logging
14
14
  from typing import Any
15
- from typing import List
16
- from typing import Optional
17
15
 
18
16
  LOG = logging.getLogger(__name__)
19
17
 
20
18
 
21
19
  def assert_field_list(
22
- fs: List[Any],
23
- size: Optional[int] = None,
24
- start: Optional[Any] = None,
25
- end: Optional[Any] = None,
20
+ fs: list[Any],
21
+ size: int | None = None,
22
+ start: Any | None = None,
23
+ end: Any | None = None,
26
24
  constant: bool = False,
27
- skip: Optional[Any] = None,
25
+ skip: Any | None = None,
28
26
  ) -> None:
29
27
  """Asserts various properties of a list of fields.
30
28
 
@@ -85,3 +83,91 @@ def assert_field_list(
85
83
  assert south >= -90, south
86
84
  assert east <= 360, east
87
85
  assert west >= -180, west
86
+
87
+
88
+ class IndexTester:
89
+ """Class to test indexing of datasets."""
90
+
91
+ def __init__(self, ds: Any) -> None:
92
+ """Initialise the IndexTester.
93
+
94
+ Parameters
95
+ ----------
96
+ ds : Any
97
+ Dataset.
98
+ """
99
+ self.ds = ds
100
+ self.np = ds[:] # Numpy array
101
+
102
+ assert self.ds.shape == self.np.shape, (self.ds.shape, self.np.shape)
103
+ assert (self.ds == self.np).all()
104
+
105
+ def __getitem__(self, index: Any) -> None:
106
+ """Test indexing.
107
+
108
+ Parameters
109
+ ----------
110
+ index : Any
111
+ Index.
112
+ """
113
+ LOG.info("IndexTester: %s", index)
114
+ if self.ds[index] is None:
115
+ assert False, (self.ds, index)
116
+
117
+ if not (self.ds[index] == self.np[index]).all():
118
+ assert (self.ds[index] == self.np[index]).all()
119
+
120
+
121
+ def default_test_indexing(ds):
122
+
123
+ t = IndexTester(ds)
124
+
125
+ t[0:10, :, 0]
126
+ t[:, 0:3, 0]
127
+ # t[:, :, 0]
128
+ t[0:10, 0:3, 0]
129
+ t[:, :, :]
130
+
131
+ if ds.shape[1] > 2: # Variable dimension
132
+ t[:, (1, 2), :]
133
+ t[:, (1, 2)]
134
+
135
+ t[0]
136
+ t[0, :]
137
+ t[0, 0, :]
138
+ t[0, 0, 0, :]
139
+
140
+ if ds.shape[2] > 1: # Ensemble dimension
141
+ t[0:10, :, (0, 1)]
142
+
143
+ for i in range(3):
144
+ t[i]
145
+ start = 5 * i
146
+ end = len(ds) - 5 * i
147
+ step = len(ds) // 10
148
+
149
+ t[start:end:step]
150
+ t[start:end]
151
+ t[start:]
152
+ t[:end]
153
+ t[::step]
154
+
155
+
156
+ class Trace:
157
+
158
+ def __init__(self, ds):
159
+ self.ds = ds
160
+ self.f = open("trace.txt", "a")
161
+
162
+ def __getattr__(self, name: str) -> Any:
163
+
164
+ print(name, file=self.f, flush=True)
165
+ return getattr(self.ds, name)
166
+
167
+ def __len__(self) -> int:
168
+ print("__len__", file=self.f, flush=True)
169
+ return len(self.ds)
170
+
171
+ def __getitem__(self, index: Any) -> Any:
172
+ print("__getitem__", file=self.f, flush=True)
173
+ return self.ds[index]
@@ -0,0 +1,598 @@
1
+ # (C) Copyright 2025- Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ import logging
12
+ import math
13
+ from collections import defaultdict
14
+
15
+ import numpy as np
16
+
17
+ from anemoi.datasets.data.dataset import Dataset
18
+ from anemoi.datasets.testing import default_test_indexing
19
+
20
+ LOG = logging.getLogger(__name__)
21
+ # List of methods called during training. To update the list, run training with ANEMOI_DATASETS_TRACE=1
22
+
23
+ TRAINING_METHODS = [
24
+ "__getitem__",
25
+ "__len__",
26
+ "latitudes",
27
+ "longitudes",
28
+ "metadata", # Accessed when checkpointing
29
+ "missing",
30
+ "name_to_index",
31
+ "shape",
32
+ "statistics",
33
+ "supporting_arrays", # Accessed when checkpointing
34
+ "variables",
35
+ ]
36
+
37
+ EXTRA_TRAINING_METHODS = [
38
+ "statistics_tendencies",
39
+ ]
40
+
41
+ DEBUGGING_METHODS = [
42
+ "plot",
43
+ "to_index",
44
+ "tree",
45
+ "source",
46
+ ]
47
+
48
+ PUBLIC_METADATA_METHODS = [
49
+ "arguments",
50
+ "dtype",
51
+ "end_date",
52
+ "resolution",
53
+ "start_date",
54
+ "field_shape",
55
+ "frequency",
56
+ "dates",
57
+ "typed_variables",
58
+ "variables_metadata",
59
+ ]
60
+
61
+ PRIVATE_METADATA_METHODS = [
62
+ "computed_constant_fields",
63
+ "constant_fields",
64
+ "dataset_metadata",
65
+ "label",
66
+ "metadata_specific",
67
+ "provenance",
68
+ ]
69
+
70
+ INTERNAL_METHODS = [
71
+ "mutate",
72
+ "swap_with_parent",
73
+ "dates_interval_to_indices",
74
+ ]
75
+
76
+ EXPERIMENTAL_METHODS = [
77
+ "get_dataset_names",
78
+ "name",
79
+ "grids",
80
+ ]
81
+
82
+ OTHER_METHODS = [
83
+ "collect_input_sources",
84
+ "collect_supporting_arrays",
85
+ "sub_shape",
86
+ ]
87
+
88
+
89
+ METHODS_CATEGORIES = {k: v for k, v in list(globals().items()) if k.endswith("_METHODS")}
90
+
91
+
92
+ METHODS = set(sum(METHODS_CATEGORIES.values(), []))
93
+
94
+
95
+ KWARGS = {
96
+ "__len__": {},
97
+ "__getitem__": {"index": 0},
98
+ "get_dataset_names": {"names": set()},
99
+ "metadata": {},
100
+ "metadata_specific": {},
101
+ "mutate": {},
102
+ "plot": {"date": 0, "variable": 0},
103
+ "provenance": {},
104
+ "source": {"index": 0},
105
+ "statistics_tendencies": {},
106
+ "sub_shape": {},
107
+ "supporting_arrays": {},
108
+ "swap_with_parent": {},
109
+ "to_index": {"date": 0, "variable": 0},
110
+ "tree": {},
111
+ }
112
+
113
+
114
+ class Unknown:
115
+ emoji = "❓"
116
+
117
+
118
+ class Success:
119
+ emoji = "✅"
120
+ success = True
121
+
122
+ def __repr__(self):
123
+ return "Success"
124
+
125
+
126
+ class Error:
127
+ success = False
128
+
129
+ def __init__(self, message):
130
+ self.message = message
131
+
132
+ def __repr__(self):
133
+ return str(self.message) or repr(self.message) or "Error"
134
+
135
+
136
+ class Failure(Error):
137
+ emoji = "💥"
138
+
139
+
140
+ class Internal(Error):
141
+ emoji = "💣"
142
+
143
+
144
+ class Invalid(Error):
145
+ emoji = "❌"
146
+
147
+
148
+ class Report:
149
+
150
+ def __init__(self):
151
+ self.report = {}
152
+ self.methods = {}
153
+ self.warnings = defaultdict(list)
154
+
155
+ def method(self, name, method):
156
+ self.methods[name] = method
157
+
158
+ def success(self, name):
159
+ self.report[name] = Success()
160
+
161
+ def failure(self, name, message):
162
+ self.report[name] = Failure(message)
163
+
164
+ def internal(self, name, message):
165
+ self.report[name] = Internal(message)
166
+
167
+ def invalid(self, name, exception):
168
+ self.report[name] = Invalid(exception)
169
+
170
+ def warning(self, name, message):
171
+ self.warnings[name].append(message)
172
+
173
+ def summary(self, detailed=False):
174
+
175
+ maxlen = max(len(name) for name in self.report.keys())
176
+
177
+ for name, methods in METHODS_CATEGORIES.items():
178
+ print()
179
+ print(f"{name.title().replace('_', ' ')}:")
180
+ print("-" * (len(name) + 1))
181
+ print()
182
+
183
+ for method in methods:
184
+ r = self.report.get(method, Unknown())
185
+ msg = repr(r)
186
+ if not msg.endswith("."):
187
+ msg += "."
188
+ print(f"{r.emoji} {method.ljust(maxlen)}: {msg}")
189
+
190
+ for w in self.warnings.get(method, []):
191
+ print(" " * (maxlen + 4), "⚠️", w)
192
+
193
+ if r.success:
194
+ continue
195
+
196
+ if not detailed:
197
+ continue
198
+
199
+ if method not in self.methods:
200
+ continue
201
+
202
+ proc = self.methods[method]
203
+
204
+ doc = proc.__doc__
205
+ if doc:
206
+ width = 80
207
+ indent = maxlen + 4
208
+ doc = "\n".join(["=" * width, "", doc, "=" * width])
209
+ indented_doc = "\n".join(" " * indent + line for line in doc.splitlines())
210
+ print()
211
+ print(indented_doc)
212
+ print()
213
+ print()
214
+
215
+ print()
216
+
217
+
218
+ def _no_validate(report, dataset, name, result):
219
+ report.warning(name, f"Validation for {name} not implemented. Result: {type(result)}")
220
+
221
+
222
+ def validate_variables(report, dataset, name, result):
223
+ """Validate the variables of the dataset."""
224
+
225
+ if not isinstance(result, (list, tuple)):
226
+ raise ValueError(f"Result is not a list or tuple {type(result)}")
227
+
228
+ if len(result) != dataset.shape[1]:
229
+ raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[1]}")
230
+
231
+ for value in result:
232
+ if not isinstance(value, str):
233
+ raise ValueError(f"`{value}` is not a string")
234
+
235
+
236
+ def validate_latitudes(report, dataset, name, result):
237
+ """Validate the latitudes of the dataset."""
238
+
239
+ if not isinstance(result, np.ndarray):
240
+ raise ValueError(f"Result is not a np.ndarray {type(result)}")
241
+
242
+ if len(result) != dataset.shape[3]:
243
+ raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[3]}")
244
+
245
+ if not np.all(np.isfinite(result)):
246
+ raise ValueError("Result contains non-finite values")
247
+
248
+ if np.isnan(result).any():
249
+ report.invalid(name, ValueError("Result contains NaN values"))
250
+ return
251
+
252
+ if not np.all((result >= -90) & (result <= 90)):
253
+ raise ValueError("Result contains values outside the range [-90, 90]")
254
+
255
+ if np.all((result >= -np.pi) & (result <= np.pi)):
256
+ report.warning(name, "All latitudes are in the range [-π, π]. Are they in radians?")
257
+
258
+
259
+ def validate_longitudes(report, dataset, name, result):
260
+ """Validate the longitudes of the dataset."""
261
+
262
+ if not isinstance(result, np.ndarray):
263
+ raise ValueError(f"Result is not a np.ndarray {type(result)}")
264
+
265
+ if len(result) != dataset.shape[3]:
266
+ raise ValueError(f"Result has wrong length: {len(result)} != {dataset.shape[2]}")
267
+
268
+ if not np.all(np.isfinite(result)):
269
+ raise ValueError("Result contains non-finite values")
270
+
271
+ if np.isnan(result).any():
272
+ report.invalid(name, ValueError("Result contains NaN values"))
273
+ return
274
+
275
+ if not np.all((result >= -180) & (result <= 360)):
276
+ raise ValueError("Result contains values outside the range [-180, 360]")
277
+
278
+ if np.all((result >= -np.pi) & (result <= 2 * np.pi)):
279
+ report.warning(name, "All longitudes are in the range [-π, 2π]. Are they in radians?")
280
+
281
+
282
+ def validate_statistics(report, dataset, name, result):
283
+ """Validate the statistics of the dataset."""
284
+
285
+ if not isinstance(result, dict):
286
+ raise ValueError(f"Result is not a dict {type(result)}")
287
+
288
+ for key in ["mean", "stdev", "minimum", "maximum"]:
289
+
290
+ if key not in result:
291
+ raise ValueError(f"Result does not contain `{key}`")
292
+
293
+ if not isinstance(result[key], np.ndarray):
294
+ raise ValueError(f"Result[{key}] is not a np.ndarray {type(result[key])}")
295
+
296
+ if len(result[key].shape) != 1:
297
+ raise ValueError(f"Result[{key}] has wrong shape: {len(result[key].shape)} != 1")
298
+
299
+ if result[key].shape[0] != len(dataset.variables):
300
+ raise ValueError(f"Result[{key}] has wrong length: {result[key].shape[0]} != {len(dataset.variables)}")
301
+
302
+ if not np.all(np.isfinite(result[key])):
303
+ raise ValueError(f"Result[{key}] contains non-finite values")
304
+
305
+ if np.isnan(result[key]).any():
306
+ report.invalid(name, ValueError(f"Result[{key}] contains NaN values"))
307
+
308
+
309
+ def validate_shape(report, dataset, name, result):
310
+ """Validate the shape of the dataset."""
311
+
312
+ if not isinstance(result, tuple):
313
+ raise ValueError(f"Result is not a tuple {type(result)}")
314
+
315
+ if len(result) != 4:
316
+ raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.shape)}")
317
+
318
+ if result[0] != len(dataset):
319
+ raise ValueError(f"Result[0] has wrong length: {result[0]} != {len(dataset)}")
320
+
321
+ if result[1] != len(dataset.variables):
322
+ raise ValueError(f"Result[1] has wrong length: {result[1]} != {len(dataset.variables)}")
323
+
324
+ if result[2] != 1: # We ignore ensemble dimension for now
325
+ pass
326
+
327
+ if result[3] != len(dataset.latitudes):
328
+ raise ValueError(f"Result[3] has wrong length: {result[3]} != {len(dataset.latitudes)}")
329
+
330
+
331
+ def validate_supporting_arrays(report, dataset, name, result):
332
+ """Validate the supporting arrays of the dataset."""
333
+
334
+ if not isinstance(result, dict):
335
+ raise ValueError(f"Result is not a dict {type(result)}")
336
+
337
+ if "latitudes" not in result:
338
+ raise ValueError("Result does not contain `latitudes`")
339
+
340
+ if "longitudes" not in result:
341
+ raise ValueError("Result does not contain `longitudes`")
342
+
343
+ if not isinstance(result["latitudes"], np.ndarray):
344
+ raise ValueError(f"Result[latitudes] is not a np.ndarray {type(result['latitudes'])}")
345
+
346
+ if not isinstance(result["longitudes"], np.ndarray):
347
+ raise ValueError(f"Result[longitudes] is not a np.ndarray {type(result['longitudes'])}")
348
+
349
+ if np.any(result["latitudes"] != dataset.latitudes):
350
+ raise ValueError("Result[latitudes] does not match dataset.latitudes")
351
+
352
+ if np.any(result["longitudes"] != dataset.longitudes):
353
+ raise ValueError("Result[longitudes] does not match dataset.longitudes")
354
+
355
+
356
+ def validate_dates(report, dataset, name, result):
357
+ """Validate the dates of the dataset."""
358
+
359
+ if not isinstance(result, np.ndarray):
360
+ raise ValueError(f"Result is not a np.ndarray {type(result)}")
361
+
362
+ if len(result.shape) != 1:
363
+ raise ValueError(f"Result has wrong shape: {len(result.shape)} != 1")
364
+
365
+ if result.shape[0] != len(dataset.dates):
366
+ raise ValueError(f"Result has wrong length: {result.shape[0]} != {len(dataset.dates)}")
367
+
368
+ if not np.issubdtype(result.dtype, np.datetime64):
369
+ raise ValueError(f"Result is not a datetime64 array {result.dtype}")
370
+
371
+ if len(result) != len(dataset.dates):
372
+ raise ValueError(f"Result has wrong length: {len(result)} != {len(dataset.dates)}")
373
+
374
+ if not np.all(np.isfinite(result)):
375
+ raise ValueError("Result contains non-finite values")
376
+
377
+ if np.isnan(result).any():
378
+ report.invalid(name, ValueError("Result contains NaN values"))
379
+ return
380
+
381
+ for d1, d2 in zip(result[:-1], result[1:]):
382
+ if d1 >= d2:
383
+ raise ValueError(f"Result contains non-increasing dates: {d1} >= {d2}")
384
+
385
+ frequency = np.diff(result)
386
+ if not np.all(frequency == frequency[0]):
387
+ raise ValueError("Result contains non-constant frequency")
388
+
389
+
390
+ def validate_metadata(report, dataset, name, result):
391
+ """Validate the metadata of the dataset."""
392
+
393
+ if not isinstance(result, dict):
394
+ raise ValueError(f"Result is not a dict {type(result)}")
395
+
396
+
397
+ def validate_missing(report, dataset, name, result):
398
+ """Validate the missing values of the dataset."""
399
+
400
+ if not isinstance(result, set):
401
+ raise ValueError(f"Result is not a set {type(result)}")
402
+
403
+ if not all(isinstance(item, int) for item in result):
404
+ raise ValueError("Result contains non-integer values")
405
+
406
+ if len(result) > 0:
407
+ if min(result) < 0:
408
+ raise ValueError("Result contains negative values")
409
+
410
+ if max(result) >= len(dataset):
411
+ raise ValueError(f"Result contains values greater than {len(dataset)}")
412
+
413
+
414
+ def validate_name_to_index(report, dataset, name, result):
415
+ """Validate the name to index mapping of the dataset."""
416
+
417
+ if not isinstance(result, dict):
418
+ raise ValueError(f"Result is not a dict {type(result)}")
419
+
420
+ for key in dataset.variables:
421
+ if key not in result:
422
+ raise ValueError(f"Result does not contain `{key}`")
423
+
424
+ if not isinstance(result[key], int):
425
+ raise ValueError(f"Result[{key}] is not an int {type(result[key])}")
426
+
427
+ if result[key] < 0 or result[key] >= len(dataset.variables):
428
+ raise ValueError(f"Result[{key}] is out of bounds: {result[key]}")
429
+
430
+ index_to_name = {v: k for k, v in result.items()}
431
+ for i in range(len(dataset.variables)):
432
+ if i not in index_to_name:
433
+ raise ValueError(f"Result does not contain index `{i}`")
434
+
435
+ if not isinstance(index_to_name[i], str):
436
+ raise ValueError(f"Result[{i}] is not a string {type(index_to_name[i])}")
437
+
438
+ if index_to_name[i] != dataset.variables[i]:
439
+ raise ValueError(
440
+ f"Result[{i}] does not match dataset.variables[{i}]: {index_to_name[i]} != {dataset.variables[i]}"
441
+ )
442
+
443
+
444
+ def validate___getitem__(report, dataset, name, result):
445
+ """Validate the __getitem__ method of the dataset."""
446
+
447
+ if not isinstance(result, np.ndarray):
448
+ raise ValueError(f"Result is not a np.ndarray {type(result)}")
449
+
450
+ if result.shape != dataset.shape[1:]:
451
+ raise ValueError(f"Result has wrong shape: {result.shape} != {dataset.shape[1:]}")
452
+
453
+
454
+ def validate___len__(report, dataset, name, result):
455
+ """Validate the __len__ method of the dataset."""
456
+
457
+ if not isinstance(result, int):
458
+ raise ValueError(f"Result is not an int {type(result)}")
459
+
460
+ if result != dataset.shape[0]:
461
+ raise ValueError(f"Result has wrong length: {result} != {len(dataset)}")
462
+
463
+ if result != len(dataset.dates):
464
+ raise ValueError(f"Result has wrong length: {result} != {len(dataset.dates)}")
465
+
466
+
467
+ def validate_start_date(report, dataset, name, result):
468
+ """Validate the start date of the dataset."""
469
+
470
+ if not isinstance(result, np.datetime64):
471
+ raise ValueError(f"Result is not a datetime64 {type(result)}")
472
+
473
+ if result != dataset.dates[0]:
474
+ raise ValueError(f"Result has wrong start date: {result} != {dataset.dates[0]}")
475
+
476
+
477
+ def validate_end_date(report, dataset, name, result):
478
+ """Validate the end date of the dataset."""
479
+
480
+ if not isinstance(result, np.datetime64):
481
+ raise ValueError(f"Result is not a datetime64 {type(result)}")
482
+
483
+ if result != dataset.dates[-1]:
484
+ raise ValueError(f"Result has wrong end date: {result} != {dataset.dates[-1]}")
485
+
486
+
487
+ def validate_field_shape(report, dataset, name, result):
488
+ """Validate the field shape of the dataset."""
489
+
490
+ if not isinstance(result, tuple):
491
+ raise ValueError(f"Result is not a tuple {type(result)}")
492
+
493
+ if math.prod(result) != dataset.shape[-1]:
494
+ raise ValueError(f"Result has wrong shape: {result} != {dataset.shape[-1]}")
495
+
496
+
497
+ def validate(report, dataset, name, kwargs=None):
498
+
499
+ try:
500
+
501
+ validate_fn = globals().get(f"validate_{name}", _no_validate)
502
+
503
+ # Check if the method is still in the Dataset class
504
+ try:
505
+ report.method(name, getattr(Dataset, name))
506
+ except AttributeError:
507
+ report.internal(name, "Attribute not found in Dataset class. Please update the list of methods.")
508
+ return
509
+
510
+ # Check if the method is supported by the dataset instance
511
+ try:
512
+ result = getattr(dataset, name)
513
+ except AttributeError as e:
514
+ report.failure(name, e)
515
+ return
516
+
517
+ # Check if the method is callable
518
+ if callable(result):
519
+ if kwargs is None:
520
+ report.internal(
521
+ name, f"`{name}` is a callable method, not an attribute. Please update KWARGS accordingly."
522
+ )
523
+ return
524
+ else:
525
+ if kwargs is not None:
526
+ report.internal(name, f"`{name}` is not callable. Please remove entry from KWARGS.")
527
+ return
528
+
529
+ if kwargs is not None:
530
+ result = result(**kwargs)
531
+
532
+ if isinstance(result, np.ndarray) and np.isnan(result).any():
533
+ report.invalid(name, ValueError("Result contains NaN values"))
534
+ return
535
+
536
+ try:
537
+ validate_fn(report, dataset, name, result)
538
+ except Exception as e:
539
+ report.invalid(name, e)
540
+ return
541
+
542
+ report.success(name)
543
+
544
+ except Exception as e:
545
+ report.failure(name, e)
546
+
547
+
548
+ def validate_dtype(report, dataset, name, result):
549
+ """Validate the dtype of the dataset."""
550
+
551
+ if not isinstance(result, np.dtype):
552
+ raise ValueError(f"Result is not a np.dtype {type(result)}")
553
+
554
+
555
+ def validate_dataset(dataset, costly_checks=False, detailed=False):
556
+ """Validate the dataset."""
557
+
558
+ report = Report()
559
+
560
+ if costly_checks:
561
+ # This check is expensive as it loads the entire dataset into memory
562
+ # so we make it optional
563
+ default_test_indexing(dataset)
564
+
565
+ for i, x in enumerate(dataset):
566
+ y = dataset[i]
567
+ assert (x == y).all(), f"Dataset indexing failed at index {i}: {x} != {y}"
568
+
569
+ for name in METHODS:
570
+ validate(report, dataset, name, kwargs=KWARGS.get(name))
571
+
572
+ report.summary(detailed=detailed)
573
+
574
+
575
+ if __name__ == "__main__":
576
+ methods = METHODS_CATEGORIES.copy()
577
+ methods.pop("OTHER_METHODS")
578
+
579
+ o = set(OTHER_METHODS)
580
+ overlap = False
581
+ for m in methods:
582
+ if set(methods[m]).intersection(set(OTHER_METHODS)):
583
+ print(
584
+ f"WARNING: {m} contains methods from OTHER_METHODS: {set(methods[m]).intersection(set(OTHER_METHODS))}"
585
+ )
586
+ o = o - set(methods[m])
587
+ overlap = True
588
+
589
+ for m in methods:
590
+ for n in methods:
591
+ if n is not m:
592
+ if set(methods[m]).intersection(set(methods[n])):
593
+ print(
594
+ f"WARNING: {m} and {n} have methods in common: {set(methods[m]).intersection(set(methods[n]))}"
595
+ )
596
+
597
+ if overlap:
598
+ print(sorted(o))