anemoi-datasets 0.5.25__py3-none-any.whl → 0.5.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/grib-index.py +1 -1
- anemoi/datasets/create/filter.py +22 -24
- anemoi/datasets/create/input/step.py +2 -16
- anemoi/datasets/create/sources/planetary_computer.py +44 -0
- anemoi/datasets/create/sources/xarray_support/__init__.py +6 -22
- anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -0
- anemoi/datasets/create/sources/xarray_support/field.py +1 -4
- anemoi/datasets/create/sources/xarray_support/flavour.py +44 -6
- anemoi/datasets/create/sources/xarray_support/patch.py +44 -1
- anemoi/datasets/create/sources/xarray_support/variable.py +6 -2
- anemoi/datasets/data/complement.py +44 -10
- anemoi/datasets/data/forwards.py +8 -2
- anemoi/datasets/data/stores.py +7 -56
- anemoi/datasets/grids.py +6 -3
- {anemoi_datasets-0.5.25.dist-info → anemoi_datasets-0.5.26.dist-info}/METADATA +3 -2
- {anemoi_datasets-0.5.25.dist-info → anemoi_datasets-0.5.26.dist-info}/RECORD +21 -40
- anemoi/datasets/create/filters/__init__.py +0 -33
- anemoi/datasets/create/filters/empty.py +0 -37
- anemoi/datasets/create/filters/legacy.py +0 -93
- anemoi/datasets/create/filters/noop.py +0 -37
- anemoi/datasets/create/filters/orog_to_z.py +0 -58
- anemoi/datasets/create/filters/pressure_level_relative_humidity_to_specific_humidity.py +0 -83
- anemoi/datasets/create/filters/pressure_level_specific_humidity_to_relative_humidity.py +0 -84
- anemoi/datasets/create/filters/rename.py +0 -205
- anemoi/datasets/create/filters/rotate_winds.py +0 -105
- anemoi/datasets/create/filters/single_level_dewpoint_to_relative_humidity.py +0 -78
- anemoi/datasets/create/filters/single_level_relative_humidity_to_dewpoint.py +0 -84
- anemoi/datasets/create/filters/single_level_relative_humidity_to_specific_humidity.py +0 -163
- anemoi/datasets/create/filters/single_level_specific_humidity_to_relative_humidity.py +0 -451
- anemoi/datasets/create/filters/speeddir_to_uv.py +0 -95
- anemoi/datasets/create/filters/sum.py +0 -68
- anemoi/datasets/create/filters/transform.py +0 -51
- anemoi/datasets/create/filters/unrotate_winds.py +0 -105
- anemoi/datasets/create/filters/uv_to_speeddir.py +0 -94
- anemoi/datasets/create/filters/wz_to_w.py +0 -98
- anemoi/datasets/create/testing.py +0 -76
- {anemoi_datasets-0.5.25.dist-info → anemoi_datasets-0.5.26.dist-info}/WHEEL +0 -0
- {anemoi_datasets-0.5.25.dist-info → anemoi_datasets-0.5.26.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.25.dist-info → anemoi_datasets-0.5.26.dist-info}/licenses/LICENSE +0 -0
- {anemoi_datasets-0.5.25.dist-info → anemoi_datasets-0.5.26.dist-info}/top_level.txt +0 -0
anemoi/datasets/_version.py
CHANGED
|
@@ -81,7 +81,7 @@ class GribIndexCmd(Command):
|
|
|
81
81
|
bool
|
|
82
82
|
True if the path matches, False otherwise.
|
|
83
83
|
"""
|
|
84
|
-
return fnmatch.fnmatch(path, args.match)
|
|
84
|
+
return fnmatch.fnmatch(os.path.basename(path), args.match)
|
|
85
85
|
|
|
86
86
|
from anemoi.datasets.create.sources.grib_index import GribIndex
|
|
87
87
|
|
anemoi/datasets/create/filter.py
CHANGED
|
@@ -7,44 +7,42 @@
|
|
|
7
7
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
8
|
# nor does it submit to any jurisdiction.
|
|
9
9
|
|
|
10
|
-
from abc import ABC
|
|
11
|
-
from abc import abstractmethod
|
|
12
10
|
from typing import Any
|
|
11
|
+
from typing import Dict
|
|
13
12
|
|
|
14
13
|
import earthkit.data as ekd
|
|
15
14
|
|
|
16
15
|
|
|
17
|
-
class
|
|
18
|
-
"""
|
|
16
|
+
class TransformFilter:
|
|
17
|
+
"""Calls filters from anemoi.transform.filters
|
|
19
18
|
|
|
20
|
-
|
|
21
|
-
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
context : Any
|
|
22
|
+
The context in which the filter is created.
|
|
23
|
+
name : str
|
|
24
|
+
The name of the filter.
|
|
25
|
+
config : Dict[str, Any]
|
|
26
|
+
The configuration for the filter.
|
|
27
|
+
"""
|
|
22
28
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
context : Any
|
|
26
|
-
The context in which the filter is created.
|
|
27
|
-
*args : tuple
|
|
28
|
-
Positional arguments.
|
|
29
|
-
**kwargs : dict
|
|
30
|
-
Keyword arguments.
|
|
31
|
-
"""
|
|
29
|
+
def __init__(self, context: Any, name: str, config: Dict[str, Any]) -> None:
|
|
30
|
+
from anemoi.transform.filters import create_filter
|
|
32
31
|
|
|
33
|
-
self.
|
|
32
|
+
self.name = name
|
|
33
|
+
self.transform_filter = create_filter(context, config)
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
"""Execute the filter.
|
|
35
|
+
def execute(self, input: ekd.FieldList) -> ekd.FieldList:
|
|
36
|
+
"""Execute the transformation filter.
|
|
38
37
|
|
|
39
38
|
Parameters
|
|
40
39
|
----------
|
|
41
|
-
|
|
42
|
-
The input data.
|
|
40
|
+
input : ekd.FieldList
|
|
41
|
+
The input data to be transformed.
|
|
43
42
|
|
|
44
43
|
Returns
|
|
45
44
|
-------
|
|
46
45
|
ekd.FieldList
|
|
47
|
-
The
|
|
46
|
+
The transformed data.
|
|
48
47
|
"""
|
|
49
|
-
|
|
50
|
-
pass
|
|
48
|
+
return self.transform_filter.forward(input)
|
|
@@ -8,7 +8,6 @@
|
|
|
8
8
|
# nor does it submit to any jurisdiction.
|
|
9
9
|
|
|
10
10
|
import logging
|
|
11
|
-
import warnings
|
|
12
11
|
from copy import deepcopy
|
|
13
12
|
from typing import Any
|
|
14
13
|
from typing import Dict
|
|
@@ -165,24 +164,11 @@ def step_factory(config: Dict[str, Any], context: ActionContext, action_path: Li
|
|
|
165
164
|
if cls is not None:
|
|
166
165
|
return cls(context, action_path, previous_step, *args, **kwargs)
|
|
167
166
|
|
|
168
|
-
# Try filters from
|
|
167
|
+
# Try filters from transform filter registry
|
|
169
168
|
from anemoi.transform.filters import filter_registry as transform_filter_registry
|
|
170
169
|
|
|
171
|
-
from ..filters import create_filter as create_datasets_filter
|
|
172
|
-
from ..filters import filter_registry as datasets_filter_registry
|
|
173
|
-
|
|
174
|
-
if datasets_filter_registry.is_registered(key):
|
|
175
|
-
|
|
176
|
-
if transform_filter_registry.is_registered(key):
|
|
177
|
-
warnings.warn(f"Filter `{key}` is registered in both datasets and transform filter registries")
|
|
178
|
-
|
|
179
|
-
filter = create_datasets_filter(None, config)
|
|
180
|
-
return FunctionStepAction(context, action_path + [key], previous_step, key, filter)
|
|
181
|
-
|
|
182
|
-
# Use filters from transform registry
|
|
183
|
-
|
|
184
170
|
if transform_filter_registry.is_registered(key):
|
|
185
|
-
from ..
|
|
171
|
+
from ..filter import TransformFilter
|
|
186
172
|
|
|
187
173
|
return FunctionStepAction(
|
|
188
174
|
context, action_path + [key], previous_step, key, TransformFilter(context, key, config)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from . import source_registry
|
|
12
|
+
from .xarray import XarraySourceBase
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@source_registry.register("planetary_computer")
|
|
16
|
+
class PlanetaryComputerSource(XarraySourceBase):
|
|
17
|
+
"""An Xarray data source for the planetary_computer."""
|
|
18
|
+
|
|
19
|
+
emoji = "🪐"
|
|
20
|
+
|
|
21
|
+
def __init__(self, context, data_catalog_id, version="v1", *args, **kwargs: dict):
|
|
22
|
+
|
|
23
|
+
import planetary_computer
|
|
24
|
+
import pystac_client
|
|
25
|
+
|
|
26
|
+
self.data_catalog_id = data_catalog_id
|
|
27
|
+
self.flavour = kwargs.pop("flavour", None)
|
|
28
|
+
self.patch = kwargs.pop("patch", None)
|
|
29
|
+
self.options = kwargs.pop("options", {})
|
|
30
|
+
|
|
31
|
+
catalog = pystac_client.Client.open(
|
|
32
|
+
f"https://planetarycomputer.microsoft.com/api/stac/{version}/",
|
|
33
|
+
modifier=planetary_computer.sign_inplace,
|
|
34
|
+
)
|
|
35
|
+
collection = catalog.get_collection(self.data_catalog_id)
|
|
36
|
+
|
|
37
|
+
asset = collection.assets["zarr-abfs"]
|
|
38
|
+
|
|
39
|
+
if "xarray:storage_options" in asset.extra_fields:
|
|
40
|
+
self.options["storage_options"] = asset.extra_fields["xarray:storage_options"]
|
|
41
|
+
|
|
42
|
+
self.options.update(asset.extra_fields["xarray:open_kwargs"])
|
|
43
|
+
|
|
44
|
+
super().__init__(context, url=asset.href, *args, **kwargs)
|
|
@@ -20,7 +20,6 @@ import xarray as xr
|
|
|
20
20
|
from earthkit.data.core.fieldlist import MultiFieldList
|
|
21
21
|
|
|
22
22
|
from anemoi.datasets.create.sources.patterns import iterate_patterns
|
|
23
|
-
from anemoi.datasets.data.stores import name_to_zarr_store
|
|
24
23
|
|
|
25
24
|
from ..legacy import legacy_source
|
|
26
25
|
from .fieldlist import XarrayFieldList
|
|
@@ -89,37 +88,22 @@ def load_one(
|
|
|
89
88
|
The loaded dataset.
|
|
90
89
|
"""
|
|
91
90
|
|
|
92
|
-
"""
|
|
93
|
-
We manage the S3 client ourselves, bypassing fsspec and s3fs layers, because sometimes something on the stack
|
|
94
|
-
zarr/fsspec/s3fs/boto3 (?) seem to flags files as missing when they actually are not (maybe when S3 reports some sort of
|
|
95
|
-
connection error). In that case, Zarr will silently fill the chunks that could not be downloaded with NaNs.
|
|
96
|
-
See https://github.com/pydata/xarray/issues/8842
|
|
97
|
-
|
|
98
|
-
We have seen this bug triggered when we run many clients in parallel, for example, when we create a new dataset using `xarray-zarr`.
|
|
99
|
-
"""
|
|
100
|
-
|
|
101
91
|
if options is None:
|
|
102
92
|
options = {}
|
|
103
93
|
|
|
104
94
|
context.trace(emoji, dataset, options, kwargs)
|
|
105
95
|
|
|
106
|
-
if isinstance(dataset, str) and ".zarr"
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
data = xr.open_zarr(**store)
|
|
112
|
-
if "filename_or_obj" in store:
|
|
113
|
-
data = xr.open_dataset(**store)
|
|
114
|
-
else:
|
|
115
|
-
data = xr.open_dataset(dataset, **options)
|
|
96
|
+
if isinstance(dataset, str) and dataset.endswith(".zarr"):
|
|
97
|
+
# If the dataset is a zarr store, we need to use the zarr engine
|
|
98
|
+
options["engine"] = "zarr"
|
|
99
|
+
|
|
100
|
+
data = xr.open_dataset(dataset, **options)
|
|
116
101
|
|
|
117
102
|
fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
|
|
118
103
|
|
|
119
104
|
if len(dates) == 0:
|
|
120
105
|
result = fs.sel(**kwargs)
|
|
121
106
|
else:
|
|
122
|
-
print("dates", dates, kwargs)
|
|
123
107
|
result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])
|
|
124
108
|
|
|
125
109
|
if len(result) == 0:
|
|
@@ -130,7 +114,7 @@ def load_one(
|
|
|
130
114
|
a = ["valid_datetime", k.metadata("valid_datetime", default=None)]
|
|
131
115
|
for n in kwargs.keys():
|
|
132
116
|
a.extend([n, k.metadata(n, default=None)])
|
|
133
|
-
|
|
117
|
+
LOG.warning(f"{[str(x) for x in a]}")
|
|
134
118
|
|
|
135
119
|
if i > 16:
|
|
136
120
|
break
|
|
@@ -95,6 +95,7 @@ class Coordinate:
|
|
|
95
95
|
is_member = False
|
|
96
96
|
is_x = False
|
|
97
97
|
is_y = False
|
|
98
|
+
is_point = False
|
|
98
99
|
|
|
99
100
|
def __init__(self, variable: xr.DataArray) -> None:
|
|
100
101
|
"""Initialize the coordinate.
|
|
@@ -390,6 +391,13 @@ class EnsembleCoordinate(Coordinate):
|
|
|
390
391
|
return value
|
|
391
392
|
|
|
392
393
|
|
|
394
|
+
class PointCoordinate(Coordinate):
|
|
395
|
+
"""Coordinate class for point data."""
|
|
396
|
+
|
|
397
|
+
is_point = True
|
|
398
|
+
mars_names = ("point",)
|
|
399
|
+
|
|
400
|
+
|
|
393
401
|
class LongitudeCoordinate(Coordinate):
|
|
394
402
|
"""Coordinate class for longitude."""
|
|
395
403
|
|
|
@@ -87,13 +87,10 @@ class XArrayField(Field):
|
|
|
87
87
|
coordinate = owner.by_name[coord_name]
|
|
88
88
|
self._md[coord_name] = coordinate.normalise(extract_single_value(coord_value))
|
|
89
89
|
|
|
90
|
-
# print(values.ndim, values.shape, selection.dims)
|
|
91
90
|
# By now, the only dimensions should be latitude and longitude
|
|
92
91
|
self._shape = tuple(list(self.selection.shape)[-2:])
|
|
93
92
|
if math.prod(self._shape) != math.prod(self.selection.shape):
|
|
94
|
-
|
|
95
|
-
print(self.selection)
|
|
96
|
-
raise ValueError("Invalid shape for selection")
|
|
93
|
+
raise ValueError(f"Invalid shape for selection {self._shape=}, {self.selection.shape=} {self.selection=}")
|
|
97
94
|
|
|
98
95
|
@property
|
|
99
96
|
def shape(self) -> Tuple[int, int]:
|
|
@@ -26,6 +26,7 @@ from .coordinates import EnsembleCoordinate
|
|
|
26
26
|
from .coordinates import LatitudeCoordinate
|
|
27
27
|
from .coordinates import LevelCoordinate
|
|
28
28
|
from .coordinates import LongitudeCoordinate
|
|
29
|
+
from .coordinates import PointCoordinate
|
|
29
30
|
from .coordinates import ScalarCoordinate
|
|
30
31
|
from .coordinates import StepCoordinate
|
|
31
32
|
from .coordinates import TimeCoordinate
|
|
@@ -134,6 +135,10 @@ class CoordinateGuesser(ABC):
|
|
|
134
135
|
|
|
135
136
|
d: Optional[Coordinate] = None
|
|
136
137
|
|
|
138
|
+
d = self._is_point(coordinate, attributes)
|
|
139
|
+
if d is not None:
|
|
140
|
+
return d
|
|
141
|
+
|
|
137
142
|
d = self._is_longitude(coordinate, attributes)
|
|
138
143
|
if d is not None:
|
|
139
144
|
return d
|
|
@@ -308,9 +313,9 @@ class CoordinateGuesser(ABC):
|
|
|
308
313
|
return self._grid_cache[(x.name, y.name, dim_vars)]
|
|
309
314
|
|
|
310
315
|
grid_mapping = variable.attrs.get("grid_mapping", None)
|
|
311
|
-
if grid_mapping is not None:
|
|
312
|
-
|
|
313
|
-
|
|
316
|
+
# if grid_mapping is not None:
|
|
317
|
+
# print(f"grid_mapping: {grid_mapping}")
|
|
318
|
+
# print(self.ds[grid_mapping])
|
|
314
319
|
|
|
315
320
|
if grid_mapping is None:
|
|
316
321
|
LOG.warning(f"No 'grid_mapping' attribute provided for '{variable.name}'")
|
|
@@ -392,6 +397,10 @@ class CoordinateGuesser(ABC):
|
|
|
392
397
|
"""
|
|
393
398
|
pass
|
|
394
399
|
|
|
400
|
+
@abstractmethod
|
|
401
|
+
def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]:
|
|
402
|
+
pass
|
|
403
|
+
|
|
395
404
|
@abstractmethod
|
|
396
405
|
def _is_latitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[LatitudeCoordinate]:
|
|
397
406
|
"""Checks if the coordinate is a latitude.
|
|
@@ -550,6 +559,15 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
|
|
|
550
559
|
"""
|
|
551
560
|
super().__init__(ds)
|
|
552
561
|
|
|
562
|
+
def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]:
|
|
563
|
+
if attributes.standard_name in ["cell", "station", "poi", "point"]:
|
|
564
|
+
return PointCoordinate(c)
|
|
565
|
+
|
|
566
|
+
if attributes.name in ["cell", "station", "poi", "point"]: # WeatherBench
|
|
567
|
+
return PointCoordinate(c)
|
|
568
|
+
|
|
569
|
+
return None
|
|
570
|
+
|
|
553
571
|
def _is_longitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[LongitudeCoordinate]:
|
|
554
572
|
"""Checks if the coordinate is a longitude.
|
|
555
573
|
|
|
@@ -750,6 +768,9 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
|
|
|
750
768
|
if attributes.standard_name == "air_pressure" and attributes.units == "hPa":
|
|
751
769
|
return LevelCoordinate(c, "pl")
|
|
752
770
|
|
|
771
|
+
if attributes.long_name == "pressure" and attributes.units in ["hPa", "Pa"]:
|
|
772
|
+
return LevelCoordinate(c, "pl")
|
|
773
|
+
|
|
753
774
|
if attributes.name == "level":
|
|
754
775
|
return LevelCoordinate(c, "pl")
|
|
755
776
|
|
|
@@ -759,9 +780,6 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
|
|
|
759
780
|
if attributes.standard_name == "depth":
|
|
760
781
|
return LevelCoordinate(c, "depth")
|
|
761
782
|
|
|
762
|
-
if attributes.name == "vertical" and attributes.units == "hPa":
|
|
763
|
-
return LevelCoordinate(c, "pl")
|
|
764
|
-
|
|
765
783
|
return None
|
|
766
784
|
|
|
767
785
|
def _is_number(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[EnsembleCoordinate]:
|
|
@@ -1040,3 +1058,23 @@ class FlavourCoordinateGuesser(CoordinateGuesser):
|
|
|
1040
1058
|
return EnsembleCoordinate(c)
|
|
1041
1059
|
|
|
1042
1060
|
return None
|
|
1061
|
+
|
|
1062
|
+
def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]:
|
|
1063
|
+
"""Checks if the coordinate is a point coordinate using the flavour rules.
|
|
1064
|
+
|
|
1065
|
+
Parameters
|
|
1066
|
+
----------
|
|
1067
|
+
c : xr.DataArray
|
|
1068
|
+
The coordinate to check.
|
|
1069
|
+
attributes : CoordinateAttributes
|
|
1070
|
+
The attributes of the coordinate.
|
|
1071
|
+
|
|
1072
|
+
Returns
|
|
1073
|
+
-------
|
|
1074
|
+
Optional[PointCoordinate]
|
|
1075
|
+
The StepCoorPointCoordinateinate if matched, else None.
|
|
1076
|
+
"""
|
|
1077
|
+
if self._match(c, "point", attributes):
|
|
1078
|
+
return PointCoordinate(c)
|
|
1079
|
+
|
|
1080
|
+
return None
|
|
@@ -61,9 +61,50 @@ def patch_coordinates(ds: xr.Dataset, coordinates: List[str]) -> Any:
|
|
|
61
61
|
return ds
|
|
62
62
|
|
|
63
63
|
|
|
64
|
+
def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any:
|
|
65
|
+
"""Rename variables in the dataset.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
ds : xr.Dataset
|
|
70
|
+
The dataset to patch.
|
|
71
|
+
renames : dict[str, str]
|
|
72
|
+
Mapping from old variable names to new variable names.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
Any
|
|
77
|
+
The patched dataset.
|
|
78
|
+
"""
|
|
79
|
+
return ds.rename(renames)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: List[str]) -> Any:
|
|
83
|
+
"""Sort the coordinates of the dataset.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
ds : xr.Dataset
|
|
88
|
+
The dataset to patch.
|
|
89
|
+
sort_coordinates : List[str]
|
|
90
|
+
The coordinates to sort.
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
Any
|
|
95
|
+
The patched dataset.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
for name in sort_coordinates:
|
|
99
|
+
ds = ds.sortby(name)
|
|
100
|
+
return ds
|
|
101
|
+
|
|
102
|
+
|
|
64
103
|
PATCHES = {
|
|
65
104
|
"attributes": patch_attributes,
|
|
66
105
|
"coordinates": patch_coordinates,
|
|
106
|
+
"rename": patch_rename,
|
|
107
|
+
"sort_coordinates": patch_sort_coordinate,
|
|
67
108
|
}
|
|
68
109
|
|
|
69
110
|
|
|
@@ -82,7 +123,9 @@ def patch_dataset(ds: xr.Dataset, patch: Dict[str, Dict[str, Any]]) -> Any:
|
|
|
82
123
|
Any
|
|
83
124
|
The patched dataset.
|
|
84
125
|
"""
|
|
85
|
-
|
|
126
|
+
|
|
127
|
+
ORDER = ["coordinates", "attributes", "rename", "sort_coordinates"]
|
|
128
|
+
for what, values in sorted(patch.items(), key=lambda x: ORDER.index(x[0])):
|
|
86
129
|
if what not in PATCHES:
|
|
87
130
|
raise ValueError(f"Unknown patch type {what!r}")
|
|
88
131
|
|
|
@@ -82,8 +82,12 @@ class Variable:
|
|
|
82
82
|
|
|
83
83
|
self.time = time
|
|
84
84
|
|
|
85
|
-
self.shape = tuple(
|
|
86
|
-
|
|
85
|
+
self.shape = tuple(
|
|
86
|
+
len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid and not c.is_point
|
|
87
|
+
)
|
|
88
|
+
self.names = {
|
|
89
|
+
c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid and not c.is_point
|
|
90
|
+
}
|
|
87
91
|
self.by_name = {c.variable.name: c for c in coordinates}
|
|
88
92
|
|
|
89
93
|
# We need that alias for the time dimension
|
|
@@ -7,7 +7,6 @@
|
|
|
7
7
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
8
|
# nor does it submit to any jurisdiction.
|
|
9
9
|
|
|
10
|
-
|
|
11
10
|
import datetime
|
|
12
11
|
import logging
|
|
13
12
|
from abc import abstractmethod
|
|
@@ -19,6 +18,7 @@ from typing import Optional
|
|
|
19
18
|
from typing import Set
|
|
20
19
|
from typing import Tuple
|
|
21
20
|
|
|
21
|
+
import numpy as np
|
|
22
22
|
from numpy.typing import NDArray
|
|
23
23
|
|
|
24
24
|
from ..grids import nearest_grid_points
|
|
@@ -85,6 +85,7 @@ class Complement(Combined):
|
|
|
85
85
|
for v in self._source.variables:
|
|
86
86
|
if v not in self._target.variables:
|
|
87
87
|
self._variables.append(v)
|
|
88
|
+
LOG.info(f"The following variables will be complemented: {self._variables}")
|
|
88
89
|
|
|
89
90
|
if not self._variables:
|
|
90
91
|
raise ValueError("Augment: no missing variables")
|
|
@@ -96,9 +97,11 @@ class Complement(Combined):
|
|
|
96
97
|
|
|
97
98
|
@property
|
|
98
99
|
def statistics(self) -> Dict[str, NDArray[Any]]:
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
100
|
+
datasets = [self._source, self._target]
|
|
101
|
+
return {
|
|
102
|
+
k: [d.statistics[k][d.name_to_index[i]] for d in datasets for i in d.variables if i in self.variables]
|
|
103
|
+
for k in datasets[0].statistics
|
|
104
|
+
}
|
|
102
105
|
|
|
103
106
|
def statistics_tendencies(self, delta: Optional[datetime.timedelta] = None) -> Dict[str, NDArray[Any]]:
|
|
104
107
|
index = [self._source.name_to_index[v] for v in self._variables]
|
|
@@ -120,7 +123,11 @@ class Complement(Combined):
|
|
|
120
123
|
@property
|
|
121
124
|
def variables_metadata(self) -> Dict[str, Any]:
|
|
122
125
|
"""Returns the metadata of the variables to be added to the target dataset."""
|
|
123
|
-
|
|
126
|
+
# Merge the two dicts first
|
|
127
|
+
all_meta = {**self._source.variables_metadata, **self._target.variables_metadata}
|
|
128
|
+
|
|
129
|
+
# Filter to keep only desired variables
|
|
130
|
+
return {k: v for k, v in all_meta.items() if k in self._variables}
|
|
124
131
|
|
|
125
132
|
def check_same_variables(self, d1: Dataset, d2: Dataset) -> None:
|
|
126
133
|
"""Checks if the variables in two datasets are the same.
|
|
@@ -231,7 +238,7 @@ class ComplementNone(Complement):
|
|
|
231
238
|
class ComplementNearest(Complement):
|
|
232
239
|
"""A class to complement a target dataset with variables from a source dataset using nearest neighbor interpolation."""
|
|
233
240
|
|
|
234
|
-
def __init__(self, target: Any, source: Any, max_distance: float = None) -> None:
|
|
241
|
+
def __init__(self, target: Any, source: Any, max_distance: float = None, k: int = 1) -> None:
|
|
235
242
|
"""Initializes the ComplementNearest class.
|
|
236
243
|
|
|
237
244
|
Parameters
|
|
@@ -242,17 +249,25 @@ class ComplementNearest(Complement):
|
|
|
242
249
|
The source dataset.
|
|
243
250
|
max_distance : float, optional
|
|
244
251
|
The maximum distance for nearest neighbor interpolation, default is None.
|
|
252
|
+
k : int, optional
|
|
253
|
+
The number of k closest neighbors to consider for interpolation
|
|
245
254
|
"""
|
|
246
255
|
super().__init__(target, source)
|
|
247
256
|
|
|
248
|
-
self.
|
|
257
|
+
self.k = k
|
|
258
|
+
self._distances, self._nearest_grid_points = nearest_grid_points(
|
|
249
259
|
self._source.latitudes,
|
|
250
260
|
self._source.longitudes,
|
|
251
261
|
self._target.latitudes,
|
|
252
262
|
self._target.longitudes,
|
|
253
263
|
max_distance=max_distance,
|
|
264
|
+
k=k,
|
|
254
265
|
)
|
|
255
266
|
|
|
267
|
+
if k == 1:
|
|
268
|
+
self._distances = np.expand_dims(self._distances, axis=1)
|
|
269
|
+
self._nearest_grid_points = np.expand_dims(self._nearest_grid_points, axis=1)
|
|
270
|
+
|
|
256
271
|
def check_compatibility(self, d1: Dataset, d2: Dataset) -> None:
|
|
257
272
|
"""Checks the compatibility of two datasets for nearest neighbor interpolation.
|
|
258
273
|
|
|
@@ -285,7 +300,19 @@ class ComplementNearest(Complement):
|
|
|
285
300
|
source_data = self._source[index[0], source_index, index[2], ...]
|
|
286
301
|
target_data = source_data[..., self._nearest_grid_points]
|
|
287
302
|
|
|
288
|
-
|
|
303
|
+
epsilon = 1e-8 # prevent division by zero
|
|
304
|
+
weights = 1.0 / (self._distances + epsilon)
|
|
305
|
+
weights = weights.astype(target_data.dtype)
|
|
306
|
+
weights /= weights.sum(axis=1, keepdims=True) # normalize
|
|
307
|
+
|
|
308
|
+
# Reshape weights to broadcast correctly
|
|
309
|
+
# Add leading singleton dimensions so it matches target_data shape
|
|
310
|
+
while weights.ndim < target_data.ndim:
|
|
311
|
+
weights = np.expand_dims(weights, axis=0)
|
|
312
|
+
|
|
313
|
+
# Compute weighted average along the last dimension
|
|
314
|
+
final_point = np.sum(target_data * weights, axis=-1)
|
|
315
|
+
result = final_point[..., index[3]]
|
|
289
316
|
|
|
290
317
|
return apply_index_to_slices_changes(result, changes)
|
|
291
318
|
|
|
@@ -330,6 +357,13 @@ def complement_factory(args: Tuple, kwargs: dict) -> Dataset:
|
|
|
330
357
|
"nearest": ComplementNearest,
|
|
331
358
|
}[interpolation]
|
|
332
359
|
|
|
333
|
-
|
|
360
|
+
if interpolation == "nearest":
|
|
361
|
+
k = kwargs.pop("k", "1")
|
|
362
|
+
complement = Class(target=target, source=source, k=k)._subset(**kwargs)
|
|
363
|
+
|
|
364
|
+
else:
|
|
365
|
+
complement = Class(target=target, source=source)._subset(**kwargs)
|
|
366
|
+
|
|
367
|
+
joined = _open_dataset([target, complement])
|
|
334
368
|
|
|
335
|
-
return _open_dataset(
|
|
369
|
+
return _open_dataset(joined, reorder=sorted(joined.variables))
|
anemoi/datasets/data/forwards.py
CHANGED
|
@@ -330,8 +330,14 @@ class Combined(Forwards):
|
|
|
330
330
|
ValueError
|
|
331
331
|
If the grids are not the same.
|
|
332
332
|
"""
|
|
333
|
-
|
|
334
|
-
|
|
333
|
+
|
|
334
|
+
# note: not a proper implementation, should be handled
|
|
335
|
+
# in a more consolidated way ...
|
|
336
|
+
rtol = 1.0e-7
|
|
337
|
+
if not np.allclose(d1.latitudes, d2.latitudes, rtol=rtol) or not np.allclose(
|
|
338
|
+
d1.longitudes, d2.longitudes, rtol=rtol
|
|
339
|
+
):
|
|
340
|
+
raise ValueError(f"Incompatible grid ({d1.longitudes} {d2.longitudes})")
|
|
335
341
|
|
|
336
342
|
def check_same_shape(self, d1: Dataset, d2: Dataset) -> None:
|
|
337
343
|
"""Checks if the shapes of two datasets are the same.
|
anemoi/datasets/data/stores.py
CHANGED
|
@@ -107,51 +107,6 @@ class S3Store(ReadOnlyStore):
|
|
|
107
107
|
return response["Body"].read()
|
|
108
108
|
|
|
109
109
|
|
|
110
|
-
class PlanetaryComputerStore(ReadOnlyStore):
|
|
111
|
-
"""We write our own Store to access catalogs on Planetary Computer,
|
|
112
|
-
as it requires some extra arguments to use xr.open_zarr.
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
def __init__(self, data_catalog_id: str) -> None:
|
|
116
|
-
"""Initialize the PlanetaryComputerStore with a data catalog ID.
|
|
117
|
-
|
|
118
|
-
Parameters
|
|
119
|
-
----------
|
|
120
|
-
data_catalog_id : str
|
|
121
|
-
The data catalog ID.
|
|
122
|
-
"""
|
|
123
|
-
self.data_catalog_id = data_catalog_id
|
|
124
|
-
|
|
125
|
-
import planetary_computer
|
|
126
|
-
import pystac_client
|
|
127
|
-
|
|
128
|
-
catalog = pystac_client.Client.open(
|
|
129
|
-
"https://planetarycomputer.microsoft.com/api/stac/v1/",
|
|
130
|
-
modifier=planetary_computer.sign_inplace,
|
|
131
|
-
)
|
|
132
|
-
collection = catalog.get_collection(self.data_catalog_id)
|
|
133
|
-
|
|
134
|
-
asset = collection.assets["zarr-abfs"]
|
|
135
|
-
|
|
136
|
-
if "xarray:storage_options" in asset.extra_fields:
|
|
137
|
-
store = {
|
|
138
|
-
"store": asset.href,
|
|
139
|
-
"storage_options": asset.extra_fields["xarray:storage_options"],
|
|
140
|
-
**asset.extra_fields["xarray:open_kwargs"],
|
|
141
|
-
}
|
|
142
|
-
else:
|
|
143
|
-
store = {
|
|
144
|
-
"filename_or_obj": asset.href,
|
|
145
|
-
**asset.extra_fields["xarray:open_kwargs"],
|
|
146
|
-
}
|
|
147
|
-
|
|
148
|
-
self.store = store
|
|
149
|
-
|
|
150
|
-
def __getitem__(self, key: str) -> bytes:
|
|
151
|
-
"""Retrieve an item from the store."""
|
|
152
|
-
raise NotImplementedError()
|
|
153
|
-
|
|
154
|
-
|
|
155
110
|
class DebugStore(ReadOnlyStore):
|
|
156
111
|
"""A store to debug the zarr loading."""
|
|
157
112
|
|
|
@@ -190,11 +145,11 @@ def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore:
|
|
|
190
145
|
|
|
191
146
|
if store.startswith("http://") or store.startswith("https://"):
|
|
192
147
|
|
|
193
|
-
parsed = urlparse(store)
|
|
194
|
-
|
|
195
148
|
if store.endswith(".zip"):
|
|
196
149
|
import multiurl
|
|
197
150
|
|
|
151
|
+
parsed = urlparse(store)
|
|
152
|
+
|
|
198
153
|
# Zarr cannot handle zip files over HTTP
|
|
199
154
|
tmpdir = tempfile.gettempdir()
|
|
200
155
|
name = os.path.basename(parsed.path)
|
|
@@ -210,15 +165,7 @@ def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore:
|
|
|
210
165
|
os.rename(path + ".tmp", path)
|
|
211
166
|
return name_to_zarr_store(path)
|
|
212
167
|
|
|
213
|
-
|
|
214
|
-
if len(bits) == 5 and (bits[1], bits[3], bits[4]) == ("s3", "amazonaws", "com"):
|
|
215
|
-
s3_url = f"s3://{bits[0]}{parsed.path}"
|
|
216
|
-
store = S3Store(s3_url, region=bits[2])
|
|
217
|
-
elif store.startswith("https://planetarycomputer.microsoft.com/"):
|
|
218
|
-
data_catalog_id = store.rsplit("/", 1)[-1]
|
|
219
|
-
store = PlanetaryComputerStore(data_catalog_id).store
|
|
220
|
-
else:
|
|
221
|
-
store = HTTPStore(store)
|
|
168
|
+
return HTTPStore(store)
|
|
222
169
|
|
|
223
170
|
return store
|
|
224
171
|
|
|
@@ -565,6 +512,10 @@ def zarr_lookup(name: str, fail: bool = True) -> Optional[str]:
|
|
|
565
512
|
config = load_config()["datasets"]
|
|
566
513
|
use_search_path_not_found = config.get("use_search_path_not_found", False)
|
|
567
514
|
|
|
515
|
+
if name.endswith(".zarr/"):
|
|
516
|
+
LOG.warning("Removing trailing slash from path: %s", name)
|
|
517
|
+
name = name[:-1]
|
|
518
|
+
|
|
568
519
|
if name.endswith(".zarr") or name.endswith(".zip"):
|
|
569
520
|
|
|
570
521
|
if os.path.exists(name):
|