anemoi-datasets 0.5.27__py3-none-any.whl → 0.5.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/recipe/__init__.py +93 -0
  3. anemoi/datasets/commands/recipe/format.py +55 -0
  4. anemoi/datasets/commands/recipe/migrate.py +555 -0
  5. anemoi/datasets/create/__init__.py +46 -13
  6. anemoi/datasets/create/config.py +52 -53
  7. anemoi/datasets/create/input/__init__.py +43 -63
  8. anemoi/datasets/create/input/action.py +296 -236
  9. anemoi/datasets/create/input/context/__init__.py +71 -0
  10. anemoi/datasets/create/input/context/field.py +54 -0
  11. anemoi/datasets/create/input/data_sources.py +2 -1
  12. anemoi/datasets/create/input/misc.py +0 -71
  13. anemoi/datasets/create/input/repeated_dates.py +0 -114
  14. anemoi/datasets/create/input/result/__init__.py +17 -0
  15. anemoi/datasets/create/input/{result.py → result/field.py} +10 -92
  16. anemoi/datasets/create/sources/accumulate.py +517 -0
  17. anemoi/datasets/create/sources/accumulate_utils/__init__.py +8 -0
  18. anemoi/datasets/create/sources/accumulate_utils/covering_intervals.py +221 -0
  19. anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py +149 -0
  20. anemoi/datasets/create/sources/accumulate_utils/interval_generators.py +321 -0
  21. anemoi/datasets/create/sources/anemoi_dataset.py +46 -42
  22. anemoi/datasets/create/sources/constants.py +39 -38
  23. anemoi/datasets/create/sources/empty.py +26 -22
  24. anemoi/datasets/create/sources/forcings.py +29 -28
  25. anemoi/datasets/create/sources/grib.py +92 -72
  26. anemoi/datasets/create/sources/grib_index.py +102 -54
  27. anemoi/datasets/create/sources/hindcasts.py +56 -55
  28. anemoi/datasets/create/sources/legacy.py +10 -62
  29. anemoi/datasets/create/sources/mars.py +159 -154
  30. anemoi/datasets/create/sources/netcdf.py +28 -24
  31. anemoi/datasets/create/sources/opendap.py +28 -24
  32. anemoi/datasets/create/sources/recentre.py +42 -41
  33. anemoi/datasets/create/sources/repeated_dates.py +44 -0
  34. anemoi/datasets/create/sources/source.py +26 -48
  35. anemoi/datasets/create/sources/xarray_support/__init__.py +30 -24
  36. anemoi/datasets/create/sources/xarray_support/coordinates.py +1 -4
  37. anemoi/datasets/create/sources/xarray_support/field.py +4 -4
  38. anemoi/datasets/create/sources/xarray_support/flavour.py +2 -2
  39. anemoi/datasets/create/sources/xarray_support/patch.py +178 -5
  40. anemoi/datasets/create/sources/xarray_zarr.py +28 -24
  41. anemoi/datasets/create/sources/zenodo.py +43 -39
  42. anemoi/datasets/create/utils.py +0 -42
  43. anemoi/datasets/data/complement.py +26 -17
  44. anemoi/datasets/data/dataset.py +12 -0
  45. anemoi/datasets/data/grids.py +0 -152
  46. anemoi/datasets/data/masked.py +74 -13
  47. anemoi/datasets/data/missing.py +5 -0
  48. anemoi/datasets/data/rolling_average.py +141 -0
  49. anemoi/datasets/data/stores.py +7 -9
  50. anemoi/datasets/dates/__init__.py +2 -0
  51. anemoi/datasets/dumper.py +76 -0
  52. anemoi/datasets/grids.py +1 -178
  53. anemoi/datasets/schemas/recipe.json +131 -0
  54. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/METADATA +9 -6
  55. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/RECORD +59 -57
  56. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/WHEEL +1 -1
  57. anemoi/datasets/create/filter.py +0 -47
  58. anemoi/datasets/create/input/concat.py +0 -161
  59. anemoi/datasets/create/input/context.py +0 -86
  60. anemoi/datasets/create/input/empty.py +0 -53
  61. anemoi/datasets/create/input/filter.py +0 -117
  62. anemoi/datasets/create/input/function.py +0 -232
  63. anemoi/datasets/create/input/join.py +0 -129
  64. anemoi/datasets/create/input/pipe.py +0 -66
  65. anemoi/datasets/create/input/step.py +0 -173
  66. anemoi/datasets/create/input/template.py +0 -161
  67. anemoi/datasets/create/sources/accumulations.py +0 -1062
  68. anemoi/datasets/create/sources/accumulations2.py +0 -647
  69. anemoi/datasets/create/sources/tendencies.py +0 -198
  70. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/entry_points.txt +0 -0
  71. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/licenses/LICENSE +0 -0
  72. {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/top_level.txt +0 -0
@@ -96,48 +96,6 @@ def to_datetime(*args: Any, **kwargs: Any) -> datetime.datetime:
96
96
  return to_datetime_(*args, **kwargs)
97
97
 
98
98
 
99
- def make_list_int(value: str | list | tuple | int) -> list[int]:
100
- """Convert a string, list, tuple, or integer to a list of integers.
101
-
102
- Parameters
103
- ----------
104
- value : str or list or tuple or int
105
- The value to convert.
106
-
107
- Returns
108
- -------
109
- list[int]
110
- A list of integers.
111
-
112
- Raises
113
- ------
114
- ValueError
115
- If the value cannot be converted to a list of integers.
116
- """
117
- # Convert a string like "1/2/3" or "1/to/3" or "1/to/10/by/2" to a list of integers.
118
- # Moved to anemoi.utils.humanize
119
- # replace with from anemoi.utils.humanize import make_list_int
120
- # when anemoi-utils is released and pyproject.toml is updated
121
- if isinstance(value, str):
122
- if "/" not in value:
123
- return [value]
124
- bits = value.split("/")
125
- if len(bits) == 3 and bits[1].lower() == "to":
126
- value = list(range(int(bits[0]), int(bits[2]) + 1, 1))
127
-
128
- elif len(bits) == 5 and bits[1].lower() == "to" and bits[3].lower() == "by":
129
- value = list(range(int(bits[0]), int(bits[2]) + int(bits[4]), int(bits[4])))
130
-
131
- if isinstance(value, list):
132
- return value
133
- if isinstance(value, tuple):
134
- return value
135
- if isinstance(value, int):
136
- return [value]
137
-
138
- raise ValueError(f"Cannot make list from {value}")
139
-
140
-
141
99
  def normalize_and_check_dates(
142
100
  dates: list[datetime.datetime],
143
101
  start: datetime.datetime,
@@ -293,21 +293,29 @@ class ComplementNearest(Complement):
293
293
  index, previous = update_tuple(index, variable_index, slice(None))
294
294
  source_index = [self._source.name_to_index[x] for x in self.variables[previous]]
295
295
  source_data = self._source[index[0], source_index, index[2], ...]
296
- target_data = source_data[..., self._nearest_grid_points]
297
-
298
- epsilon = 1e-8 # prevent division by zero
299
- weights = 1.0 / (self._distances + epsilon)
300
- weights = weights.astype(target_data.dtype)
301
- weights /= weights.sum(axis=1, keepdims=True) # normalize
302
-
303
- # Reshape weights to broadcast correctly
304
- # Add leading singleton dimensions so it matches target_data shape
305
- while weights.ndim < target_data.ndim:
306
- weights = np.expand_dims(weights, axis=0)
307
-
308
- # Compute weighted average along the last dimension
309
- final_point = np.sum(target_data * weights, axis=-1)
310
- result = final_point[..., index[3]]
296
+ if any(self._nearest_grid_points >= source_data.shape[-1]):
297
+ target_shape = source_data.shape[:-1] + self._target.shape[-1:]
298
+ target_data = np.full(target_shape, np.nan, dtype=self._target.dtype)
299
+ cond = self._nearest_grid_points < source_data.shape[-1]
300
+ reachable = np.where(cond)[0]
301
+ nearest_reachable = self._nearest_grid_points[cond]
302
+ target_data[..., reachable] = source_data[..., nearest_reachable]
303
+ result = target_data[..., index[3]]
304
+ else:
305
+ target_data = source_data[..., self._nearest_grid_points]
306
+ epsilon = 1e-8 # prevent division by zero
307
+ weights = 1.0 / (self._distances + epsilon)
308
+ weights = weights.astype(target_data.dtype)
309
+ weights /= weights.sum(axis=1, keepdims=True) # normalize
310
+
311
+ # Reshape weights to broadcast correctly
312
+ # Add leading singleton dimensions so it matches target_data shape
313
+ while weights.ndim < target_data.ndim:
314
+ weights = np.expand_dims(weights, axis=0)
315
+
316
+ # Compute weighted average along the last dimension
317
+ final_point = np.sum(target_data * weights, axis=-1)
318
+ result = final_point[..., index[3]]
311
319
 
312
320
  return apply_index_to_slices_changes(result, changes)
313
321
 
@@ -353,8 +361,9 @@ def complement_factory(args: tuple, kwargs: dict) -> Dataset:
353
361
  }[interpolation]
354
362
 
355
363
  if interpolation == "nearest":
356
- k = kwargs.pop("k", "1")
357
- complement = Class(target=target, source=source, k=k)._subset(**kwargs)
364
+ k = kwargs.pop("k", 1)
365
+ max_distance = kwargs.pop("max_distance", None)
366
+ complement = Class(target=target, source=source, k=k, max_distance=max_distance)._subset(**kwargs)
358
367
 
359
368
  else:
360
369
  complement = Class(target=target, source=source)._subset(**kwargs)
@@ -245,6 +245,12 @@ class Dataset(ABC, Sized):
245
245
 
246
246
  return Statistics(self, open_dataset(statistics))._subset(**kwargs).mutate()
247
247
 
248
+ if "mask" in kwargs:
249
+ from .masked import Masking
250
+
251
+ mask_file = kwargs.pop("mask")
252
+ return Masking(self, mask_file)._subset(**kwargs).mutate()
253
+
248
254
  # Note: trim_edge should go before thinning
249
255
  if "trim_edge" in kwargs:
250
256
  from .masked import TrimEdge
@@ -293,6 +299,12 @@ class Dataset(ABC, Sized):
293
299
  if skip_missing_dates:
294
300
  return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate()
295
301
 
302
+ if "rolling_average" in kwargs:
303
+ from .rolling_average import RollingAverage
304
+
305
+ rolling_average = kwargs.pop("rolling_average")
306
+ return RollingAverage(self, rolling_average)._subset(**kwargs).mutate()
307
+
296
308
  if "interpolate_frequency" in kwargs:
297
309
  from .interpolate import InterpolateFrequency
298
310
 
@@ -21,167 +21,15 @@ from .dataset import FullIndex
21
21
  from .dataset import Shape
22
22
  from .dataset import TupleIndex
23
23
  from .debug import Node
24
- from .debug import debug_indexing
25
- from .forwards import Combined
26
24
  from .forwards import GivenAxis
27
25
  from .indexing import apply_index_to_slices_changes
28
- from .indexing import expand_list_indexing
29
26
  from .indexing import index_to_slices
30
- from .indexing import length_to_slices
31
- from .indexing import update_tuple
32
27
  from .misc import _auto_adjust
33
28
  from .misc import _open
34
29
 
35
30
  LOG = logging.getLogger(__name__)
36
31
 
37
32
 
38
- class Concat(Combined):
39
- """A class to represent concatenated datasets."""
40
-
41
- def __len__(self) -> int:
42
- """Returns the total length of the concatenated datasets.
43
-
44
- Returns
45
- -------
46
- int
47
- Total length of the concatenated datasets.
48
- """
49
- return sum(len(i) for i in self.datasets)
50
-
51
- @debug_indexing
52
- @expand_list_indexing
53
- def _get_tuple(self, index: TupleIndex) -> NDArray[Any]:
54
- """Retrieves a tuple of data from the concatenated datasets based on the given index.
55
-
56
- Parameters
57
- ----------
58
- index : TupleIndex
59
- Index specifying the data to retrieve.
60
-
61
- Returns
62
- -------
63
- NDArray[Any]
64
- Concatenated data array from the specified index.
65
- """
66
- index, changes = index_to_slices(index, self.shape)
67
- # print(index, changes)
68
- lengths = [d.shape[0] for d in self.datasets]
69
- slices = length_to_slices(index[0], lengths)
70
- # print("slies", slices)
71
- result = [d[update_tuple(index, 0, i)[0]] for (d, i) in zip(self.datasets, slices) if i is not None]
72
- result = np.concatenate(result, axis=0)
73
- return apply_index_to_slices_changes(result, changes)
74
-
75
- @debug_indexing
76
- def __getitem__(self, n: FullIndex) -> NDArray[Any]:
77
- """Retrieves data from the concatenated datasets based on the given index.
78
-
79
- Parameters
80
- ----------
81
- n : FullIndex
82
- Index specifying the data to retrieve.
83
-
84
- Returns
85
- -------
86
- NDArray[Any]
87
- Data array from the concatenated datasets based on the index.
88
- """
89
- if isinstance(n, tuple):
90
- return self._get_tuple(n)
91
-
92
- if isinstance(n, slice):
93
- return self._get_slice(n)
94
-
95
- # TODO: optimize
96
- k = 0
97
- while n >= self.datasets[k]._len:
98
- n -= self.datasets[k]._len
99
- k += 1
100
- return self.datasets[k][n]
101
-
102
- @debug_indexing
103
- def _get_slice(self, s: slice) -> NDArray[Any]:
104
- """Retrieves a slice of data from the concatenated datasets.
105
-
106
- Parameters
107
- ----------
108
- s : slice
109
- Slice object specifying the range of data to retrieve.
110
-
111
- Returns
112
- -------
113
- NDArray[Any]
114
- Concatenated data array from the specified slice.
115
- """
116
- result = []
117
-
118
- lengths = [d.shape[0] for d in self.datasets]
119
- slices = length_to_slices(s, lengths)
120
-
121
- result = [d[i] for (d, i) in zip(self.datasets, slices) if i is not None]
122
-
123
- return np.concatenate(result)
124
-
125
- def check_compatibility(self, d1: Dataset, d2: Dataset) -> None:
126
- """Check the compatibility of two datasets for concatenation.
127
-
128
- Parameters
129
- ----------
130
- d1 : Dataset
131
- The first dataset.
132
- d2 : Dataset
133
- The second dataset.
134
- """
135
- super().check_compatibility(d1, d2)
136
- self.check_same_sub_shapes(d1, d2, drop_axis=0)
137
-
138
- def check_same_lengths(self, d1: Dataset, d2: Dataset) -> None:
139
- """Check if the lengths of two datasets are the same.
140
-
141
- Parameters
142
- ----------
143
- d1 : Dataset
144
- The first dataset.
145
- d2 : Dataset
146
- The second dataset.
147
- """
148
- # Turned off because we are concatenating along the first axis
149
- pass
150
-
151
- def check_same_dates(self, d1: Dataset, d2: Dataset) -> None:
152
- """Check if the dates of two datasets are the same.
153
-
154
- Parameters
155
- ----------
156
- d1 : Dataset
157
- The first dataset.
158
- d2 : Dataset
159
- The second dataset.
160
- """
161
- # Turned off because we are concatenating along the dates axis
162
- pass
163
-
164
- @property
165
- def dates(self) -> NDArray[np.datetime64]:
166
- """Returns the concatenated dates of all datasets."""
167
- return np.concatenate([d.dates for d in self.datasets])
168
-
169
- @property
170
- def shape(self) -> Shape:
171
- """Returns the shape of the concatenated datasets."""
172
- return (len(self),) + self.datasets[0].shape[1:]
173
-
174
- def tree(self) -> Node:
175
- """Generates a hierarchical tree structure for the concatenated datasets.
176
-
177
- Returns
178
- -------
179
- Node
180
- A Node object representing the concatenated datasets.
181
- """
182
- return Node(self, [d.tree() for d in self.datasets])
183
-
184
-
185
33
  class GridsBase(GivenAxis):
186
34
  """A base class for handling grids in datasets."""
187
35
 
@@ -10,6 +10,7 @@
10
10
 
11
11
  import logging
12
12
  from functools import cached_property
13
+ from pathlib import Path
13
14
  from typing import Any
14
15
 
15
16
  import numpy as np
@@ -66,6 +67,12 @@ class Masked(Forwards):
66
67
  """Get the masked longitudes."""
67
68
  return self.forward.longitudes[self.mask]
68
69
 
70
+ @property
71
+ def grids(self) -> TupleIndex:
72
+ """Returns the number of grid points after masking"""
73
+ grids = np.sum(self.mask)
74
+ return (grids,)
75
+
69
76
  @debug_indexing
70
77
  def __getitem__(self, index: FullIndex) -> NDArray[Any]:
71
78
  """Get the masked data at the specified index.
@@ -150,19 +157,9 @@ class Thinning(Masked):
150
157
  if len(shape) != 2:
151
158
  raise ValueError("Thinning only works latitude/longitude fields")
152
159
 
153
- # Make a copy, so we read the data only once from zarr
154
- forward_latitudes = forward.latitudes.copy()
155
- forward_longitudes = forward.longitudes.copy()
156
-
157
- latitudes = forward_latitudes.reshape(shape)
158
- longitudes = forward_longitudes.reshape(shape)
159
- latitudes = latitudes[::thinning, ::thinning].flatten()
160
- longitudes = longitudes[::thinning, ::thinning].flatten()
161
-
162
- # TODO: This is not very efficient
163
-
164
- mask = [lat in latitudes and lon in longitudes for lat, lon in zip(forward_latitudes, forward_longitudes)]
165
- mask = np.array(mask, dtype=bool)
160
+ mask = np.full(shape, False, dtype=bool)
161
+ mask[::thinning, ::thinning] = True
162
+ mask = mask.flatten()
166
163
  else:
167
164
  mask = None
168
165
 
@@ -200,6 +197,70 @@ class Thinning(Masked):
200
197
  """
201
198
  return dict(thinning=self.thinning, method=self.method)
202
199
 
200
+ @property
201
+ def field_shape(self) -> Shape:
202
+ """Returns the field shape of the dataset."""
203
+ if self.thinning is None:
204
+ return self.forward.field_shape
205
+ x, y = self.forward.field_shape
206
+ x = (x + self.thinning - 1) // self.thinning
207
+ y = (y + self.thinning - 1) // self.thinning
208
+ return x, y
209
+
210
+
211
+ class Masking(Masked):
212
+ """A class that applies a precomputed boolean mask from a .npy file."""
213
+
214
+ def __init__(self, forward: Dataset, mask_file: str) -> None:
215
+ """Initialize the Masking class.
216
+
217
+ Parameters
218
+ ----------
219
+ forward : Dataset
220
+ The dataset to be masked.
221
+ mask_file : str
222
+ Path to a .npy file containing a boolean mask of same shape as fields.
223
+ """
224
+ self.mask_file = mask_file
225
+
226
+ # Check path
227
+ if not Path(self.mask_file).exists():
228
+ raise FileNotFoundError(f"Mask file not found: {self.mask_file}")
229
+ # Load mask
230
+ try:
231
+ mask = np.load(self.mask_file)
232
+ except Exception as e:
233
+ raise ValueError(f"Could not load data from {mask_file}: {e}")
234
+
235
+ if mask.dtype != bool:
236
+ raise ValueError(f"Mask file {mask_file} does not contain boolean values.")
237
+ if mask.shape != forward.field_shape:
238
+ raise ValueError(f"Mask length {mask.shape} does not match field size {forward.field_shape}.")
239
+ if sum(mask) == 0:
240
+ LOG.warning(f"Mask in {mask_file} eliminates all points in field.")
241
+
242
+ super().__init__(forward, mask)
243
+
244
+ def tree(self) -> Node:
245
+ """Get the tree representation of the dataset.
246
+
247
+ Returns
248
+ -------
249
+ Node
250
+ The tree representation of the dataset.
251
+ """
252
+ return Node(self, [self.forward.tree()], mask_file=self.mask_file)
253
+
254
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
255
+ """Get the metadata specific to the Masking subclass.
256
+
257
+ Returns
258
+ -------
259
+ Dict[str, Any]
260
+ The metadata specific to the Masking subclass.
261
+ """
262
+ return dict(mask_file=self.mask_file)
263
+
203
264
 
204
265
  class Cropping(Masked):
205
266
  """A class to represent a cropped dataset."""
@@ -440,3 +440,8 @@ class MissingDataset(Forwards):
440
440
  Metadata specific to the subclass.
441
441
  """
442
442
  return {"start": self.start, "end": self.end}
443
+
444
+ @property
445
+ def shape(self) -> tuple[int, ...]:
446
+ """Return the shape of the dataset."""
447
+ return (len(self),) + self.forward.shape[1:]
@@ -0,0 +1,141 @@
1
+ # (C) Copyright 2025 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ import logging
12
+ from functools import cached_property
13
+ from typing import Any
14
+
15
+ import numpy as np
16
+ from numpy.typing import NDArray
17
+
18
+ from anemoi.datasets.data.indexing import expand_list_indexing
19
+
20
+ from .dataset import Dataset
21
+ from .dataset import FullIndex
22
+ from .debug import Node
23
+ from .debug import debug_indexing
24
+ from .forwards import Forwards
25
+
26
+ LOG = logging.getLogger(__name__)
27
+
28
+
29
+ class RollingAverage(Forwards):
30
+ """A class to represent a dataset with interpolated frequency."""
31
+
32
+ def __init__(self, dataset: Dataset, window: str | tuple[int, int, str]) -> None:
33
+ """Initialize the RollingAverage class.
34
+
35
+ Parameters
36
+ ----------
37
+ dataset : Dataset
38
+ The dataset to be averaged with a rolling window.
39
+ window : (int, int, str)
40
+ The rolling average window (start, end, 'freq').
41
+ 'freq' means the window is in number of time steps in the dataset.
42
+ Both start and end are inclusive, i.e. window = (-2, 2, 'freq') means a window of 5 time steps.
43
+ For now, only 'freq' is supported, in the future other units may be supported.
44
+ Windows such as "[-2h, +2h]" are not supported yet.
45
+ """
46
+ super().__init__(dataset)
47
+ if not (isinstance(window, (list, tuple)) and len(window) == 3):
48
+ raise ValueError(f"Window must be (int, int, str), got {window}")
49
+ if not isinstance(window[0], int) or not isinstance(window[1], int) or not isinstance(window[2], str):
50
+ raise ValueError(f"Window must be (int, int, str), got {window}")
51
+ if window[2] not in ["freq", "frequency"]:
52
+ raise NotImplementedError(f"Window must be (int, int, 'freq'), got {window}")
53
+
54
+ # window = (0, 0, 'freq') means no change
55
+ self.i_start = -window[0]
56
+ self.i_end = window[1] + 1
57
+ if self.i_start <= 0:
58
+ raise ValueError(f"Window start must be negative, got {window}")
59
+ if self.i_end <= 0:
60
+ raise ValueError(f"Window end must be positive, got {window}")
61
+
62
+ self.window_str = f"-{self.i_start}-to-{self.i_end}"
63
+
64
+ @property
65
+ def shape(self):
66
+ shape = list(self.forward.shape)
67
+ shape[0] = len(self)
68
+ return tuple(shape)
69
+
70
+ @debug_indexing
71
+ @expand_list_indexing
72
+ def __getitem__(self, n: FullIndex) -> NDArray[Any]:
73
+ def f(array):
74
+ return np.nanmean(array, axis=0)
75
+
76
+ if isinstance(n, slice):
77
+ n = (n,)
78
+
79
+ if isinstance(n, tuple):
80
+ first = n[0]
81
+ if len(n) > 1:
82
+ rest = n[1:]
83
+ else:
84
+ rest = ()
85
+
86
+ if isinstance(first, int):
87
+ slice_ = slice(first, first + self.i_start + self.i_end)
88
+ data = self.forward[(slice_,) + rest]
89
+ return f(data)
90
+
91
+ if isinstance(first, slice):
92
+ first = list(range(first.start or 0, first.stop or len(self), first.step or 1))
93
+
94
+ if isinstance(first, (list, tuple)):
95
+ first = [i if i >= 0 else len(self) + i for i in first]
96
+ if any(i >= len(self) for i in first):
97
+ raise IndexError(f"Index out of range: {first}")
98
+ slices = [slice(i, i + self.i_start + self.i_end) for i in first]
99
+ data = [self.forward[(slice_,) + rest] for slice_ in slices]
100
+ res = [f(d) for d in data]
101
+ return np.array(res)
102
+
103
+ assert False, f"Expected int, slice, list or tuple as first element of tuple, got {type(first)}"
104
+
105
+ assert isinstance(n, int), f"Expected int, slice, tuple, got {type(n)}"
106
+
107
+ if n < 0:
108
+ n = len(self) + n
109
+ if n >= len(self):
110
+ raise IndexError(f"Index out of range: {n}")
111
+
112
+ slice_ = slice(n, n + self.i_start + self.i_end)
113
+ data = self.forward[slice_]
114
+ return f(data)
115
+
116
+ def __len__(self) -> int:
117
+ return len(self.forward) - (self.i_end + self.i_start - 1)
118
+
119
+ @cached_property
120
+ def dates(self) -> NDArray[np.datetime64]:
121
+ """Get the interpolated dates."""
122
+ dates = self.forward.dates
123
+ return dates[self.i_start : len(dates) - self.i_end + 1]
124
+
125
+ def tree(self) -> Node:
126
+ return Node(self, [self.forward.tree()], window=self.window_str)
127
+
128
+ @cached_property
129
+ def missing(self) -> set[int]:
130
+ """Get the missing data indices."""
131
+ result = []
132
+
133
+ for i in self.forward.missing:
134
+ for j in range(0, self.i_end + self.i_start):
135
+ result.append(i + j)
136
+
137
+ result = {x for x in result if x < self._len}
138
+ return result
139
+
140
+ def forwards_subclass_metadata_specific(self) -> dict[str, Any]:
141
+ return {}
@@ -85,22 +85,20 @@ class S3Store(ReadOnlyStore):
85
85
  options using the anemoi configs.
86
86
  """
87
87
 
88
- def __init__(self, url: str, region: str | None = None) -> None:
89
- """Initialize the S3Store with a URL and optional region."""
90
- from anemoi.utils.remote.s3 import s3_client
88
+ def __init__(self, url: str) -> None:
89
+ """Initialize the S3Store with a URL."""
91
90
 
92
- _, _, self.bucket, self.key = url.split("/", 3)
93
- self.s3 = s3_client(self.bucket, region=region)
91
+ self.url = url
94
92
 
95
93
  def __getitem__(self, key: str) -> bytes:
96
94
  """Retrieve an item from the store."""
95
+ from anemoi.utils.remote.s3 import get_object
96
+
97
97
  try:
98
- response = self.s3.get_object(Bucket=self.bucket, Key=self.key + "/" + key)
99
- except self.s3.exceptions.NoSuchKey:
98
+ return get_object(os.path.join(self.url, key))
99
+ except FileNotFoundError:
100
100
  raise KeyError(key)
101
101
 
102
- return response["Body"].read()
103
-
104
102
 
105
103
  class DebugStore(ReadOnlyStore):
106
104
  """A store to debug the zarr loading."""
@@ -58,6 +58,8 @@ def extend(x: str | list[Any] | tuple[Any, ...]) -> Iterator[datetime.datetime]:
58
58
  class DatesProvider:
59
59
  """Base class for date generation.
60
60
 
61
+ Examples
62
+ --------
61
63
  >>> DatesProvider.from_config(**{"start": "2023-01-01 00:00", "end": "2023-01-02 00:00", "frequency": "1d"}).values
62
64
  [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 2, 0, 0)]
63
65
 
@@ -0,0 +1,76 @@
1
+ # (C) Copyright 2025 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+ import datetime
11
+ import io
12
+ import logging
13
+
14
+ import ruamel.yaml
15
+
16
+ LOG = logging.getLogger(__name__)
17
+
18
+
19
+ def represent_date(dumper, data):
20
+
21
+ if isinstance(data, datetime.datetime):
22
+ if data.tzinfo is None:
23
+ data = data.replace(tzinfo=datetime.timezone.utc)
24
+ data = data.astimezone(datetime.timezone.utc)
25
+ iso_str = data.replace(tzinfo=None).isoformat(timespec="seconds") + "Z"
26
+ else:
27
+ iso_str = data.isoformat()
28
+
29
+ return dumper.represent_scalar("tag:yaml.org,2002:timestamp", iso_str)
30
+
31
+
32
+ # --- Represent multiline strings with | style ---
33
+ def represent_multiline_str(dumper, data):
34
+ if "\n" in data:
35
+ return dumper.represent_scalar("tag:yaml.org,2002:str", data.strip(), style="|")
36
+ return dumper.represent_scalar("tag:yaml.org,2002:str", data)
37
+
38
+
39
+ # --- Represent short lists inline (flow style) ---
40
+ def represent_inline_list(dumper, data):
41
+
42
+ if not all(isinstance(i, (str, int, float, bool, type(None))) for i in data):
43
+ return dumper.represent_sequence("tag:yaml.org,2002:seq", data)
44
+
45
+ return dumper.represent_sequence("tag:yaml.org,2002:seq", data, flow_style=True)
46
+
47
+
48
+ def yaml_dump(obj, order=None, stream=None, **kwargs):
49
+
50
+ if order:
51
+
52
+ def _ordering(k):
53
+ return order.index(k) if k in order else len(order)
54
+
55
+ obj = {k: v for k, v in sorted(obj.items(), key=lambda item: _ordering(item[0]))}
56
+
57
+ yaml = ruamel.yaml.YAML()
58
+ yaml.width = 120 # wrap long flow sequences
59
+
60
+ yaml.Representer.add_representer(datetime.date, represent_date)
61
+ yaml.Representer.add_representer(datetime.datetime, represent_date)
62
+ yaml.Representer.add_representer(str, represent_multiline_str)
63
+ yaml.Representer.add_representer(list, represent_inline_list)
64
+
65
+ data = ruamel.yaml.comments.CommentedMap()
66
+ for i, (k, v) in enumerate(obj.items()):
67
+ data[k] = v
68
+ if i > 0:
69
+ data.yaml_set_comment_before_after_key(key=k, before="\n")
70
+
71
+ if stream:
72
+ yaml.dump(data, stream=stream, **kwargs)
73
+
74
+ stream = io.StringIO()
75
+ yaml.dump(data, stream=stream, **kwargs)
76
+ return stream.getvalue()