anemoi-datasets 0.5.24__py3-none-any.whl → 0.5.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/finalise-additions.py +2 -1
- anemoi/datasets/commands/finalise.py +2 -1
- anemoi/datasets/commands/grib-index.py +1 -1
- anemoi/datasets/commands/init-additions.py +2 -1
- anemoi/datasets/commands/load-additions.py +2 -1
- anemoi/datasets/commands/load.py +2 -1
- anemoi/datasets/create/__init__.py +24 -33
- anemoi/datasets/create/filter.py +22 -24
- anemoi/datasets/create/input/__init__.py +0 -20
- anemoi/datasets/create/input/step.py +2 -16
- anemoi/datasets/create/sources/accumulations.py +7 -6
- anemoi/datasets/create/sources/planetary_computer.py +44 -0
- anemoi/datasets/create/sources/xarray_support/__init__.py +6 -22
- anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -0
- anemoi/datasets/create/sources/xarray_support/field.py +1 -4
- anemoi/datasets/create/sources/xarray_support/flavour.py +44 -6
- anemoi/datasets/create/sources/xarray_support/patch.py +44 -1
- anemoi/datasets/create/sources/xarray_support/variable.py +6 -2
- anemoi/datasets/data/complement.py +44 -10
- anemoi/datasets/data/dataset.py +29 -0
- anemoi/datasets/data/forwards.py +8 -2
- anemoi/datasets/data/misc.py +74 -16
- anemoi/datasets/data/observations/__init__.py +316 -0
- anemoi/datasets/data/observations/legacy_obs_dataset.py +200 -0
- anemoi/datasets/data/observations/multi.py +64 -0
- anemoi/datasets/data/padded.py +227 -0
- anemoi/datasets/data/records/__init__.py +442 -0
- anemoi/datasets/data/records/backends/__init__.py +157 -0
- anemoi/datasets/data/stores.py +7 -56
- anemoi/datasets/data/subset.py +5 -0
- anemoi/datasets/grids.py +6 -3
- {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/METADATA +3 -2
- {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/RECORD +38 -51
- {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/WHEEL +1 -1
- anemoi/datasets/create/filters/__init__.py +0 -33
- anemoi/datasets/create/filters/empty.py +0 -37
- anemoi/datasets/create/filters/legacy.py +0 -93
- anemoi/datasets/create/filters/noop.py +0 -37
- anemoi/datasets/create/filters/orog_to_z.py +0 -58
- anemoi/datasets/create/filters/pressure_level_relative_humidity_to_specific_humidity.py +0 -83
- anemoi/datasets/create/filters/pressure_level_specific_humidity_to_relative_humidity.py +0 -84
- anemoi/datasets/create/filters/rename.py +0 -205
- anemoi/datasets/create/filters/rotate_winds.py +0 -105
- anemoi/datasets/create/filters/single_level_dewpoint_to_relative_humidity.py +0 -78
- anemoi/datasets/create/filters/single_level_relative_humidity_to_dewpoint.py +0 -84
- anemoi/datasets/create/filters/single_level_relative_humidity_to_specific_humidity.py +0 -163
- anemoi/datasets/create/filters/single_level_specific_humidity_to_relative_humidity.py +0 -451
- anemoi/datasets/create/filters/speeddir_to_uv.py +0 -95
- anemoi/datasets/create/filters/sum.py +0 -68
- anemoi/datasets/create/filters/transform.py +0 -51
- anemoi/datasets/create/filters/unrotate_winds.py +0 -105
- anemoi/datasets/create/filters/uv_to_speeddir.py +0 -94
- anemoi/datasets/create/filters/wz_to_w.py +0 -98
- anemoi/datasets/create/testing.py +0 -76
- {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/licenses/LICENSE +0 -0
- {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
# (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts.
|
|
2
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
3
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
4
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
5
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
|
+
# nor does it submit to any jurisdiction.
|
|
7
|
+
|
|
8
|
+
import datetime
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
from functools import cached_property
|
|
12
|
+
from typing import Any
|
|
13
|
+
from typing import Dict
|
|
14
|
+
from typing import Tuple
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
from anemoi.utils.dates import frequency_to_timedelta
|
|
18
|
+
|
|
19
|
+
from anemoi.datasets.data.dataset import Dataset
|
|
20
|
+
|
|
21
|
+
from ..debug import Node
|
|
22
|
+
|
|
23
|
+
LOG = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def round_datetime(dt, frequency, up=True):
|
|
27
|
+
dt = dt.replace(minute=0, second=0, microsecond=0)
|
|
28
|
+
hour = dt.hour
|
|
29
|
+
if hour % frequency != 0:
|
|
30
|
+
dt = dt.replace(hour=(hour // frequency) * frequency)
|
|
31
|
+
dt = dt + datetime.timedelta(hours=frequency)
|
|
32
|
+
return dt
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def make_dates(start, end, frequency):
|
|
36
|
+
if isinstance(start, np.datetime64):
|
|
37
|
+
start = start.astype(datetime.datetime)
|
|
38
|
+
if isinstance(end, np.datetime64):
|
|
39
|
+
end = end.astype(datetime.datetime)
|
|
40
|
+
|
|
41
|
+
dates = []
|
|
42
|
+
current_date = start
|
|
43
|
+
while current_date <= end:
|
|
44
|
+
dates.append(current_date)
|
|
45
|
+
current_date += frequency
|
|
46
|
+
|
|
47
|
+
dates = [np.datetime64(d, "s") for d in dates]
|
|
48
|
+
dates = np.array(dates, dtype="datetime64[s]")
|
|
49
|
+
return dates
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ObservationsBase(Dataset):
|
|
53
|
+
resolution = None
|
|
54
|
+
|
|
55
|
+
@cached_property
|
|
56
|
+
def shape(self):
|
|
57
|
+
return (len(self.dates), len(self.variables), "dynamic")
|
|
58
|
+
|
|
59
|
+
def empty_item(self):
|
|
60
|
+
return np.full(self.shape[1:-1] + (0,), 0.0, dtype=np.float32)
|
|
61
|
+
|
|
62
|
+
def metadata(self):
|
|
63
|
+
return dict(observations_datasets="obs datasets currenty have no metadata")
|
|
64
|
+
|
|
65
|
+
def _check(self):
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
def __len__(self):
|
|
69
|
+
return len(self.dates)
|
|
70
|
+
|
|
71
|
+
def tree(self):
|
|
72
|
+
return Node(self)
|
|
73
|
+
|
|
74
|
+
def __getitem__(self, i):
|
|
75
|
+
if isinstance(i, int):
|
|
76
|
+
return self.getitem(i)
|
|
77
|
+
|
|
78
|
+
# The following may would work but is likely to change in the future
|
|
79
|
+
# if isinstance(i, slice):
|
|
80
|
+
# return [self.getitem(j) for j in range(int(slice.start), int(slice.stop))]
|
|
81
|
+
# if isinstance(i, list):
|
|
82
|
+
# return [self.getitem(j) for j in i]
|
|
83
|
+
|
|
84
|
+
raise ValueError(
|
|
85
|
+
(
|
|
86
|
+
f"Expected int, got {i} of type {type(i)}. Only int is supported to index "
|
|
87
|
+
"observations datasets. Please use a second [] to select part of the data [i][a,b,c]"
|
|
88
|
+
)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def variables(self):
|
|
93
|
+
raise NotImplementedError()
|
|
94
|
+
|
|
95
|
+
def collect_input_sources(self):
|
|
96
|
+
LOG.warning("collect_input_sources method is not implemented")
|
|
97
|
+
return []
|
|
98
|
+
|
|
99
|
+
def constant_fields(self):
|
|
100
|
+
LOG.warning("constant_fields method is not implemented")
|
|
101
|
+
return []
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def dates(self):
|
|
105
|
+
return self._dates
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def dtype(self):
|
|
109
|
+
return np.float32
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def field_shape(self):
|
|
113
|
+
return self.shape[1:]
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def frequency(self):
|
|
117
|
+
assert isinstance(self._frequency, datetime.timedelta), f"Expected timedelta, got {type(self._frequency)}"
|
|
118
|
+
return self._frequency
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def latitudes(self):
|
|
122
|
+
raise NotImplementedError("latitudes property is not implemented")
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def longitudes(self):
|
|
126
|
+
raise NotImplementedError("longitudes property is not implemented")
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def missing(self):
|
|
130
|
+
return []
|
|
131
|
+
|
|
132
|
+
def statistics_tendencies(self):
|
|
133
|
+
raise NotImplementedError("statistics_tendencies method is not implemented")
|
|
134
|
+
|
|
135
|
+
def variables_metadata(self):
|
|
136
|
+
raise NotImplementedError("variables_metadata method is not implemented")
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class ObservationsZarr(ObservationsBase):
|
|
140
|
+
def __init__(self, dataset, frequency=None, window=None):
|
|
141
|
+
import zarr
|
|
142
|
+
|
|
143
|
+
if isinstance(dataset, zarr.hierarchy.Group):
|
|
144
|
+
dataset = dataset._store.path
|
|
145
|
+
|
|
146
|
+
from ..stores import zarr_lookup
|
|
147
|
+
|
|
148
|
+
dataset = zarr_lookup(dataset)
|
|
149
|
+
self.path = dataset
|
|
150
|
+
assert self._probe_attributes["is_observations"], f"Expected observations dataset, got {self.path}"
|
|
151
|
+
|
|
152
|
+
if frequency is None:
|
|
153
|
+
frequency = self._probe_attributes.get("frequency")
|
|
154
|
+
# LOG.warning(f"Frequency not provided, using the one from the dataset: {frequency}")
|
|
155
|
+
if frequency is None:
|
|
156
|
+
frequency = "6h"
|
|
157
|
+
# LOG.warning(f"Frequency not provided in the dataset, using the default : {frequency}")
|
|
158
|
+
self._frequency = frequency_to_timedelta(frequency)
|
|
159
|
+
assert self.frequency.total_seconds() % 3600 == 0, f"Expected multiple of 3600, got {self.frequency}"
|
|
160
|
+
if self.frequency.total_seconds() != 6 * 3600:
|
|
161
|
+
LOG.warning("Frequency is not 6h, this has not been tested, behaviour is unknown")
|
|
162
|
+
|
|
163
|
+
frequency_hours = int(self.frequency.total_seconds() // 3600)
|
|
164
|
+
assert isinstance(frequency_hours, int), f"Expected int, got {type(frequency_hours)}"
|
|
165
|
+
|
|
166
|
+
if window is None:
|
|
167
|
+
window = (-frequency_hours, 0)
|
|
168
|
+
if window != (-frequency_hours, 0):
|
|
169
|
+
raise ValueError("For now, only window = (- frequency, 0) are supported")
|
|
170
|
+
|
|
171
|
+
self.window = window
|
|
172
|
+
|
|
173
|
+
start, end = self._probe_attributes["start_date"], self._probe_attributes["end_date"]
|
|
174
|
+
start, end = datetime.datetime.fromisoformat(start), datetime.datetime.fromisoformat(end)
|
|
175
|
+
start, end = round_datetime(start, frequency_hours), round_datetime(end, frequency_hours)
|
|
176
|
+
|
|
177
|
+
self._dates = make_dates(start + self.frequency, end, self.frequency)
|
|
178
|
+
|
|
179
|
+
first_window_begin = start.strftime("%Y%m%d%H%M%S")
|
|
180
|
+
first_window_begin = int(first_window_begin)
|
|
181
|
+
# last_window_end must be the end of the time window of the last item
|
|
182
|
+
last_window_end = int(end.strftime("%Y%m%d%H%M%S"))
|
|
183
|
+
|
|
184
|
+
from .legacy_obs_dataset import ObsDataset
|
|
185
|
+
|
|
186
|
+
args = [self.path, first_window_begin, last_window_end]
|
|
187
|
+
kwargs = dict(
|
|
188
|
+
len_hrs=frequency_hours, # length the time windows, i.e. the time span of one item
|
|
189
|
+
step_hrs=frequency_hours, # frequency of the dataset, i.e. the time shift between two items
|
|
190
|
+
)
|
|
191
|
+
self.forward = ObsDataset(*args, **kwargs)
|
|
192
|
+
|
|
193
|
+
assert frequency_hours == self.forward.step_hrs, f"Expected {frequency_hours}, got {self.forward.len_hrs}"
|
|
194
|
+
assert frequency_hours == self.forward.len_hrs, f"Expected {frequency_hours}, got {self.forward.step_hrs}"
|
|
195
|
+
|
|
196
|
+
if len(self.forward) != len(self.dates):
|
|
197
|
+
raise ValueError(
|
|
198
|
+
(
|
|
199
|
+
f"Dates are not consistent with the number of items in the dataset. "
|
|
200
|
+
f"The dataset contains {len(self.forward)} time windows. "
|
|
201
|
+
f"This is not compatible with the "
|
|
202
|
+
f"{len(self.dates)} requested dates with frequency={frequency_hours}"
|
|
203
|
+
f"{self.dates[0]}, {self.dates[1]}, ..., {self.dates[-2]}, {self.dates[-1]} "
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def source(self):
|
|
209
|
+
return self.path
|
|
210
|
+
|
|
211
|
+
def get_dataset_names(self):
|
|
212
|
+
name = os.path.basename(self.path)
|
|
213
|
+
if name.endswith(".zarr"):
|
|
214
|
+
name = name[:-5]
|
|
215
|
+
return [name]
|
|
216
|
+
|
|
217
|
+
@cached_property
|
|
218
|
+
def _probe_attributes(self):
|
|
219
|
+
import zarr
|
|
220
|
+
|
|
221
|
+
z = zarr.open(self.path, mode="r")
|
|
222
|
+
return dict(z.data.attrs)
|
|
223
|
+
|
|
224
|
+
def get_aux(self, i):
|
|
225
|
+
data = self.forward[i]
|
|
226
|
+
|
|
227
|
+
latitudes = data[:, self.name_to_index["__latitudes"]].numpy()
|
|
228
|
+
longitudes = data[:, self.name_to_index["__longitudes"]].numpy()
|
|
229
|
+
|
|
230
|
+
reference = self.dates[i]
|
|
231
|
+
times = self.forward.get_dates(i)
|
|
232
|
+
if str(times.dtype) != "datetime64[s]":
|
|
233
|
+
LOG.warning(f"Expected np.datetime64[s], got {times.dtype}. ")
|
|
234
|
+
times = times.astype("datetime64[s]")
|
|
235
|
+
assert str(reference.dtype) == "datetime64[s]", f"Expected np.datetime64[s], got {type(reference)}"
|
|
236
|
+
timedeltas = times - reference
|
|
237
|
+
|
|
238
|
+
assert latitudes.shape == longitudes.shape, f"Expected {latitudes.shape}, got {longitudes.shape}"
|
|
239
|
+
assert timedeltas.shape == latitudes.shape, f"Expected {timedeltas.shape}, got {latitudes.shape}"
|
|
240
|
+
|
|
241
|
+
return latitudes, longitudes, timedeltas
|
|
242
|
+
|
|
243
|
+
def getitem(self, i):
|
|
244
|
+
data = self.forward[i]
|
|
245
|
+
|
|
246
|
+
data = data.numpy().astype(np.float32)
|
|
247
|
+
assert len(data.shape) == 2, f"Expected 2D array, got {data.shape}"
|
|
248
|
+
data = data.T
|
|
249
|
+
|
|
250
|
+
if not data.size:
|
|
251
|
+
data = self.empty_item()
|
|
252
|
+
assert (
|
|
253
|
+
data.shape[0] == self.shape[1]
|
|
254
|
+
), f"Data shape {data.shape} does not match {self.shape} : {data.shape[0]} != {self.shape[1]}"
|
|
255
|
+
return data
|
|
256
|
+
|
|
257
|
+
@cached_property
|
|
258
|
+
def variables(self):
|
|
259
|
+
colnames = self.forward.colnames
|
|
260
|
+
variables = []
|
|
261
|
+
for n in colnames:
|
|
262
|
+
if n.startswith("obsvalue_"):
|
|
263
|
+
n = n.replace("obsvalue_", "")
|
|
264
|
+
if n == "latitude" or n == "lat":
|
|
265
|
+
assert "latitudes" not in variables, f"Duplicate latitudes found in {variables}"
|
|
266
|
+
variables.append("__latitudes")
|
|
267
|
+
continue
|
|
268
|
+
if n == "longitude" or n == "lon":
|
|
269
|
+
assert "longitudes" not in variables, f"Duplicate longitudes found in {variables}"
|
|
270
|
+
variables.append("__longitudes")
|
|
271
|
+
continue
|
|
272
|
+
assert not n.startswith("__"), f"Invalid name {n} found in {colnames}"
|
|
273
|
+
variables.append(n)
|
|
274
|
+
return variables
|
|
275
|
+
|
|
276
|
+
@property
|
|
277
|
+
def name_to_index(self):
|
|
278
|
+
return {n: i for i, n in enumerate(self.variables)}
|
|
279
|
+
|
|
280
|
+
@property
|
|
281
|
+
def statistics(self):
|
|
282
|
+
mean = self.forward.properties["means"]
|
|
283
|
+
mean = np.array(mean, dtype=np.float32)
|
|
284
|
+
|
|
285
|
+
var = self.forward.properties["vars"]
|
|
286
|
+
var = np.array(var, dtype=np.float32)
|
|
287
|
+
stdev = np.sqrt(var)
|
|
288
|
+
|
|
289
|
+
minimum = np.array(self.forward.z.data.attrs["mins"], dtype=np.float32)
|
|
290
|
+
maximum = np.array(self.forward.z.data.attrs["maxs"], dtype=np.float32)
|
|
291
|
+
|
|
292
|
+
assert isinstance(mean, np.ndarray), f"Expected np.ndarray, got {type(mean)}"
|
|
293
|
+
assert isinstance(stdev, np.ndarray), f"Expected np.ndarray, got {type(stdev)}"
|
|
294
|
+
assert isinstance(minimum, np.ndarray), f"Expected np.ndarray, got {type(minimum)}"
|
|
295
|
+
assert isinstance(maximum, np.ndarray), f"Expected np.ndarray, got {type(maximum)}"
|
|
296
|
+
return dict(mean=mean, stdev=stdev, minimum=minimum, maximum=maximum)
|
|
297
|
+
|
|
298
|
+
def tree(self):
|
|
299
|
+
return Node(
|
|
300
|
+
self,
|
|
301
|
+
[],
|
|
302
|
+
path=self.path,
|
|
303
|
+
frequency=self.frequency,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
def __repr__(self):
|
|
307
|
+
return f"Observations({os.path.basename(self.path)}, {self.dates[0]};{self.dates[-1]}, {len(self)})"
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def observations_factory(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> ObservationsBase:
|
|
311
|
+
observations = kwargs.pop("observations")
|
|
312
|
+
|
|
313
|
+
if not isinstance(observations, dict):
|
|
314
|
+
observations = dict(dataset=observations)
|
|
315
|
+
dataset = ObservationsZarr(**observations)
|
|
316
|
+
return dataset._subset(**kwargs)
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
# (C) Copyright 2025 European Centre for Medium-Range Weather Forecasts.
|
|
2
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
3
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
4
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
5
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
|
+
# nor does it submit to any jurisdiction.
|
|
7
|
+
|
|
8
|
+
import datetime
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import torch
|
|
14
|
+
import zarr
|
|
15
|
+
from torch.utils.data import Dataset
|
|
16
|
+
|
|
17
|
+
LOG = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ObsDataset(Dataset):
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
filename: str,
|
|
25
|
+
start: int,
|
|
26
|
+
end: int,
|
|
27
|
+
len_hrs: int,
|
|
28
|
+
step_hrs: int = None,
|
|
29
|
+
select: list[str] = None,
|
|
30
|
+
drop: list[str] = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
|
|
33
|
+
self.filename = filename
|
|
34
|
+
self.z = zarr.open(filename, mode="r")
|
|
35
|
+
self.data = self.z["data"]
|
|
36
|
+
self.dt = self.z["dates"] # datetime only
|
|
37
|
+
self.hrly_index = self.z["idx_197001010000_1"]
|
|
38
|
+
self.colnames = self.data.attrs["colnames"]
|
|
39
|
+
self.selected_colnames = self.colnames
|
|
40
|
+
self.selected_cols_idx = np.arange(len(self.colnames))
|
|
41
|
+
self.len_hrs = len_hrs
|
|
42
|
+
self.step_hrs = step_hrs if step_hrs else len_hrs
|
|
43
|
+
|
|
44
|
+
# Create index for samples
|
|
45
|
+
self._setup_sample_index(start, end, self.len_hrs, self.step_hrs)
|
|
46
|
+
|
|
47
|
+
self._load_properties()
|
|
48
|
+
|
|
49
|
+
if select:
|
|
50
|
+
self.select(select)
|
|
51
|
+
|
|
52
|
+
if drop:
|
|
53
|
+
self.drop(drop)
|
|
54
|
+
|
|
55
|
+
def __getitem__(
|
|
56
|
+
self,
|
|
57
|
+
idx: int,
|
|
58
|
+
) -> torch.tensor:
|
|
59
|
+
|
|
60
|
+
start_row = self.indices_start[idx]
|
|
61
|
+
end_row = self.indices_end[idx]
|
|
62
|
+
|
|
63
|
+
data = self.data.oindex[start_row:end_row, self.selected_cols_idx]
|
|
64
|
+
|
|
65
|
+
return torch.from_numpy(data)
|
|
66
|
+
|
|
67
|
+
def __len__(self) -> int:
|
|
68
|
+
|
|
69
|
+
return len(self.indices_start)
|
|
70
|
+
|
|
71
|
+
def get_dates(
|
|
72
|
+
self,
|
|
73
|
+
idx: int,
|
|
74
|
+
) -> np.ndarray:
|
|
75
|
+
|
|
76
|
+
start_row = self.indices_start[idx]
|
|
77
|
+
end_row = self.indices_end[idx]
|
|
78
|
+
dates = self.dt.oindex[start_row:end_row]
|
|
79
|
+
|
|
80
|
+
assert len(dates.shape) == 2, dates.shape
|
|
81
|
+
dates = dates[:, 0]
|
|
82
|
+
|
|
83
|
+
if len(dates) and dates[0].dtype != np.dtype("datetime64[s]"):
|
|
84
|
+
dates = dates.astype("datetime64[s]")
|
|
85
|
+
|
|
86
|
+
return dates
|
|
87
|
+
|
|
88
|
+
def get_df(self, idx: int) -> pd.DataFrame:
|
|
89
|
+
"""Convenience function to return data for sample idx packaged in a pandas DataFrame"""
|
|
90
|
+
|
|
91
|
+
d = self.__getitem__(idx)
|
|
92
|
+
|
|
93
|
+
df = pd.DataFrame(data=d, columns=[self.colnames[i] for i in self.selected_cols_idx])
|
|
94
|
+
|
|
95
|
+
start_row = self.indices_start[idx]
|
|
96
|
+
end_row = self.indices_end[idx]
|
|
97
|
+
|
|
98
|
+
dts = self.dt[start_row:end_row, :]
|
|
99
|
+
df["datetime"] = dts
|
|
100
|
+
|
|
101
|
+
return df
|
|
102
|
+
|
|
103
|
+
def select(self, cols_list: list[str]) -> None:
|
|
104
|
+
"""Allow user to specify which columns they want to access.
|
|
105
|
+
Get functions only returned for these specified columns.
|
|
106
|
+
"""
|
|
107
|
+
self.selected_colnames = cols_list
|
|
108
|
+
self.selected_cols_idx = np.array([self.colnames.index(item) for item in cols_list])
|
|
109
|
+
|
|
110
|
+
def drop(self, cols_to_drop: list[str]) -> None:
|
|
111
|
+
"""Allow user to drop specific columns from the dataset.
|
|
112
|
+
Get functions no longer return data for these columns after being set.
|
|
113
|
+
"""
|
|
114
|
+
mask = [name not in cols_to_drop for name in self.selected_colnames]
|
|
115
|
+
|
|
116
|
+
self.selected_colnames = [name for name, keep in zip(self.selected_colnames, mask) if keep]
|
|
117
|
+
self.selected_cols_idx = self.selected_cols_idx[mask]
|
|
118
|
+
|
|
119
|
+
def time_window(self, idx: int) -> tuple[np.datetime64, np.datetime64]:
|
|
120
|
+
"""Returns a tuple of datetime objects describing the start and end times of the sample at position idx."""
|
|
121
|
+
|
|
122
|
+
if idx < 0:
|
|
123
|
+
idx = len(self) + idx
|
|
124
|
+
|
|
125
|
+
time_start = self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs), seconds=1)
|
|
126
|
+
time_end = min(
|
|
127
|
+
self.start_dt + datetime.timedelta(hours=(idx * self.step_hrs + self.len_hrs)),
|
|
128
|
+
self.end_dt,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return (np.datetime64(time_start), np.datetime64(time_end))
|
|
132
|
+
|
|
133
|
+
def first_sample_with_data(self) -> int:
|
|
134
|
+
"""Returns the position of the first sample which contains data."""
|
|
135
|
+
return int(np.nonzero(self.indices_end)[0][0]) if self.indices_end.max() > 0 else None
|
|
136
|
+
|
|
137
|
+
def last_sample_with_data(self) -> int:
|
|
138
|
+
"""Returns the position of the last sample which contains data."""
|
|
139
|
+
if self.indices_end.max() == 0:
|
|
140
|
+
last_sample = None
|
|
141
|
+
else:
|
|
142
|
+
last_sample = int(np.where(np.diff(np.append(self.indices_end, self.indices_end[-1])) > 0)[0][-1] + 1)
|
|
143
|
+
|
|
144
|
+
return last_sample
|
|
145
|
+
|
|
146
|
+
def _setup_sample_index(self, start: int, end: int, len_hrs: int, step_hrs: int) -> None:
|
|
147
|
+
"""Dataset is divided into samples;
|
|
148
|
+
- each n_hours long
|
|
149
|
+
- sample 0 starts at start (yyyymmddhhmm)
|
|
150
|
+
- index array has one entry for each sample; contains the index of the first row
|
|
151
|
+
containing data for that sample
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
from obsdata.config import config
|
|
156
|
+
|
|
157
|
+
assert config.base_index_yyyymmddhhmm == 197001010000, "base_index_yyyymmddhhmm must be 197001010000"
|
|
158
|
+
except ImportError:
|
|
159
|
+
pass
|
|
160
|
+
base_yyyymmddhhmm = 197001010000
|
|
161
|
+
|
|
162
|
+
assert start > base_yyyymmddhhmm, (
|
|
163
|
+
f"Abort: ObsDataset sample start (yyyymmddhhmm) must be greater than {base_yyyymmddhhmm}\n"
|
|
164
|
+
f" Current value: {start}"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
format_str = "%Y%m%d%H%M%S"
|
|
168
|
+
base_dt = datetime.datetime.strptime(str(base_yyyymmddhhmm), format_str)
|
|
169
|
+
self.start_dt = datetime.datetime.strptime(str(start), format_str)
|
|
170
|
+
self.end_dt = datetime.datetime.strptime(str(end), format_str)
|
|
171
|
+
|
|
172
|
+
# Calculate hours since the base date for the requested dataset ranges
|
|
173
|
+
diff_in_hours_start = int((self.start_dt - base_dt).total_seconds() // 3600)
|
|
174
|
+
diff_in_hours_end = int((self.end_dt - base_dt).total_seconds() // 3600)
|
|
175
|
+
|
|
176
|
+
# Find elements that need to be extracted from the hourly index
|
|
177
|
+
# + ensuring that the dataset respects the requested end-hour even if it is mid-way through a sample
|
|
178
|
+
sample_starts = np.arange(diff_in_hours_start, diff_in_hours_end, step_hrs)
|
|
179
|
+
sample_ends = np.minimum(sample_starts + len_hrs, diff_in_hours_end)
|
|
180
|
+
|
|
181
|
+
# Initialize local index arrays
|
|
182
|
+
self.indices_start = np.zeros(sample_starts.shape, dtype=int)
|
|
183
|
+
self.indices_end = np.zeros(self.indices_start.shape, dtype=int)
|
|
184
|
+
|
|
185
|
+
max_hrly_index = self.hrly_index.shape[0] - 1
|
|
186
|
+
valid_start_mask = sample_starts <= max_hrly_index
|
|
187
|
+
valid_end_mask = (sample_ends > 0) & (sample_ends <= max_hrly_index)
|
|
188
|
+
|
|
189
|
+
# Copy elements from the hrly_index into the local index
|
|
190
|
+
self.indices_start[valid_start_mask] = self.hrly_index[sample_starts[valid_start_mask]]
|
|
191
|
+
self.indices_end[valid_end_mask] = np.maximum(self.hrly_index[sample_ends[valid_end_mask]], 0)
|
|
192
|
+
|
|
193
|
+
def _load_properties(self) -> None:
|
|
194
|
+
|
|
195
|
+
self.properties = {}
|
|
196
|
+
|
|
197
|
+
self.properties["means"] = self.data.attrs["means"]
|
|
198
|
+
self.properties["vars"] = self.data.attrs["vars"]
|
|
199
|
+
self.properties["data_idxs"] = self.data.attrs["data_idxs"]
|
|
200
|
+
self.properties["obs_id"] = self.data.attrs["obs_id"]
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
|
|
13
|
+
from anemoi.datasets.data import open_dataset
|
|
14
|
+
|
|
15
|
+
LOG = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LegacyDatasets:
|
|
19
|
+
def __init__(self, paths, start=None, end=None, **kwargs):
|
|
20
|
+
self.paths = paths
|
|
21
|
+
|
|
22
|
+
if not start or not end:
|
|
23
|
+
print(
|
|
24
|
+
"❌❌ Warning: start and end not provided, using the minima first and maximal last dates of the datasets"
|
|
25
|
+
)
|
|
26
|
+
lst = [self._open_dataset(p, **kwargs) for p in paths]
|
|
27
|
+
start = min([d.dates[0] for d in lst])
|
|
28
|
+
end = max([d.dates[-1] for d in lst])
|
|
29
|
+
|
|
30
|
+
self._datasets = {
|
|
31
|
+
os.path.basename(p).split(".")[0]: self._open_dataset(p, start=start, end=end, padding="empty")
|
|
32
|
+
for p in paths
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
first = list(self._datasets.values())[0]
|
|
36
|
+
for name, dataset in self._datasets.items():
|
|
37
|
+
if dataset.dates[0] != first.dates[0] or dataset.dates[-1] != first.dates[-1]:
|
|
38
|
+
raise ValueError("Datasets have different start and end times")
|
|
39
|
+
if dataset.frequency != first.frequency:
|
|
40
|
+
raise ValueError("Datasets have different frequencies")
|
|
41
|
+
|
|
42
|
+
self._keys = self._datasets.keys
|
|
43
|
+
|
|
44
|
+
self._first = list(self._datasets.values())[0]
|
|
45
|
+
|
|
46
|
+
def _open_dataset(self, p, **kwargs):
|
|
47
|
+
if p.startswith("observations-"):
|
|
48
|
+
return open_dataset(observations=p, **kwargs)
|
|
49
|
+
else:
|
|
50
|
+
print("❗ Opening non-observations dataset:", p)
|
|
51
|
+
return open_dataset(p, **kwargs)
|
|
52
|
+
|
|
53
|
+
def items(self):
|
|
54
|
+
return self._datasets.items()
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def dates(self):
|
|
58
|
+
return self._first.dates
|
|
59
|
+
|
|
60
|
+
def __len__(self):
|
|
61
|
+
return len(self._first)
|
|
62
|
+
|
|
63
|
+
def __getitem__(self, i):
|
|
64
|
+
return {k: d[i] for k, d in self._datasets.items()}
|