anemoi-datasets 0.5.24__py3-none-any.whl → 0.5.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/finalise-additions.py +2 -1
- anemoi/datasets/commands/finalise.py +2 -1
- anemoi/datasets/commands/grib-index.py +1 -1
- anemoi/datasets/commands/init-additions.py +2 -1
- anemoi/datasets/commands/load-additions.py +2 -1
- anemoi/datasets/commands/load.py +2 -1
- anemoi/datasets/create/__init__.py +24 -33
- anemoi/datasets/create/filter.py +22 -24
- anemoi/datasets/create/input/__init__.py +0 -20
- anemoi/datasets/create/input/step.py +2 -16
- anemoi/datasets/create/sources/accumulations.py +7 -6
- anemoi/datasets/create/sources/planetary_computer.py +44 -0
- anemoi/datasets/create/sources/xarray_support/__init__.py +6 -22
- anemoi/datasets/create/sources/xarray_support/coordinates.py +8 -0
- anemoi/datasets/create/sources/xarray_support/field.py +1 -4
- anemoi/datasets/create/sources/xarray_support/flavour.py +44 -6
- anemoi/datasets/create/sources/xarray_support/patch.py +44 -1
- anemoi/datasets/create/sources/xarray_support/variable.py +6 -2
- anemoi/datasets/data/complement.py +44 -10
- anemoi/datasets/data/dataset.py +29 -0
- anemoi/datasets/data/forwards.py +8 -2
- anemoi/datasets/data/misc.py +74 -16
- anemoi/datasets/data/observations/__init__.py +316 -0
- anemoi/datasets/data/observations/legacy_obs_dataset.py +200 -0
- anemoi/datasets/data/observations/multi.py +64 -0
- anemoi/datasets/data/padded.py +227 -0
- anemoi/datasets/data/records/__init__.py +442 -0
- anemoi/datasets/data/records/backends/__init__.py +157 -0
- anemoi/datasets/data/stores.py +7 -56
- anemoi/datasets/data/subset.py +5 -0
- anemoi/datasets/grids.py +6 -3
- {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/METADATA +3 -2
- {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/RECORD +38 -51
- {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/WHEEL +1 -1
- anemoi/datasets/create/filters/__init__.py +0 -33
- anemoi/datasets/create/filters/empty.py +0 -37
- anemoi/datasets/create/filters/legacy.py +0 -93
- anemoi/datasets/create/filters/noop.py +0 -37
- anemoi/datasets/create/filters/orog_to_z.py +0 -58
- anemoi/datasets/create/filters/pressure_level_relative_humidity_to_specific_humidity.py +0 -83
- anemoi/datasets/create/filters/pressure_level_specific_humidity_to_relative_humidity.py +0 -84
- anemoi/datasets/create/filters/rename.py +0 -205
- anemoi/datasets/create/filters/rotate_winds.py +0 -105
- anemoi/datasets/create/filters/single_level_dewpoint_to_relative_humidity.py +0 -78
- anemoi/datasets/create/filters/single_level_relative_humidity_to_dewpoint.py +0 -84
- anemoi/datasets/create/filters/single_level_relative_humidity_to_specific_humidity.py +0 -163
- anemoi/datasets/create/filters/single_level_specific_humidity_to_relative_humidity.py +0 -451
- anemoi/datasets/create/filters/speeddir_to_uv.py +0 -95
- anemoi/datasets/create/filters/sum.py +0 -68
- anemoi/datasets/create/filters/transform.py +0 -51
- anemoi/datasets/create/filters/unrotate_winds.py +0 -105
- anemoi/datasets/create/filters/uv_to_speeddir.py +0 -94
- anemoi/datasets/create/filters/wz_to_w.py +0 -98
- anemoi/datasets/create/testing.py +0 -76
- {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/licenses/LICENSE +0 -0
- {anemoi_datasets-0.5.24.dist-info → anemoi_datasets-0.5.26.dist-info}/top_level.txt +0 -0
anemoi/datasets/_version.py
CHANGED
|
@@ -81,7 +81,7 @@ class GribIndexCmd(Command):
|
|
|
81
81
|
bool
|
|
82
82
|
True if the path matches, False otherwise.
|
|
83
83
|
"""
|
|
84
|
-
return fnmatch.fnmatch(path, args.match)
|
|
84
|
+
return fnmatch.fnmatch(os.path.basename(path), args.match)
|
|
85
85
|
|
|
86
86
|
from anemoi.datasets.create.sources.grib_index import GribIndex
|
|
87
87
|
|
anemoi/datasets/commands/load.py
CHANGED
|
@@ -44,7 +44,7 @@ from .check import check_data_values
|
|
|
44
44
|
from .chunks import ChunkFilter
|
|
45
45
|
from .config import build_output
|
|
46
46
|
from .config import loader_config
|
|
47
|
-
from .input import
|
|
47
|
+
from .input import InputBuilder
|
|
48
48
|
from .statistics import Summary
|
|
49
49
|
from .statistics import TmpStatistics
|
|
50
50
|
from .statistics import check_variance
|
|
@@ -101,7 +101,9 @@ def json_tidy(o: Any) -> Any:
|
|
|
101
101
|
|
|
102
102
|
|
|
103
103
|
def build_statistics_dates(
|
|
104
|
-
dates: list[datetime.datetime],
|
|
104
|
+
dates: list[datetime.datetime],
|
|
105
|
+
start: Optional[datetime.datetime],
|
|
106
|
+
end: Optional[datetime.datetime],
|
|
105
107
|
) -> tuple[str, str]:
|
|
106
108
|
"""Compute the start and end dates for the statistics.
|
|
107
109
|
|
|
@@ -551,36 +553,16 @@ class HasElementForDataMixin:
|
|
|
551
553
|
|
|
552
554
|
self.output = build_output(config.output, parent=self)
|
|
553
555
|
|
|
554
|
-
self.input =
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
The main configuration.
|
|
565
|
-
output_config : Any
|
|
566
|
-
The output configuration.
|
|
567
|
-
|
|
568
|
-
Returns
|
|
569
|
-
-------
|
|
570
|
-
Any
|
|
571
|
-
The input builder.
|
|
572
|
-
"""
|
|
573
|
-
builder = build_input(
|
|
574
|
-
main_config.input,
|
|
575
|
-
data_sources=main_config.get("data_sources", {}),
|
|
576
|
-
order_by=output_config.order_by,
|
|
577
|
-
flatten_grid=output_config.flatten_grid,
|
|
578
|
-
remapping=build_remapping(output_config.remapping),
|
|
579
|
-
use_grib_paramid=main_config.build.use_grib_paramid,
|
|
580
|
-
)
|
|
581
|
-
LOG.debug("✅ INPUT_BUILDER")
|
|
582
|
-
LOG.debug(builder)
|
|
583
|
-
return builder
|
|
556
|
+
self.input = InputBuilder(
|
|
557
|
+
config.input,
|
|
558
|
+
data_sources=config.get("data_sources", {}),
|
|
559
|
+
order_by=self.output.order_by,
|
|
560
|
+
flatten_grid=self.output.flatten_grid,
|
|
561
|
+
remapping=build_remapping(self.output.remapping),
|
|
562
|
+
use_grib_paramid=config.build.use_grib_paramid,
|
|
563
|
+
)
|
|
564
|
+
LOG.debug("✅ INPUT_BUILDER")
|
|
565
|
+
LOG.debug(self.input)
|
|
584
566
|
|
|
585
567
|
|
|
586
568
|
class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin):
|
|
@@ -1541,7 +1523,16 @@ class Statistics(Actor, HasStatisticTempMixin, HasRegistryMixin):
|
|
|
1541
1523
|
if not all(self.registry.get_flags(sync=False)):
|
|
1542
1524
|
raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.")
|
|
1543
1525
|
|
|
1544
|
-
for k in [
|
|
1526
|
+
for k in [
|
|
1527
|
+
"mean",
|
|
1528
|
+
"stdev",
|
|
1529
|
+
"minimum",
|
|
1530
|
+
"maximum",
|
|
1531
|
+
"sums",
|
|
1532
|
+
"squares",
|
|
1533
|
+
"count",
|
|
1534
|
+
"has_nans",
|
|
1535
|
+
]:
|
|
1545
1536
|
self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",))
|
|
1546
1537
|
|
|
1547
1538
|
self.registry.add_to_history("compute_statistics_end")
|
anemoi/datasets/create/filter.py
CHANGED
|
@@ -7,44 +7,42 @@
|
|
|
7
7
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
8
|
# nor does it submit to any jurisdiction.
|
|
9
9
|
|
|
10
|
-
from abc import ABC
|
|
11
|
-
from abc import abstractmethod
|
|
12
10
|
from typing import Any
|
|
11
|
+
from typing import Dict
|
|
13
12
|
|
|
14
13
|
import earthkit.data as ekd
|
|
15
14
|
|
|
16
15
|
|
|
17
|
-
class
|
|
18
|
-
"""
|
|
16
|
+
class TransformFilter:
|
|
17
|
+
"""Calls filters from anemoi.transform.filters
|
|
19
18
|
|
|
20
|
-
|
|
21
|
-
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
context : Any
|
|
22
|
+
The context in which the filter is created.
|
|
23
|
+
name : str
|
|
24
|
+
The name of the filter.
|
|
25
|
+
config : Dict[str, Any]
|
|
26
|
+
The configuration for the filter.
|
|
27
|
+
"""
|
|
22
28
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
context : Any
|
|
26
|
-
The context in which the filter is created.
|
|
27
|
-
*args : tuple
|
|
28
|
-
Positional arguments.
|
|
29
|
-
**kwargs : dict
|
|
30
|
-
Keyword arguments.
|
|
31
|
-
"""
|
|
29
|
+
def __init__(self, context: Any, name: str, config: Dict[str, Any]) -> None:
|
|
30
|
+
from anemoi.transform.filters import create_filter
|
|
32
31
|
|
|
33
|
-
self.
|
|
32
|
+
self.name = name
|
|
33
|
+
self.transform_filter = create_filter(context, config)
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
"""Execute the filter.
|
|
35
|
+
def execute(self, input: ekd.FieldList) -> ekd.FieldList:
|
|
36
|
+
"""Execute the transformation filter.
|
|
38
37
|
|
|
39
38
|
Parameters
|
|
40
39
|
----------
|
|
41
|
-
|
|
42
|
-
The input data.
|
|
40
|
+
input : ekd.FieldList
|
|
41
|
+
The input data to be transformed.
|
|
43
42
|
|
|
44
43
|
Returns
|
|
45
44
|
-------
|
|
46
45
|
ekd.FieldList
|
|
47
|
-
The
|
|
46
|
+
The transformed data.
|
|
48
47
|
"""
|
|
49
|
-
|
|
50
|
-
pass
|
|
48
|
+
return self.transform_filter.forward(input)
|
|
@@ -104,23 +104,3 @@ class InputBuilder:
|
|
|
104
104
|
Trace string.
|
|
105
105
|
"""
|
|
106
106
|
return f"InputBuilder({group_of_dates})"
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
def build_input(config: dict, data_sources: Union[dict, list], **kwargs: Any) -> InputBuilder:
|
|
110
|
-
"""Build an InputBuilder instance.
|
|
111
|
-
|
|
112
|
-
Parameters
|
|
113
|
-
----------
|
|
114
|
-
config : dict
|
|
115
|
-
Configuration dictionary.
|
|
116
|
-
data_sources : Union[dict, list]
|
|
117
|
-
Data sources.
|
|
118
|
-
**kwargs : Any
|
|
119
|
-
Additional keyword arguments.
|
|
120
|
-
|
|
121
|
-
Returns
|
|
122
|
-
-------
|
|
123
|
-
InputBuilder
|
|
124
|
-
An instance of InputBuilder.
|
|
125
|
-
"""
|
|
126
|
-
return InputBuilder(config, data_sources, **kwargs)
|
|
@@ -8,7 +8,6 @@
|
|
|
8
8
|
# nor does it submit to any jurisdiction.
|
|
9
9
|
|
|
10
10
|
import logging
|
|
11
|
-
import warnings
|
|
12
11
|
from copy import deepcopy
|
|
13
12
|
from typing import Any
|
|
14
13
|
from typing import Dict
|
|
@@ -165,24 +164,11 @@ def step_factory(config: Dict[str, Any], context: ActionContext, action_path: Li
|
|
|
165
164
|
if cls is not None:
|
|
166
165
|
return cls(context, action_path, previous_step, *args, **kwargs)
|
|
167
166
|
|
|
168
|
-
# Try filters from
|
|
167
|
+
# Try filters from transform filter registry
|
|
169
168
|
from anemoi.transform.filters import filter_registry as transform_filter_registry
|
|
170
169
|
|
|
171
|
-
from ..filters import create_filter as create_datasets_filter
|
|
172
|
-
from ..filters import filter_registry as datasets_filter_registry
|
|
173
|
-
|
|
174
|
-
if datasets_filter_registry.is_registered(key):
|
|
175
|
-
|
|
176
|
-
if transform_filter_registry.is_registered(key):
|
|
177
|
-
warnings.warn(f"Filter `{key}` is registered in both datasets and transform filter registries")
|
|
178
|
-
|
|
179
|
-
filter = create_datasets_filter(None, config)
|
|
180
|
-
return FunctionStepAction(context, action_path + [key], previous_step, key, filter)
|
|
181
|
-
|
|
182
|
-
# Use filters from transform registry
|
|
183
|
-
|
|
184
170
|
if transform_filter_registry.is_registered(key):
|
|
185
|
-
from ..
|
|
171
|
+
from ..filter import TransformFilter
|
|
186
172
|
|
|
187
173
|
return FunctionStepAction(
|
|
188
174
|
context, action_path + [key], previous_step, key, TransformFilter(context, key, config)
|
|
@@ -459,12 +459,13 @@ class AccumulationFromStart(Accumulation):
|
|
|
459
459
|
A tuple representing the MARS date-time step.
|
|
460
460
|
"""
|
|
461
461
|
assert user_date is None, user_date
|
|
462
|
-
assert not frequency, frequency
|
|
463
462
|
|
|
464
463
|
steps = (step1 + add_step, step2 + add_step)
|
|
465
464
|
if steps[0] == 0:
|
|
466
465
|
steps = (steps[1],)
|
|
467
466
|
|
|
467
|
+
assert frequency == 0 or frequency == (step2 - step1), frequency
|
|
468
|
+
|
|
468
469
|
return (
|
|
469
470
|
base_date.year * 10000 + base_date.month * 100 + base_date.day,
|
|
470
471
|
base_date.hour * 100 + base_date.minute,
|
|
@@ -824,6 +825,11 @@ def _compute_accumulations(
|
|
|
824
825
|
step1, step2 = user_accumulation_period
|
|
825
826
|
assert step1 < step2, user_accumulation_period
|
|
826
827
|
|
|
828
|
+
if accumulations_reset_frequency is not None:
|
|
829
|
+
AccumulationClass = AccumulationFromLastReset
|
|
830
|
+
else:
|
|
831
|
+
AccumulationClass = AccumulationFromStart if data_accumulation_period in (0, None) else AccumulationFromLastStep
|
|
832
|
+
|
|
827
833
|
if data_accumulation_period is None:
|
|
828
834
|
data_accumulation_period = user_accumulation_period[1] - user_accumulation_period[0]
|
|
829
835
|
|
|
@@ -838,11 +844,6 @@ def _compute_accumulations(
|
|
|
838
844
|
|
|
839
845
|
base_times = [t // 100 if t > 100 else t for t in base_times]
|
|
840
846
|
|
|
841
|
-
if accumulations_reset_frequency is not None:
|
|
842
|
-
AccumulationClass = AccumulationFromLastReset
|
|
843
|
-
else:
|
|
844
|
-
AccumulationClass = AccumulationFromStart if data_accumulation_period in (0, None) else AccumulationFromLastStep
|
|
845
|
-
|
|
846
847
|
mars_date_time_steps = AccumulationClass.mars_date_time_steps(
|
|
847
848
|
dates=dates,
|
|
848
849
|
step1=step1,
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# (C) Copyright 2024 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from . import source_registry
|
|
12
|
+
from .xarray import XarraySourceBase
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@source_registry.register("planetary_computer")
|
|
16
|
+
class PlanetaryComputerSource(XarraySourceBase):
|
|
17
|
+
"""An Xarray data source for the planetary_computer."""
|
|
18
|
+
|
|
19
|
+
emoji = "🪐"
|
|
20
|
+
|
|
21
|
+
def __init__(self, context, data_catalog_id, version="v1", *args, **kwargs: dict):
|
|
22
|
+
|
|
23
|
+
import planetary_computer
|
|
24
|
+
import pystac_client
|
|
25
|
+
|
|
26
|
+
self.data_catalog_id = data_catalog_id
|
|
27
|
+
self.flavour = kwargs.pop("flavour", None)
|
|
28
|
+
self.patch = kwargs.pop("patch", None)
|
|
29
|
+
self.options = kwargs.pop("options", {})
|
|
30
|
+
|
|
31
|
+
catalog = pystac_client.Client.open(
|
|
32
|
+
f"https://planetarycomputer.microsoft.com/api/stac/{version}/",
|
|
33
|
+
modifier=planetary_computer.sign_inplace,
|
|
34
|
+
)
|
|
35
|
+
collection = catalog.get_collection(self.data_catalog_id)
|
|
36
|
+
|
|
37
|
+
asset = collection.assets["zarr-abfs"]
|
|
38
|
+
|
|
39
|
+
if "xarray:storage_options" in asset.extra_fields:
|
|
40
|
+
self.options["storage_options"] = asset.extra_fields["xarray:storage_options"]
|
|
41
|
+
|
|
42
|
+
self.options.update(asset.extra_fields["xarray:open_kwargs"])
|
|
43
|
+
|
|
44
|
+
super().__init__(context, url=asset.href, *args, **kwargs)
|
|
@@ -20,7 +20,6 @@ import xarray as xr
|
|
|
20
20
|
from earthkit.data.core.fieldlist import MultiFieldList
|
|
21
21
|
|
|
22
22
|
from anemoi.datasets.create.sources.patterns import iterate_patterns
|
|
23
|
-
from anemoi.datasets.data.stores import name_to_zarr_store
|
|
24
23
|
|
|
25
24
|
from ..legacy import legacy_source
|
|
26
25
|
from .fieldlist import XarrayFieldList
|
|
@@ -89,37 +88,22 @@ def load_one(
|
|
|
89
88
|
The loaded dataset.
|
|
90
89
|
"""
|
|
91
90
|
|
|
92
|
-
"""
|
|
93
|
-
We manage the S3 client ourselves, bypassing fsspec and s3fs layers, because sometimes something on the stack
|
|
94
|
-
zarr/fsspec/s3fs/boto3 (?) seem to flags files as missing when they actually are not (maybe when S3 reports some sort of
|
|
95
|
-
connection error). In that case, Zarr will silently fill the chunks that could not be downloaded with NaNs.
|
|
96
|
-
See https://github.com/pydata/xarray/issues/8842
|
|
97
|
-
|
|
98
|
-
We have seen this bug triggered when we run many clients in parallel, for example, when we create a new dataset using `xarray-zarr`.
|
|
99
|
-
"""
|
|
100
|
-
|
|
101
91
|
if options is None:
|
|
102
92
|
options = {}
|
|
103
93
|
|
|
104
94
|
context.trace(emoji, dataset, options, kwargs)
|
|
105
95
|
|
|
106
|
-
if isinstance(dataset, str) and ".zarr"
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
data = xr.open_zarr(**store)
|
|
112
|
-
if "filename_or_obj" in store:
|
|
113
|
-
data = xr.open_dataset(**store)
|
|
114
|
-
else:
|
|
115
|
-
data = xr.open_dataset(dataset, **options)
|
|
96
|
+
if isinstance(dataset, str) and dataset.endswith(".zarr"):
|
|
97
|
+
# If the dataset is a zarr store, we need to use the zarr engine
|
|
98
|
+
options["engine"] = "zarr"
|
|
99
|
+
|
|
100
|
+
data = xr.open_dataset(dataset, **options)
|
|
116
101
|
|
|
117
102
|
fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
|
|
118
103
|
|
|
119
104
|
if len(dates) == 0:
|
|
120
105
|
result = fs.sel(**kwargs)
|
|
121
106
|
else:
|
|
122
|
-
print("dates", dates, kwargs)
|
|
123
107
|
result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])
|
|
124
108
|
|
|
125
109
|
if len(result) == 0:
|
|
@@ -130,7 +114,7 @@ def load_one(
|
|
|
130
114
|
a = ["valid_datetime", k.metadata("valid_datetime", default=None)]
|
|
131
115
|
for n in kwargs.keys():
|
|
132
116
|
a.extend([n, k.metadata(n, default=None)])
|
|
133
|
-
|
|
117
|
+
LOG.warning(f"{[str(x) for x in a]}")
|
|
134
118
|
|
|
135
119
|
if i > 16:
|
|
136
120
|
break
|
|
@@ -95,6 +95,7 @@ class Coordinate:
|
|
|
95
95
|
is_member = False
|
|
96
96
|
is_x = False
|
|
97
97
|
is_y = False
|
|
98
|
+
is_point = False
|
|
98
99
|
|
|
99
100
|
def __init__(self, variable: xr.DataArray) -> None:
|
|
100
101
|
"""Initialize the coordinate.
|
|
@@ -390,6 +391,13 @@ class EnsembleCoordinate(Coordinate):
|
|
|
390
391
|
return value
|
|
391
392
|
|
|
392
393
|
|
|
394
|
+
class PointCoordinate(Coordinate):
|
|
395
|
+
"""Coordinate class for point data."""
|
|
396
|
+
|
|
397
|
+
is_point = True
|
|
398
|
+
mars_names = ("point",)
|
|
399
|
+
|
|
400
|
+
|
|
393
401
|
class LongitudeCoordinate(Coordinate):
|
|
394
402
|
"""Coordinate class for longitude."""
|
|
395
403
|
|
|
@@ -87,13 +87,10 @@ class XArrayField(Field):
|
|
|
87
87
|
coordinate = owner.by_name[coord_name]
|
|
88
88
|
self._md[coord_name] = coordinate.normalise(extract_single_value(coord_value))
|
|
89
89
|
|
|
90
|
-
# print(values.ndim, values.shape, selection.dims)
|
|
91
90
|
# By now, the only dimensions should be latitude and longitude
|
|
92
91
|
self._shape = tuple(list(self.selection.shape)[-2:])
|
|
93
92
|
if math.prod(self._shape) != math.prod(self.selection.shape):
|
|
94
|
-
|
|
95
|
-
print(self.selection)
|
|
96
|
-
raise ValueError("Invalid shape for selection")
|
|
93
|
+
raise ValueError(f"Invalid shape for selection {self._shape=}, {self.selection.shape=} {self.selection=}")
|
|
97
94
|
|
|
98
95
|
@property
|
|
99
96
|
def shape(self) -> Tuple[int, int]:
|
|
@@ -26,6 +26,7 @@ from .coordinates import EnsembleCoordinate
|
|
|
26
26
|
from .coordinates import LatitudeCoordinate
|
|
27
27
|
from .coordinates import LevelCoordinate
|
|
28
28
|
from .coordinates import LongitudeCoordinate
|
|
29
|
+
from .coordinates import PointCoordinate
|
|
29
30
|
from .coordinates import ScalarCoordinate
|
|
30
31
|
from .coordinates import StepCoordinate
|
|
31
32
|
from .coordinates import TimeCoordinate
|
|
@@ -134,6 +135,10 @@ class CoordinateGuesser(ABC):
|
|
|
134
135
|
|
|
135
136
|
d: Optional[Coordinate] = None
|
|
136
137
|
|
|
138
|
+
d = self._is_point(coordinate, attributes)
|
|
139
|
+
if d is not None:
|
|
140
|
+
return d
|
|
141
|
+
|
|
137
142
|
d = self._is_longitude(coordinate, attributes)
|
|
138
143
|
if d is not None:
|
|
139
144
|
return d
|
|
@@ -308,9 +313,9 @@ class CoordinateGuesser(ABC):
|
|
|
308
313
|
return self._grid_cache[(x.name, y.name, dim_vars)]
|
|
309
314
|
|
|
310
315
|
grid_mapping = variable.attrs.get("grid_mapping", None)
|
|
311
|
-
if grid_mapping is not None:
|
|
312
|
-
|
|
313
|
-
|
|
316
|
+
# if grid_mapping is not None:
|
|
317
|
+
# print(f"grid_mapping: {grid_mapping}")
|
|
318
|
+
# print(self.ds[grid_mapping])
|
|
314
319
|
|
|
315
320
|
if grid_mapping is None:
|
|
316
321
|
LOG.warning(f"No 'grid_mapping' attribute provided for '{variable.name}'")
|
|
@@ -392,6 +397,10 @@ class CoordinateGuesser(ABC):
|
|
|
392
397
|
"""
|
|
393
398
|
pass
|
|
394
399
|
|
|
400
|
+
@abstractmethod
|
|
401
|
+
def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]:
|
|
402
|
+
pass
|
|
403
|
+
|
|
395
404
|
@abstractmethod
|
|
396
405
|
def _is_latitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[LatitudeCoordinate]:
|
|
397
406
|
"""Checks if the coordinate is a latitude.
|
|
@@ -550,6 +559,15 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
|
|
|
550
559
|
"""
|
|
551
560
|
super().__init__(ds)
|
|
552
561
|
|
|
562
|
+
def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]:
|
|
563
|
+
if attributes.standard_name in ["cell", "station", "poi", "point"]:
|
|
564
|
+
return PointCoordinate(c)
|
|
565
|
+
|
|
566
|
+
if attributes.name in ["cell", "station", "poi", "point"]: # WeatherBench
|
|
567
|
+
return PointCoordinate(c)
|
|
568
|
+
|
|
569
|
+
return None
|
|
570
|
+
|
|
553
571
|
def _is_longitude(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[LongitudeCoordinate]:
|
|
554
572
|
"""Checks if the coordinate is a longitude.
|
|
555
573
|
|
|
@@ -750,6 +768,9 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
|
|
|
750
768
|
if attributes.standard_name == "air_pressure" and attributes.units == "hPa":
|
|
751
769
|
return LevelCoordinate(c, "pl")
|
|
752
770
|
|
|
771
|
+
if attributes.long_name == "pressure" and attributes.units in ["hPa", "Pa"]:
|
|
772
|
+
return LevelCoordinate(c, "pl")
|
|
773
|
+
|
|
753
774
|
if attributes.name == "level":
|
|
754
775
|
return LevelCoordinate(c, "pl")
|
|
755
776
|
|
|
@@ -759,9 +780,6 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
|
|
|
759
780
|
if attributes.standard_name == "depth":
|
|
760
781
|
return LevelCoordinate(c, "depth")
|
|
761
782
|
|
|
762
|
-
if attributes.name == "vertical" and attributes.units == "hPa":
|
|
763
|
-
return LevelCoordinate(c, "pl")
|
|
764
|
-
|
|
765
783
|
return None
|
|
766
784
|
|
|
767
785
|
def _is_number(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[EnsembleCoordinate]:
|
|
@@ -1040,3 +1058,23 @@ class FlavourCoordinateGuesser(CoordinateGuesser):
|
|
|
1040
1058
|
return EnsembleCoordinate(c)
|
|
1041
1059
|
|
|
1042
1060
|
return None
|
|
1061
|
+
|
|
1062
|
+
def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> Optional[PointCoordinate]:
|
|
1063
|
+
"""Checks if the coordinate is a point coordinate using the flavour rules.
|
|
1064
|
+
|
|
1065
|
+
Parameters
|
|
1066
|
+
----------
|
|
1067
|
+
c : xr.DataArray
|
|
1068
|
+
The coordinate to check.
|
|
1069
|
+
attributes : CoordinateAttributes
|
|
1070
|
+
The attributes of the coordinate.
|
|
1071
|
+
|
|
1072
|
+
Returns
|
|
1073
|
+
-------
|
|
1074
|
+
Optional[PointCoordinate]
|
|
1075
|
+
The StepCoorPointCoordinateinate if matched, else None.
|
|
1076
|
+
"""
|
|
1077
|
+
if self._match(c, "point", attributes):
|
|
1078
|
+
return PointCoordinate(c)
|
|
1079
|
+
|
|
1080
|
+
return None
|
|
@@ -61,9 +61,50 @@ def patch_coordinates(ds: xr.Dataset, coordinates: List[str]) -> Any:
|
|
|
61
61
|
return ds
|
|
62
62
|
|
|
63
63
|
|
|
64
|
+
def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any:
|
|
65
|
+
"""Rename variables in the dataset.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
ds : xr.Dataset
|
|
70
|
+
The dataset to patch.
|
|
71
|
+
renames : dict[str, str]
|
|
72
|
+
Mapping from old variable names to new variable names.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
Any
|
|
77
|
+
The patched dataset.
|
|
78
|
+
"""
|
|
79
|
+
return ds.rename(renames)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: List[str]) -> Any:
|
|
83
|
+
"""Sort the coordinates of the dataset.
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
ds : xr.Dataset
|
|
88
|
+
The dataset to patch.
|
|
89
|
+
sort_coordinates : List[str]
|
|
90
|
+
The coordinates to sort.
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
Any
|
|
95
|
+
The patched dataset.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
for name in sort_coordinates:
|
|
99
|
+
ds = ds.sortby(name)
|
|
100
|
+
return ds
|
|
101
|
+
|
|
102
|
+
|
|
64
103
|
PATCHES = {
|
|
65
104
|
"attributes": patch_attributes,
|
|
66
105
|
"coordinates": patch_coordinates,
|
|
106
|
+
"rename": patch_rename,
|
|
107
|
+
"sort_coordinates": patch_sort_coordinate,
|
|
67
108
|
}
|
|
68
109
|
|
|
69
110
|
|
|
@@ -82,7 +123,9 @@ def patch_dataset(ds: xr.Dataset, patch: Dict[str, Dict[str, Any]]) -> Any:
|
|
|
82
123
|
Any
|
|
83
124
|
The patched dataset.
|
|
84
125
|
"""
|
|
85
|
-
|
|
126
|
+
|
|
127
|
+
ORDER = ["coordinates", "attributes", "rename", "sort_coordinates"]
|
|
128
|
+
for what, values in sorted(patch.items(), key=lambda x: ORDER.index(x[0])):
|
|
86
129
|
if what not in PATCHES:
|
|
87
130
|
raise ValueError(f"Unknown patch type {what!r}")
|
|
88
131
|
|
|
@@ -82,8 +82,12 @@ class Variable:
|
|
|
82
82
|
|
|
83
83
|
self.time = time
|
|
84
84
|
|
|
85
|
-
self.shape = tuple(
|
|
86
|
-
|
|
85
|
+
self.shape = tuple(
|
|
86
|
+
len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid and not c.is_point
|
|
87
|
+
)
|
|
88
|
+
self.names = {
|
|
89
|
+
c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid and not c.is_point
|
|
90
|
+
}
|
|
87
91
|
self.by_name = {c.variable.name: c for c in coordinates}
|
|
88
92
|
|
|
89
93
|
# We need that alias for the time dimension
|