anemoi-datasets 0.5.16__py3-none-any.whl → 0.5.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. anemoi/datasets/__init__.py +4 -1
  2. anemoi/datasets/__main__.py +12 -2
  3. anemoi/datasets/_version.py +9 -4
  4. anemoi/datasets/commands/cleanup.py +17 -2
  5. anemoi/datasets/commands/compare.py +18 -2
  6. anemoi/datasets/commands/copy.py +196 -14
  7. anemoi/datasets/commands/create.py +50 -7
  8. anemoi/datasets/commands/finalise-additions.py +17 -2
  9. anemoi/datasets/commands/finalise.py +17 -2
  10. anemoi/datasets/commands/init-additions.py +17 -2
  11. anemoi/datasets/commands/init.py +16 -2
  12. anemoi/datasets/commands/inspect.py +283 -62
  13. anemoi/datasets/commands/load-additions.py +16 -2
  14. anemoi/datasets/commands/load.py +16 -2
  15. anemoi/datasets/commands/patch.py +17 -2
  16. anemoi/datasets/commands/publish.py +17 -2
  17. anemoi/datasets/commands/scan.py +31 -3
  18. anemoi/datasets/compute/recentre.py +47 -11
  19. anemoi/datasets/create/__init__.py +612 -85
  20. anemoi/datasets/create/check.py +142 -20
  21. anemoi/datasets/create/chunks.py +64 -4
  22. anemoi/datasets/create/config.py +185 -21
  23. anemoi/datasets/create/filter.py +50 -0
  24. anemoi/datasets/create/filters/__init__.py +33 -0
  25. anemoi/datasets/create/filters/empty.py +37 -0
  26. anemoi/datasets/create/filters/legacy.py +93 -0
  27. anemoi/datasets/create/filters/noop.py +37 -0
  28. anemoi/datasets/create/filters/orog_to_z.py +58 -0
  29. anemoi/datasets/create/{functions/filters → filters}/pressure_level_relative_humidity_to_specific_humidity.py +33 -10
  30. anemoi/datasets/create/{functions/filters → filters}/pressure_level_specific_humidity_to_relative_humidity.py +32 -8
  31. anemoi/datasets/create/filters/rename.py +205 -0
  32. anemoi/datasets/create/{functions/filters → filters}/rotate_winds.py +43 -28
  33. anemoi/datasets/create/{functions/filters → filters}/single_level_dewpoint_to_relative_humidity.py +32 -9
  34. anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_dewpoint.py +33 -9
  35. anemoi/datasets/create/{functions/filters → filters}/single_level_relative_humidity_to_specific_humidity.py +55 -7
  36. anemoi/datasets/create/{functions/filters → filters}/single_level_specific_humidity_to_relative_humidity.py +98 -37
  37. anemoi/datasets/create/filters/speeddir_to_uv.py +95 -0
  38. anemoi/datasets/create/{functions/filters → filters}/sum.py +24 -27
  39. anemoi/datasets/create/filters/transform.py +53 -0
  40. anemoi/datasets/create/{functions/filters → filters}/unrotate_winds.py +27 -18
  41. anemoi/datasets/create/filters/uv_to_speeddir.py +94 -0
  42. anemoi/datasets/create/{functions/filters → filters}/wz_to_w.py +51 -33
  43. anemoi/datasets/create/input/__init__.py +76 -5
  44. anemoi/datasets/create/input/action.py +149 -13
  45. anemoi/datasets/create/input/concat.py +81 -10
  46. anemoi/datasets/create/input/context.py +39 -4
  47. anemoi/datasets/create/input/data_sources.py +72 -6
  48. anemoi/datasets/create/input/empty.py +21 -3
  49. anemoi/datasets/create/input/filter.py +60 -12
  50. anemoi/datasets/create/input/function.py +154 -37
  51. anemoi/datasets/create/input/join.py +86 -14
  52. anemoi/datasets/create/input/misc.py +67 -17
  53. anemoi/datasets/create/input/pipe.py +33 -6
  54. anemoi/datasets/create/input/repeated_dates.py +189 -41
  55. anemoi/datasets/create/input/result.py +202 -87
  56. anemoi/datasets/create/input/step.py +119 -22
  57. anemoi/datasets/create/input/template.py +100 -13
  58. anemoi/datasets/create/input/trace.py +62 -7
  59. anemoi/datasets/create/patch.py +52 -4
  60. anemoi/datasets/create/persistent.py +134 -17
  61. anemoi/datasets/create/size.py +15 -1
  62. anemoi/datasets/create/source.py +51 -0
  63. anemoi/datasets/create/sources/__init__.py +36 -0
  64. anemoi/datasets/create/{functions/sources → sources}/accumulations.py +296 -30
  65. anemoi/datasets/create/{functions/sources → sources}/constants.py +27 -2
  66. anemoi/datasets/create/{functions/sources → sources}/eccc_fstd.py +7 -3
  67. anemoi/datasets/create/sources/empty.py +37 -0
  68. anemoi/datasets/create/{functions/sources → sources}/forcings.py +25 -1
  69. anemoi/datasets/create/sources/grib.py +297 -0
  70. anemoi/datasets/create/{functions/sources → sources}/hindcasts.py +38 -4
  71. anemoi/datasets/create/sources/legacy.py +93 -0
  72. anemoi/datasets/create/{functions/sources → sources}/mars.py +168 -20
  73. anemoi/datasets/create/sources/netcdf.py +42 -0
  74. anemoi/datasets/create/sources/opendap.py +43 -0
  75. anemoi/datasets/create/{functions/sources/__init__.py → sources/patterns.py} +35 -4
  76. anemoi/datasets/create/sources/recentre.py +150 -0
  77. anemoi/datasets/create/{functions/sources → sources}/source.py +27 -5
  78. anemoi/datasets/create/{functions/sources → sources}/tendencies.py +64 -7
  79. anemoi/datasets/create/sources/xarray.py +92 -0
  80. anemoi/datasets/create/sources/xarray_kerchunk.py +36 -0
  81. anemoi/datasets/create/sources/xarray_support/README.md +1 -0
  82. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/__init__.py +109 -8
  83. anemoi/datasets/create/sources/xarray_support/coordinates.py +442 -0
  84. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/field.py +94 -16
  85. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/fieldlist.py +90 -25
  86. anemoi/datasets/create/sources/xarray_support/flavour.py +1036 -0
  87. anemoi/datasets/create/{functions/sources/xarray → sources/xarray_support}/grid.py +92 -31
  88. anemoi/datasets/create/sources/xarray_support/metadata.py +395 -0
  89. anemoi/datasets/create/sources/xarray_support/patch.py +91 -0
  90. anemoi/datasets/create/sources/xarray_support/time.py +391 -0
  91. anemoi/datasets/create/sources/xarray_support/variable.py +331 -0
  92. anemoi/datasets/create/sources/xarray_zarr.py +41 -0
  93. anemoi/datasets/create/{functions/sources → sources}/zenodo.py +34 -5
  94. anemoi/datasets/create/statistics/__init__.py +233 -44
  95. anemoi/datasets/create/statistics/summary.py +52 -6
  96. anemoi/datasets/create/testing.py +76 -0
  97. anemoi/datasets/create/{functions/filters/noop.py → typing.py} +6 -3
  98. anemoi/datasets/create/utils.py +97 -6
  99. anemoi/datasets/create/writer.py +26 -4
  100. anemoi/datasets/create/zarr.py +170 -23
  101. anemoi/datasets/data/__init__.py +51 -4
  102. anemoi/datasets/data/complement.py +191 -40
  103. anemoi/datasets/data/concat.py +141 -16
  104. anemoi/datasets/data/dataset.py +552 -61
  105. anemoi/datasets/data/debug.py +197 -26
  106. anemoi/datasets/data/ensemble.py +93 -8
  107. anemoi/datasets/data/fill_missing.py +165 -18
  108. anemoi/datasets/data/forwards.py +428 -56
  109. anemoi/datasets/data/grids.py +323 -97
  110. anemoi/datasets/data/indexing.py +112 -19
  111. anemoi/datasets/data/interpolate.py +92 -12
  112. anemoi/datasets/data/join.py +158 -19
  113. anemoi/datasets/data/masked.py +129 -15
  114. anemoi/datasets/data/merge.py +137 -23
  115. anemoi/datasets/data/misc.py +172 -16
  116. anemoi/datasets/data/missing.py +233 -29
  117. anemoi/datasets/data/rescale.py +111 -10
  118. anemoi/datasets/data/select.py +168 -26
  119. anemoi/datasets/data/statistics.py +67 -6
  120. anemoi/datasets/data/stores.py +149 -64
  121. anemoi/datasets/data/subset.py +159 -25
  122. anemoi/datasets/data/unchecked.py +168 -57
  123. anemoi/datasets/data/xy.py +168 -25
  124. anemoi/datasets/dates/__init__.py +191 -16
  125. anemoi/datasets/dates/groups.py +189 -47
  126. anemoi/datasets/grids.py +270 -31
  127. anemoi/datasets/testing.py +28 -1
  128. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/METADATA +9 -6
  129. anemoi_datasets-0.5.17.dist-info/RECORD +137 -0
  130. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/WHEEL +1 -1
  131. anemoi/datasets/create/functions/__init__.py +0 -66
  132. anemoi/datasets/create/functions/filters/__init__.py +0 -9
  133. anemoi/datasets/create/functions/filters/empty.py +0 -17
  134. anemoi/datasets/create/functions/filters/orog_to_z.py +0 -58
  135. anemoi/datasets/create/functions/filters/rename.py +0 -79
  136. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +0 -78
  137. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +0 -56
  138. anemoi/datasets/create/functions/sources/empty.py +0 -15
  139. anemoi/datasets/create/functions/sources/grib.py +0 -150
  140. anemoi/datasets/create/functions/sources/netcdf.py +0 -15
  141. anemoi/datasets/create/functions/sources/opendap.py +0 -15
  142. anemoi/datasets/create/functions/sources/recentre.py +0 -60
  143. anemoi/datasets/create/functions/sources/xarray/coordinates.py +0 -255
  144. anemoi/datasets/create/functions/sources/xarray/flavour.py +0 -472
  145. anemoi/datasets/create/functions/sources/xarray/metadata.py +0 -148
  146. anemoi/datasets/create/functions/sources/xarray/patch.py +0 -44
  147. anemoi/datasets/create/functions/sources/xarray/time.py +0 -177
  148. anemoi/datasets/create/functions/sources/xarray/variable.py +0 -188
  149. anemoi/datasets/create/functions/sources/xarray_kerchunk.py +0 -42
  150. anemoi/datasets/create/functions/sources/xarray_zarr.py +0 -15
  151. anemoi/datasets/utils/fields.py +0 -47
  152. anemoi_datasets-0.5.16.dist-info/RECORD +0 -129
  153. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/entry_points.txt +0 -0
  154. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info/licenses}/LICENSE +0 -0
  155. {anemoi_datasets-0.5.16.dist-info → anemoi_datasets-0.5.17.dist-info}/top_level.txt +0 -0
@@ -16,10 +16,15 @@ import os
16
16
  import pickle
17
17
  import shutil
18
18
  import socket
19
+ from typing import Any
20
+ from typing import List
21
+ from typing import Optional
22
+ from typing import Union
19
23
 
20
24
  import numpy as np
21
25
  import tqdm
22
26
  from anemoi.utils.provenance import gather_provenance_info
27
+ from numpy.typing import NDArray
23
28
 
24
29
  from ..check import check_data_values
25
30
  from .summary import Summary
@@ -27,15 +32,18 @@ from .summary import Summary
27
32
  LOG = logging.getLogger(__name__)
28
33
 
29
34
 
30
- def default_statistics_dates(dates):
35
+ def default_statistics_dates(dates: list[datetime.datetime]) -> tuple[datetime.datetime, datetime.datetime]:
31
36
  """Calculate default statistics dates based on the given list of dates.
32
37
 
33
- Args:
34
- dates (list): List of datetime objects representing dates.
35
-
36
- Returns:
37
- tuple: A tuple containing the default start and end dates.
38
+ Parameters
39
+ ----------
40
+ dates : list of datetime.datetime
41
+ List of datetime objects representing dates.
38
42
 
43
+ Returns
44
+ -------
45
+ tuple of datetime.datetime
46
+ A tuple containing the default start and end dates.
39
47
  """
40
48
 
41
49
  def to_datetime(d):
@@ -69,7 +77,19 @@ def default_statistics_dates(dates):
69
77
  return dates[0], end
70
78
 
71
79
 
72
- def to_datetime(date):
80
+ def to_datetime(date: Union[str, datetime.datetime]) -> np.datetime64:
81
+ """Convert a date to numpy datetime64 format.
82
+
83
+ Parameters
84
+ ----------
85
+ date : str or datetime.datetime
86
+ The date to convert.
87
+
88
+ Returns
89
+ -------
90
+ numpy.datetime64
91
+ The converted date.
92
+ """
73
93
  if isinstance(date, str):
74
94
  return np.datetime64(date)
75
95
  if isinstance(date, datetime.datetime):
@@ -77,11 +97,43 @@ def to_datetime(date):
77
97
  return date
78
98
 
79
99
 
80
- def to_datetimes(dates):
100
+ def to_datetimes(dates: list[Union[str, datetime.datetime]]) -> list[np.datetime64]:
101
+ """Convert a list of dates to numpy datetime64 format.
102
+
103
+ Parameters
104
+ ----------
105
+ dates : list of str or datetime.datetime
106
+ List of dates to convert.
107
+
108
+ Returns
109
+ -------
110
+ list of numpy.datetime64
111
+ List of converted dates.
112
+ """
81
113
  return [to_datetime(d) for d in dates]
82
114
 
83
115
 
84
- def fix_variance(x, name, count, sums, squares):
116
+ def fix_variance(x: float, name: str, count: NDArray[Any], sums: NDArray[Any], squares: NDArray[Any]) -> float:
117
+ """Fix negative variance values due to numerical errors.
118
+
119
+ Parameters
120
+ ----------
121
+ x : float
122
+ The variance value.
123
+ name : str
124
+ The variable name.
125
+ count : numpy.ndarray
126
+ The count array.
127
+ sums : numpy.ndarray
128
+ The sums array.
129
+ squares : numpy.ndarray
130
+ The squares array.
131
+
132
+ Returns
133
+ -------
134
+ float
135
+ The fixed variance value.
136
+ """
85
137
  assert count.shape == sums.shape == squares.shape
86
138
  assert isinstance(x, float)
87
139
 
@@ -112,7 +164,42 @@ def fix_variance(x, name, count, sums, squares):
112
164
  return 0
113
165
 
114
166
 
115
- def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squares):
167
+ def check_variance(
168
+ x: NDArray[Any],
169
+ variables_names: list[str],
170
+ minimum: NDArray[Any],
171
+ maximum: NDArray[Any],
172
+ mean: NDArray[Any],
173
+ count: NDArray[Any],
174
+ sums: NDArray[Any],
175
+ squares: NDArray[Any],
176
+ ) -> None:
177
+ """Check for negative variance values and raise an error if found.
178
+
179
+ Parameters
180
+ ----------
181
+ x : numpy.ndarray
182
+ The variance array.
183
+ variables_names : list of str
184
+ List of variable names.
185
+ minimum : numpy.ndarray
186
+ The minimum values array.
187
+ maximum : numpy.ndarray
188
+ The maximum values array.
189
+ mean : numpy.ndarray
190
+ The mean values array.
191
+ count : numpy.ndarray
192
+ The count array.
193
+ sums : numpy.ndarray
194
+ The sums array.
195
+ squares : numpy.ndarray
196
+ The squares array.
197
+
198
+ Raises
199
+ ------
200
+ ValueError
201
+ If negative variance is found.
202
+ """
116
203
  if (x >= 0).all():
117
204
  return
118
205
  print(x)
@@ -133,8 +220,25 @@ def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squa
133
220
  raise ValueError("Negative variance")
134
221
 
135
222
 
136
- def compute_statistics(array, check_variables_names=None, allow_nans=False):
137
- """Compute statistics for a given array, provides minimum, maximum, sum, squares, count and has_nans as a dictionary."""
223
+ def compute_statistics(
224
+ array: NDArray[Any], check_variables_names: Optional[List[str]] = None, allow_nans: bool = False
225
+ ) -> dict[str, np.ndarray]:
226
+ """Compute statistics for a given array, provides minimum, maximum, sum, squares, count and has_nans as a dictionary.
227
+
228
+ Parameters
229
+ ----------
230
+ array : numpy.ndarray
231
+ The array to compute statistics for.
232
+ check_variables_names : list of str, optional
233
+ List of variable names to check. Defaults to None.
234
+ allow_nans : bool, optional
235
+ Whether to allow NaN values. Defaults to False.
236
+
237
+ Returns
238
+ -------
239
+ dict of str to numpy.ndarray
240
+ A dictionary containing the computed statistics.
241
+ """
138
242
  LOG.info(f"Computing statistics for {array.shape} array")
139
243
  nvars = array.shape[1]
140
244
 
@@ -180,16 +284,34 @@ def compute_statistics(array, check_variables_names=None, allow_nans=False):
180
284
 
181
285
 
182
286
  class TmpStatistics:
287
+ """Temporary statistics storage class."""
288
+
183
289
  version = 3
184
290
  # Used in parrallel, during data loading,
185
291
  # to write statistics in pickled npz files.
186
292
  # can provide statistics for a subset of dates.
187
293
 
188
- def __init__(self, dirname, overwrite=False):
294
+ def __init__(self, dirname: str, overwrite: bool = False) -> None:
295
+ """Initialize TmpStatistics.
296
+
297
+ Parameters
298
+ ----------
299
+ dirname : str
300
+ Directory name for storing statistics.
301
+ overwrite : bool, optional
302
+ Whether to overwrite existing files. Defaults to False.
303
+ """
189
304
  self.dirname = dirname
190
305
  self.overwrite = overwrite
191
306
 
192
- def add_provenance(self, **kwargs):
307
+ def add_provenance(self, **kwargs: dict) -> None:
308
+ """Add provenance information.
309
+
310
+ Parameters
311
+ ----------
312
+ **kwargs : dict
313
+ Additional provenance information.
314
+ """
193
315
  self.create(exist_ok=True)
194
316
  path = os.path.join(self.dirname, "provenance.json")
195
317
  if os.path.exists(path):
@@ -198,16 +320,35 @@ class TmpStatistics:
198
320
  with open(path, "w") as f:
199
321
  json.dump(out, f)
200
322
 
201
- def create(self, exist_ok):
323
+ def create(self, exist_ok: bool) -> None:
324
+ """Create the directory for storing statistics.
325
+
326
+ Parameters
327
+ ----------
328
+ exist_ok : bool
329
+ Whether to ignore if the directory already exists.
330
+ """
202
331
  os.makedirs(self.dirname, exist_ok=exist_ok)
203
332
 
204
- def delete(self):
333
+ def delete(self) -> None:
334
+ """Delete the directory for storing statistics."""
205
335
  try:
206
336
  shutil.rmtree(self.dirname)
207
337
  except FileNotFoundError:
208
338
  pass
209
339
 
210
- def write(self, key, data, dates):
340
+ def write(self, key: str, data: any, dates: list[datetime.datetime]) -> None:
341
+ """Write statistics data to a file.
342
+
343
+ Parameters
344
+ ----------
345
+ key : str
346
+ The key for the data.
347
+ data : any
348
+ The data to write.
349
+ dates : list of datetime.datetime
350
+ List of dates associated with the data.
351
+ """
211
352
  self.create(exist_ok=True)
212
353
  h = hashlib.sha256(str(dates).encode("utf-8")).hexdigest()
213
354
  path = os.path.join(self.dirname, f"{h}.npz")
@@ -222,7 +363,14 @@ class TmpStatistics:
222
363
 
223
364
  LOG.debug(f"Written statistics data for {len(dates)} dates in {path} ({dates})")
224
365
 
225
- def _gather_data(self):
366
+ def _gather_data(self) -> tuple[str, list[datetime.datetime], dict]:
367
+ """Gather data from stored files.
368
+
369
+ Yields
370
+ ------
371
+ tuple of str, list of datetime.datetime, dict
372
+ A tuple containing key, dates, and data.
373
+ """
226
374
  # use glob to read all pickles
227
375
  files = glob.glob(self.dirname + "/*.npz")
228
376
  LOG.debug(f"Reading stats data, found {len(files)} files in {self.dirname}")
@@ -231,37 +379,67 @@ class TmpStatistics:
231
379
  with open(f, "rb") as f:
232
380
  yield pickle.load(f)
233
381
 
234
- def get_aggregated(self, *args, **kwargs):
382
+ def get_aggregated(self, *args: Any, **kwargs: Any) -> Summary:
383
+ """Get aggregated statistics.
384
+
385
+ Parameters
386
+ ----------
387
+ *args : Any
388
+ Additional arguments.
389
+ **kwargs : Any
390
+ Additional keyword arguments.
391
+
392
+ Returns
393
+ -------
394
+ Summary
395
+ The aggregated statistics summary.
396
+ """
235
397
  aggregator = StatAggregator(self, *args, **kwargs)
236
398
  return aggregator.aggregate()
237
399
 
238
- def __str__(self):
239
- return f"TmpStatistics({self.dirname})"
240
-
241
-
242
- def normalise_date(d):
243
- if isinstance(d, str):
244
- d = np.datetime64(d)
245
- return d
400
+ def __str__(self) -> str:
401
+ """String representation of TmpStatistics.
246
402
 
247
-
248
- def normalise_dates(dates):
249
- return [normalise_date(d) for d in dates]
403
+ Returns
404
+ -------
405
+ str
406
+ The string representation.
407
+ """
408
+ return f"TmpStatistics({self.dirname})"
250
409
 
251
410
 
252
411
  class StatAggregator:
412
+ """Statistics aggregator class."""
413
+
253
414
  NAMES = ["minimum", "maximum", "sums", "squares", "count", "has_nans"]
254
415
 
255
- def __init__(self, owner, dates, variables_names, allow_nans):
416
+ def __init__(
417
+ self, owner: TmpStatistics, dates: list[datetime.datetime], variables_names: list[str], allow_nans: bool
418
+ ) -> None:
419
+ """Initialize StatAggregator.
420
+
421
+ Parameters
422
+ ----------
423
+ owner : TmpStatistics
424
+ The owner TmpStatistics instance.
425
+ dates : list of datetime.datetime
426
+ List of dates to aggregate.
427
+ variables_names : list of str
428
+ List of variable names.
429
+ allow_nans : bool
430
+ Whether to allow NaN values.
431
+ """
256
432
  dates = sorted(dates)
257
433
  dates = to_datetimes(dates)
258
434
  assert dates, "No dates selected"
259
435
  self.owner = owner
260
436
  self.dates = dates
437
+ self._number_of_dates = len(dates)
438
+ self._set_of_dates = set(dates)
261
439
  self.variables_names = variables_names
262
440
  self.allow_nans = allow_nans
263
441
 
264
- self.shape = (len(self.dates), len(self.variables_names))
442
+ self.shape = (self._number_of_dates, len(self.variables_names))
265
443
  LOG.debug(f"Aggregating statistics on shape={self.shape}. Variables : {self.variables_names}")
266
444
 
267
445
  self.minimum = np.full(self.shape, np.nan, dtype=np.float64)
@@ -273,12 +451,16 @@ class StatAggregator:
273
451
 
274
452
  self._read()
275
453
 
276
- def _read(self):
454
+ def _read(self) -> None:
455
+ """Read and aggregate statistics data from files."""
456
+
277
457
  def check_type(a, b):
278
- a = list(a)
279
- b = list(b)
280
- a = a[0] if a else None
281
- b = b[0] if b else None
458
+ if not isinstance(a, set):
459
+ a = set(list(a))
460
+ if not isinstance(b, set):
461
+ b = set(list(b))
462
+ a = next(iter(a)) if a else None
463
+ b = next(iter(b)) if b else None
282
464
  assert type(a) is type(b), (type(a), type(b))
283
465
 
284
466
  found = set()
@@ -294,20 +476,20 @@ class StatAggregator:
294
476
  for n in self.NAMES:
295
477
  assert n in stats, (n, list(stats.keys()))
296
478
  _dates = to_datetimes(_dates)
297
- check_type(_dates, self.dates)
479
+ check_type(_dates, self._set_of_dates)
298
480
  if found:
299
- check_type(found, self.dates)
481
+ check_type(found, self._set_of_dates)
300
482
  assert found.isdisjoint(_dates), "Duplicate dates found in precomputed statistics"
301
483
 
302
484
  # filter dates
303
- dates = set(_dates) & set(self.dates)
485
+ dates = set(_dates) & self._set_of_dates
304
486
 
305
487
  if not dates:
306
488
  # dates have been completely filtered for this chunk
307
489
  continue
308
490
 
309
491
  # filter data
310
- bitmap = np.isin(_dates, self.dates)
492
+ bitmap = np.array([d in self._set_of_dates for d in _dates])
311
493
  for k in self.NAMES:
312
494
  stats[k] = stats[k][bitmap]
313
495
 
@@ -323,11 +505,18 @@ class StatAggregator:
323
505
 
324
506
  for d in self.dates:
325
507
  assert d in found, f"Statistics for date {d} not precomputed."
326
- assert len(self.dates) == len(found), "Not all dates found in precomputed statistics"
327
- assert len(self.dates) == offset, "Not all dates found in precomputed statistics."
508
+ assert self._number_of_dates == len(found), "Not all dates found in precomputed statistics"
509
+ assert self._number_of_dates == offset, "Not all dates found in precomputed statistics."
328
510
  LOG.debug(f"Statistics for {len(found)} dates found.")
329
511
 
330
- def aggregate(self):
512
+ def aggregate(self) -> Summary:
513
+ """Aggregate the statistics data.
514
+
515
+ Returns
516
+ -------
517
+ Summary
518
+ The aggregated statistics summary.
519
+ """
331
520
  minimum = np.nanmin(self.minimum, axis=0)
332
521
  maximum = np.nanmax(self.maximum, axis=0)
333
522
 
@@ -9,6 +9,7 @@
9
9
 
10
10
  import json
11
11
  from collections import defaultdict
12
+ from typing import Any
12
13
 
13
14
  import numpy as np
14
15
 
@@ -28,15 +29,32 @@ class Summary(dict):
28
29
  "has_nans",
29
30
  ] # order matter for __str__.
30
31
 
31
- def __init__(self, **kwargs):
32
+ def __init__(self, **kwargs: Any) -> None:
33
+ """Initialize the Summary object with given keyword arguments.
34
+
35
+ Parameters
36
+ ----------
37
+ **kwargs : Any
38
+ Arbitrary keyword arguments representing summary statistics.
39
+ """
32
40
  super().__init__(**kwargs)
33
41
  self.check()
34
42
 
35
43
  @property
36
- def size(self):
44
+ def size(self) -> int:
45
+ """Get the size of the summary, which is the number of variables."""
37
46
  return len(self["variables_names"])
38
47
 
39
- def check(self):
48
+ def check(self) -> None:
49
+ """Perform checks on the summary statistics to ensure they are valid.
50
+
51
+ Raises
52
+ ------
53
+ AssertionError
54
+ If any of the checks fail.
55
+ StatisticsValueError
56
+ If any of the statistical checks fail.
57
+ """
40
58
  for k, v in self.items():
41
59
  if k == "variables_names":
42
60
  assert len(v) == self.size
@@ -63,7 +81,14 @@ class Summary(dict):
63
81
  e.args += (i, name)
64
82
  raise
65
83
 
66
- def __str__(self):
84
+ def __str__(self) -> str:
85
+ """Return a string representation of the summary statistics.
86
+
87
+ Returns
88
+ -------
89
+ str
90
+ A formatted string of the summary statistics.
91
+ """
67
92
  header = ["Variables"] + self.STATS_NAMES
68
93
  out = [" ".join(header)]
69
94
 
@@ -73,7 +98,16 @@ class Summary(dict):
73
98
  ]
74
99
  return "\n".join(out)
75
100
 
76
- def save(self, filename, **metadata):
101
+ def save(self, filename: str, **metadata: Any) -> None:
102
+ """Save the summary statistics to a JSON file.
103
+
104
+ Parameters
105
+ ----------
106
+ filename : str
107
+ The name of the file to save the summary statistics.
108
+ **metadata : Any
109
+ Additional metadata to include in the JSON file.
110
+ """
77
111
  assert filename.endswith(".json"), filename
78
112
  dic = {}
79
113
  for k in self.STATS_NAMES:
@@ -89,7 +123,19 @@ class Summary(dict):
89
123
  with open(filename, "w") as f:
90
124
  json.dump(out, f, indent=2)
91
125
 
92
- def load(self, filename):
126
+ def load(self, filename: str) -> "Summary":
127
+ """Load the summary statistics from a JSON file.
128
+
129
+ Parameters
130
+ ----------
131
+ filename : str
132
+ The name of the file to load the summary statistics from.
133
+
134
+ Returns
135
+ -------
136
+ Summary
137
+ The loaded Summary object.
138
+ """
93
139
  assert filename.endswith(".json"), filename
94
140
  with open(filename) as f:
95
141
  dic = json.load(f)
@@ -0,0 +1,76 @@
1
+ # (C) Copyright 2025- Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+ import tempfile
11
+ from typing import Any
12
+ from typing import Dict
13
+ from typing import List
14
+ from typing import Optional
15
+ from typing import Union
16
+
17
+ import yaml
18
+
19
+ from anemoi.datasets.create import creator_factory
20
+
21
+
22
+ class TestingContext:
23
+ pass
24
+
25
+
26
+ def create_dataset(
27
+ *,
28
+ config: Union[str, Dict[str, Any]],
29
+ output: Optional[str],
30
+ delta: Optional[List[str]] = None,
31
+ is_test: bool = False,
32
+ ) -> str:
33
+ """Create a dataset based on the provided configuration.
34
+
35
+ Parameters
36
+ ----------
37
+ config : Union[str, Dict[str, Any]]
38
+ The configuration for the dataset. Can be a path to a YAML file or a dictionary.
39
+ output : Optional[str]
40
+ The output path for the dataset. If None, a temporary directory will be created.
41
+ delta : Optional[List[str]], optional
42
+ List of delta for secondary statistics, by default None.
43
+ is_test : bool, optional
44
+ Flag indicating if the dataset creation is for testing purposes, by default False.
45
+
46
+ Returns
47
+ -------
48
+ str
49
+ The path to the created dataset.
50
+ """
51
+ if isinstance(config, dict):
52
+ temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml")
53
+ yaml.dump(config, temp_file)
54
+ config = temp_file.name
55
+
56
+ if output is None:
57
+ output = tempfile.mkdtemp(suffix=".zarr")
58
+
59
+ creator_factory("init", config=config, path=output, overwrite=True, test=is_test).run()
60
+ creator_factory("load", path=output).run()
61
+ creator_factory("finalise", path=output).run()
62
+ creator_factory("patch", path=output).run()
63
+
64
+ if delta is not None:
65
+ creator_factory("init_additions", path=output, delta=delta).run()
66
+ creator_factory("run_additions", path=output, delta=delta).run()
67
+ creator_factory("finalise_additions", path=output, delta=delta).run()
68
+
69
+ creator_factory("cleanup", path=output).run()
70
+
71
+ if delta is not None:
72
+ creator_factory("cleanup", path=output, delta=delta).run()
73
+
74
+ creator_factory("verify", path=output).run()
75
+
76
+ return output
@@ -1,4 +1,4 @@
1
- # (C) Copyright 2024 Anemoi contributors.
1
+ # (C) Copyright 2025- Anemoi contributors.
2
2
  #
3
3
  # This software is licensed under the terms of the Apache Licence Version 2.0
4
4
  # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
@@ -7,6 +7,9 @@
7
7
  # granted to it by virtue of its status as an intergovernmental organisation
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
+ import datetime
11
+ from typing import List
10
12
 
11
- def execute(context, input, *args, **kwargs):
12
- return input
13
+ Date = datetime.datetime
14
+
15
+ DateList = List[Date]