anemoi-datasets 0.5.24__py3-none-any.whl → 0.5.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/finalise-additions.py +2 -1
  3. anemoi/datasets/commands/finalise.py +2 -1
  4. anemoi/datasets/commands/grib-index.py +1 -1
  5. anemoi/datasets/commands/init-additions.py +2 -1
  6. anemoi/datasets/commands/load-additions.py +2 -1
  7. anemoi/datasets/commands/load.py +2 -1
  8. anemoi/datasets/create/__init__.py +24 -33
  9. anemoi/datasets/create/filter.py +22 -24
  10. anemoi/datasets/create/input/__init__.py +0 -20
  11. anemoi/datasets/create/input/step.py +2 -16
  12. anemoi/datasets/create/sources/accumulations.py +7 -6
  13. anemoi/datasets/create/sources/planetary_computer.py +44 -0
  14. anemoi/datasets/create/sources/xarray_support/__init__.py +6 -22
  15. anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -0
  16. anemoi/datasets/create/sources/xarray_support/field.py +1 -4
  17. anemoi/datasets/create/sources/xarray_support/flavour.py +44 -6
  18. anemoi/datasets/create/sources/xarray_support/patch.py +44 -1
  19. anemoi/datasets/create/sources/xarray_support/variable.py +6 -2
  20. anemoi/datasets/data/complement.py +44 -10
  21. anemoi/datasets/data/dataset.py +29 -0
  22. anemoi/datasets/data/forwards.py +8 -2
  23. anemoi/datasets/data/misc.py +74 -16
  24. anemoi/datasets/data/observations/__init__.py +316 -0
  25. anemoi/datasets/data/observations/legacy_obs_dataset.py +200 -0
  26. anemoi/datasets/data/observations/multi.py +64 -0
  27. anemoi/datasets/data/padded.py +227 -0
  28. anemoi/datasets/data/records/__init__.py +442 -0
  29. anemoi/datasets/data/records/backends/__init__.py +157 -0
  30. anemoi/datasets/data/stores.py +7 -56
  31. anemoi/datasets/data/subset.py +5 -0
  32. anemoi/datasets/grids.py +6 -3
  33. {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/METADATA +3 -2
  34. {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/RECORD +38 -51
  35. {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/WHEEL +1 -1
  36. anemoi/datasets/create/filters/__init__.py +0 -33
  37. anemoi/datasets/create/filters/empty.py +0 -37
  38. anemoi/datasets/create/filters/legacy.py +0 -93
  39. anemoi/datasets/create/filters/noop.py +0 -37
  40. anemoi/datasets/create/filters/orog_to_z.py +0 -58
  41. anemoi/datasets/create/filters/pressure_level_relative_humidity_to_specific_humidity.py +0 -83
  42. anemoi/datasets/create/filters/pressure_level_specific_humidity_to_relative_humidity.py +0 -84
  43. anemoi/datasets/create/filters/rename.py +0 -205
  44. anemoi/datasets/create/filters/rotate_winds.py +0 -105
  45. anemoi/datasets/create/filters/single_level_dewpoint_to_relative_humidity.py +0 -78
  46. anemoi/datasets/create/filters/single_level_relative_humidity_to_dewpoint.py +0 -84
  47. anemoi/datasets/create/filters/single_level_relative_humidity_to_specific_humidity.py +0 -163
  48. anemoi/datasets/create/filters/single_level_specific_humidity_to_relative_humidity.py +0 -451
  49. anemoi/datasets/create/filters/speeddir_to_uv.py +0 -95
  50. anemoi/datasets/create/filters/sum.py +0 -68
  51. anemoi/datasets/create/filters/transform.py +0 -51
  52. anemoi/datasets/create/filters/unrotate_winds.py +0 -105
  53. anemoi/datasets/create/filters/uv_to_speeddir.py +0 -94
  54. anemoi/datasets/create/filters/wz_to_w.py +0 -98
  55. anemoi/datasets/create/testing.py +0 -76
  56. {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/entry_points.txt +0 -0
  57. {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/licenses/LICENSE +0 -0
  58. {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,316 @@
1
+ # (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ import datetime
9
+ import logging
10
+ import os
11
+ from functools import cached_property
12
+ from typing import Any
13
+ from typing import Dict
14
+ from typing import Tuple
15
+
16
+ import numpy as np
17
+ from anemoi.utils.dates import frequency_to_timedelta
18
+
19
+ from anemoi.datasets.data.dataset import Dataset
20
+
21
+ from ..debug import Node
22
+
23
+ LOG = logging.getLogger(__name__)
24
+
25
+
26
+ def round_datetime(dt, frequency, up=True):
27
+ dt = dt.replace(minute=0, second=0, microsecond=0)
28
+ hour = dt.hour
29
+ if hour % frequency != 0:
30
+ dt = dt.replace(hour=(hour // frequency) * frequency)
31
+ dt = dt + datetime.timedelta(hours=frequency)
32
+ return dt
33
+
34
+
35
+ def make_dates(start, end, frequency):
36
+ if isinstance(start, np.datetime64):
37
+ start = start.astype(datetime.datetime)
38
+ if isinstance(end, np.datetime64):
39
+ end = end.astype(datetime.datetime)
40
+
41
+ dates = []
42
+ current_date = start
43
+ while current_date <= end:
44
+ dates.append(current_date)
45
+ current_date += frequency
46
+
47
+ dates = [np.datetime64(d, "s") for d in dates]
48
+ dates = np.array(dates, dtype="datetime64[s]")
49
+ return dates
50
+
51
+
52
+ class ObservationsBase(Dataset):
53
+ resolution = None
54
+
55
+ @cached_property
56
+ def shape(self):
57
+ return (len(self.dates), len(self.variables), "dynamic")
58
+
59
+ def empty_item(self):
60
+ return np.full(self.shape[1:-1] + (0,), 0.0, dtype=np.float32)
61
+
62
+ def metadata(self):
63
+ return dict(observations_datasets="obs datasets currenty have no metadata")
64
+
65
+ def _check(self):
66
+ pass
67
+
68
+ def __len__(self):
69
+ return len(self.dates)
70
+
71
+ def tree(self):
72
+ return Node(self)
73
+
74
+ def __getitem__(self, i):
75
+ if isinstance(i, int):
76
+ return self.getitem(i)
77
+
78
+ # The following may would work but is likely to change in the future
79
+ # if isinstance(i, slice):
80
+ # return [self.getitem(j) for j in range(int(slice.start), int(slice.stop))]
81
+ # if isinstance(i, list):
82
+ # return [self.getitem(j) for j in i]
83
+
84
+ raise ValueError(
85
+ (
86
+ f"Expected int, got {i} of type {type(i)}. Only int is supported to index "
87
+ "observations datasets. Please use a second [] to select part of the data [i][a,b,c]"
88
+ )
89
+ )
90
+
91
+ @property
92
+ def variables(self):
93
+ raise NotImplementedError()
94
+
95
+ def collect_input_sources(self):
96
+ LOG.warning("collect_input_sources method is not implemented")
97
+ return []
98
+
99
+ def constant_fields(self):
100
+ LOG.warning("constant_fields method is not implemented")
101
+ return []
102
+
103
+ @property
104
+ def dates(self):
105
+ return self._dates
106
+
107
+ @property
108
+ def dtype(self):
109
+ return np.float32
110
+
111
+ @property
112
+ def field_shape(self):
113
+ return self.shape[1:]
114
+
115
+ @property
116
+ def frequency(self):
117
+ assert isinstance(self._frequency, datetime.timedelta), f"Expected timedelta, got {type(self._frequency)}"
118
+ return self._frequency
119
+
120
+ @property
121
+ def latitudes(self):
122
+ raise NotImplementedError("latitudes property is not implemented")
123
+
124
+ @property
125
+ def longitudes(self):
126
+ raise NotImplementedError("longitudes property is not implemented")
127
+
128
+ @property
129
+ def missing(self):
130
+ return []
131
+
132
+ def statistics_tendencies(self):
133
+ raise NotImplementedError("statistics_tendencies method is not implemented")
134
+
135
+ def variables_metadata(self):
136
+ raise NotImplementedError("variables_metadata method is not implemented")
137
+
138
+
139
+ class ObservationsZarr(ObservationsBase):
140
+ def __init__(self, dataset, frequency=None, window=None):
141
+ import zarr
142
+
143
+ if isinstance(dataset, zarr.hierarchy.Group):
144
+ dataset = dataset._store.path
145
+
146
+ from ..stores import zarr_lookup
147
+
148
+ dataset = zarr_lookup(dataset)
149
+ self.path = dataset
150
+ assert self._probe_attributes["is_observations"], f"Expected observations dataset, got {self.path}"
151
+
152
+ if frequency is None:
153
+ frequency = self._probe_attributes.get("frequency")
154
+ # LOG.warning(f"Frequency not provided, using the one from the dataset: {frequency}")
155
+ if frequency is None:
156
+ frequency = "6h"
157
+ # LOG.warning(f"Frequency not provided in the dataset, using the default : {frequency}")
158
+ self._frequency = frequency_to_timedelta(frequency)
159
+ assert self.frequency.total_seconds() % 3600 == 0, f"Expected multiple of 3600, got {self.frequency}"
160
+ if self.frequency.total_seconds() != 6 * 3600:
161
+ LOG.warning("Frequency is not 6h, this has not been tested, behaviour is unknown")
162
+
163
+ frequency_hours = int(self.frequency.total_seconds() // 3600)
164
+ assert isinstance(frequency_hours, int), f"Expected int, got {type(frequency_hours)}"
165
+
166
+ if window is None:
167
+ window = (-frequency_hours, 0)
168
+ if window != (-frequency_hours, 0):
169
+ raise ValueError("For now, only window = (- frequency, 0) are supported")
170
+
171
+ self.window = window
172
+
173
+ start, end = self._probe_attributes["start_date"], self._probe_attributes["end_date"]
174
+ start, end = datetime.datetime.fromisoformat(start), datetime.datetime.fromisoformat(end)
175
+ start, end = round_datetime(start, frequency_hours), round_datetime(end, frequency_hours)
176
+
177
+ self._dates = make_dates(start + self.frequency, end, self.frequency)
178
+
179
+ first_window_begin = start.strftime("%Y%m%d%H%M%S")
180
+ first_window_begin = int(first_window_begin)
181
+ # last_window_end must be the end of the time window of the last item
182
+ last_window_end = int(end.strftime("%Y%m%d%H%M%S"))
183
+
184
+ from .legacy_obs_dataset import ObsDataset
185
+
186
+ args = [self.path, first_window_begin, last_window_end]
187
+ kwargs = dict(
188
+ len_hrs=frequency_hours, # length the time windows, i.e. the time span of one item
189
+ step_hrs=frequency_hours, # frequency of the dataset, i.e. the time shift between two items
190
+ )
191
+ self.forward = ObsDataset(*args, **kwargs)
192
+
193
+ assert frequency_hours == self.forward.step_hrs, f"Expected {frequency_hours}, got {self.forward.len_hrs}"
194
+ assert frequency_hours == self.forward.len_hrs, f"Expected {frequency_hours}, got {self.forward.step_hrs}"
195
+
196
+ if len(self.forward) != len(self.dates):
197
+ raise ValueError(
198
+ (
199
+ f"Dates are not consistent with the number of items in the dataset. "
200
+ f"The dataset contains {len(self.forward)} time windows. "
201
+ f"This is not compatible with the "
202
+ f"{len(self.dates)} requested dates with frequency={frequency_hours}"
203
+ f"{self.dates[0]}, {self.dates[1]}, ..., {self.dates[-2]}, {self.dates[-1]} "
204
+ )
205
+ )
206
+
207
+ @property
208
+ def source(self):
209
+ return self.path
210
+
211
+ def get_dataset_names(self):
212
+ name = os.path.basename(self.path)
213
+ if name.endswith(".zarr"):
214
+ name = name[:-5]
215
+ return [name]
216
+
217
+ @cached_property
218
+ def _probe_attributes(self):
219
+ import zarr
220
+
221
+ z = zarr.open(self.path, mode="r")
222
+ return dict(z.data.attrs)
223
+
224
+ def get_aux(self, i):
225
+ data = self.forward[i]
226
+
227
+ latitudes = data[:, self.name_to_index["__latitudes"]].numpy()
228
+ longitudes = data[:, self.name_to_index["__longitudes"]].numpy()
229
+
230
+ reference = self.dates[i]
231
+ times = self.forward.get_dates(i)
232
+ if str(times.dtype) != "datetime64[s]":
233
+ LOG.warning(f"Expected np.datetime64[s], got {times.dtype}. ")
234
+ times = times.astype("datetime64[s]")
235
+ assert str(reference.dtype) == "datetime64[s]", f"Expected np.datetime64[s], got {type(reference)}"
236
+ timedeltas = times - reference
237
+
238
+ assert latitudes.shape == longitudes.shape, f"Expected {latitudes.shape}, got {longitudes.shape}"
239
+ assert timedeltas.shape == latitudes.shape, f"Expected {timedeltas.shape}, got {latitudes.shape}"
240
+
241
+ return latitudes, longitudes, timedeltas
242
+
243
+ def getitem(self, i):
244
+ data = self.forward[i]
245
+
246
+ data = data.numpy().astype(np.float32)
247
+ assert len(data.shape) == 2, f"Expected 2D array, got {data.shape}"
248
+ data = data.T
249
+
250
+ if not data.size:
251
+ data = self.empty_item()
252
+ assert (
253
+ data.shape[0] == self.shape[1]
254
+ ), f"Data shape {data.shape} does not match {self.shape} : {data.shape[0]} != {self.shape[1]}"
255
+ return data
256
+
257
+ @cached_property
258
+ def variables(self):
259
+ colnames = self.forward.colnames
260
+ variables = []
261
+ for n in colnames:
262
+ if n.startswith("obsvalue_"):
263
+ n = n.replace("obsvalue_", "")
264
+ if n == "latitude" or n == "lat":
265
+ assert "latitudes" not in variables, f"Duplicate latitudes found in {variables}"
266
+ variables.append("__latitudes")
267
+ continue
268
+ if n == "longitude" or n == "lon":
269
+ assert "longitudes" not in variables, f"Duplicate longitudes found in {variables}"
270
+ variables.append("__longitudes")
271
+ continue
272
+ assert not n.startswith("__"), f"Invalid name {n} found in {colnames}"
273
+ variables.append(n)
274
+ return variables
275
+
276
+ @property
277
+ def name_to_index(self):
278
+ return {n: i for i, n in enumerate(self.variables)}
279
+
280
+ @property
281
+ def statistics(self):
282
+ mean = self.forward.properties["means"]
283
+ mean = np.array(mean, dtype=np.float32)
284
+
285
+ var = self.forward.properties["vars"]
286
+ var = np.array(var, dtype=np.float32)
287
+ stdev = np.sqrt(var)
288
+
289
+ minimum = np.array(self.forward.z.data.attrs["mins"], dtype=np.float32)
290
+ maximum = np.array(self.forward.z.data.attrs["maxs"], dtype=np.float32)
291
+
292
+ assert isinstance(mean, np.ndarray), f"Expected np.ndarray, got {type(mean)}"
293
+ assert isinstance(stdev, np.ndarray), f"Expected np.ndarray, got {type(stdev)}"
294
+ assert isinstance(minimum, np.ndarray), f"Expected np.ndarray, got {type(minimum)}"
295
+ assert isinstance(maximum, np.ndarray), f"Expected np.ndarray, got {type(maximum)}"
296
+ return dict(mean=mean, stdev=stdev, minimum=minimum, maximum=maximum)
297
+
298
+ def tree(self):
299
+ return Node(
300
+ self,
301
+ [],
302
+ path=self.path,
303
+ frequency=self.frequency,
304
+ )
305
+
306
+ def __repr__(self):
307
+ return f"Observations({os.path.basename(self.path)}, {self.dates[0]};{self.dates[-1]}, {len(self)})"
308
+
309
+
310
+ def observations_factory(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> ObservationsBase:
311
+ observations = kwargs.pop("observations")
312
+
313
+ if not isinstance(observations, dict):
314
+ observations = dict(dataset=observations)
315
+ dataset = ObservationsZarr(**observations)
316
+ return dataset._subset(**kwargs)
@@ -0,0 +1,200 @@
1
+ # (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ import datetime
9
+ import logging
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ import zarr
15
+ from torch.utils.data import Dataset
16
+
17
+ LOG = logging.getLogger(__name__)
18
+
19
+
20
+ class ObsDataset(Dataset):
21
+
22
+ def __init__(
23
+ self,
24
+ filename: str,
25
+ start: int,
26
+ end: int,
27
+ len_hrs: int,
28
+ step_hrs: int = None,
29
+ select: list[str] = None,
30
+ drop: list[str] = None,
31
+ ) -> None:
32
+
33
+ self.filename = filename
34
+ self.z = zarr.open(filename, mode="r")
35
+ self.data = self.z["data"]
36
+ self.dt = self.z["dates"] # datetime only
37
+ self.hrly_index = self.z["idx_197001010000_1"]
38
+ self.colnames = self.data.attrs["colnames"]
39
+ self.selected_colnames = self.colnames
40
+ self.selected_cols_idx = np.arange(len(self.colnames))
41
+ self.len_hrs = len_hrs
42
+ self.step_hrs = step_hrs if step_hrs else len_hrs
43
+
44
+ # Create index for samples
45
+ self._setup_sample_index(start, end, self.len_hrs, self.step_hrs)
46
+
47
+ self._load_properties()
48
+
49
+ if select:
50
+ self.select(select)
51
+
52
+ if drop:
53
+ self.drop(drop)
54
+
55
+ def __getitem__(
56
+ self,
57
+ idx: int,
58
+ ) -> torch.tensor:
59
+
60
+ start_row = self.indices_start[idx]
61
+ end_row = self.indices_end[idx]
62
+
63
+ data = self.data.oindex[start_row:end_row, self.selected_cols_idx]
64
+
65
+ return torch.from_numpy(data)
66
+
67
+ def __len__(self) -> int:
68
+
69
+ return len(self.indices_start)
70
+
71
+ def get_dates(
72
+ self,
73
+ idx: int,
74
+ ) -> np.ndarray:
75
+
76
+ start_row = self.indices_start[idx]
77
+ end_row = self.indices_end[idx]
78
+ dates = self.dt.oindex[start_row:end_row]
79
+
80
+ assert len(dates.shape) == 2, dates.shape
81
+ dates = dates[:, 0]
82
+
83
+ if len(dates) and dates[0].dtype != np.dtype("datetime64[s]"):
84
+ dates = dates.astype("datetime64[s]")
85
+
86
+ return dates
87
+
88
+ def get_df(self, idx: int) -> pd.DataFrame:
89
+ """Convenience function to return data for sample idx packaged in a pandas DataFrame"""
90
+
91
+ d = self.__getitem__(idx)
92
+
93
+ df = pd.DataFrame(data=d, columns=[self.colnames[i] for i in self.selected_cols_idx])
94
+
95
+ start_row = self.indices_start[idx]
96
+ end_row = self.indices_end[idx]
97
+
98
+ dts = self.dt[start_row:end_row, :]
99
+ df["datetime"] = dts
100
+
101
+ return df
102
+
103
+ def select(self, cols_list: list[str]) -> None:
104
+ """Allow user to specify which columns they want to access.
105
+ Get functions only returned for these specified columns.
106
+ """
107
+ self.selected_colnames = cols_list
108
+ self.selected_cols_idx = np.array([self.colnames.index(item) for item in cols_list])
109
+
110
+ def drop(self, cols_to_drop: list[str]) -> None:
111
+ """Allow user to drop specific columns from the dataset.
112
+ Get functions no longer return data for these columns after being set.
113
+ """
114
+ mask = [name not in cols_to_drop for name in self.selected_colnames]
115
+
116
+ self.selected_colnames = [name for name, keep in zip(self.selected_colnames, mask) if keep]
117
+ self.selected_cols_idx = self.selected_cols_idx[mask]
118
+
119
+ def time_window(self, idx: int) -> tuple[np.datetime64, np.datetime64]:
120
+ """Returns a tuple of datetime objects describing the start and end times of the sample at position idx."""
121
+
122
+ if idx < 0:
123
+ idx = len(self) + idx
124
+
125
+ time_start = self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs), seconds=1)
126
+ time_end = min(
127
+ self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs + self.len_hrs)),
128
+ self.end_dt,
129
+ )
130
+
131
+ return (np.datetime64(time_start), np.datetime64(time_end))
132
+
133
+ def first_sample_with_data(self) -> int:
134
+ """Returns the position of the first sample which contains data."""
135
+ return int(np.nonzero(self.indices_end)[0][0]) if self.indices_end.max() > 0 else None
136
+
137
+ def last_sample_with_data(self) -> int:
138
+ """Returns the position of the last sample which contains data."""
139
+ if self.indices_end.max() == 0:
140
+ last_sample = None
141
+ else:
142
+ last_sample = int(np.where(np.diff(np.append(self.indices_end, self.indices_end[-1])) > 0)[0][-1] + 1)
143
+
144
+ return last_sample
145
+
146
+ def _setup_sample_index(self, start: int, end: int, len_hrs: int, step_hrs: int) -> None:
147
+ """Dataset is divided into samples;
148
+ - each n_hours long
149
+ - sample 0 starts at start (yyyymmddhhmm)
150
+ - index array has one entry for each sample; contains the index of the first row
151
+ containing data for that sample
152
+ """
153
+
154
+ try:
155
+ from obsdata.config import config
156
+
157
+ assert config.base_index_yyyymmddhhmm == 197001010000, "base_index_yyyymmddhhmm must be 197001010000"
158
+ except ImportError:
159
+ pass
160
+ base_yyyymmddhhmm = 197001010000
161
+
162
+ assert start > base_yyyymmddhhmm, (
163
+ f"Abort: ObsDataset sample start (yyyymmddhhmm) must be greater than {base_yyyymmddhhmm}\n"
164
+ f" Current value: {start}"
165
+ )
166
+
167
+ format_str = "%Y%m%d%H%M%S"
168
+ base_dt = datetime.datetime.strptime(str(base_yyyymmddhhmm), format_str)
169
+ self.start_dt = datetime.datetime.strptime(str(start), format_str)
170
+ self.end_dt = datetime.datetime.strptime(str(end), format_str)
171
+
172
+ # Calculate hours since the base date for the requested dataset ranges
173
+ diff_in_hours_start = int((self.start_dt - base_dt).total_seconds() // 3600)
174
+ diff_in_hours_end = int((self.end_dt - base_dt).total_seconds() // 3600)
175
+
176
+ # Find elements that need to be extracted from the hourly index
177
+ # + ensuring that the dataset respects the requested end-hour even if it is mid-way through a sample
178
+ sample_starts = np.arange(diff_in_hours_start, diff_in_hours_end, step_hrs)
179
+ sample_ends = np.minimum(sample_starts + len_hrs, diff_in_hours_end)
180
+
181
+ # Initialize local index arrays
182
+ self.indices_start = np.zeros(sample_starts.shape, dtype=int)
183
+ self.indices_end = np.zeros(self.indices_start.shape, dtype=int)
184
+
185
+ max_hrly_index = self.hrly_index.shape[0] - 1
186
+ valid_start_mask = sample_starts <= max_hrly_index
187
+ valid_end_mask = (sample_ends > 0) & (sample_ends <= max_hrly_index)
188
+
189
+ # Copy elements from the hrly_index into the local index
190
+ self.indices_start[valid_start_mask] = self.hrly_index[sample_starts[valid_start_mask]]
191
+ self.indices_end[valid_end_mask] = np.maximum(self.hrly_index[sample_ends[valid_end_mask]], 0)
192
+
193
+ def _load_properties(self) -> None:
194
+
195
+ self.properties = {}
196
+
197
+ self.properties["means"] = self.data.attrs["means"]
198
+ self.properties["vars"] = self.data.attrs["vars"]
199
+ self.properties["data_idxs"] = self.data.attrs["data_idxs"]
200
+ self.properties["obs_id"] = self.data.attrs["obs_id"]
@@ -0,0 +1,64 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+ import logging
11
+ import os
12
+
13
+ from anemoi.datasets.data import open_dataset
14
+
15
+ LOG = logging.getLogger(__name__)
16
+
17
+
18
+ class LegacyDatasets:
19
+ def __init__(self, paths, start=None, end=None, **kwargs):
20
+ self.paths = paths
21
+
22
+ if not start or not end:
23
+ print(
24
+ "❌❌ Warning: start and end not provided, using the minima first and maximal last dates of the datasets"
25
+ )
26
+ lst = [self._open_dataset(p, **kwargs) for p in paths]
27
+ start = min([d.dates[0] for d in lst])
28
+ end = max([d.dates[-1] for d in lst])
29
+
30
+ self._datasets = {
31
+ os.path.basename(p).split(".")[0]: self._open_dataset(p, start=start, end=end, padding="empty")
32
+ for p in paths
33
+ }
34
+
35
+ first = list(self._datasets.values())[0]
36
+ for name, dataset in self._datasets.items():
37
+ if dataset.dates[0] != first.dates[0] or dataset.dates[-1] != first.dates[-1]:
38
+ raise ValueError("Datasets have different start and end times")
39
+ if dataset.frequency != first.frequency:
40
+ raise ValueError("Datasets have different frequencies")
41
+
42
+ self._keys = self._datasets.keys
43
+
44
+ self._first = list(self._datasets.values())[0]
45
+
46
+ def _open_dataset(self, p, **kwargs):
47
+ if p.startswith("observations-"):
48
+ return open_dataset(observations=p, **kwargs)
49
+ else:
50
+ print("❗ Opening non-observations dataset:", p)
51
+ return open_dataset(p, **kwargs)
52
+
53
+ def items(self):
54
+ return self._datasets.items()
55
+
56
+ @property
57
+ def dates(self):
58
+ return self._first.dates
59
+
60
+ def __len__(self):
61
+ return len(self._first)
62
+
63
+ def __getitem__(self, i):
64
+ return {k: d[i] for k, d in self._datasets.items()}