anemoi-datasets 0.5.27__py3-none-any.whl → 0.5.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/recipe/__init__.py +93 -0
- anemoi/datasets/commands/recipe/format.py +55 -0
- anemoi/datasets/commands/recipe/migrate.py +555 -0
- anemoi/datasets/create/__init__.py +46 -13
- anemoi/datasets/create/config.py +52 -53
- anemoi/datasets/create/input/__init__.py +43 -63
- anemoi/datasets/create/input/action.py +296 -236
- anemoi/datasets/create/input/context/__init__.py +71 -0
- anemoi/datasets/create/input/context/field.py +54 -0
- anemoi/datasets/create/input/data_sources.py +2 -1
- anemoi/datasets/create/input/misc.py +0 -71
- anemoi/datasets/create/input/repeated_dates.py +0 -114
- anemoi/datasets/create/input/result/__init__.py +17 -0
- anemoi/datasets/create/input/{result.py → result/field.py} +10 -92
- anemoi/datasets/create/sources/accumulate.py +517 -0
- anemoi/datasets/create/sources/accumulate_utils/__init__.py +8 -0
- anemoi/datasets/create/sources/accumulate_utils/covering_intervals.py +221 -0
- anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py +149 -0
- anemoi/datasets/create/sources/accumulate_utils/interval_generators.py +321 -0
- anemoi/datasets/create/sources/anemoi_dataset.py +46 -42
- anemoi/datasets/create/sources/constants.py +39 -38
- anemoi/datasets/create/sources/empty.py +26 -22
- anemoi/datasets/create/sources/forcings.py +29 -28
- anemoi/datasets/create/sources/grib.py +92 -72
- anemoi/datasets/create/sources/grib_index.py +102 -54
- anemoi/datasets/create/sources/hindcasts.py +56 -55
- anemoi/datasets/create/sources/legacy.py +10 -62
- anemoi/datasets/create/sources/mars.py +159 -154
- anemoi/datasets/create/sources/netcdf.py +28 -24
- anemoi/datasets/create/sources/opendap.py +28 -24
- anemoi/datasets/create/sources/recentre.py +42 -41
- anemoi/datasets/create/sources/repeated_dates.py +44 -0
- anemoi/datasets/create/sources/source.py +26 -48
- anemoi/datasets/create/sources/xarray_support/__init__.py +30 -24
- anemoi/datasets/create/sources/xarray_support/coordinates.py +1 -4
- anemoi/datasets/create/sources/xarray_support/field.py +4 -4
- anemoi/datasets/create/sources/xarray_support/flavour.py +2 -2
- anemoi/datasets/create/sources/xarray_support/patch.py +178 -5
- anemoi/datasets/create/sources/xarray_zarr.py +28 -24
- anemoi/datasets/create/sources/zenodo.py +43 -39
- anemoi/datasets/create/utils.py +0 -42
- anemoi/datasets/data/complement.py +26 -17
- anemoi/datasets/data/dataset.py +12 -0
- anemoi/datasets/data/grids.py +0 -152
- anemoi/datasets/data/masked.py +74 -13
- anemoi/datasets/data/missing.py +5 -0
- anemoi/datasets/data/rolling_average.py +141 -0
- anemoi/datasets/data/stores.py +7 -9
- anemoi/datasets/dates/__init__.py +2 -0
- anemoi/datasets/dumper.py +76 -0
- anemoi/datasets/grids.py +1 -178
- anemoi/datasets/schemas/recipe.json +131 -0
- {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/METADATA +9 -6
- {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/RECORD +59 -57
- {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/WHEEL +1 -1
- anemoi/datasets/create/filter.py +0 -47
- anemoi/datasets/create/input/concat.py +0 -161
- anemoi/datasets/create/input/context.py +0 -86
- anemoi/datasets/create/input/empty.py +0 -53
- anemoi/datasets/create/input/filter.py +0 -117
- anemoi/datasets/create/input/function.py +0 -232
- anemoi/datasets/create/input/join.py +0 -129
- anemoi/datasets/create/input/pipe.py +0 -66
- anemoi/datasets/create/input/step.py +0 -173
- anemoi/datasets/create/input/template.py +0 -161
- anemoi/datasets/create/sources/accumulations.py +0 -1062
- anemoi/datasets/create/sources/accumulations2.py +0 -647
- anemoi/datasets/create/sources/tendencies.py +0 -198
- {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/licenses/LICENSE +0 -0
- {anemoi_datasets-0.5.27.dist-info → anemoi_datasets-0.5.29.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from anemoi.transform.fields import new_field_with_valid_datetime
|
|
15
|
+
from anemoi.transform.fields import new_fieldlist_from_list
|
|
16
|
+
|
|
17
|
+
from anemoi.datasets.create.input.repeated_dates import DateMapper
|
|
18
|
+
from anemoi.datasets.create.source import Source
|
|
19
|
+
from anemoi.datasets.create.sources import source_registry
|
|
20
|
+
|
|
21
|
+
LOG = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@source_registry.register("repeated_dates")
|
|
25
|
+
class RepeatedDatesSource(Source):
|
|
26
|
+
|
|
27
|
+
def __init__(self, context, source: Any, mode: str, **kwargs) -> None:
|
|
28
|
+
# assert False, (context, source, mode, kwargs)
|
|
29
|
+
super().__init__(context, **kwargs)
|
|
30
|
+
self.mapper = DateMapper.from_mode(mode, source, kwargs)
|
|
31
|
+
self.source = source
|
|
32
|
+
|
|
33
|
+
def execute(self, group_of_dates):
|
|
34
|
+
source = self.context.create_source(self.source, "data_sources", str(id(self)))
|
|
35
|
+
|
|
36
|
+
result = []
|
|
37
|
+
for one_date_group, many_dates_group in self.mapper.transform(group_of_dates):
|
|
38
|
+
print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}")
|
|
39
|
+
source_results = source(self.context, one_date_group)
|
|
40
|
+
for field in source_results:
|
|
41
|
+
for date in many_dates_group:
|
|
42
|
+
result.append(new_field_with_valid_datetime(field, date))
|
|
43
|
+
|
|
44
|
+
return new_fieldlist_from_list(result)
|
|
@@ -12,58 +12,36 @@ from typing import Any
|
|
|
12
12
|
|
|
13
13
|
from earthkit.data import from_source
|
|
14
14
|
|
|
15
|
-
from anemoi.datasets.create.
|
|
15
|
+
from anemoi.datasets.create.sources import source_registry
|
|
16
16
|
|
|
17
|
-
from .legacy import
|
|
17
|
+
from .legacy import LegacySource
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
@
|
|
21
|
-
|
|
22
|
-
"""Generates a source based on the provided context, dates, and additional keyword arguments.
|
|
20
|
+
@source_registry.register("source")
|
|
21
|
+
class GenericSource(LegacySource):
|
|
23
22
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
The context in which the source is generated.
|
|
28
|
-
dates : List[datetime]
|
|
29
|
-
A list of datetime objects representing the dates.
|
|
30
|
-
**kwargs : Any
|
|
31
|
-
Additional keyword arguments for the source generation.
|
|
23
|
+
@staticmethod
|
|
24
|
+
def _execute(context: Any | None, dates: list[datetime], **kwargs: Any) -> Any:
|
|
25
|
+
"""Generates a source based on the provided context, dates, and additional keyword arguments.
|
|
32
26
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
kwargs["date"] = list({d.strftime("%Y%m%d") for d in dates})
|
|
42
|
-
if kwargs["time"] == "$from_dates":
|
|
43
|
-
kwargs["time"] = list({d.strftime("%H%M") for d in dates})
|
|
44
|
-
return from_source(name, **kwargs)
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
context : Optional[Any]
|
|
30
|
+
The context in which the source is generated.
|
|
31
|
+
dates : List[datetime]
|
|
32
|
+
A list of datetime objects representing the dates.
|
|
33
|
+
**kwargs : Any
|
|
34
|
+
Additional keyword arguments for the source generation.
|
|
45
35
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
import yaml
|
|
51
|
-
|
|
52
|
-
config: dict[str, Any] = yaml.safe_load(
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
Any
|
|
39
|
+
The generated source.
|
|
53
40
|
"""
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
date: $from_dates
|
|
62
|
-
time: $from_dates
|
|
63
|
-
"""
|
|
64
|
-
)
|
|
65
|
-
dates: list[str] = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]")
|
|
66
|
-
dates = to_datetime_list(dates)
|
|
67
|
-
|
|
68
|
-
for f in source(None, dates, **config):
|
|
69
|
-
print(f, f.to_numpy().mean())
|
|
41
|
+
name = kwargs.pop("name")
|
|
42
|
+
context.trace("✅", f"from_source({name}, {dates}, {kwargs}")
|
|
43
|
+
if kwargs["date"] == "$from_dates":
|
|
44
|
+
kwargs["date"] = list({d.strftime("%Y%m%d") for d in dates})
|
|
45
|
+
if kwargs["time"] == "$from_dates":
|
|
46
|
+
kwargs["time"] = list({d.strftime("%H%M") for d in dates})
|
|
47
|
+
return from_source(name, **kwargs)
|
|
@@ -17,7 +17,8 @@ from earthkit.data.core.fieldlist import MultiFieldList
|
|
|
17
17
|
|
|
18
18
|
from anemoi.datasets.create.sources.patterns import iterate_patterns
|
|
19
19
|
|
|
20
|
-
from ..
|
|
20
|
+
from .. import source_registry
|
|
21
|
+
from ..legacy import LegacySource
|
|
21
22
|
from .fieldlist import XarrayFieldList
|
|
22
23
|
|
|
23
24
|
LOG = logging.getLogger(__name__)
|
|
@@ -96,6 +97,7 @@ def load_one(
|
|
|
96
97
|
if isinstance(dataset, xr.Dataset):
|
|
97
98
|
data = dataset
|
|
98
99
|
else:
|
|
100
|
+
print(f"Opening dataset {dataset} with options {options}")
|
|
99
101
|
data = xr.open_dataset(dataset, **options)
|
|
100
102
|
|
|
101
103
|
fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
|
|
@@ -152,26 +154,30 @@ def load_many(emoji: str, context: Any, dates: list[datetime.datetime], pattern:
|
|
|
152
154
|
return MultiFieldList(result)
|
|
153
155
|
|
|
154
156
|
|
|
155
|
-
@
|
|
156
|
-
|
|
157
|
-
""
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
157
|
+
@source_registry.register("xarray")
|
|
158
|
+
class LegacyXarraySource(LegacySource):
|
|
159
|
+
name = "xarray"
|
|
160
|
+
|
|
161
|
+
@staticmethod
|
|
162
|
+
def _execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList:
|
|
163
|
+
"""Executes the loading of datasets.
|
|
164
|
+
|
|
165
|
+
Parameters
|
|
166
|
+
----------
|
|
167
|
+
context : Any
|
|
168
|
+
Context object.
|
|
169
|
+
dates : List[str]
|
|
170
|
+
List of dates.
|
|
171
|
+
url : str
|
|
172
|
+
URL pattern for loading datasets.
|
|
173
|
+
*args : Any
|
|
174
|
+
Additional arguments.
|
|
175
|
+
**kwargs : Any
|
|
176
|
+
Additional keyword arguments.
|
|
177
|
+
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
ekd.FieldList
|
|
181
|
+
The loaded datasets.
|
|
182
|
+
"""
|
|
183
|
+
return load_many("🌐", context, dates, url, *args, **kwargs)
|
|
@@ -223,13 +223,10 @@ class Coordinate:
|
|
|
223
223
|
# Assume the array is sorted
|
|
224
224
|
|
|
225
225
|
index = np.searchsorted(values, value)
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
if np.all(values[index] == value):
|
|
226
|
+
if np.all(index < len(values)) and np.all(values[index] == value):
|
|
229
227
|
return index
|
|
230
228
|
|
|
231
229
|
# If not found, we need to check if the value is in the array
|
|
232
|
-
|
|
233
230
|
index = np.where(np.isin(values, value))[0]
|
|
234
231
|
|
|
235
232
|
# We could also return incomplete matches
|
|
@@ -121,16 +121,16 @@ class XArrayField(Field):
|
|
|
121
121
|
Index to select a specific element, by default None.
|
|
122
122
|
"""
|
|
123
123
|
if index is not None:
|
|
124
|
-
values = self.selection[index]
|
|
124
|
+
values = self.selection[index].values
|
|
125
125
|
else:
|
|
126
|
-
values = self.selection
|
|
126
|
+
values = self.selection.values
|
|
127
127
|
|
|
128
128
|
assert dtype is None
|
|
129
129
|
|
|
130
130
|
if flatten:
|
|
131
|
-
return values.
|
|
131
|
+
return values.flatten()
|
|
132
132
|
|
|
133
|
-
return values
|
|
133
|
+
return values
|
|
134
134
|
|
|
135
135
|
@cached_property
|
|
136
136
|
def _metadata(self) -> XArrayMetadata:
|
|
@@ -557,10 +557,10 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
|
|
|
557
557
|
super().__init__(ds)
|
|
558
558
|
|
|
559
559
|
def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> PointCoordinate | None:
|
|
560
|
-
if attributes.standard_name in ["cell", "station", "poi", "point"]:
|
|
560
|
+
if attributes.standard_name in ["location", "cell", "id", "station", "poi", "point"]:
|
|
561
561
|
return PointCoordinate(c)
|
|
562
562
|
|
|
563
|
-
if attributes.name in ["cell", "station", "poi", "point"]: # WeatherBench
|
|
563
|
+
if attributes.name in ["location", "cell", "id", "station", "poi", "point"]: # WeatherBench
|
|
564
564
|
return PointCoordinate(c)
|
|
565
565
|
|
|
566
566
|
return None
|
|
@@ -10,13 +10,14 @@
|
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
12
|
from typing import Any
|
|
13
|
+
from typing import Literal
|
|
13
14
|
|
|
14
15
|
import xarray as xr
|
|
15
16
|
|
|
16
17
|
LOG = logging.getLogger(__name__)
|
|
17
18
|
|
|
18
19
|
|
|
19
|
-
def patch_attributes(ds: xr.Dataset, attributes: dict[str, dict[str, Any]]) ->
|
|
20
|
+
def patch_attributes(ds: xr.Dataset, attributes: dict[str, dict[str, Any]]) -> xr.Dataset:
|
|
20
21
|
"""Patch the attributes of the dataset.
|
|
21
22
|
|
|
22
23
|
Parameters
|
|
@@ -38,7 +39,7 @@ def patch_attributes(ds: xr.Dataset, attributes: dict[str, dict[str, Any]]) -> A
|
|
|
38
39
|
return ds
|
|
39
40
|
|
|
40
41
|
|
|
41
|
-
def patch_coordinates(ds: xr.Dataset, coordinates: list[str]) ->
|
|
42
|
+
def patch_coordinates(ds: xr.Dataset, coordinates: list[str]) -> xr.Dataset:
|
|
42
43
|
"""Patch the coordinates of the dataset.
|
|
43
44
|
|
|
44
45
|
Parameters
|
|
@@ -59,7 +60,7 @@ def patch_coordinates(ds: xr.Dataset, coordinates: list[str]) -> Any:
|
|
|
59
60
|
return ds
|
|
60
61
|
|
|
61
62
|
|
|
62
|
-
def patch_rename(ds: xr.Dataset, renames: dict[str, str]) ->
|
|
63
|
+
def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> xr.Dataset:
|
|
63
64
|
"""Rename variables in the dataset.
|
|
64
65
|
|
|
65
66
|
Parameters
|
|
@@ -77,7 +78,7 @@ def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any:
|
|
|
77
78
|
return ds.rename(renames)
|
|
78
79
|
|
|
79
80
|
|
|
80
|
-
def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: list[str]) ->
|
|
81
|
+
def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: list[str]) -> xr.Dataset:
|
|
81
82
|
"""Sort the coordinates of the dataset.
|
|
82
83
|
|
|
83
84
|
Parameters
|
|
@@ -98,11 +99,175 @@ def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: list[str]) -> Any:
|
|
|
98
99
|
return ds
|
|
99
100
|
|
|
100
101
|
|
|
102
|
+
def patch_subset_dataset(ds: xr.Dataset, selection: dict[str, Any]) -> xr.Dataset:
|
|
103
|
+
"""Select a subset of the dataset using xarray's sel method.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
ds : xr.Dataset
|
|
108
|
+
The dataset to patch.
|
|
109
|
+
selection : dict[str, Any]
|
|
110
|
+
Dictionary mapping dimension names to selection criteria.
|
|
111
|
+
Keys must be existing dimension names in the dataset.
|
|
112
|
+
Values can be any type accepted by xarray's sel method, including:
|
|
113
|
+
- Single values (int, float, str, datetime)
|
|
114
|
+
- Lists or arrays of values
|
|
115
|
+
- Slices (using slice() objects)
|
|
116
|
+
- Boolean arrays
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
xr.Dataset
|
|
121
|
+
The patched dataset containing only the selected subset.
|
|
122
|
+
|
|
123
|
+
Examples
|
|
124
|
+
--------
|
|
125
|
+
>>> # Select specific time and pressure level
|
|
126
|
+
>>> patch_subset_dataset(ds, {
|
|
127
|
+
... 'time': '2020-01-01',
|
|
128
|
+
... 'pressure': 500
|
|
129
|
+
... })
|
|
130
|
+
|
|
131
|
+
>>> # Select a range using slice
|
|
132
|
+
>>> patch_subset_dataset(ds, {
|
|
133
|
+
... 'lat': slice(-90, 90),
|
|
134
|
+
... 'lon': slice(0, 180)
|
|
135
|
+
... })
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
ds = ds.sel(selection)
|
|
139
|
+
|
|
140
|
+
return ds
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def patch_analysis_lead_to_valid_time(
|
|
144
|
+
ds: xr.Dataset,
|
|
145
|
+
time_coord_names: dict[Literal["analysis_time_coordinate", "lead_time_coordinate", "valid_time_coordinate"], str],
|
|
146
|
+
) -> xr.Dataset:
|
|
147
|
+
"""Convert analysis time and lead time coordinates to valid time.
|
|
148
|
+
|
|
149
|
+
This function creates a new valid time coordinate by adding the analysis time
|
|
150
|
+
and lead time coordinates, then stacks and reorganizes the dataset to use
|
|
151
|
+
valid time as the primary time dimension.
|
|
152
|
+
|
|
153
|
+
Parameters
|
|
154
|
+
----------
|
|
155
|
+
ds : xr.Dataset
|
|
156
|
+
The dataset to patch.
|
|
157
|
+
time_coord_names : dict[str, str]
|
|
158
|
+
Dictionary mapping required keys to coordinate names in the dataset:
|
|
159
|
+
|
|
160
|
+
- 'analysis_time_coordinate' : str
|
|
161
|
+
Name of the analysis/initialization time coordinate.
|
|
162
|
+
- 'lead_time_coordinate' : str
|
|
163
|
+
Name of the forecast lead time coordinate.
|
|
164
|
+
- 'valid_time_coordinate' : str
|
|
165
|
+
Name for the new valid time coordinate to create.
|
|
166
|
+
|
|
167
|
+
Returns
|
|
168
|
+
-------
|
|
169
|
+
xr.Dataset
|
|
170
|
+
The patched dataset with valid time as the primary time coordinate.
|
|
171
|
+
The analysis and lead time coordinates are removed.
|
|
172
|
+
|
|
173
|
+
Examples
|
|
174
|
+
--------
|
|
175
|
+
>>> patch_analysis_lead_to_valid_time(ds, {
|
|
176
|
+
... 'analysis_time_coordinate': 'forecast_reference_time',
|
|
177
|
+
... 'lead_time_coordinate': 'step',
|
|
178
|
+
... 'valid_time_coordinate': 'time'
|
|
179
|
+
... })
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
assert time_coord_names.keys() == {
|
|
183
|
+
"analysis_time_coordinate",
|
|
184
|
+
"lead_time_coordinate",
|
|
185
|
+
"valid_time_coordinate",
|
|
186
|
+
}, "time_coord_names must contain exactly keys 'analysis_time_coordinate', 'lead_time_coordinate', and 'valid_time_coordinate'"
|
|
187
|
+
|
|
188
|
+
analysis_time_coordinate = time_coord_names["analysis_time_coordinate"]
|
|
189
|
+
lead_time_coordinate = time_coord_names["lead_time_coordinate"]
|
|
190
|
+
valid_time_coordinate = time_coord_names["valid_time_coordinate"]
|
|
191
|
+
|
|
192
|
+
valid_time = ds[analysis_time_coordinate] + ds[lead_time_coordinate]
|
|
193
|
+
|
|
194
|
+
ds = (
|
|
195
|
+
ds.assign_coords({valid_time_coordinate: valid_time})
|
|
196
|
+
.stack(time_index=[analysis_time_coordinate, lead_time_coordinate])
|
|
197
|
+
.set_index(time_index=valid_time_coordinate)
|
|
198
|
+
.rename(time_index=valid_time_coordinate)
|
|
199
|
+
.drop_vars([analysis_time_coordinate, lead_time_coordinate])
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
return ds
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def patch_rolling_operation(
|
|
206
|
+
ds: xr.Dataset, vars_operation_config: dict[Literal["dim", "steps", "vars", "operation"], str | int | list[str]]
|
|
207
|
+
) -> xr.Dataset:
|
|
208
|
+
"""Apply a rolling operation to specified variables in the dataset.
|
|
209
|
+
|
|
210
|
+
This function calculates a rolling operation over a specified dimension for selected
|
|
211
|
+
variables. The rolling window requires all periods to be present (min_periods=steps).
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
ds : xr.Dataset
|
|
216
|
+
The dataset to patch.
|
|
217
|
+
vars_operation_config: dict
|
|
218
|
+
Configuration for the rolling operation with the following keys:
|
|
219
|
+
|
|
220
|
+
- 'dim' : str
|
|
221
|
+
The dimension along which to apply the rolling operation (e.g., 'time').
|
|
222
|
+
- 'steps' : int
|
|
223
|
+
The number of steps in the rolling window.
|
|
224
|
+
- 'vars' : list[str]
|
|
225
|
+
List of variable names to apply the rolling operation to.
|
|
226
|
+
- 'operation' : str
|
|
227
|
+
The operation to apply ('sum', 'mean', 'min', 'max', 'std', etc.).
|
|
228
|
+
|
|
229
|
+
Returns
|
|
230
|
+
-------
|
|
231
|
+
xr.Dataset
|
|
232
|
+
The patched dataset with rolling operations applied to the specified variables.
|
|
233
|
+
|
|
234
|
+
Examples
|
|
235
|
+
--------
|
|
236
|
+
>>> patch_rolling_operation(ds, {
|
|
237
|
+
... 'dim': 'time',
|
|
238
|
+
... 'steps': 3,
|
|
239
|
+
... 'vars': ['precipitation', 'radiation'],
|
|
240
|
+
... 'operation': 'sum'
|
|
241
|
+
... })
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
assert vars_operation_config.keys() == {
|
|
245
|
+
"dim",
|
|
246
|
+
"steps",
|
|
247
|
+
"vars",
|
|
248
|
+
"operation",
|
|
249
|
+
}, "vars_operation_config must contain exactly keys 'dim', 'steps', 'vars', and 'operation'"
|
|
250
|
+
|
|
251
|
+
dim = vars_operation_config["dim"]
|
|
252
|
+
steps = vars_operation_config["steps"]
|
|
253
|
+
vars = vars_operation_config["vars"]
|
|
254
|
+
operation = vars_operation_config["operation"]
|
|
255
|
+
|
|
256
|
+
for var in vars:
|
|
257
|
+
rolling = ds[var].rolling(dim={dim: steps}, min_periods=steps)
|
|
258
|
+
ds[var] = getattr(rolling, operation)()
|
|
259
|
+
|
|
260
|
+
return ds
|
|
261
|
+
|
|
262
|
+
|
|
101
263
|
PATCHES = {
|
|
102
264
|
"attributes": patch_attributes,
|
|
103
265
|
"coordinates": patch_coordinates,
|
|
104
266
|
"rename": patch_rename,
|
|
105
267
|
"sort_coordinates": patch_sort_coordinate,
|
|
268
|
+
"analysis_lead_to_valid_time": patch_analysis_lead_to_valid_time,
|
|
269
|
+
"rolling_operation": patch_rolling_operation,
|
|
270
|
+
"subset_dataset": patch_subset_dataset,
|
|
106
271
|
}
|
|
107
272
|
|
|
108
273
|
|
|
@@ -122,7 +287,15 @@ def patch_dataset(ds: xr.Dataset, patch: dict[str, dict[str, Any]]) -> Any:
|
|
|
122
287
|
The patched dataset.
|
|
123
288
|
"""
|
|
124
289
|
|
|
125
|
-
ORDER = [
|
|
290
|
+
ORDER = [
|
|
291
|
+
"coordinates",
|
|
292
|
+
"attributes",
|
|
293
|
+
"rename",
|
|
294
|
+
"sort_coordinates",
|
|
295
|
+
"subset_dataset",
|
|
296
|
+
"analysis_lead_to_valid_time",
|
|
297
|
+
"rolling_operation",
|
|
298
|
+
]
|
|
126
299
|
for what, values in sorted(patch.items(), key=lambda x: ORDER.index(x[0])):
|
|
127
300
|
if what not in PATCHES:
|
|
128
301
|
raise ValueError(f"Unknown patch type {what!r}")
|
|
@@ -11,30 +11,34 @@ from typing import Any
|
|
|
11
11
|
|
|
12
12
|
import earthkit.data as ekd
|
|
13
13
|
|
|
14
|
-
from .
|
|
14
|
+
from . import source_registry
|
|
15
|
+
from .legacy import LegacySource
|
|
15
16
|
from .xarray import load_many
|
|
16
17
|
|
|
17
18
|
|
|
18
|
-
@
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
19
|
+
@source_registry.register("xarray_zarr")
|
|
20
|
+
class XarrayZarrSource(LegacySource):
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def _execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList:
|
|
24
|
+
"""Execute the data loading process.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
context : Any
|
|
29
|
+
The context in which the execution occurs.
|
|
30
|
+
dates : List[str]
|
|
31
|
+
List of dates for which data is to be loaded.
|
|
32
|
+
url : str
|
|
33
|
+
The URL from which data is to be loaded.
|
|
34
|
+
*args : tuple
|
|
35
|
+
Additional positional arguments.
|
|
36
|
+
**kwargs : dict
|
|
37
|
+
Additional keyword arguments.
|
|
38
|
+
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
41
|
+
ekd.FieldList
|
|
42
|
+
The loaded data.
|
|
43
|
+
"""
|
|
44
|
+
return load_many("🇿", context, dates, url, *args, **kwargs)
|
|
@@ -14,54 +14,58 @@ import earthkit.data as ekd
|
|
|
14
14
|
from earthkit.data.core.fieldlist import MultiFieldList
|
|
15
15
|
from earthkit.data.sources.url import download_and_cache
|
|
16
16
|
|
|
17
|
-
from .
|
|
17
|
+
from . import source_registry
|
|
18
|
+
from .legacy import LegacySource
|
|
18
19
|
from .patterns import iterate_patterns
|
|
19
20
|
from .xarray import load_one
|
|
20
21
|
|
|
21
22
|
|
|
22
|
-
@
|
|
23
|
-
|
|
24
|
-
"""Executes the download and processing of files from Zenodo.
|
|
23
|
+
@source_registry.register("zenodo")
|
|
24
|
+
class ZenodoSource(LegacySource):
|
|
25
25
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
The context in which the function is executed.
|
|
30
|
-
dates : Any
|
|
31
|
-
The dates for which the data is required.
|
|
32
|
-
record_id : str
|
|
33
|
-
The Zenodo record ID.
|
|
34
|
-
file_key : str
|
|
35
|
-
The key to identify the file.
|
|
36
|
-
*args : Any
|
|
37
|
-
Additional arguments.
|
|
38
|
-
**kwargs : Any
|
|
39
|
-
Additional keyword arguments.
|
|
26
|
+
@staticmethod
|
|
27
|
+
def _execute(context: Any, dates: Any, record_id: str, file_key: str, *args: Any, **kwargs: Any) -> ekd.FieldList:
|
|
28
|
+
"""Executes the download and processing of files from Zenodo.
|
|
40
29
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
context : Any
|
|
33
|
+
The context in which the function is executed.
|
|
34
|
+
dates : Any
|
|
35
|
+
The dates for which the data is required.
|
|
36
|
+
record_id : str
|
|
37
|
+
The Zenodo record ID.
|
|
38
|
+
file_key : str
|
|
39
|
+
The key to identify the file.
|
|
40
|
+
*args : Any
|
|
41
|
+
Additional arguments.
|
|
42
|
+
**kwargs : Any
|
|
43
|
+
Additional keyword arguments.
|
|
47
44
|
|
|
48
|
-
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
MultiFieldList
|
|
48
|
+
A list of fields loaded from the downloaded files.
|
|
49
|
+
"""
|
|
50
|
+
import requests
|
|
49
51
|
|
|
50
|
-
|
|
51
|
-
url = URLPATTERN.format(record_id=record_id)
|
|
52
|
-
r = requests.get(url)
|
|
53
|
-
r.raise_for_status()
|
|
54
|
-
record: dict[str, Any] = r.json()
|
|
52
|
+
result: list[Any] = []
|
|
55
53
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
54
|
+
URLPATTERN = "https://zenodo.org/api/records/{record_id}"
|
|
55
|
+
url = URLPATTERN.format(record_id=record_id)
|
|
56
|
+
r = requests.get(url)
|
|
57
|
+
r.raise_for_status()
|
|
58
|
+
record: dict[str, Any] = r.json()
|
|
59
59
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
60
|
+
urls: dict[str, str] = {}
|
|
61
|
+
for file in record["files"]:
|
|
62
|
+
urls[file["key"]] = file["links"]["self"]
|
|
63
63
|
|
|
64
|
-
|
|
65
|
-
|
|
64
|
+
for url, dates in iterate_patterns(file_key, dates, **kwargs):
|
|
65
|
+
if url not in urls:
|
|
66
|
+
continue
|
|
66
67
|
|
|
67
|
-
|
|
68
|
+
path = download_and_cache(urls[url])
|
|
69
|
+
result.append(load_one("?", context, dates, path, options={}, flavour=None, **kwargs))
|
|
70
|
+
|
|
71
|
+
return MultiFieldList(result)
|