anemoi-datasets 0.5.12__py3-none-any.whl → 0.5.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/create.py +4 -2
- anemoi/datasets/create/__init__.py +22 -6
- anemoi/datasets/create/check.py +1 -1
- anemoi/datasets/create/functions/__init__.py +15 -1
- anemoi/datasets/create/functions/filters/orog_to_z.py +58 -0
- anemoi/datasets/create/functions/filters/sum.py +71 -0
- anemoi/datasets/create/functions/filters/wz_to_w.py +79 -0
- anemoi/datasets/create/functions/sources/accumulations.py +7 -2
- anemoi/datasets/create/functions/sources/eccc_fstd.py +16 -0
- anemoi/datasets/create/functions/sources/mars.py +5 -1
- anemoi/datasets/create/functions/sources/xarray/__init__.py +3 -3
- anemoi/datasets/create/functions/sources/xarray/field.py +5 -1
- anemoi/datasets/create/functions/sources/xarray/fieldlist.py +10 -1
- anemoi/datasets/create/functions/sources/xarray/metadata.py +5 -11
- anemoi/datasets/create/functions/sources/xarray/patch.py +44 -0
- anemoi/datasets/create/functions/sources/xarray/time.py +15 -0
- anemoi/datasets/create/functions/sources/xarray/variable.py +18 -2
- anemoi/datasets/create/input/repeated_dates.py +18 -0
- anemoi/datasets/create/input/result.py +1 -1
- anemoi/datasets/create/statistics/__init__.py +7 -4
- anemoi/datasets/create/utils.py +4 -0
- anemoi/datasets/data/complement.py +164 -0
- anemoi/datasets/data/dataset.py +68 -5
- anemoi/datasets/data/ensemble.py +55 -0
- anemoi/datasets/data/join.py +1 -2
- anemoi/datasets/data/merge.py +3 -0
- anemoi/datasets/data/misc.py +34 -1
- anemoi/datasets/grids.py +29 -10
- {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/METADATA +2 -2
- {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/RECORD +35 -29
- {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/WHEEL +1 -1
- {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/LICENSE +0 -0
- {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.12.dist-info → anemoi_datasets-0.5.14.dist-info}/top_level.txt +0 -0
anemoi/datasets/_version.py
CHANGED
|
@@ -83,11 +83,12 @@ class Create(Command):
|
|
|
83
83
|
task("load", options)
|
|
84
84
|
task("finalise", options)
|
|
85
85
|
|
|
86
|
-
task("patch", options)
|
|
87
|
-
|
|
88
86
|
task("init_additions", options)
|
|
89
87
|
task("run_additions", options)
|
|
90
88
|
task("finalise_additions", options)
|
|
89
|
+
|
|
90
|
+
task("patch", options)
|
|
91
|
+
|
|
91
92
|
task("cleanup", options)
|
|
92
93
|
task("verify", options)
|
|
93
94
|
|
|
@@ -153,6 +154,7 @@ class Create(Command):
|
|
|
153
154
|
|
|
154
155
|
with ExecutorClass(max_workers=1) as executor:
|
|
155
156
|
executor.submit(task, "finalise-additions", options).result()
|
|
157
|
+
executor.submit(task, "patch", options).result()
|
|
156
158
|
executor.submit(task, "cleanup", options).result()
|
|
157
159
|
executor.submit(task, "verify", options).result()
|
|
158
160
|
|
|
@@ -79,7 +79,10 @@ def json_tidy(o):
|
|
|
79
79
|
)
|
|
80
80
|
return o.isoformat()
|
|
81
81
|
|
|
82
|
-
|
|
82
|
+
if isinstance(o, (np.float32, np.float64)):
|
|
83
|
+
return float(o)
|
|
84
|
+
|
|
85
|
+
raise TypeError(f"{repr(o)} is not JSON serializable {type(o)}")
|
|
83
86
|
|
|
84
87
|
|
|
85
88
|
def build_statistics_dates(dates, start, end):
|
|
@@ -596,6 +599,8 @@ class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
|
|
|
596
599
|
# There is one cube to load for each result.
|
|
597
600
|
dates = list(result.group_of_dates)
|
|
598
601
|
|
|
602
|
+
LOG.debug(f"Loading cube for {len(dates)} dates")
|
|
603
|
+
|
|
599
604
|
cube = result.get_cube()
|
|
600
605
|
shape = cube.extended_user_shape
|
|
601
606
|
dates_in_data = cube.user_coords["valid_datetime"]
|
|
@@ -622,10 +627,14 @@ class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
|
|
|
622
627
|
|
|
623
628
|
check_shape(cube, dates, dates_in_data)
|
|
624
629
|
|
|
625
|
-
def check_dates_in_data(
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
assert
|
|
630
|
+
def check_dates_in_data(dates_in_data, requested_dates):
|
|
631
|
+
requested_dates = [np.datetime64(_) for _ in requested_dates]
|
|
632
|
+
dates_in_data = [np.datetime64(_) for _ in dates_in_data]
|
|
633
|
+
assert dates_in_data == requested_dates, (
|
|
634
|
+
"Dates in data are not the requested ones:",
|
|
635
|
+
dates_in_data,
|
|
636
|
+
requested_dates,
|
|
637
|
+
)
|
|
629
638
|
|
|
630
639
|
check_dates_in_data(dates_in_data, dates)
|
|
631
640
|
|
|
@@ -638,12 +647,14 @@ class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
|
|
|
638
647
|
indexes = dates_to_indexes(self.dates, dates_in_data)
|
|
639
648
|
|
|
640
649
|
array = ViewCacheArray(self.data_array, shape=shape, indexes=indexes)
|
|
650
|
+
LOG.info(f"Loading array shape={shape}, indexes={len(indexes)}")
|
|
641
651
|
self.load_cube(cube, array)
|
|
642
652
|
|
|
643
653
|
stats = compute_statistics(array.cache, self.variables_names, allow_nans=self._get_allow_nans())
|
|
644
654
|
self.tmp_statistics.write(indexes, stats, dates=dates_in_data)
|
|
645
|
-
|
|
655
|
+
LOG.info("Flush data array")
|
|
646
656
|
array.flush()
|
|
657
|
+
LOG.info("Flushed data array")
|
|
647
658
|
|
|
648
659
|
def _get_allow_nans(self):
|
|
649
660
|
config = self.main_config
|
|
@@ -736,6 +747,11 @@ class AdditionsMixin:
|
|
|
736
747
|
if not self.delta.total_seconds() % frequency.total_seconds() == 0:
|
|
737
748
|
LOG.debug(f"Delta {self.delta} is not a multiple of frequency {frequency}. Skipping.")
|
|
738
749
|
return True
|
|
750
|
+
|
|
751
|
+
if self.dataset.zarr_metadata.get("build", {}).get("additions", None) is False:
|
|
752
|
+
LOG.warning(f"Additions are disabled for {self.path} in the recipe.")
|
|
753
|
+
return True
|
|
754
|
+
|
|
739
755
|
return False
|
|
740
756
|
|
|
741
757
|
@cached_property
|
anemoi/datasets/create/check.py
CHANGED
|
@@ -58,7 +58,7 @@ class DatasetName:
|
|
|
58
58
|
raise ValueError(self.error_message)
|
|
59
59
|
|
|
60
60
|
def _parse(self, name):
|
|
61
|
-
pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h)-v(\d+)-?([a-zA-Z0-9-]+)?$"
|
|
61
|
+
pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h|\d+m)-v(\d+)-?([a-zA-Z0-9-]+)?$"
|
|
62
62
|
match = re.match(pattern, name)
|
|
63
63
|
|
|
64
64
|
if not match:
|
|
@@ -22,6 +22,7 @@ def assert_is_fieldlist(obj):
|
|
|
22
22
|
def import_function(name, kind):
|
|
23
23
|
|
|
24
24
|
from anemoi.transform.filters import filter_registry
|
|
25
|
+
from anemoi.transform.sources import source_registry
|
|
25
26
|
|
|
26
27
|
name = name.replace("-", "_")
|
|
27
28
|
|
|
@@ -45,7 +46,20 @@ def import_function(name, kind):
|
|
|
45
46
|
if filter_registry.lookup(name, return_none=True):
|
|
46
47
|
|
|
47
48
|
def proc(context, data, *args, **kwargs):
|
|
48
|
-
|
|
49
|
+
filter = filter_registry.create(name, *args, **kwargs)
|
|
50
|
+
filter.context = context
|
|
51
|
+
# filter = filter_registry.create(context, name, *args, **kwargs)
|
|
52
|
+
return filter.forward(data)
|
|
53
|
+
|
|
54
|
+
return proc
|
|
55
|
+
|
|
56
|
+
if kind == "sources":
|
|
57
|
+
if source_registry.lookup(name, return_none=True):
|
|
58
|
+
|
|
59
|
+
def proc(context, data, *args, **kwargs):
|
|
60
|
+
source = source_registry.create(name, *args, **kwargs)
|
|
61
|
+
# source = source_registry.create(context, name, *args, **kwargs)
|
|
62
|
+
return source.forward(data)
|
|
49
63
|
|
|
50
64
|
return proc
|
|
51
65
|
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
|
|
13
|
+
from earthkit.data.indexing.fieldlist import FieldArray
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class NewDataField:
|
|
17
|
+
def __init__(self, field, data, new_name):
|
|
18
|
+
self.field = field
|
|
19
|
+
self.data = data
|
|
20
|
+
self.new_name = new_name
|
|
21
|
+
|
|
22
|
+
def to_numpy(self, *args, **kwargs):
|
|
23
|
+
return self.data
|
|
24
|
+
|
|
25
|
+
def metadata(self, key=None, **kwargs):
|
|
26
|
+
if key is None:
|
|
27
|
+
return self.field.metadata(**kwargs)
|
|
28
|
+
|
|
29
|
+
value = self.field.metadata(key, **kwargs)
|
|
30
|
+
if key == "param":
|
|
31
|
+
return self.new_name
|
|
32
|
+
return value
|
|
33
|
+
|
|
34
|
+
def __getattr__(self, name):
|
|
35
|
+
return getattr(self.field, name)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def execute(context, input, orog, z="z"):
|
|
39
|
+
"""Convert orography [m] to z (geopotential height)"""
|
|
40
|
+
result = FieldArray()
|
|
41
|
+
|
|
42
|
+
processed_fields = defaultdict(dict)
|
|
43
|
+
|
|
44
|
+
for f in input:
|
|
45
|
+
key = f.metadata(namespace="mars")
|
|
46
|
+
param = key.pop("param")
|
|
47
|
+
if param == orog:
|
|
48
|
+
key = tuple(key.items())
|
|
49
|
+
|
|
50
|
+
if param in processed_fields[key]:
|
|
51
|
+
raise ValueError(f"Duplicate field {param} for {key}")
|
|
52
|
+
|
|
53
|
+
output = f.to_numpy(flatten=True) * 9.80665
|
|
54
|
+
result.append(NewDataField(f, output, z))
|
|
55
|
+
else:
|
|
56
|
+
result.append(f)
|
|
57
|
+
|
|
58
|
+
return result
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
|
|
13
|
+
from earthkit.data.indexing.fieldlist import FieldArray
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class NewDataField:
|
|
17
|
+
def __init__(self, field, data, new_name):
|
|
18
|
+
self.field = field
|
|
19
|
+
self.data = data
|
|
20
|
+
self.new_name = new_name
|
|
21
|
+
|
|
22
|
+
def to_numpy(self, *args, **kwargs):
|
|
23
|
+
return self.data
|
|
24
|
+
|
|
25
|
+
def metadata(self, key=None, **kwargs):
|
|
26
|
+
if key is None:
|
|
27
|
+
return self.field.metadata(**kwargs)
|
|
28
|
+
|
|
29
|
+
value = self.field.metadata(key, **kwargs)
|
|
30
|
+
if key == "param":
|
|
31
|
+
return self.new_name
|
|
32
|
+
return value
|
|
33
|
+
|
|
34
|
+
def __getattr__(self, name):
|
|
35
|
+
return getattr(self.field, name)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def execute(context, input, params, output):
|
|
39
|
+
"""Computes the sum over a set of variables"""
|
|
40
|
+
result = FieldArray()
|
|
41
|
+
|
|
42
|
+
needed_fields = defaultdict(dict)
|
|
43
|
+
|
|
44
|
+
for f in input:
|
|
45
|
+
key = f.metadata(namespace="mars")
|
|
46
|
+
param = key.pop("param")
|
|
47
|
+
if param in params:
|
|
48
|
+
key = tuple(key.items())
|
|
49
|
+
|
|
50
|
+
if param in needed_fields[key]:
|
|
51
|
+
raise ValueError(f"Duplicate field {param} for {key}")
|
|
52
|
+
|
|
53
|
+
needed_fields[key][param] = f
|
|
54
|
+
else:
|
|
55
|
+
result.append(f)
|
|
56
|
+
|
|
57
|
+
for keys, values in needed_fields.items():
|
|
58
|
+
|
|
59
|
+
if len(values) != len(params):
|
|
60
|
+
raise ValueError("Missing fields")
|
|
61
|
+
|
|
62
|
+
s = None
|
|
63
|
+
for k, v in values.items():
|
|
64
|
+
c = v.to_numpy(flatten=True)
|
|
65
|
+
if s is None:
|
|
66
|
+
s = c
|
|
67
|
+
else:
|
|
68
|
+
s += c
|
|
69
|
+
result.append(NewDataField(values[list(values.keys())[0]], s, output))
|
|
70
|
+
|
|
71
|
+
return result
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
|
|
13
|
+
from earthkit.data.indexing.fieldlist import FieldArray
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class NewDataField:
|
|
17
|
+
def __init__(self, field, data, new_name):
|
|
18
|
+
self.field = field
|
|
19
|
+
self.data = data
|
|
20
|
+
self.new_name = new_name
|
|
21
|
+
|
|
22
|
+
def to_numpy(self, *args, **kwargs):
|
|
23
|
+
return self.data
|
|
24
|
+
|
|
25
|
+
def metadata(self, key=None, **kwargs):
|
|
26
|
+
if key is None:
|
|
27
|
+
return self.field.metadata(**kwargs)
|
|
28
|
+
|
|
29
|
+
value = self.field.metadata(key, **kwargs)
|
|
30
|
+
if key == "param":
|
|
31
|
+
return self.new_name
|
|
32
|
+
return value
|
|
33
|
+
|
|
34
|
+
def __getattr__(self, name):
|
|
35
|
+
return getattr(self.field, name)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def execute(context, input, wz, t, w="w"):
|
|
39
|
+
"""Convert geometric vertical velocity (m/s) to vertical velocity (Pa / s)"""
|
|
40
|
+
result = FieldArray()
|
|
41
|
+
|
|
42
|
+
params = (wz, t)
|
|
43
|
+
pairs = defaultdict(dict)
|
|
44
|
+
|
|
45
|
+
for f in input:
|
|
46
|
+
key = f.metadata(namespace="mars")
|
|
47
|
+
param = key.pop("param")
|
|
48
|
+
if param in params:
|
|
49
|
+
key = tuple(key.items())
|
|
50
|
+
|
|
51
|
+
if param in pairs[key]:
|
|
52
|
+
raise ValueError(f"Duplicate field {param} for {key}")
|
|
53
|
+
|
|
54
|
+
pairs[key][param] = f
|
|
55
|
+
if param == t:
|
|
56
|
+
result.append(f)
|
|
57
|
+
else:
|
|
58
|
+
result.append(f)
|
|
59
|
+
|
|
60
|
+
for keys, values in pairs.items():
|
|
61
|
+
|
|
62
|
+
if len(values) != 2:
|
|
63
|
+
raise ValueError("Missing fields")
|
|
64
|
+
|
|
65
|
+
wz_pl = values[wz].to_numpy(flatten=True)
|
|
66
|
+
t_pl = values[t].to_numpy(flatten=True)
|
|
67
|
+
pressure = keys[4][1] * 100 # TODO: REMOVE HARDCODED INDICES
|
|
68
|
+
|
|
69
|
+
w_pl = wz_to_w(wz_pl, t_pl, pressure)
|
|
70
|
+
result.append(NewDataField(values[wz], w_pl, w))
|
|
71
|
+
|
|
72
|
+
return result
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def wz_to_w(wz, t, pressure):
|
|
76
|
+
g = 9.81
|
|
77
|
+
Rd = 287.058
|
|
78
|
+
|
|
79
|
+
return -wz * g * pressure / (t * Rd)
|
|
@@ -253,6 +253,7 @@ def _compute_accumulations(
|
|
|
253
253
|
data_accumulation_period=None,
|
|
254
254
|
patch=_identity,
|
|
255
255
|
base_times=None,
|
|
256
|
+
use_cdsapi_dataset=None,
|
|
256
257
|
):
|
|
257
258
|
adjust_step = isinstance(user_accumulation_period, int)
|
|
258
259
|
|
|
@@ -311,7 +312,9 @@ def _compute_accumulations(
|
|
|
311
312
|
|
|
312
313
|
requests.append(patch(r))
|
|
313
314
|
|
|
314
|
-
ds = mars(
|
|
315
|
+
ds = mars(
|
|
316
|
+
context, dates, *requests, request_already_using_valid_datetime=True, use_cdsapi_dataset=use_cdsapi_dataset
|
|
317
|
+
)
|
|
315
318
|
|
|
316
319
|
accumulations = {}
|
|
317
320
|
for a in [AccumulationClass(out, frequency=frequency, **r) for r in requests]:
|
|
@@ -366,7 +369,7 @@ def _scda(request):
|
|
|
366
369
|
return request
|
|
367
370
|
|
|
368
371
|
|
|
369
|
-
def accumulations(context, dates, **request):
|
|
372
|
+
def accumulations(context, dates, use_cdsapi_dataset=None, **request):
|
|
370
373
|
_to_list(request["param"])
|
|
371
374
|
class_ = request.get("class", "od")
|
|
372
375
|
stream = request.get("stream", "oper")
|
|
@@ -379,6 +382,7 @@ def accumulations(context, dates, **request):
|
|
|
379
382
|
KWARGS = {
|
|
380
383
|
("od", "oper"): dict(patch=_scda),
|
|
381
384
|
("od", "elda"): dict(base_times=(6, 18)),
|
|
385
|
+
("od", "enfo"): dict(base_times=(0, 6, 12, 18)),
|
|
382
386
|
("ea", "oper"): dict(data_accumulation_period=1, base_times=(6, 18)),
|
|
383
387
|
("ea", "enda"): dict(data_accumulation_period=3, base_times=(6, 18)),
|
|
384
388
|
("rr", "oper"): dict(base_times=(0, 3, 6, 9, 12, 15, 18, 21)),
|
|
@@ -394,6 +398,7 @@ def accumulations(context, dates, **request):
|
|
|
394
398
|
dates,
|
|
395
399
|
request,
|
|
396
400
|
user_accumulation_period=user_accumulation_period,
|
|
401
|
+
use_cdsapi_dataset=use_cdsapi_dataset,
|
|
397
402
|
**kwargs,
|
|
398
403
|
)
|
|
399
404
|
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# (C) Copyright 2025 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from .xarray import load_many
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def execute(context, dates, path, *args, **kwargs):
|
|
15
|
+
options = {"engine": "fstd"}
|
|
16
|
+
return load_many("🍁", context, dates, path, *args, options=options, **kwargs)
|
|
@@ -246,6 +246,7 @@ def mars(
|
|
|
246
246
|
*requests,
|
|
247
247
|
request_already_using_valid_datetime=False,
|
|
248
248
|
date_key="date",
|
|
249
|
+
use_cdsapi_dataset=None,
|
|
249
250
|
**kwargs,
|
|
250
251
|
):
|
|
251
252
|
|
|
@@ -305,7 +306,10 @@ def mars(
|
|
|
305
306
|
f"⚠️ Unknown key {k}={v} in MARS request. Did you mean '{did_you_mean(k, MARS_KEYS)}' ?"
|
|
306
307
|
)
|
|
307
308
|
try:
|
|
308
|
-
|
|
309
|
+
if use_cdsapi_dataset:
|
|
310
|
+
ds = ds + from_source("cds", use_cdsapi_dataset, r)
|
|
311
|
+
else:
|
|
312
|
+
ds = ds + from_source("mars", **r)
|
|
309
313
|
except Exception as e:
|
|
310
314
|
if "File is empty:" not in str(e):
|
|
311
315
|
raise
|
|
@@ -29,7 +29,7 @@ def check(what, ds, paths, **kwargs):
|
|
|
29
29
|
raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, {what}s={paths})")
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs):
|
|
32
|
+
def load_one(emoji, context, dates, dataset, *, options={}, flavour=None, patch=None, **kwargs):
|
|
33
33
|
import xarray as xr
|
|
34
34
|
|
|
35
35
|
"""
|
|
@@ -54,10 +54,10 @@ def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs)
|
|
|
54
54
|
else:
|
|
55
55
|
data = xr.open_dataset(dataset, **options)
|
|
56
56
|
|
|
57
|
-
fs = XarrayFieldList.from_xarray(data, flavour)
|
|
57
|
+
fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
|
|
58
58
|
|
|
59
59
|
if len(dates) == 0:
|
|
60
|
-
|
|
60
|
+
result = fs.sel(**kwargs)
|
|
61
61
|
else:
|
|
62
62
|
result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])
|
|
63
63
|
|
|
@@ -92,6 +92,10 @@ class XArrayField(Field):
|
|
|
92
92
|
def grid_points(self):
|
|
93
93
|
return self.owner.grid_points()
|
|
94
94
|
|
|
95
|
+
def to_latlon(self, flatten=True):
|
|
96
|
+
assert flatten
|
|
97
|
+
return dict(lat=self.latitudes, lon=self.longitudes)
|
|
98
|
+
|
|
95
99
|
@property
|
|
96
100
|
def resolution(self):
|
|
97
101
|
return None
|
|
@@ -120,6 +124,6 @@ class XArrayField(Field):
|
|
|
120
124
|
def __repr__(self):
|
|
121
125
|
return repr(self._metadata)
|
|
122
126
|
|
|
123
|
-
def _values(self):
|
|
127
|
+
def _values(self, dtype=None):
|
|
124
128
|
# we don't use .values as this will download the data
|
|
125
129
|
return self.selection
|
|
@@ -16,6 +16,7 @@ from earthkit.data.core.fieldlist import FieldList
|
|
|
16
16
|
|
|
17
17
|
from .field import EmptyFieldList
|
|
18
18
|
from .flavour import CoordinateGuesser
|
|
19
|
+
from .patch import patch_dataset
|
|
19
20
|
from .time import Time
|
|
20
21
|
from .variable import FilteredVariable
|
|
21
22
|
from .variable import Variable
|
|
@@ -49,7 +50,11 @@ class XarrayFieldList(FieldList):
|
|
|
49
50
|
raise IndexError(k)
|
|
50
51
|
|
|
51
52
|
@classmethod
|
|
52
|
-
def from_xarray(cls, ds, flavour=None):
|
|
53
|
+
def from_xarray(cls, ds, *, flavour=None, patch=None):
|
|
54
|
+
|
|
55
|
+
if patch is not None:
|
|
56
|
+
ds = patch_dataset(ds, patch)
|
|
57
|
+
|
|
53
58
|
variables = []
|
|
54
59
|
|
|
55
60
|
if isinstance(flavour, str):
|
|
@@ -83,6 +88,8 @@ class XarrayFieldList(FieldList):
|
|
|
83
88
|
_skip_attr(variable, "bounds")
|
|
84
89
|
_skip_attr(variable, "grid_mapping")
|
|
85
90
|
|
|
91
|
+
LOG.debug("Xarray data_vars: %s", ds.data_vars)
|
|
92
|
+
|
|
86
93
|
# Select only geographical variables
|
|
87
94
|
for name in ds.data_vars:
|
|
88
95
|
|
|
@@ -97,6 +104,7 @@ class XarrayFieldList(FieldList):
|
|
|
97
104
|
c = guess.guess(ds[coord], coord)
|
|
98
105
|
assert c, f"Could not guess coordinate for {coord}"
|
|
99
106
|
if coord not in variable.dims:
|
|
107
|
+
LOG.debug("%s: coord=%s (not a dimension): dims=%s", variable, coord, variable.dims)
|
|
100
108
|
c.is_dim = False
|
|
101
109
|
coordinates.append(c)
|
|
102
110
|
|
|
@@ -104,6 +112,7 @@ class XarrayFieldList(FieldList):
|
|
|
104
112
|
assert grid_coords <= 2
|
|
105
113
|
|
|
106
114
|
if grid_coords < 2:
|
|
115
|
+
LOG.debug("Skipping %s (not 2D): %s", variable, [(c, c.is_grid, c.is_dim) for c in coordinates])
|
|
107
116
|
continue
|
|
108
117
|
|
|
109
118
|
v = Variable(
|
|
@@ -24,6 +24,7 @@ class _MDMapping:
|
|
|
24
24
|
def __init__(self, variable):
|
|
25
25
|
self.variable = variable
|
|
26
26
|
self.time = variable.time
|
|
27
|
+
# Aliases
|
|
27
28
|
self.mapping = dict(param="variable")
|
|
28
29
|
for c in variable.coordinates:
|
|
29
30
|
for v in c.mars_names:
|
|
@@ -34,7 +35,6 @@ class _MDMapping:
|
|
|
34
35
|
return self.mapping.get(key, key)
|
|
35
36
|
|
|
36
37
|
def from_user(self, kwargs):
|
|
37
|
-
print("from_user", kwargs, self)
|
|
38
38
|
return {self._from_user(k): v for k, v in kwargs.items()}
|
|
39
39
|
|
|
40
40
|
def __repr__(self):
|
|
@@ -81,22 +81,16 @@ class XArrayMetadata(RawMetadata):
|
|
|
81
81
|
def _valid_datetime(self):
|
|
82
82
|
return self._get("valid_datetime")
|
|
83
83
|
|
|
84
|
-
def
|
|
84
|
+
def get(self, key, astype=None, **kwargs):
|
|
85
85
|
|
|
86
86
|
if key in self._d:
|
|
87
|
+
if astype is not None:
|
|
88
|
+
return astype(self._d[key])
|
|
87
89
|
return self._d[key]
|
|
88
90
|
|
|
89
|
-
if key.startswith("mars."):
|
|
90
|
-
key = key[5:]
|
|
91
|
-
if key not in self.MARS_KEYS:
|
|
92
|
-
if kwargs.get("raise_on_missing", False):
|
|
93
|
-
raise KeyError(f"Invalid key '{key}' in namespace='mars'")
|
|
94
|
-
else:
|
|
95
|
-
return kwargs.get("default", None)
|
|
96
|
-
|
|
97
91
|
key = self._mapping._from_user(key)
|
|
98
92
|
|
|
99
|
-
return super().
|
|
93
|
+
return super().get(key, astype=astype, **kwargs)
|
|
100
94
|
|
|
101
95
|
|
|
102
96
|
class XArrayFieldGeography(Geography):
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
LOG = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def patch_attributes(ds, attributes):
|
|
17
|
+
for name, value in attributes.items():
|
|
18
|
+
variable = ds[name]
|
|
19
|
+
variable.attrs.update(value)
|
|
20
|
+
|
|
21
|
+
return ds
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def patch_coordinates(ds, coordinates):
|
|
25
|
+
for name in coordinates:
|
|
26
|
+
ds = ds.assign_coords({name: ds[name]})
|
|
27
|
+
|
|
28
|
+
return ds
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
PATCHES = {
|
|
32
|
+
"attributes": patch_attributes,
|
|
33
|
+
"coordinates": patch_coordinates,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def patch_dataset(ds, patch):
|
|
38
|
+
for what, values in patch.items():
|
|
39
|
+
if what not in PATCHES:
|
|
40
|
+
raise ValueError(f"Unknown patch type {what!r}")
|
|
41
|
+
|
|
42
|
+
ds = PATCHES[what](ds, values)
|
|
43
|
+
|
|
44
|
+
return ds
|
|
@@ -62,12 +62,18 @@ class Time:
|
|
|
62
62
|
|
|
63
63
|
raise NotImplementedError(f"{len(date_coordinate)=} {len(time_coordinate)=} {len(step_coordinate)=}")
|
|
64
64
|
|
|
65
|
+
def select_valid_datetime(self, variable):
|
|
66
|
+
raise NotImplementedError(f"{self.__class__.__name__}.select_valid_datetime()")
|
|
67
|
+
|
|
65
68
|
|
|
66
69
|
class Constant(Time):
|
|
67
70
|
|
|
68
71
|
def fill_time_metadata(self, coords_values, metadata):
|
|
69
72
|
return None
|
|
70
73
|
|
|
74
|
+
def select_valid_datetime(self, variable):
|
|
75
|
+
return None
|
|
76
|
+
|
|
71
77
|
|
|
72
78
|
class Analysis(Time):
|
|
73
79
|
|
|
@@ -83,6 +89,9 @@ class Analysis(Time):
|
|
|
83
89
|
|
|
84
90
|
return valid_datetime
|
|
85
91
|
|
|
92
|
+
def select_valid_datetime(self, variable):
|
|
93
|
+
return self.time_coordinate_name
|
|
94
|
+
|
|
86
95
|
|
|
87
96
|
class ForecastFromValidTimeAndStep(Time):
|
|
88
97
|
|
|
@@ -116,6 +125,9 @@ class ForecastFromValidTimeAndStep(Time):
|
|
|
116
125
|
|
|
117
126
|
return valid_datetime
|
|
118
127
|
|
|
128
|
+
def select_valid_datetime(self, variable):
|
|
129
|
+
return self.time_coordinate_name
|
|
130
|
+
|
|
119
131
|
|
|
120
132
|
class ForecastFromValidTimeAndBaseTime(Time):
|
|
121
133
|
|
|
@@ -138,6 +150,9 @@ class ForecastFromValidTimeAndBaseTime(Time):
|
|
|
138
150
|
|
|
139
151
|
return valid_datetime
|
|
140
152
|
|
|
153
|
+
def select_valid_datetime(self, variable):
|
|
154
|
+
return self.time_coordinate_name
|
|
155
|
+
|
|
141
156
|
|
|
142
157
|
class ForecastFromBaseTimeAndDate(Time):
|
|
143
158
|
|