anemoi-datasets 0.5.26__py3-none-any.whl → 0.5.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/__init__.py +1 -2
- anemoi/datasets/_version.py +16 -3
- anemoi/datasets/commands/check.py +1 -1
- anemoi/datasets/commands/copy.py +1 -2
- anemoi/datasets/commands/create.py +1 -1
- anemoi/datasets/commands/inspect.py +27 -35
- anemoi/datasets/commands/recipe/__init__.py +93 -0
- anemoi/datasets/commands/recipe/format.py +55 -0
- anemoi/datasets/commands/recipe/migrate.py +555 -0
- anemoi/datasets/commands/validate.py +59 -0
- anemoi/datasets/compute/recentre.py +3 -6
- anemoi/datasets/create/__init__.py +64 -26
- anemoi/datasets/create/check.py +10 -12
- anemoi/datasets/create/chunks.py +1 -2
- anemoi/datasets/create/config.py +5 -6
- anemoi/datasets/create/input/__init__.py +44 -65
- anemoi/datasets/create/input/action.py +296 -238
- anemoi/datasets/create/input/context/__init__.py +71 -0
- anemoi/datasets/create/input/context/field.py +54 -0
- anemoi/datasets/create/input/data_sources.py +7 -9
- anemoi/datasets/create/input/misc.py +2 -75
- anemoi/datasets/create/input/repeated_dates.py +11 -130
- anemoi/datasets/{utils → create/input/result}/__init__.py +10 -1
- anemoi/datasets/create/input/{result.py → result/field.py} +36 -120
- anemoi/datasets/create/input/trace.py +1 -1
- anemoi/datasets/create/patch.py +1 -2
- anemoi/datasets/create/persistent.py +3 -5
- anemoi/datasets/create/size.py +1 -3
- anemoi/datasets/create/sources/accumulations.py +120 -145
- anemoi/datasets/create/sources/accumulations2.py +20 -53
- anemoi/datasets/create/sources/anemoi_dataset.py +46 -42
- anemoi/datasets/create/sources/constants.py +39 -40
- anemoi/datasets/create/sources/empty.py +22 -19
- anemoi/datasets/create/sources/fdb.py +133 -0
- anemoi/datasets/create/sources/forcings.py +29 -29
- anemoi/datasets/create/sources/grib.py +94 -78
- anemoi/datasets/create/sources/grib_index.py +57 -55
- anemoi/datasets/create/sources/hindcasts.py +57 -59
- anemoi/datasets/create/sources/legacy.py +10 -62
- anemoi/datasets/create/sources/mars.py +121 -149
- anemoi/datasets/create/sources/netcdf.py +28 -25
- anemoi/datasets/create/sources/opendap.py +28 -26
- anemoi/datasets/create/sources/patterns.py +4 -6
- anemoi/datasets/create/sources/recentre.py +46 -48
- anemoi/datasets/create/sources/repeated_dates.py +44 -0
- anemoi/datasets/create/sources/source.py +26 -51
- anemoi/datasets/create/sources/tendencies.py +68 -98
- anemoi/datasets/create/sources/xarray.py +4 -6
- anemoi/datasets/create/sources/xarray_support/__init__.py +40 -36
- anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -12
- anemoi/datasets/create/sources/xarray_support/field.py +20 -16
- anemoi/datasets/create/sources/xarray_support/fieldlist.py +11 -15
- anemoi/datasets/create/sources/xarray_support/flavour.py +42 -42
- anemoi/datasets/create/sources/xarray_support/grid.py +15 -9
- anemoi/datasets/create/sources/xarray_support/metadata.py +19 -128
- anemoi/datasets/create/sources/xarray_support/patch.py +4 -6
- anemoi/datasets/create/sources/xarray_support/time.py +10 -13
- anemoi/datasets/create/sources/xarray_support/variable.py +21 -21
- anemoi/datasets/create/sources/xarray_zarr.py +28 -25
- anemoi/datasets/create/sources/zenodo.py +43 -41
- anemoi/datasets/create/statistics/__init__.py +3 -6
- anemoi/datasets/create/testing.py +4 -0
- anemoi/datasets/create/typing.py +1 -2
- anemoi/datasets/create/utils.py +0 -43
- anemoi/datasets/create/zarr.py +7 -2
- anemoi/datasets/data/__init__.py +15 -6
- anemoi/datasets/data/complement.py +7 -12
- anemoi/datasets/data/concat.py +5 -8
- anemoi/datasets/data/dataset.py +48 -47
- anemoi/datasets/data/debug.py +7 -9
- anemoi/datasets/data/ensemble.py +4 -6
- anemoi/datasets/data/fill_missing.py +7 -10
- anemoi/datasets/data/forwards.py +22 -26
- anemoi/datasets/data/grids.py +12 -168
- anemoi/datasets/data/indexing.py +9 -12
- anemoi/datasets/data/interpolate.py +7 -15
- anemoi/datasets/data/join.py +8 -12
- anemoi/datasets/data/masked.py +6 -11
- anemoi/datasets/data/merge.py +5 -9
- anemoi/datasets/data/misc.py +41 -45
- anemoi/datasets/data/missing.py +11 -16
- anemoi/datasets/data/observations/__init__.py +8 -14
- anemoi/datasets/data/padded.py +3 -5
- anemoi/datasets/data/records/backends/__init__.py +2 -2
- anemoi/datasets/data/rescale.py +5 -12
- anemoi/datasets/data/rolling_average.py +141 -0
- anemoi/datasets/data/select.py +13 -16
- anemoi/datasets/data/statistics.py +4 -7
- anemoi/datasets/data/stores.py +22 -29
- anemoi/datasets/data/subset.py +8 -11
- anemoi/datasets/data/unchecked.py +7 -11
- anemoi/datasets/data/xy.py +25 -21
- anemoi/datasets/dates/__init__.py +15 -18
- anemoi/datasets/dates/groups.py +7 -10
- anemoi/datasets/dumper.py +76 -0
- anemoi/datasets/grids.py +4 -185
- anemoi/datasets/schemas/recipe.json +131 -0
- anemoi/datasets/testing.py +93 -7
- anemoi/datasets/validate.py +598 -0
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/METADATA +7 -4
- anemoi_datasets-0.5.28.dist-info/RECORD +134 -0
- anemoi/datasets/create/filter.py +0 -48
- anemoi/datasets/create/input/concat.py +0 -164
- anemoi/datasets/create/input/context.py +0 -89
- anemoi/datasets/create/input/empty.py +0 -54
- anemoi/datasets/create/input/filter.py +0 -118
- anemoi/datasets/create/input/function.py +0 -233
- anemoi/datasets/create/input/join.py +0 -130
- anemoi/datasets/create/input/pipe.py +0 -66
- anemoi/datasets/create/input/step.py +0 -177
- anemoi/datasets/create/input/template.py +0 -162
- anemoi_datasets-0.5.26.dist-info/RECORD +0 -131
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/WHEEL +0 -0
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/licenses/LICENSE +0 -0
- {anemoi_datasets-0.5.26.dist-info → anemoi_datasets-0.5.28.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from anemoi.transform.fields import new_field_with_valid_datetime
|
|
15
|
+
from anemoi.transform.fields import new_fieldlist_from_list
|
|
16
|
+
|
|
17
|
+
from anemoi.datasets.create.input.repeated_dates import DateMapper
|
|
18
|
+
from anemoi.datasets.create.source import Source
|
|
19
|
+
from anemoi.datasets.create.sources import source_registry
|
|
20
|
+
|
|
21
|
+
LOG = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@source_registry.register("repeated_dates")
|
|
25
|
+
class RepeatedDatesSource(Source):
|
|
26
|
+
|
|
27
|
+
def __init__(self, context, source: Any, mode: str, **kwargs) -> None:
|
|
28
|
+
# assert False, (context, source, mode, kwargs)
|
|
29
|
+
super().__init__(context, **kwargs)
|
|
30
|
+
self.mapper = DateMapper.from_mode(mode, source, kwargs)
|
|
31
|
+
self.source = source
|
|
32
|
+
|
|
33
|
+
def execute(self, group_of_dates):
|
|
34
|
+
source = self.context.create_source(self.source, "data_sources", str(id(self)))
|
|
35
|
+
|
|
36
|
+
result = []
|
|
37
|
+
for one_date_group, many_dates_group in self.mapper.transform(group_of_dates):
|
|
38
|
+
print(f"one_date_group: {one_date_group}, many_dates_group: {many_dates_group}")
|
|
39
|
+
source_results = source(self.context, one_date_group)
|
|
40
|
+
for field in source_results:
|
|
41
|
+
for date in many_dates_group:
|
|
42
|
+
result.append(new_field_with_valid_datetime(field, date))
|
|
43
|
+
|
|
44
|
+
return new_fieldlist_from_list(result)
|
|
@@ -9,64 +9,39 @@
|
|
|
9
9
|
|
|
10
10
|
from datetime import datetime
|
|
11
11
|
from typing import Any
|
|
12
|
-
from typing import Dict
|
|
13
|
-
from typing import List
|
|
14
|
-
from typing import Optional
|
|
15
12
|
|
|
16
13
|
from earthkit.data import from_source
|
|
17
14
|
|
|
18
|
-
from anemoi.datasets.create.
|
|
15
|
+
from anemoi.datasets.create.sources import source_registry
|
|
19
16
|
|
|
20
|
-
from .legacy import
|
|
17
|
+
from .legacy import LegacySource
|
|
21
18
|
|
|
22
19
|
|
|
23
|
-
@
|
|
24
|
-
|
|
25
|
-
"""Generates a source based on the provided context, dates, and additional keyword arguments.
|
|
20
|
+
@source_registry.register("source")
|
|
21
|
+
class GenericSource(LegacySource):
|
|
26
22
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
The context in which the source is generated.
|
|
31
|
-
dates : List[datetime]
|
|
32
|
-
A list of datetime objects representing the dates.
|
|
33
|
-
**kwargs : Any
|
|
34
|
-
Additional keyword arguments for the source generation.
|
|
23
|
+
@staticmethod
|
|
24
|
+
def _execute(context: Any | None, dates: list[datetime], **kwargs: Any) -> Any:
|
|
25
|
+
"""Generates a source based on the provided context, dates, and additional keyword arguments.
|
|
35
26
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
kwargs["date"] = list({d.strftime("%Y%m%d") for d in dates})
|
|
45
|
-
if kwargs["time"] == "$from_dates":
|
|
46
|
-
kwargs["time"] = list({d.strftime("%H%M") for d in dates})
|
|
47
|
-
return from_source(name, **kwargs)
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
context : Optional[Any]
|
|
30
|
+
The context in which the source is generated.
|
|
31
|
+
dates : List[datetime]
|
|
32
|
+
A list of datetime objects representing the dates.
|
|
33
|
+
**kwargs : Any
|
|
34
|
+
Additional keyword arguments for the source generation.
|
|
48
35
|
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
import yaml
|
|
54
|
-
|
|
55
|
-
config: Dict[str, Any] = yaml.safe_load(
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
Any
|
|
39
|
+
The generated source.
|
|
56
40
|
"""
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
date: $from_dates
|
|
65
|
-
time: $from_dates
|
|
66
|
-
"""
|
|
67
|
-
)
|
|
68
|
-
dates: List[str] = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]")
|
|
69
|
-
dates = to_datetime_list(dates)
|
|
70
|
-
|
|
71
|
-
for f in source(None, dates, **config):
|
|
72
|
-
print(f, f.to_numpy().mean())
|
|
41
|
+
name = kwargs.pop("name")
|
|
42
|
+
context.trace("✅", f"from_source({name}, {dates}, {kwargs}")
|
|
43
|
+
if kwargs["date"] == "$from_dates":
|
|
44
|
+
kwargs["date"] = list({d.strftime("%Y%m%d") for d in dates})
|
|
45
|
+
if kwargs["time"] == "$from_dates":
|
|
46
|
+
kwargs["time"] = list({d.strftime("%H%M") for d in dates})
|
|
47
|
+
return from_source(name, **kwargs)
|
|
@@ -10,16 +10,13 @@
|
|
|
10
10
|
import datetime
|
|
11
11
|
from collections import defaultdict
|
|
12
12
|
from typing import Any
|
|
13
|
-
from typing import Dict
|
|
14
|
-
from typing import List
|
|
15
|
-
from typing import Tuple
|
|
16
13
|
|
|
17
14
|
from earthkit.data.core.temporary import temp_file
|
|
18
15
|
from earthkit.data.readers.grib.output import new_grib_output
|
|
19
16
|
|
|
20
|
-
from anemoi.datasets.create.
|
|
17
|
+
from anemoi.datasets.create.sources import source_registry
|
|
21
18
|
|
|
22
|
-
from .legacy import
|
|
19
|
+
from .legacy import LegacySource
|
|
23
20
|
|
|
24
21
|
|
|
25
22
|
def _date_to_datetime(d: Any) -> Any:
|
|
@@ -63,7 +60,7 @@ def normalise_time_delta(t: Any) -> datetime.timedelta:
|
|
|
63
60
|
return t
|
|
64
61
|
|
|
65
62
|
|
|
66
|
-
def group_by_field(ds: Any) ->
|
|
63
|
+
def group_by_field(ds: Any) -> dict[tuple, list[Any]]:
|
|
67
64
|
"""Groups fields by their metadata excluding 'date', 'time', and 'step'.
|
|
68
65
|
|
|
69
66
|
Parameters
|
|
@@ -86,116 +83,89 @@ def group_by_field(ds: Any) -> Dict[Tuple, List[Any]]:
|
|
|
86
83
|
return d
|
|
87
84
|
|
|
88
85
|
|
|
89
|
-
@
|
|
90
|
-
|
|
91
|
-
"""Computes tendencies for the given dates and time increment.
|
|
86
|
+
@source_registry.register("tendencies")
|
|
87
|
+
class TendenciesSource(LegacySource):
|
|
92
88
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
A list of datetime objects.
|
|
97
|
-
time_increment : Any
|
|
98
|
-
A time increment string ending with 'h' or a datetime.timedelta object.
|
|
99
|
-
**kwargs : Any
|
|
100
|
-
Additional keyword arguments.
|
|
101
|
-
|
|
102
|
-
Returns
|
|
103
|
-
-------
|
|
104
|
-
Any
|
|
105
|
-
A dataset object with computed tendencies.
|
|
106
|
-
"""
|
|
107
|
-
print("✅", kwargs)
|
|
108
|
-
time_increment = normalise_time_delta(time_increment)
|
|
109
|
-
|
|
110
|
-
shifted_dates = [d - time_increment for d in dates]
|
|
111
|
-
all_dates = sorted(list(set(dates + shifted_dates)))
|
|
89
|
+
@staticmethod
|
|
90
|
+
def _execute(dates: list[datetime.datetime], time_increment: Any, **kwargs: Any) -> Any:
|
|
91
|
+
"""Computes tendencies for the given dates and time increment.
|
|
112
92
|
|
|
113
|
-
|
|
114
|
-
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
dates : List[datetime.datetime]
|
|
96
|
+
A list of datetime objects.
|
|
97
|
+
time_increment : Any
|
|
98
|
+
A time increment string ending with 'h' or a datetime.timedelta object.
|
|
99
|
+
**kwargs : Any
|
|
100
|
+
Additional keyword arguments.
|
|
115
101
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
ds2 = ds.sel(valid_datetime=[d.isoformat() for d in shifted_dates])
|
|
124
|
-
|
|
125
|
-
assert len(ds1) == len(ds2), (len(ds1), len(ds2))
|
|
126
|
-
|
|
127
|
-
group1 = group_by_field(ds1)
|
|
128
|
-
group2 = group_by_field(ds2)
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
Any
|
|
105
|
+
A dataset object with computed tendencies.
|
|
106
|
+
"""
|
|
107
|
+
print("✅", kwargs)
|
|
108
|
+
time_increment = normalise_time_delta(time_increment)
|
|
129
109
|
|
|
130
|
-
|
|
110
|
+
shifted_dates = [d - time_increment for d in dates]
|
|
111
|
+
all_dates = sorted(list(set(dates + shifted_dates)))
|
|
131
112
|
|
|
132
|
-
|
|
133
|
-
tmp = temp_file()
|
|
134
|
-
path = tmp.path
|
|
135
|
-
out = new_grib_output(path)
|
|
113
|
+
from .mars import mars
|
|
136
114
|
|
|
137
|
-
|
|
138
|
-
assert len(group1[k]) == len(group2[k]), k
|
|
139
|
-
print()
|
|
140
|
-
print("❌", k)
|
|
115
|
+
ds = mars(dates=all_dates, **kwargs)
|
|
141
116
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
k,
|
|
146
|
-
field.metadata(k),
|
|
147
|
-
b_field.metadata(k),
|
|
148
|
-
)
|
|
117
|
+
dates_in_data = ds.unique_values("valid_datetime", progress_bar=False)["valid_datetime"]
|
|
118
|
+
for d in all_dates:
|
|
119
|
+
assert d.isoformat() in dates_in_data, d
|
|
149
120
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
assert c.shape == b.shape, (c.shape, b.shape)
|
|
121
|
+
ds1 = ds.sel(valid_datetime=[d.isoformat() for d in dates])
|
|
122
|
+
ds2 = ds.sel(valid_datetime=[d.isoformat() for d in shifted_dates])
|
|
153
123
|
|
|
154
|
-
|
|
155
|
-
# Actual computation happens here
|
|
156
|
-
x = c - b
|
|
157
|
-
################
|
|
124
|
+
assert len(ds1) == len(ds2), (len(ds1), len(ds2))
|
|
158
125
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
out.write(x, template=field)
|
|
126
|
+
group1 = group_by_field(ds1)
|
|
127
|
+
group2 = group_by_field(ds2)
|
|
162
128
|
|
|
163
|
-
|
|
129
|
+
assert group1.keys() == group2.keys(), (group1.keys(), group2.keys())
|
|
164
130
|
|
|
165
|
-
|
|
131
|
+
# prepare output tmp file so we can read it back
|
|
132
|
+
tmp = temp_file()
|
|
133
|
+
path = tmp.path
|
|
134
|
+
out = new_grib_output(path)
|
|
166
135
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
136
|
+
for k in group1:
|
|
137
|
+
assert len(group1[k]) == len(group2[k]), k
|
|
138
|
+
print()
|
|
139
|
+
print("❌", k)
|
|
171
140
|
|
|
172
|
-
|
|
141
|
+
for field, b_field in zip(group1[k], group2[k]):
|
|
142
|
+
for k in ["param", "level", "number", "grid", "shape"]:
|
|
143
|
+
assert field.metadata(k) == b_field.metadata(k), (
|
|
144
|
+
k,
|
|
145
|
+
field.metadata(k),
|
|
146
|
+
b_field.metadata(k),
|
|
147
|
+
)
|
|
173
148
|
|
|
149
|
+
c = field.to_numpy()
|
|
150
|
+
b = b_field.to_numpy()
|
|
151
|
+
assert c.shape == b.shape, (c.shape, b.shape)
|
|
174
152
|
|
|
175
|
-
|
|
153
|
+
################
|
|
154
|
+
# Actual computation happens here
|
|
155
|
+
x = c - b
|
|
156
|
+
################
|
|
176
157
|
|
|
177
|
-
|
|
178
|
-
|
|
158
|
+
assert x.shape == c.shape, c.shape
|
|
159
|
+
print(f"Computing data for {field.metadata('valid_datetime')}={field}-{b_field}")
|
|
160
|
+
out.write(x, template=field)
|
|
179
161
|
|
|
180
|
-
|
|
181
|
-
"""
|
|
162
|
+
out.close()
|
|
182
163
|
|
|
183
|
-
|
|
184
|
-
time_increment: 12h
|
|
185
|
-
database: marser
|
|
186
|
-
class: ea
|
|
187
|
-
# date: computed automatically
|
|
188
|
-
# time: computed automatically
|
|
189
|
-
expver: "0001"
|
|
190
|
-
grid: 20.0/20.0
|
|
191
|
-
levtype: sfc
|
|
192
|
-
param: [2t]
|
|
193
|
-
"""
|
|
194
|
-
)["config"]
|
|
164
|
+
from earthkit.data import from_source
|
|
195
165
|
|
|
196
|
-
|
|
197
|
-
|
|
166
|
+
ds = from_source("file", path)
|
|
167
|
+
# save a reference to the tmp file so it is deleted
|
|
168
|
+
# only when the dataset is not used anymore
|
|
169
|
+
ds._tmp = tmp
|
|
198
170
|
|
|
199
|
-
|
|
200
|
-
for f in tendencies(dates, **config):
|
|
201
|
-
print(f, f.to_numpy().mean())
|
|
171
|
+
return ds
|
|
@@ -8,8 +8,6 @@
|
|
|
8
8
|
# nor does it submit to any jurisdiction.
|
|
9
9
|
|
|
10
10
|
from typing import Any
|
|
11
|
-
from typing import Dict
|
|
12
|
-
from typing import Optional
|
|
13
11
|
|
|
14
12
|
import earthkit.data as ekd
|
|
15
13
|
|
|
@@ -28,11 +26,11 @@ class XarraySourceBase(Source):
|
|
|
28
26
|
|
|
29
27
|
emoji = "✖️" # For tracing
|
|
30
28
|
|
|
31
|
-
options:
|
|
32
|
-
flavour:
|
|
33
|
-
patch:
|
|
29
|
+
options: dict[str, Any] | None = None
|
|
30
|
+
flavour: dict[str, Any] | None = None
|
|
31
|
+
patch: dict[str, Any] | None = None
|
|
34
32
|
|
|
35
|
-
path_or_url:
|
|
33
|
+
path_or_url: str | None = None
|
|
36
34
|
|
|
37
35
|
def __init__(self, context: Any, path: str = None, url: str = None, *args: Any, **kwargs: Any):
|
|
38
36
|
"""Initialise the source.
|
|
@@ -10,10 +10,6 @@
|
|
|
10
10
|
import datetime
|
|
11
11
|
import logging
|
|
12
12
|
from typing import Any
|
|
13
|
-
from typing import Dict
|
|
14
|
-
from typing import List
|
|
15
|
-
from typing import Optional
|
|
16
|
-
from typing import Union
|
|
17
13
|
|
|
18
14
|
import earthkit.data as ekd
|
|
19
15
|
import xarray as xr
|
|
@@ -21,13 +17,14 @@ from earthkit.data.core.fieldlist import MultiFieldList
|
|
|
21
17
|
|
|
22
18
|
from anemoi.datasets.create.sources.patterns import iterate_patterns
|
|
23
19
|
|
|
24
|
-
from ..
|
|
20
|
+
from .. import source_registry
|
|
21
|
+
from ..legacy import LegacySource
|
|
25
22
|
from .fieldlist import XarrayFieldList
|
|
26
23
|
|
|
27
24
|
LOG = logging.getLogger(__name__)
|
|
28
25
|
|
|
29
26
|
|
|
30
|
-
def check(what: str, ds: xr.Dataset, paths:
|
|
27
|
+
def check(what: str, ds: xr.Dataset, paths: list[str], **kwargs: Any) -> None:
|
|
31
28
|
"""Checks if the dataset has the expected number of fields.
|
|
32
29
|
|
|
33
30
|
Parameters
|
|
@@ -53,12 +50,12 @@ def check(what: str, ds: xr.Dataset, paths: List[str], **kwargs: Any) -> None:
|
|
|
53
50
|
def load_one(
|
|
54
51
|
emoji: str,
|
|
55
52
|
context: Any,
|
|
56
|
-
dates:
|
|
57
|
-
dataset:
|
|
53
|
+
dates: list[str],
|
|
54
|
+
dataset: str | xr.Dataset,
|
|
58
55
|
*,
|
|
59
|
-
options:
|
|
60
|
-
flavour:
|
|
61
|
-
patch:
|
|
56
|
+
options: dict[str, Any] | None = None,
|
|
57
|
+
flavour: str | None = None,
|
|
58
|
+
patch: Any | None = None,
|
|
62
59
|
**kwargs: Any,
|
|
63
60
|
) -> ekd.FieldList:
|
|
64
61
|
"""Loads a single dataset.
|
|
@@ -97,7 +94,10 @@ def load_one(
|
|
|
97
94
|
# If the dataset is a zarr store, we need to use the zarr engine
|
|
98
95
|
options["engine"] = "zarr"
|
|
99
96
|
|
|
100
|
-
|
|
97
|
+
if isinstance(dataset, xr.Dataset):
|
|
98
|
+
data = dataset
|
|
99
|
+
else:
|
|
100
|
+
data = xr.open_dataset(dataset, **options)
|
|
101
101
|
|
|
102
102
|
fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
|
|
103
103
|
|
|
@@ -124,7 +124,7 @@ def load_one(
|
|
|
124
124
|
return result
|
|
125
125
|
|
|
126
126
|
|
|
127
|
-
def load_many(emoji: str, context: Any, dates:
|
|
127
|
+
def load_many(emoji: str, context: Any, dates: list[datetime.datetime], pattern: str, **kwargs: Any) -> ekd.FieldList:
|
|
128
128
|
"""Loads multiple datasets.
|
|
129
129
|
|
|
130
130
|
Parameters
|
|
@@ -153,26 +153,30 @@ def load_many(emoji: str, context: Any, dates: List[datetime.datetime], pattern:
|
|
|
153
153
|
return MultiFieldList(result)
|
|
154
154
|
|
|
155
155
|
|
|
156
|
-
@
|
|
157
|
-
|
|
158
|
-
""
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
156
|
+
@source_registry.register("xarray")
|
|
157
|
+
class LegacyXarraySource(LegacySource):
|
|
158
|
+
name = "xarray"
|
|
159
|
+
|
|
160
|
+
@staticmethod
|
|
161
|
+
def _execute(context: Any, dates: list[str], url: str, *args: Any, **kwargs: Any) -> ekd.FieldList:
|
|
162
|
+
"""Executes the loading of datasets.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
context : Any
|
|
167
|
+
Context object.
|
|
168
|
+
dates : List[str]
|
|
169
|
+
List of dates.
|
|
170
|
+
url : str
|
|
171
|
+
URL pattern for loading datasets.
|
|
172
|
+
*args : Any
|
|
173
|
+
Additional arguments.
|
|
174
|
+
**kwargs : Any
|
|
175
|
+
Additional keyword arguments.
|
|
176
|
+
|
|
177
|
+
Returns
|
|
178
|
+
-------
|
|
179
|
+
ekd.FieldList
|
|
180
|
+
The loaded datasets.
|
|
181
|
+
"""
|
|
182
|
+
return load_many("🌐", context, dates, url, *args, **kwargs)
|
|
@@ -13,10 +13,6 @@ from __future__ import annotations
|
|
|
13
13
|
import datetime
|
|
14
14
|
import logging
|
|
15
15
|
from typing import Any
|
|
16
|
-
from typing import Dict
|
|
17
|
-
from typing import Optional
|
|
18
|
-
from typing import Tuple
|
|
19
|
-
from typing import Union
|
|
20
16
|
|
|
21
17
|
import numpy as np
|
|
22
18
|
import xarray as xr
|
|
@@ -107,7 +103,7 @@ class Coordinate:
|
|
|
107
103
|
"""
|
|
108
104
|
self.variable = variable
|
|
109
105
|
self.scalar = is_scalar(variable)
|
|
110
|
-
self.kwargs:
|
|
106
|
+
self.kwargs: dict[str, Any] = {} # Used when creating a new coordinate (reduced method)
|
|
111
107
|
|
|
112
108
|
def __len__(self) -> int:
|
|
113
109
|
"""Get the length of the coordinate.
|
|
@@ -127,7 +123,7 @@ class Coordinate:
|
|
|
127
123
|
str
|
|
128
124
|
The string representation of the coordinate.
|
|
129
125
|
"""
|
|
130
|
-
return "
|
|
126
|
+
return "{}[name={},values={},shape={}]".format(
|
|
131
127
|
self.__class__.__name__,
|
|
132
128
|
self.variable.name,
|
|
133
129
|
self.variable.values if self.scalar else len(self),
|
|
@@ -152,7 +148,7 @@ class Coordinate:
|
|
|
152
148
|
**self.kwargs,
|
|
153
149
|
)
|
|
154
150
|
|
|
155
|
-
def index(self, value:
|
|
151
|
+
def index(self, value: Any | list | tuple) -> int | list | None:
|
|
156
152
|
"""Return the index of the value in the coordinate.
|
|
157
153
|
|
|
158
154
|
Parameters
|
|
@@ -172,7 +168,7 @@ class Coordinate:
|
|
|
172
168
|
return self._index_multiple(value)
|
|
173
169
|
return self._index_single(value)
|
|
174
170
|
|
|
175
|
-
def _index_single(self, value: Any) ->
|
|
171
|
+
def _index_single(self, value: Any) -> int | None:
|
|
176
172
|
"""Return the index of a single value in the coordinate.
|
|
177
173
|
|
|
178
174
|
Parameters
|
|
@@ -205,7 +201,7 @@ class Coordinate:
|
|
|
205
201
|
|
|
206
202
|
return None
|
|
207
203
|
|
|
208
|
-
def _index_multiple(self, value: list) ->
|
|
204
|
+
def _index_multiple(self, value: list) -> list | None:
|
|
209
205
|
"""Return the indices of multiple values in the coordinate.
|
|
210
206
|
|
|
211
207
|
Parameters
|
|
@@ -275,7 +271,7 @@ class TimeCoordinate(Coordinate):
|
|
|
275
271
|
is_time = True
|
|
276
272
|
mars_names = ("valid_datetime",)
|
|
277
273
|
|
|
278
|
-
def index(self, time: datetime.datetime) ->
|
|
274
|
+
def index(self, time: datetime.datetime) -> int | None:
|
|
279
275
|
"""Return the index of the time in the coordinate.
|
|
280
276
|
|
|
281
277
|
Parameters
|
|
@@ -297,7 +293,7 @@ class DateCoordinate(Coordinate):
|
|
|
297
293
|
is_date = True
|
|
298
294
|
mars_names = ("date",)
|
|
299
295
|
|
|
300
|
-
def index(self, date: datetime.datetime) ->
|
|
296
|
+
def index(self, date: datetime.datetime) -> int | None:
|
|
301
297
|
"""Return the index of the date in the coordinate.
|
|
302
298
|
|
|
303
299
|
Parameters
|
|
@@ -436,7 +432,7 @@ class ScalarCoordinate(Coordinate):
|
|
|
436
432
|
is_grid = False
|
|
437
433
|
|
|
438
434
|
@property
|
|
439
|
-
def mars_names(self) ->
|
|
435
|
+
def mars_names(self) -> tuple[str, ...]:
|
|
440
436
|
"""Get the MARS names for the coordinate."""
|
|
441
437
|
return (self.variable.name,)
|
|
442
438
|
|
|
@@ -12,9 +12,6 @@ import datetime
|
|
|
12
12
|
import logging
|
|
13
13
|
from functools import cached_property
|
|
14
14
|
from typing import Any
|
|
15
|
-
from typing import Dict
|
|
16
|
-
from typing import Optional
|
|
17
|
-
from typing import Tuple
|
|
18
15
|
|
|
19
16
|
from earthkit.data import Field
|
|
20
17
|
from earthkit.data.core.fieldlist import math
|
|
@@ -80,12 +77,21 @@ class XArrayField(Field):
|
|
|
80
77
|
# Copy the metadata from the owner
|
|
81
78
|
self._md = owner._metadata.copy()
|
|
82
79
|
|
|
80
|
+
aliases = {}
|
|
83
81
|
for coord_name, coord_value in self.selection.coords.items():
|
|
84
82
|
if is_scalar(coord_value):
|
|
85
83
|
# Extract the single value from the scalar dimension
|
|
86
84
|
# and store it in the metadata
|
|
87
85
|
coordinate = owner.by_name[coord_name]
|
|
88
|
-
|
|
86
|
+
normalised = coordinate.normalise(extract_single_value(coord_value))
|
|
87
|
+
self._md[coord_name] = normalised
|
|
88
|
+
for alias in coordinate.mars_names:
|
|
89
|
+
aliases[alias] = normalised
|
|
90
|
+
|
|
91
|
+
# Add metadata aliases (e.g. levelist == level) only if they are not already present
|
|
92
|
+
for alias, value in aliases.items():
|
|
93
|
+
if alias not in self._md:
|
|
94
|
+
self._md[alias] = value
|
|
89
95
|
|
|
90
96
|
# By now, the only dimensions should be latitude and longitude
|
|
91
97
|
self._shape = tuple(list(self.selection.shape)[-2:])
|
|
@@ -93,13 +99,11 @@ class XArrayField(Field):
|
|
|
93
99
|
raise ValueError(f"Invalid shape for selection {self._shape=}, {self.selection.shape=} {self.selection=}")
|
|
94
100
|
|
|
95
101
|
@property
|
|
96
|
-
def shape(self) ->
|
|
102
|
+
def shape(self) -> tuple[int, int]:
|
|
97
103
|
"""Return the shape of the field."""
|
|
98
104
|
return self._shape
|
|
99
105
|
|
|
100
|
-
def to_numpy(
|
|
101
|
-
self, flatten: bool = False, dtype: Optional[type] = None, index: Optional[int] = None
|
|
102
|
-
) -> NDArray[Any]:
|
|
106
|
+
def to_numpy(self, flatten: bool = False, dtype: type | None = None, index: int | None = None) -> NDArray[Any]:
|
|
103
107
|
"""Convert the selection to a numpy array.
|
|
104
108
|
|
|
105
109
|
Returns
|
|
@@ -117,16 +121,16 @@ class XArrayField(Field):
|
|
|
117
121
|
Index to select a specific element, by default None.
|
|
118
122
|
"""
|
|
119
123
|
if index is not None:
|
|
120
|
-
values = self.selection[index]
|
|
124
|
+
values = self.selection[index].values
|
|
121
125
|
else:
|
|
122
|
-
values = self.selection
|
|
126
|
+
values = self.selection.values
|
|
123
127
|
|
|
124
128
|
assert dtype is None
|
|
125
129
|
|
|
126
130
|
if flatten:
|
|
127
|
-
return values.
|
|
131
|
+
return values.flatten()
|
|
128
132
|
|
|
129
|
-
return values
|
|
133
|
+
return values
|
|
130
134
|
|
|
131
135
|
@cached_property
|
|
132
136
|
def _metadata(self) -> XArrayMetadata:
|
|
@@ -137,7 +141,7 @@ class XArrayField(Field):
|
|
|
137
141
|
"""Return the grid points of the field."""
|
|
138
142
|
return self.owner.grid_points()
|
|
139
143
|
|
|
140
|
-
def to_latlon(self, flatten: bool = True) ->
|
|
144
|
+
def to_latlon(self, flatten: bool = True) -> dict[str, Any]:
|
|
141
145
|
"""Convert the selection to latitude and longitude coordinates.
|
|
142
146
|
|
|
143
147
|
Returns
|
|
@@ -154,7 +158,7 @@ class XArrayField(Field):
|
|
|
154
158
|
return dict(lat=self.latitudes, lon=self.longitudes)
|
|
155
159
|
|
|
156
160
|
@property
|
|
157
|
-
def resolution(self) ->
|
|
161
|
+
def resolution(self) -> Any | None:
|
|
158
162
|
"""Return the resolution of the field."""
|
|
159
163
|
return None
|
|
160
164
|
|
|
@@ -185,9 +189,9 @@ class XArrayField(Field):
|
|
|
185
189
|
|
|
186
190
|
def __repr__(self) -> str:
|
|
187
191
|
"""Return a string representation of the field."""
|
|
188
|
-
return
|
|
192
|
+
return f"XArrayField({self._metadata})"
|
|
189
193
|
|
|
190
|
-
def _values(self, dtype:
|
|
194
|
+
def _values(self, dtype: type | None = None) -> Any:
|
|
191
195
|
"""Return the values of the selection.
|
|
192
196
|
|
|
193
197
|
Returns
|