anemoi-datasets 0.5.26__py3-none-any.whl → 0.5.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. anemoi/datasets/__init__.py +1 -2
  2. anemoi/datasets/_version.py +16 -3
  3. anemoi/datasets/commands/check.py +1 -1
  4. anemoi/datasets/commands/copy.py +1 -2
  5. anemoi/datasets/commands/create.py +1 -1
  6. anemoi/datasets/commands/inspect.py +27 -35
  7. anemoi/datasets/commands/recipe/__init__.py +93 -0
  8. anemoi/datasets/commands/recipe/format.py +55 -0
  9. anemoi/datasets/commands/recipe/migrate.py +555 -0
  10. anemoi/datasets/commands/validate.py +59 -0
  11. anemoi/datasets/compute/recentre.py +3 -6
  12. anemoi/datasets/create/__init__.py +64 -26
  13. anemoi/datasets/create/check.py +10 -12
  14. anemoi/datasets/create/chunks.py +1 -2
  15. anemoi/datasets/create/config.py +5 -6
  16. anemoi/datasets/create/input/__init__.py +44 -65
  17. anemoi/datasets/create/input/action.py +296 -238
  18. anemoi/datasets/create/input/context/__init__.py +71 -0
  19. anemoi/datasets/create/input/context/field.py +54 -0
  20. anemoi/datasets/create/input/data_sources.py +7 -9
  21. anemoi/datasets/create/input/misc.py +2 -75
  22. anemoi/datasets/create/input/repeated_dates.py +11 -130
  23. anemoi/datasets/{utils → create/input/result}/__init__.py +10 -1
  24. anemoi/datasets/create/input/{result.py → result/field.py} +36 -120
  25. anemoi/datasets/create/input/trace.py +1 -1
  26. anemoi/datasets/create/patch.py +1 -2
  27. anemoi/datasets/create/persistent.py +3 -5
  28. anemoi/datasets/create/size.py +1 -3
  29. anemoi/datasets/create/sources/accumulations.py +120 -145
  30. anemoi/datasets/create/sources/accumulations2.py +20 -53
  31. anemoi/datasets/create/sources/anemoi_dataset.py +46 -42
  32. anemoi/datasets/create/sources/constants.py +39 -40
  33. anemoi/datasets/create/sources/empty.py +22 -19
  34. anemoi/datasets/create/sources/fdb.py +133 -0
  35. anemoi/datasets/create/sources/forcings.py +29 -29
  36. anemoi/datasets/create/sources/grib.py +94 -78
  37. anemoi/datasets/create/sources/grib_index.py +57 -55
  38. anemoi/datasets/create/sources/hindcasts.py +57 -59
  39. anemoi/datasets/create/sources/legacy.py +10 -62
  40. anemoi/datasets/create/sources/mars.py +121 -149
  41. anemoi/datasets/create/sources/netcdf.py +28 -25
  42. anemoi/datasets/create/sources/opendap.py +28 -26
  43. anemoi/datasets/create/sources/patterns.py +4 -6
  44. anemoi/datasets/create/sources/recentre.py +46 -48
  45. anemoi/datasets/create/sources/repeated_dates.py +44 -0
  46. anemoi/datasets/create/sources/source.py +26 -51
  47. anemoi/datasets/create/sources/tendencies.py +68 -98
  48. anemoi/datasets/create/sources/xarray.py +4 -6
  49. anemoi/datasets/create/sources/xarray_support/__init__.py +40 -36
  50. anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -12
  51. anemoi/datasets/create/sources/xarray_support/field.py +20 -16
  52. anemoi/datasets/create/sources/xarray_support/fieldlist.py +11 -15
  53. anemoi/datasets/create/sources/xarray_support/flavour.py +42 -42
  54. anemoi/datasets/create/sources/xarray_support/grid.py +15 -9
  55. anemoi/datasets/create/sources/xarray_support/metadata.py +19 -128
  56. anemoi/datasets/create/sources/xarray_support/patch.py +4 -6
  57. anemoi/datasets/create/sources/xarray_support/time.py +10 -13
  58. anemoi/datasets/create/sources/xarray_support/variable.py +21 -21
  59. anemoi/datasets/create/sources/xarray_zarr.py +28 -25
  60. anemoi/datasets/create/sources/zenodo.py +43 -41
  61. anemoi/datasets/create/statistics/__init__.py +3 -6
  62. anemoi/datasets/create/testing.py +4 -0
  63. anemoi/datasets/create/typing.py +1 -2
  64. anemoi/datasets/create/utils.py +0 -43
  65. anemoi/datasets/create/zarr.py +7 -2
  66. anemoi/datasets/data/__init__.py +15 -6
  67. anemoi/datasets/data/complement.py +7 -12
  68. anemoi/datasets/data/concat.py +5 -8
  69. anemoi/datasets/data/dataset.py +48 -47
  70. anemoi/datasets/data/debug.py +7 -9
  71. anemoi/datasets/data/ensemble.py +4 -6
  72. anemoi/datasets/data/fill_missing.py +7 -10
  73. anemoi/datasets/data/forwards.py +22 -26
  74. anemoi/datasets/data/grids.py +12 -168
  75. anemoi/datasets/data/indexing.py +9 -12
  76. anemoi/datasets/data/interpolate.py +7 -15
  77. anemoi/datasets/data/join.py +8 -12
  78. anemoi/datasets/data/masked.py +6 -11
  79. anemoi/datasets/data/merge.py +5 -9
  80. anemoi/datasets/data/misc.py +41 -45
  81. anemoi/datasets/data/missing.py +11 -16
  82. anemoi/datasets/data/observations/__init__.py +8 -14
  83. anemoi/datasets/data/padded.py +3 -5
  84. anemoi/datasets/data/records/backends/__init__.py +2 -2
  85. anemoi/datasets/data/rescale.py +5 -12
  86. anemoi/datasets/data/rolling_average.py +141 -0
  87. anemoi/datasets/data/select.py +13 -16
  88. anemoi/datasets/data/statistics.py +4 -7
  89. anemoi/datasets/data/stores.py +22 -29
  90. anemoi/datasets/data/subset.py +8 -11
  91. anemoi/datasets/data/unchecked.py +7 -11
  92. anemoi/datasets/data/xy.py +25 -21
  93. anemoi/datasets/dates/__init__.py +15 -18
  94. anemoi/datasets/dates/groups.py +7 -10
  95. anemoi/datasets/dumper.py +76 -0
  96. anemoi/datasets/grids.py +4 -185
  97. anemoi/datasets/schemas/recipe.json +131 -0
  98. anemoi/datasets/testing.py +93 -7
  99. anemoi/datasets/validate.py +598 -0
  100. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/METADATA +7 -4
  101. anemoi_datasets-0.5.28.dist-info/RECORD +134 -0
  102. anemoi/datasets/create/filter.py +0 -48
  103. anemoi/datasets/create/input/concat.py +0 -164
  104. anemoi/datasets/create/input/context.py +0 -89
  105. anemoi/datasets/create/input/empty.py +0 -54
  106. anemoi/datasets/create/input/filter.py +0 -118
  107. anemoi/datasets/create/input/function.py +0 -233
  108. anemoi/datasets/create/input/join.py +0 -130
  109. anemoi/datasets/create/input/pipe.py +0 -66
  110. anemoi/datasets/create/input/step.py +0 -177
  111. anemoi/datasets/create/input/template.py +0 -162
  112. anemoi_datasets-0.5.26.dist-info/RECORD +0 -131
  113. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/WHEEL +0 -0
  114. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/entry_points.txt +0 -0
  115. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/licenses/LICENSE +0 -0
  116. {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/top_level.txt +0 -0
@@ -8,12 +8,9 @@
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
10
 
11
+ from collections.abc import Callable
11
12
  from functools import wraps
12
13
  from typing import Any
13
- from typing import Callable
14
- from typing import List
15
- from typing import Tuple
16
- from typing import Union
17
14
 
18
15
  import numpy as np
19
16
  from numpy.typing import NDArray
@@ -23,7 +20,7 @@ from .dataset import Shape
23
20
  from .dataset import TupleIndex
24
21
 
25
22
 
26
- def _tuple_with_slices(t: TupleIndex, shape: Shape) -> Tuple[TupleIndex, Tuple[int, ...]]:
23
+ def _tuple_with_slices(t: TupleIndex, shape: Shape) -> tuple[TupleIndex, tuple[int, ...]]:
27
24
  """Replace all integers in a tuple with slices, so we preserve the dimensionality.
28
25
 
29
26
  Parameters:
@@ -87,7 +84,7 @@ def _index_to_tuple(index: FullIndex, shape: Shape) -> TupleIndex:
87
84
  raise ValueError(f"Invalid index: {index}")
88
85
 
89
86
 
90
- def index_to_slices(index: Union[int, slice, Tuple], shape: Shape) -> Tuple[TupleIndex, Tuple[int, ...]]:
87
+ def index_to_slices(index: int | slice | tuple, shape: Shape) -> tuple[TupleIndex, tuple[int, ...]]:
91
88
  """Convert an index to a tuple of slices, with the same dimensionality as the shape.
92
89
 
93
90
  Parameters:
@@ -100,7 +97,7 @@ def index_to_slices(index: Union[int, slice, Tuple], shape: Shape) -> Tuple[Tupl
100
97
  return _tuple_with_slices(_index_to_tuple(index, shape), shape)
101
98
 
102
99
 
103
- def apply_index_to_slices_changes(result: NDArray[Any], changes: Tuple[int, ...]) -> NDArray[Any]:
100
+ def apply_index_to_slices_changes(result: NDArray[Any], changes: tuple[int, ...]) -> NDArray[Any]:
104
101
  """Apply changes to the result array based on the slices.
105
102
 
106
103
  Parameters:
@@ -118,7 +115,7 @@ def apply_index_to_slices_changes(result: NDArray[Any], changes: Tuple[int, ...]
118
115
  return result
119
116
 
120
117
 
121
- def update_tuple(t: Tuple, index: int, value: Any) -> Tuple[Tuple, Any]:
118
+ def update_tuple(t: tuple, index: int, value: Any) -> tuple[tuple, Any]:
122
119
  """Replace the elements of a tuple at the given index with a new value.
123
120
 
124
121
  Parameters:
@@ -135,7 +132,7 @@ def update_tuple(t: Tuple, index: int, value: Any) -> Tuple[Tuple, Any]:
135
132
  return tuple(t), prev
136
133
 
137
134
 
138
- def length_to_slices(index: slice, lengths: List[int]) -> List[Union[slice, None]]:
135
+ def length_to_slices(index: slice, lengths: list[int]) -> list[slice | None]:
139
136
  """Convert an index to a list of slices, given the lengths of the dimensions.
140
137
 
141
138
  Parameters:
@@ -174,7 +171,7 @@ def length_to_slices(index: slice, lengths: List[int]) -> List[Union[slice, None
174
171
  return result
175
172
 
176
173
 
177
- def _as_tuples(index: Tuple) -> Tuple:
174
+ def _as_tuples(index: tuple) -> tuple:
178
175
  """Convert elements of the index to tuples if they are lists or arrays.
179
176
 
180
177
  Parameters:
@@ -219,7 +216,7 @@ def expand_list_indexing(method: Callable[..., NDArray[Any]]) -> Callable[..., N
219
216
  if not any(isinstance(i, (list, tuple)) for i in index):
220
217
  return method(self, index)
221
218
 
222
- which: List[int] = []
219
+ which: list[int] = []
223
220
  for i, idx in enumerate(index):
224
221
  if isinstance(idx, (list, tuple)):
225
222
  which.append(i)
@@ -241,7 +238,7 @@ def expand_list_indexing(method: Callable[..., NDArray[Any]]) -> Callable[..., N
241
238
  return wrapper
242
239
 
243
240
 
244
- def make_slice_or_index_from_list_or_tuple(indices: List[int]) -> Union[List[int], slice]:
241
+ def make_slice_or_index_from_list_or_tuple(indices: list[int]) -> list[int] | slice:
245
242
  """Convert a list or tuple of indices to a slice or an index, if possible.
246
243
 
247
244
  Parameters:
@@ -12,12 +12,6 @@ import datetime
12
12
  import logging
13
13
  from functools import cached_property
14
14
  from typing import Any
15
- from typing import Dict
16
- from typing import List
17
- from typing import Optional
18
- from typing import Set
19
- from typing import Tuple
20
- from typing import Union
21
15
 
22
16
  import numpy as np
23
17
  from anemoi.utils.dates import frequency_to_timedelta
@@ -193,7 +187,7 @@ class InterpolateFrequency(Forwards):
193
187
  return Node(self, [self.forward.tree()], frequency=self.frequency)
194
188
 
195
189
  @cached_property
196
- def missing(self) -> Set[int]:
190
+ def missing(self) -> set[int]:
197
191
  """Get the missing data indices."""
198
192
  result = []
199
193
  j = 0
@@ -204,10 +198,10 @@ class InterpolateFrequency(Forwards):
204
198
  result.append(j)
205
199
  j += 1
206
200
 
207
- result = set(x for x in result if x < self._len)
201
+ result = {x for x in result if x < self._len}
208
202
  return result
209
203
 
210
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
204
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
211
205
  """Get the metadata specific to the InterpolateFrequency subclass.
212
206
 
213
207
  Returns
@@ -221,9 +215,7 @@ class InterpolateFrequency(Forwards):
221
215
 
222
216
 
223
217
  class InterpolateNearest(Forwards):
224
- def __init__(
225
- self, dataset: Dataset, interpolate_variables: List[str], max_distance: Optional[float] = None
226
- ) -> None:
218
+ def __init__(self, dataset: Dataset, interpolate_variables: list[str], max_distance: float | None = None) -> None:
227
219
  """Initialize the InterpolateNearest class.
228
220
 
229
221
  Parameters
@@ -262,7 +254,7 @@ class InterpolateNearest(Forwards):
262
254
  return self.forward.shape
263
255
 
264
256
  @property
265
- def metadata(self) -> Dict[str, Any]:
257
+ def metadata(self) -> dict[str, Any]:
266
258
  return self.forward.metadata()
267
259
 
268
260
  @staticmethod
@@ -281,12 +273,12 @@ class InterpolateNearest(Forwards):
281
273
  result = target_data[(slice(None),) + index[1:]]
282
274
  return apply_index_to_slices_changes(result, changes)
283
275
 
284
- def __getitem__(self, index: Union[int, slice, Tuple[Union[int, slice], ...]]) -> NDArray[Any]:
276
+ def __getitem__(self, index: int | slice | tuple[int | slice, ...]) -> NDArray[Any]:
285
277
  if isinstance(index, (int, slice)):
286
278
  index = (index, slice(None), slice(None), slice(None))
287
279
  return self._get_tuple(index)
288
280
 
289
- def subclass_metadata_specific(self) -> Dict[str, Any]:
281
+ def subclass_metadata_specific(self) -> dict[str, Any]:
290
282
  return {
291
283
  "interpolate_variables": self.vars,
292
284
  }
@@ -12,10 +12,6 @@ import datetime
12
12
  import logging
13
13
  from functools import cached_property
14
14
  from typing import Any
15
- from typing import Dict
16
- from typing import List
17
- from typing import Optional
18
- from typing import Set
19
15
 
20
16
  import numpy as np
21
17
  from numpy.typing import NDArray
@@ -182,10 +178,10 @@ class Join(Combined):
182
178
  return Select(self, indices, {"overlay": variables})
183
179
 
184
180
  @cached_property
185
- def variables(self) -> List[str]:
181
+ def variables(self) -> list[str]:
186
182
  """Get the variables of the joined dataset."""
187
183
  seen = set()
188
- result: List[str] = []
184
+ result: list[str] = []
189
185
  for d in reversed(self.datasets):
190
186
  for v in reversed(d.variables):
191
187
  while v in seen:
@@ -196,7 +192,7 @@ class Join(Combined):
196
192
  return result
197
193
 
198
194
  @property
199
- def variables_metadata(self) -> Dict[str, Any]:
195
+ def variables_metadata(self) -> dict[str, Any]:
200
196
  """Get the metadata of the variables."""
201
197
  result = {}
202
198
  variables = [v for v in self.variables if not (v.startswith("(") and v.endswith(")"))]
@@ -216,18 +212,18 @@ class Join(Combined):
216
212
  return result
217
213
 
218
214
  @cached_property
219
- def name_to_index(self) -> Dict[str, int]:
215
+ def name_to_index(self) -> dict[str, int]:
220
216
  """Get the mapping of variable names to indices."""
221
217
  return {k: i for i, k in enumerate(self.variables)}
222
218
 
223
219
  @property
224
- def statistics(self) -> Dict[str, NDArray[Any]]:
220
+ def statistics(self) -> dict[str, NDArray[Any]]:
225
221
  """Get the statistics of the joined dataset."""
226
222
  return {
227
223
  k: np.concatenate([d.statistics[k] for d in self.datasets], axis=0) for k in self.datasets[0].statistics
228
224
  }
229
225
 
230
- def statistics_tendencies(self, delta: Optional[datetime.timedelta] = None) -> Dict[str, NDArray[Any]]:
226
+ def statistics_tendencies(self, delta: datetime.timedelta | None = None) -> dict[str, NDArray[Any]]:
231
227
  """Get the statistics tendencies of the joined dataset.
232
228
 
233
229
  Parameters
@@ -268,9 +264,9 @@ class Join(Combined):
268
264
  assert False
269
265
 
270
266
  @cached_property
271
- def missing(self) -> Set[int]:
267
+ def missing(self) -> set[int]:
272
268
  """Get the missing data indices."""
273
- result: Set[int] = set()
269
+ result: set[int] = set()
274
270
  for d in self.datasets:
275
271
  result = result | d.missing
276
272
  return result
@@ -11,11 +11,6 @@
11
11
  import logging
12
12
  from functools import cached_property
13
13
  from typing import Any
14
- from typing import Dict
15
- from typing import List
16
- from typing import Optional
17
- from typing import Tuple
18
- from typing import Union
19
14
 
20
15
  import numpy as np
21
16
  from numpy.typing import NDArray
@@ -117,7 +112,7 @@ class Masked(Forwards):
117
112
  result = apply_index_to_slices_changes(result, changes)
118
113
  return result
119
114
 
120
- def collect_supporting_arrays(self, collected: List[Tuple], *path: Any) -> None:
115
+ def collect_supporting_arrays(self, collected: list[tuple], *path: Any) -> None:
121
116
  """Collect supporting arrays.
122
117
 
123
118
  Parameters
@@ -134,7 +129,7 @@ class Masked(Forwards):
134
129
  class Thinning(Masked):
135
130
  """A class to represent a thinned dataset."""
136
131
 
137
- def __init__(self, forward: Dataset, thinning: Optional[int], method: str) -> None:
132
+ def __init__(self, forward: Dataset, thinning: int | None, method: str) -> None:
138
133
  """Initialize the Thinning class.
139
134
 
140
135
  Parameters
@@ -195,7 +190,7 @@ class Thinning(Masked):
195
190
  """
196
191
  return Node(self, [self.forward.tree()], thinning=self.thinning, method=self.method)
197
192
 
198
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
193
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
199
194
  """Get the metadata specific to the Thinning subclass.
200
195
 
201
196
  Returns
@@ -209,7 +204,7 @@ class Thinning(Masked):
209
204
  class Cropping(Masked):
210
205
  """A class to represent a cropped dataset."""
211
206
 
212
- def __init__(self, forward: Dataset, area: Union[Dataset, Tuple[float, float, float, float]]) -> None:
207
+ def __init__(self, forward: Dataset, area: Dataset | tuple[float, float, float, float]) -> None:
213
208
  """Initialize the Cropping class.
214
209
 
215
210
  Parameters
@@ -245,7 +240,7 @@ class Cropping(Masked):
245
240
  """
246
241
  return Node(self, [self.forward.tree()], area=self.area)
247
242
 
248
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
243
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
249
244
  """Get the metadata specific to the Cropping subclass.
250
245
 
251
246
  Returns
@@ -314,7 +309,7 @@ class TrimEdge(Masked):
314
309
  """
315
310
  return Node(self, [self.forward.tree()], edge=self.edge)
316
311
 
317
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
312
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
318
313
  """Get the metadata specific to the TrimEdge subclass.
319
314
 
320
315
  Returns
@@ -12,10 +12,6 @@ import datetime
12
12
  import logging
13
13
  from functools import cached_property
14
14
  from typing import Any
15
- from typing import Dict
16
- from typing import List
17
- from typing import Set
18
- from typing import Tuple
19
15
 
20
16
  import numpy as np
21
17
  from numpy.typing import NDArray
@@ -40,7 +36,7 @@ LOG = logging.getLogger(__name__)
40
36
  class Merge(Combined):
41
37
  """A class to merge multiple datasets along the dates axis, handling gaps in dates if allowed."""
42
38
 
43
- def __init__(self, datasets: List[Dataset], allow_gaps_in_dates: bool = False) -> None:
39
+ def __init__(self, datasets: list[Dataset], allow_gaps_in_dates: bool = False) -> None:
44
40
  """Initialize the Merge object.
45
41
 
46
42
  Parameters
@@ -128,10 +124,10 @@ class Merge(Combined):
128
124
  return self._frequency
129
125
 
130
126
  @cached_property
131
- def missing(self) -> Set[int]:
127
+ def missing(self) -> set[int]:
132
128
  """Get the indices of missing dates in the merged dataset."""
133
129
  # TODO: optimize
134
- result: Set[int] = set()
130
+ result: set[int] = set()
135
131
 
136
132
  for i, (dataset, row) in enumerate(self._indices):
137
133
  if dataset == self._missing_index:
@@ -192,7 +188,7 @@ class Merge(Combined):
192
188
  """
193
189
  return Node(self, [d.tree() for d in self.datasets], allow_gaps_in_dates=self.allow_gaps_in_dates)
194
190
 
195
- def metadata_specific(self) -> Dict[str, Any]:
191
+ def metadata_specific(self) -> dict[str, Any]:
196
192
  """Get the specific metadata for the merged dataset.
197
193
 
198
194
  Returns
@@ -265,7 +261,7 @@ class Merge(Combined):
265
261
  return np.stack([self[i] for i in range(*s.indices(self._len))])
266
262
 
267
263
 
268
- def merge_factory(args: Tuple, kwargs: Dict[str, Any]) -> Dataset:
264
+ def merge_factory(args: tuple, kwargs: dict[str, Any]) -> Dataset:
269
265
  """Factory function to create a merged dataset.
270
266
 
271
267
  Parameters
@@ -15,11 +15,6 @@ import os
15
15
  from pathlib import PurePath
16
16
  from typing import TYPE_CHECKING
17
17
  from typing import Any
18
- from typing import Dict
19
- from typing import List
20
- from typing import Optional
21
- from typing import Tuple
22
- from typing import Union
23
18
 
24
19
  import numpy as np
25
20
  import zarr
@@ -33,7 +28,7 @@ if TYPE_CHECKING:
33
28
  LOG = logging.getLogger(__name__)
34
29
 
35
30
 
36
- def load_config() -> Dict[str, Any]:
31
+ def load_config() -> dict[str, Any]:
37
32
  """Load the configuration settings.
38
33
 
39
34
  Returns
@@ -110,10 +105,10 @@ def round_datetime(d: np.datetime64, dates: NDArray[np.datetime64], up: bool) ->
110
105
 
111
106
 
112
107
  def _as_date(
113
- d: Union[int, str, np.datetime64, datetime.date],
108
+ d: int | str | np.datetime64 | datetime.date,
114
109
  dates: NDArray[np.datetime64],
115
110
  last: bool,
116
- frequency: Optional[datetime.timedelta] = None,
111
+ frequency: datetime.timedelta | None = None,
117
112
  ) -> np.datetime64:
118
113
  """Convert a date to a numpy datetime64 object, rounding to the nearest date in a list of dates.
119
114
 
@@ -221,8 +216,8 @@ def _as_date(
221
216
 
222
217
  if "-" in d and ":" in d:
223
218
  date, time = d.replace(" ", "T").split("T")
224
- year, month, day = [int(_) for _ in date.split("-")]
225
- hour, minute, second = [int(_) for _ in time.split(":")]
219
+ year, month, day = (int(_) for _ in date.split("-"))
220
+ hour, minute, second = (int(_) for _ in time.split(":"))
226
221
  return _as_date(
227
222
  np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour:02}:{minute:02}:{second:02}"),
228
223
  dates,
@@ -258,9 +253,9 @@ def _as_date(
258
253
 
259
254
 
260
255
  def as_first_date(
261
- d: Union[int, str, np.datetime64, datetime.date],
256
+ d: int | str | np.datetime64 | datetime.date,
262
257
  dates: NDArray[np.datetime64],
263
- frequency: Optional[datetime.timedelta] = None,
258
+ frequency: datetime.timedelta | None = None,
264
259
  ) -> np.datetime64:
265
260
  """Convert a date to the first date in a list of dates.
266
261
 
@@ -282,9 +277,9 @@ def as_first_date(
282
277
 
283
278
 
284
279
  def as_last_date(
285
- d: Union[int, str, np.datetime64, datetime.date],
280
+ d: int | str | np.datetime64 | datetime.date,
286
281
  dates: NDArray[np.datetime64],
287
- frequency: Optional[datetime.timedelta] = None,
282
+ frequency: datetime.timedelta | None = None,
288
283
  ) -> np.datetime64:
289
284
  """Convert a date to the last date in a list of dates.
290
285
 
@@ -305,7 +300,7 @@ def as_last_date(
305
300
  return _as_date(d, dates, last=True, frequency=frequency)
306
301
 
307
302
 
308
- def _concat_or_join(datasets: List["Dataset"], kwargs: Dict[str, Any]) -> Tuple["Dataset", Dict[str, Any]]:
303
+ def _concat_or_join(datasets: list["Dataset"], kwargs: dict[str, Any]) -> tuple["Dataset", dict[str, Any]]:
309
304
  """Concatenate or join datasets based on their date ranges.
310
305
 
311
306
  Parameters
@@ -317,7 +312,7 @@ def _concat_or_join(datasets: List["Dataset"], kwargs: Dict[str, Any]) -> Tuple[
317
312
 
318
313
  Returns
319
314
  -------
320
- Tuple[Dataset, Dict[str, Any]]
315
+ tuple[Dataset, Dict[str, Any]]
321
316
  The concatenated or joined dataset and remaining arguments.
322
317
  """
323
318
  if "adjust" in kwargs:
@@ -339,12 +334,12 @@ def _concat_or_join(datasets: List["Dataset"], kwargs: Dict[str, Any]) -> Tuple[
339
334
  return Concat(datasets), kwargs
340
335
 
341
336
 
342
- def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) -> "Dataset":
337
+ def _open(a: str | PurePath | dict[str, Any] | list[Any] | tuple[Any, ...]) -> "Dataset":
343
338
  """Open a dataset from various input types.
344
339
 
345
340
  Parameters
346
341
  ----------
347
- a : Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]
342
+ a : Union[str, PurePath, Dict[str, Any], List[Any], tuple[Any, ...]]
348
343
  The input to open.
349
344
 
350
345
  Returns
@@ -390,10 +385,10 @@ def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) -
390
385
 
391
386
 
392
387
  def _auto_adjust(
393
- datasets: List["Dataset"],
394
- kwargs: Dict[str, Any],
395
- exclude: Optional[List[str]] = None,
396
- ) -> Tuple[List["Dataset"], Dict[str, Any]]:
388
+ datasets: list["Dataset"],
389
+ kwargs: dict[str, Any],
390
+ exclude: list[str] | None = None,
391
+ ) -> tuple[list["Dataset"], dict[str, Any]]:
397
392
  """Automatically adjust datasets based on specified criteria.
398
393
 
399
394
  Parameters
@@ -407,7 +402,7 @@ def _auto_adjust(
407
402
 
408
403
  Returns
409
404
  -------
410
- Tuple[List[Dataset], Dict[str, Any]]
405
+ tuple[List[Dataset], Dict[str, Any]]
411
406
  The adjusted datasets and remaining arguments.
412
407
  """
413
408
  if "adjust" not in kwargs:
@@ -620,7 +615,7 @@ def append_to_zarr(new_data: np.ndarray, new_dates: np.ndarray, zarr_path: str)
620
615
  # Re-open the zarr store to avoid root object accumulating memory.
621
616
  root = zarr.open(zarr_path, mode="a")
622
617
  # Convert new dates to strings (using str) regardless of input dtype.
623
- new_dates = np.array(new_dates, dtype="datetime64[ns]")
618
+ new_dates = np.array(new_dates, dtype="datetime64[s]")
624
619
  dates_ds = root["dates"]
625
620
  old_len = dates_ds.shape[0]
626
621
  dates_ds.resize((old_len + len(new_dates),))
@@ -633,19 +628,19 @@ def append_to_zarr(new_data: np.ndarray, new_dates: np.ndarray, zarr_path: str)
633
628
  data_ds[old_shape[0] :] = new_data
634
629
 
635
630
 
636
- def process_date(date: Any, big_dataset: Any) -> Tuple[np.ndarray, np.ndarray]:
631
+ def process_date(date: Any, big_dataset: "Dataset") -> tuple[np.ndarray, np.ndarray]:
637
632
  """Open the subset corresponding to the given date and return (date, subset).
638
633
 
639
634
  Parameters
640
635
  ----------
641
636
  date : Any
642
637
  The date to process.
643
- big_dataset : Any
638
+ big_dataset : Dataset
644
639
  The dataset to process.
645
640
 
646
641
  Returns
647
642
  -------
648
- Tuple[np.ndarray, np.ndarray]
643
+ tuple[np.ndarray, np.ndarray]
649
644
  The subset and the date.
650
645
  """
651
646
  print("Processing:", date, flush=True)
@@ -655,26 +650,24 @@ def process_date(date: Any, big_dataset: Any) -> Tuple[np.ndarray, np.ndarray]:
655
650
  return s, date
656
651
 
657
652
 
658
- def initialize_zarr_store(root: Any, big_dataset: Any, recipe: Dict[str, Any]) -> None:
653
+ def initialize_zarr_store(root: Any, big_dataset: "Dataset") -> None:
659
654
  """Initialize the Zarr store with the given dataset and recipe.
660
655
 
661
656
  Parameters
662
657
  ----------
663
658
  root : Any
664
- The root of the Zarr store.
665
- big_dataset : Any
659
+ The root Zarr store.
660
+ big_dataset : Dataset
666
661
  The dataset to initialize the store with.
667
- recipe : Dict[str, Any]
668
- The recipe for initializing the store.
669
662
  """
670
- ensembles = big_dataset.shape[1]
663
+ ensembles = big_dataset.shape[2]
671
664
  # Create or append to "dates" dataset.
672
665
  if "dates" not in root:
673
666
  full_length = len(big_dataset.dates)
674
667
  root.create_dataset("dates", data=np.array([], dtype="datetime64[s]"), chunks=(full_length,))
675
668
 
676
669
  if "data" not in root:
677
- dims = (1, len(big_dataset.variables), ensembles, big_dataset.grids[0])
670
+ dims = (1, len(big_dataset.variables), ensembles, big_dataset.shape[-1])
678
671
  root.create_dataset(
679
672
  "data",
680
673
  shape=dims,
@@ -694,25 +687,28 @@ def initialize_zarr_store(root: Any, big_dataset: Any, recipe: Dict[str, Any]) -
694
687
  if "latitudes" not in root or "longitudes" not in root:
695
688
  root.create_dataset("latitudes", data=big_dataset.latitudes, compressor=None)
696
689
  root.create_dataset("longitudes", data=big_dataset.longitudes, compressor=None)
697
-
690
+ for k, v in big_dataset.metadata().items():
691
+ if k not in root.attrs:
692
+ root.attrs[k] = v
698
693
  # Set store-wide attributes if not already set.
699
- if "frequency" not in root.attrs:
700
- root.attrs["frequency"] = "10m"
701
- root.attrs["resolution"] = "1km"
694
+ if "first_date" not in root.attrs:
695
+ root.attrs["first_date"] = big_dataset.metadata()["start_date"]
696
+ root.attrs["last_date"] = big_dataset.metadata()["end_date"]
697
+ root.attrs["resolution"] = big_dataset.resolution
702
698
  root.attrs["name_to_index"] = {k: i for i, k in enumerate(big_dataset.variables)}
703
- root.attrs["ensemble_dimension"] = 1
699
+ root.attrs["ensemble_dimension"] = 2
704
700
  root.attrs["field_shape"] = big_dataset.field_shape
705
701
  root.attrs["flatten_grid"] = True
706
- root.attrs["recipe"] = recipe
702
+ root.attrs["recipe"] = {}
707
703
 
708
704
 
709
- def _save_dataset(recipe: Dict[str, Any], zarr_path: str, n_workers: int = 1) -> None:
705
+ def _save_dataset(dataset: "Dataset", zarr_path: str, n_workers: int = 1) -> None:
710
706
  """Incrementally create (or update) a Zarr store from an Anemoi dataset.
711
707
 
712
708
  Parameters
713
709
  ----------
714
- recipe : Dict[str, Any]
715
- The recipe for creating the dataset.
710
+ dataset : Dataset
711
+ anemoi-dataset opened from python to save to Zarr store
716
712
  zarr_path : str
717
713
  The path to the Zarr store.
718
714
  n_workers : int, optional
@@ -728,13 +724,13 @@ def _save_dataset(recipe: Dict[str, Any], zarr_path: str, n_workers: int = 1) ->
728
724
  """
729
725
  from concurrent.futures import ProcessPoolExecutor
730
726
 
731
- full_ds = _open_dataset(recipe).mutate()
727
+ full_ds = dataset
732
728
  print("Opened full dataset.", flush=True)
733
729
 
734
730
  # Use ProcessPoolExecutor for parallel data extraction.
735
731
  # Workers return (date, subset) tuples.
736
732
  root = zarr.open(zarr_path, mode="a")
737
- initialize_zarr_store(root, full_ds, recipe)
733
+ initialize_zarr_store(root, full_ds)
738
734
  print("Zarr store initialised.", flush=True)
739
735
 
740
736
  existing_dates = np.array(sorted(root["dates"]), dtype="datetime64[s]")
@@ -12,11 +12,6 @@ import datetime
12
12
  import logging
13
13
  from functools import cached_property
14
14
  from typing import Any
15
- from typing import Dict
16
- from typing import List
17
- from typing import Set
18
- from typing import Tuple
19
- from typing import Union
20
15
 
21
16
  import numpy as np
22
17
  from numpy.typing import NDArray
@@ -49,7 +44,7 @@ class MissingDates(Forwards):
49
44
  List of missing dates.
50
45
  """
51
46
 
52
- def __init__(self, dataset: Dataset, missing_dates: List[Union[int, str]]) -> None:
47
+ def __init__(self, dataset: Dataset, missing_dates: list[int | str]) -> None:
53
48
  """Initializes the MissingDates class.
54
49
 
55
50
  Parameters
@@ -80,13 +75,13 @@ class MissingDates(Forwards):
80
75
  self.missing_dates.append(date)
81
76
 
82
77
  n = self.forward._len
83
- self._missing = set(i for i in self._missing if 0 <= i < n)
78
+ self._missing = {i for i in self._missing if 0 <= i < n}
84
79
  self.missing_dates = sorted(to_datetime(x) for x in self.missing_dates)
85
80
 
86
81
  assert len(self._missing), "No dates to force missing"
87
82
 
88
83
  @cached_property
89
- def missing(self) -> Set[int]:
84
+ def missing(self) -> set[int]:
90
85
  """Returns the set of missing indices."""
91
86
  return self._missing.union(self.forward.missing)
92
87
 
@@ -148,7 +143,7 @@ class MissingDates(Forwards):
148
143
  raise MissingDateError(f"Date {self.forward.dates[n]} is missing (index={n})")
149
144
 
150
145
  @property
151
- def reason(self) -> Dict[str, Any]:
146
+ def reason(self) -> dict[str, Any]:
152
147
  """Provides the reason for missing dates."""
153
148
  return {"missing_dates": self.missing_dates}
154
149
 
@@ -162,7 +157,7 @@ class MissingDates(Forwards):
162
157
  """
163
158
  return Node(self, [self.forward.tree()], **self.reason)
164
159
 
165
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
160
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
166
161
  """Provides metadata specific to the subclass.
167
162
 
168
163
  Returns
@@ -184,7 +179,7 @@ class SkipMissingDates(Forwards):
184
179
  The expected access pattern.
185
180
  """
186
181
 
187
- def __init__(self, dataset: Dataset, expected_access: Union[int, slice]) -> None:
182
+ def __init__(self, dataset: Dataset, expected_access: int | slice) -> None:
188
183
  """Initializes the SkipMissingDates class.
189
184
 
190
185
  Parameters
@@ -285,7 +280,7 @@ class SkipMissingDates(Forwards):
285
280
  return tuple(np.stack(_) for _ in result)
286
281
 
287
282
  @debug_indexing
288
- def _get_slice(self, s: slice) -> Tuple[NDArray[Any], ...]:
283
+ def _get_slice(self, s: slice) -> tuple[NDArray[Any], ...]:
289
284
  """Retrieves a slice of items.
290
285
 
291
286
  Parameters
@@ -303,7 +298,7 @@ class SkipMissingDates(Forwards):
303
298
  return tuple(np.stack(_) for _ in result)
304
299
 
305
300
  @debug_indexing
306
- def __getitem__(self, n: FullIndex) -> Tuple[NDArray[Any], ...]:
301
+ def __getitem__(self, n: FullIndex) -> tuple[NDArray[Any], ...]:
307
302
  """Retrieves the item at the given index.
308
303
 
309
304
  Parameters
@@ -339,7 +334,7 @@ class SkipMissingDates(Forwards):
339
334
  """
340
335
  return Node(self, [self.forward.tree()], expected_access=self.expected_access)
341
336
 
342
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
337
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
343
338
  """Provides metadata specific to the subclass.
344
339
 
345
340
  Returns
@@ -404,7 +399,7 @@ class MissingDataset(Forwards):
404
399
  return self._dates
405
400
 
406
401
  @property
407
- def missing(self) -> Set[int]:
402
+ def missing(self) -> set[int]:
408
403
  """Returns the set of missing indices."""
409
404
  return self._missing
410
405
 
@@ -436,7 +431,7 @@ class MissingDataset(Forwards):
436
431
  """
437
432
  return Node(self, [self.forward.tree()], start=self.start, end=self.end)
438
433
 
439
- def forwards_subclass_metadata_specific(self) -> Dict[str, Any]:
434
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
440
435
  """Provides metadata specific to the subclass.
441
436
 
442
437
  Returns