anemoi-datasets 0.4.5__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/create.py +3 -2
- anemoi/datasets/commands/inspect.py +1 -1
- anemoi/datasets/commands/publish.py +30 -0
- anemoi/datasets/create/__init__.py +72 -35
- anemoi/datasets/create/check.py +6 -0
- anemoi/datasets/create/config.py +4 -3
- anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py +57 -0
- anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py +57 -0
- anemoi/datasets/create/functions/filters/rename.py +2 -3
- anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py +54 -0
- anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py +59 -0
- anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py +115 -0
- anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py +390 -0
- anemoi/datasets/create/functions/filters/speeddir_to_uv.py +77 -0
- anemoi/datasets/create/functions/filters/uv_to_speeddir.py +55 -0
- anemoi/datasets/create/functions/sources/__init__.py +7 -1
- anemoi/datasets/create/functions/sources/accumulations.py +2 -0
- anemoi/datasets/create/functions/sources/grib.py +87 -2
- anemoi/datasets/create/functions/sources/hindcasts.py +14 -73
- anemoi/datasets/create/functions/sources/mars.py +9 -3
- anemoi/datasets/create/functions/sources/xarray/__init__.py +6 -1
- anemoi/datasets/create/functions/sources/xarray/coordinates.py +6 -1
- anemoi/datasets/create/functions/sources/xarray/field.py +20 -5
- anemoi/datasets/create/functions/sources/xarray/fieldlist.py +16 -16
- anemoi/datasets/create/functions/sources/xarray/flavour.py +126 -12
- anemoi/datasets/create/functions/sources/xarray/grid.py +106 -17
- anemoi/datasets/create/functions/sources/xarray/metadata.py +6 -12
- anemoi/datasets/create/functions/sources/xarray/time.py +1 -5
- anemoi/datasets/create/functions/sources/xarray/variable.py +10 -10
- anemoi/datasets/create/input/__init__.py +69 -0
- anemoi/datasets/create/input/action.py +123 -0
- anemoi/datasets/create/input/concat.py +92 -0
- anemoi/datasets/create/input/context.py +59 -0
- anemoi/datasets/create/input/data_sources.py +71 -0
- anemoi/datasets/create/input/empty.py +42 -0
- anemoi/datasets/create/input/filter.py +76 -0
- anemoi/datasets/create/input/function.py +122 -0
- anemoi/datasets/create/input/join.py +57 -0
- anemoi/datasets/create/input/misc.py +85 -0
- anemoi/datasets/create/input/pipe.py +33 -0
- anemoi/datasets/create/input/repeated_dates.py +217 -0
- anemoi/datasets/create/input/result.py +413 -0
- anemoi/datasets/create/input/step.py +99 -0
- anemoi/datasets/create/{template.py → input/template.py} +0 -42
- anemoi/datasets/create/persistent.py +1 -1
- anemoi/datasets/create/statistics/__init__.py +1 -1
- anemoi/datasets/create/utils.py +3 -0
- anemoi/datasets/create/zarr.py +4 -2
- anemoi/datasets/data/dataset.py +11 -1
- anemoi/datasets/data/debug.py +5 -1
- anemoi/datasets/data/masked.py +2 -2
- anemoi/datasets/data/rescale.py +147 -0
- anemoi/datasets/data/stores.py +20 -7
- anemoi/datasets/dates/__init__.py +113 -30
- anemoi/datasets/dates/groups.py +92 -19
- anemoi/datasets/fields.py +66 -0
- anemoi/datasets/utils/fields.py +47 -0
- {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/METADATA +10 -19
- anemoi_datasets-0.5.5.dist-info/RECORD +121 -0
- {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/WHEEL +1 -1
- anemoi/datasets/create/input.py +0 -1065
- anemoi_datasets-0.4.5.dist-info/RECORD +0 -96
- /anemoi/datasets/create/{trace.py → input/trace.py} +0 -0
- {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/LICENSE +0 -0
- {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/top_level.txt +0 -0
|
@@ -8,39 +8,128 @@
|
|
|
8
8
|
#
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
import logging
|
|
12
|
+
from functools import cached_property
|
|
13
|
+
|
|
11
14
|
import numpy as np
|
|
12
15
|
|
|
16
|
+
LOG = logging.getLogger(__name__)
|
|
17
|
+
|
|
13
18
|
|
|
14
19
|
class Grid:
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
pass
|
|
18
23
|
|
|
19
24
|
@property
|
|
20
25
|
def latitudes(self):
|
|
21
|
-
return self.grid_points
|
|
26
|
+
return self.grid_points[0]
|
|
22
27
|
|
|
23
28
|
@property
|
|
24
29
|
def longitudes(self):
|
|
25
|
-
return self.grid_points
|
|
30
|
+
return self.grid_points[1]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LatLonGrid(Grid):
|
|
34
|
+
def __init__(self, lat, lon, variable_dims):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.lat = lat
|
|
37
|
+
self.lon = lon
|
|
38
|
+
|
|
26
39
|
|
|
40
|
+
class XYGrid(Grid):
|
|
41
|
+
def __init__(self, x, y):
|
|
42
|
+
self.x = x
|
|
43
|
+
self.y = y
|
|
27
44
|
|
|
28
|
-
class MeshedGrid(Grid):
|
|
29
|
-
_cache = None
|
|
30
45
|
|
|
46
|
+
class MeshedGrid(LatLonGrid):
|
|
47
|
+
|
|
48
|
+
@cached_property
|
|
31
49
|
def grid_points(self):
|
|
32
|
-
if self._cache is not None:
|
|
33
|
-
return self._cache
|
|
34
|
-
lat = self.lat.variable.values
|
|
35
|
-
lon = self.lon.variable.values
|
|
36
50
|
|
|
37
|
-
lat, lon = np.meshgrid(
|
|
38
|
-
|
|
39
|
-
|
|
51
|
+
lat, lon = np.meshgrid(
|
|
52
|
+
self.lat.variable.values,
|
|
53
|
+
self.lon.variable.values,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return lat.flatten(), lon.flatten()
|
|
57
|
+
|
|
40
58
|
|
|
59
|
+
class UnstructuredGrid(LatLonGrid):
|
|
41
60
|
|
|
42
|
-
|
|
61
|
+
def __init__(self, lat, lon, variable_dims):
|
|
62
|
+
super().__init__(lat, lon, variable_dims)
|
|
63
|
+
assert len(lat) == len(lon), (len(lat), len(lon))
|
|
64
|
+
self.variable_dims = variable_dims
|
|
65
|
+
self.grid_dims = lat.variable.dims
|
|
66
|
+
assert lon.variable.dims == self.grid_dims, (lon.variable.dims, self.grid_dims)
|
|
67
|
+
assert set(self.variable_dims) == set(self.grid_dims), (self.variable_dims, self.grid_dims)
|
|
68
|
+
|
|
69
|
+
@cached_property
|
|
43
70
|
def grid_points(self):
|
|
44
|
-
|
|
45
|
-
|
|
71
|
+
|
|
72
|
+
assert 1 <= len(self.variable_dims) <= 2
|
|
73
|
+
|
|
74
|
+
if len(self.variable_dims) == 1:
|
|
75
|
+
return self.lat.variable.values.flatten(), self.lon.variable.values.flatten()
|
|
76
|
+
|
|
77
|
+
if len(self.variable_dims) == 2 and self.variable_dims == self.grid_dims:
|
|
78
|
+
return self.lat.variable.values.flatten(), self.lon.variable.values.flatten()
|
|
79
|
+
|
|
80
|
+
LOG.warning(
|
|
81
|
+
"UnstructuredGrid: variable indexing %s does not match grid indexing %s", self.variable_dims, self.grid_dims
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
lat = self.lat.variable.values.transpose().flatten()
|
|
85
|
+
lon = self.lon.variable.values.transpose().flatten()
|
|
86
|
+
|
|
46
87
|
return lat, lon
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ProjectionGrid(XYGrid):
|
|
91
|
+
def __init__(self, x, y, projection):
|
|
92
|
+
super().__init__(x, y)
|
|
93
|
+
self.projection = projection
|
|
94
|
+
|
|
95
|
+
def transformer(self):
|
|
96
|
+
from pyproj import CRS
|
|
97
|
+
from pyproj import Transformer
|
|
98
|
+
|
|
99
|
+
if isinstance(self.projection, dict):
|
|
100
|
+
data_crs = CRS.from_cf(self.projection)
|
|
101
|
+
else:
|
|
102
|
+
data_crs = self.projection
|
|
103
|
+
wgs84_crs = CRS.from_epsg(4326) # WGS84
|
|
104
|
+
|
|
105
|
+
return Transformer.from_crs(data_crs, wgs84_crs, always_xy=True)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class MeshProjectionGrid(ProjectionGrid):
|
|
109
|
+
|
|
110
|
+
@cached_property
|
|
111
|
+
def grid_points(self):
|
|
112
|
+
|
|
113
|
+
transformer = self.transformer()
|
|
114
|
+
xv, yv = np.meshgrid(self.x.variable.values, self.y.variable.values) # , indexing="ij")
|
|
115
|
+
lon, lat = transformer.transform(xv, yv)
|
|
116
|
+
return lat.flatten(), lon.flatten()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class UnstructuredProjectionGrid(XYGrid):
|
|
120
|
+
@cached_property
|
|
121
|
+
def grid_points(self):
|
|
122
|
+
assert False, "Not implemented"
|
|
123
|
+
|
|
124
|
+
# lat, lon = transformer.transform(
|
|
125
|
+
# self.y.variable.values.flatten(),
|
|
126
|
+
# self.x.variable.values.flatten(),
|
|
127
|
+
|
|
128
|
+
# )
|
|
129
|
+
|
|
130
|
+
# lat = lat[::len(lat)//100]
|
|
131
|
+
# lon = lon[::len(lon)//100]
|
|
132
|
+
|
|
133
|
+
# print(len(lat), len(lon))
|
|
134
|
+
|
|
135
|
+
# return np.meshgrid(lat, lon)
|
|
@@ -40,7 +40,9 @@ class _MDMapping:
|
|
|
40
40
|
return f"MDMapping({self.mapping})"
|
|
41
41
|
|
|
42
42
|
def fill_time_metadata(self, field, md):
|
|
43
|
-
|
|
43
|
+
valid_datetime = self.variable.time.fill_time_metadata(field._md, md)
|
|
44
|
+
if valid_datetime is not None:
|
|
45
|
+
md["valid_datetime"] = as_datetime(valid_datetime).isoformat()
|
|
44
46
|
|
|
45
47
|
|
|
46
48
|
class XArrayMetadata(RawMetadata):
|
|
@@ -70,15 +72,7 @@ class XArrayMetadata(RawMetadata):
|
|
|
70
72
|
return self._as_mars()
|
|
71
73
|
|
|
72
74
|
def _as_mars(self):
|
|
73
|
-
return
|
|
74
|
-
param=self["variable"],
|
|
75
|
-
step=self["step"],
|
|
76
|
-
levelist=self["level"],
|
|
77
|
-
levtype=self["levtype"],
|
|
78
|
-
number=self["number"],
|
|
79
|
-
date=self["date"],
|
|
80
|
-
time=self["time"],
|
|
81
|
-
)
|
|
75
|
+
return {}
|
|
82
76
|
|
|
83
77
|
def _base_datetime(self):
|
|
84
78
|
return self._field.forecast_reference_time
|
|
@@ -135,12 +129,12 @@ class XArrayFieldGeography(Geography):
|
|
|
135
129
|
# TODO: implement resolution
|
|
136
130
|
return None
|
|
137
131
|
|
|
138
|
-
@property
|
|
132
|
+
# @property
|
|
139
133
|
def mars_grid(self):
|
|
140
134
|
# TODO: implement mars_grid
|
|
141
135
|
return None
|
|
142
136
|
|
|
143
|
-
@property
|
|
137
|
+
# @property
|
|
144
138
|
def mars_area(self):
|
|
145
139
|
# TODO: code me
|
|
146
140
|
# return [self.north, self.west, self.south, self.east]
|
|
@@ -42,11 +42,7 @@ class Time:
|
|
|
42
42
|
class Constant(Time):
|
|
43
43
|
|
|
44
44
|
def fill_time_metadata(self, coords_values, metadata):
|
|
45
|
-
|
|
46
|
-
# print("Constant", coords_values, metadata)
|
|
47
|
-
# metadata["date"] = time.strftime("%Y%m%d")
|
|
48
|
-
# metadata["time"] = time.strftime("%H%M")
|
|
49
|
-
# metadata["step"] = 0
|
|
45
|
+
return None
|
|
50
46
|
|
|
51
47
|
|
|
52
48
|
class Analysis(Time):
|
|
@@ -24,7 +24,7 @@ class Variable:
|
|
|
24
24
|
self,
|
|
25
25
|
*,
|
|
26
26
|
ds,
|
|
27
|
-
|
|
27
|
+
variable,
|
|
28
28
|
coordinates,
|
|
29
29
|
grid,
|
|
30
30
|
time,
|
|
@@ -32,13 +32,13 @@ class Variable:
|
|
|
32
32
|
array_backend=None,
|
|
33
33
|
):
|
|
34
34
|
self.ds = ds
|
|
35
|
-
self.
|
|
35
|
+
self.variable = variable
|
|
36
36
|
|
|
37
37
|
self.grid = grid
|
|
38
38
|
self.coordinates = coordinates
|
|
39
39
|
|
|
40
40
|
self._metadata = metadata.copy()
|
|
41
|
-
self._metadata.update({"variable":
|
|
41
|
+
self._metadata.update({"variable": variable.name})
|
|
42
42
|
|
|
43
43
|
self.time = time
|
|
44
44
|
|
|
@@ -51,20 +51,20 @@ class Variable:
|
|
|
51
51
|
|
|
52
52
|
@property
|
|
53
53
|
def name(self):
|
|
54
|
-
return self.
|
|
54
|
+
return self.variable.name
|
|
55
55
|
|
|
56
56
|
def __len__(self):
|
|
57
57
|
return self.length
|
|
58
58
|
|
|
59
59
|
@property
|
|
60
60
|
def grid_mapping(self):
|
|
61
|
-
grid_mapping = self.
|
|
61
|
+
grid_mapping = self.variable.attrs.get("grid_mapping", None)
|
|
62
62
|
if grid_mapping is None:
|
|
63
63
|
return None
|
|
64
64
|
return self.ds[grid_mapping].attrs
|
|
65
65
|
|
|
66
66
|
def grid_points(self):
|
|
67
|
-
return self.grid.grid_points
|
|
67
|
+
return self.grid.grid_points
|
|
68
68
|
|
|
69
69
|
@property
|
|
70
70
|
def latitudes(self):
|
|
@@ -76,7 +76,7 @@ class Variable:
|
|
|
76
76
|
|
|
77
77
|
def __repr__(self):
|
|
78
78
|
return "Variable[name=%s,coordinates=%s,metadata=%s]" % (
|
|
79
|
-
self.
|
|
79
|
+
self.variable.name,
|
|
80
80
|
self.coordinates,
|
|
81
81
|
self._metadata,
|
|
82
82
|
)
|
|
@@ -90,7 +90,7 @@ class Variable:
|
|
|
90
90
|
|
|
91
91
|
coords = np.unravel_index(i, self.shape)
|
|
92
92
|
kwargs = {k: v for k, v in zip(self.names, coords)}
|
|
93
|
-
return XArrayField(self, self.
|
|
93
|
+
return XArrayField(self, self.variable.isel(kwargs))
|
|
94
94
|
|
|
95
95
|
def sel(self, missing, **kwargs):
|
|
96
96
|
|
|
@@ -117,7 +117,7 @@ class Variable:
|
|
|
117
117
|
|
|
118
118
|
variable = Variable(
|
|
119
119
|
ds=self.ds,
|
|
120
|
-
var=self.
|
|
120
|
+
var=self.variable.isel({k: i}),
|
|
121
121
|
coordinates=coordinates,
|
|
122
122
|
grid=self.grid,
|
|
123
123
|
time=self.time,
|
|
@@ -136,7 +136,7 @@ class Variable:
|
|
|
136
136
|
name = kwargs.pop("variable")
|
|
137
137
|
if not isinstance(name, (list, tuple)):
|
|
138
138
|
name = [name]
|
|
139
|
-
if self.
|
|
139
|
+
if self.variable.name not in name:
|
|
140
140
|
return False, None
|
|
141
141
|
return True, kwargs
|
|
142
142
|
return True, kwargs
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# (C) Copyright 2023 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
#
|
|
9
|
+
import datetime
|
|
10
|
+
import itertools
|
|
11
|
+
import logging
|
|
12
|
+
import math
|
|
13
|
+
import time
|
|
14
|
+
from collections import defaultdict
|
|
15
|
+
from copy import deepcopy
|
|
16
|
+
from functools import cached_property
|
|
17
|
+
from functools import wraps
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
from anemoi.utils.dates import as_datetime as as_datetime
|
|
21
|
+
from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta
|
|
22
|
+
|
|
23
|
+
from anemoi.datasets.dates import DatesProvider as DatesProvider
|
|
24
|
+
from anemoi.datasets.fields import FieldArray as FieldArray
|
|
25
|
+
from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField
|
|
26
|
+
|
|
27
|
+
from .trace import trace_select
|
|
28
|
+
|
|
29
|
+
LOG = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class InputBuilder:
|
|
33
|
+
def __init__(self, config, data_sources, **kwargs):
|
|
34
|
+
self.kwargs = kwargs
|
|
35
|
+
|
|
36
|
+
config = deepcopy(config)
|
|
37
|
+
if data_sources:
|
|
38
|
+
config = dict(
|
|
39
|
+
data_sources=dict(
|
|
40
|
+
sources=data_sources,
|
|
41
|
+
input=config,
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
self.config = config
|
|
45
|
+
self.action_path = ["input"]
|
|
46
|
+
|
|
47
|
+
@trace_select
|
|
48
|
+
def select(self, group_of_dates):
|
|
49
|
+
from .action import ActionContext
|
|
50
|
+
from .action import action_factory
|
|
51
|
+
|
|
52
|
+
"""This changes the context."""
|
|
53
|
+
context = ActionContext(**self.kwargs)
|
|
54
|
+
action = action_factory(self.config, context, self.action_path)
|
|
55
|
+
return action.select(group_of_dates)
|
|
56
|
+
|
|
57
|
+
def __repr__(self):
|
|
58
|
+
from .action import ActionContext
|
|
59
|
+
from .action import action_factory
|
|
60
|
+
|
|
61
|
+
context = ActionContext(**self.kwargs)
|
|
62
|
+
a = action_factory(self.config, context, self.action_path)
|
|
63
|
+
return repr(a)
|
|
64
|
+
|
|
65
|
+
def _trace_select(self, group_of_dates):
|
|
66
|
+
return f"InputBuilder({group_of_dates})"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
build_input = InputBuilder
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# (C) Copyright 2024 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
#
|
|
9
|
+
import logging
|
|
10
|
+
from copy import deepcopy
|
|
11
|
+
|
|
12
|
+
from anemoi.utils.dates import as_datetime as as_datetime
|
|
13
|
+
from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta
|
|
14
|
+
from earthkit.data.core.order import build_remapping
|
|
15
|
+
|
|
16
|
+
from anemoi.datasets.dates import DatesProvider as DatesProvider
|
|
17
|
+
from anemoi.datasets.fields import FieldArray as FieldArray
|
|
18
|
+
from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField
|
|
19
|
+
|
|
20
|
+
from .context import Context
|
|
21
|
+
from .misc import is_function
|
|
22
|
+
|
|
23
|
+
LOG = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Action:
|
|
27
|
+
def __init__(self, context, action_path, /, *args, **kwargs):
|
|
28
|
+
if "args" in kwargs and "kwargs" in kwargs:
|
|
29
|
+
"""We have:
|
|
30
|
+
args = []
|
|
31
|
+
kwargs = {args: [...], kwargs: {...}}
|
|
32
|
+
move the content of kwargs to args and kwargs.
|
|
33
|
+
"""
|
|
34
|
+
assert len(kwargs) == 2, (args, kwargs)
|
|
35
|
+
assert not args, (args, kwargs)
|
|
36
|
+
args = kwargs.pop("args")
|
|
37
|
+
kwargs = kwargs.pop("kwargs")
|
|
38
|
+
|
|
39
|
+
assert isinstance(context, ActionContext), type(context)
|
|
40
|
+
self.context = context
|
|
41
|
+
self.kwargs = kwargs
|
|
42
|
+
self.args = args
|
|
43
|
+
self.action_path = action_path
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def _short_str(cls, x):
|
|
47
|
+
x = str(x)
|
|
48
|
+
if len(x) < 1000:
|
|
49
|
+
return x
|
|
50
|
+
return x[:1000] + "..."
|
|
51
|
+
|
|
52
|
+
def __repr__(self, *args, _indent_="\n", _inline_="", **kwargs):
|
|
53
|
+
more = ",".join([str(a)[:5000] for a in args])
|
|
54
|
+
more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()])
|
|
55
|
+
|
|
56
|
+
more = more[:5000]
|
|
57
|
+
txt = f"{self.__class__.__name__}: {_inline_}{_indent_}{more}"
|
|
58
|
+
if _indent_:
|
|
59
|
+
txt = txt.replace("\n", "\n ")
|
|
60
|
+
return txt
|
|
61
|
+
|
|
62
|
+
def select(self, dates, **kwargs):
|
|
63
|
+
self._raise_not_implemented()
|
|
64
|
+
|
|
65
|
+
def _raise_not_implemented(self):
|
|
66
|
+
raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
|
|
67
|
+
|
|
68
|
+
def _trace_select(self, group_of_dates):
|
|
69
|
+
return f"{self.__class__.__name__}({group_of_dates})"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ActionContext(Context):
|
|
73
|
+
def __init__(self, /, order_by, flatten_grid, remapping, use_grib_paramid):
|
|
74
|
+
super().__init__()
|
|
75
|
+
self.order_by = order_by
|
|
76
|
+
self.flatten_grid = flatten_grid
|
|
77
|
+
self.remapping = build_remapping(remapping)
|
|
78
|
+
self.use_grib_paramid = use_grib_paramid
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def action_factory(config, context, action_path):
|
|
82
|
+
|
|
83
|
+
from .concat import ConcatAction
|
|
84
|
+
from .data_sources import DataSourcesAction
|
|
85
|
+
from .function import FunctionAction
|
|
86
|
+
from .join import JoinAction
|
|
87
|
+
from .pipe import PipeAction
|
|
88
|
+
from .repeated_dates import RepeatedDatesAction
|
|
89
|
+
|
|
90
|
+
# from .data_sources import DataSourcesAction
|
|
91
|
+
|
|
92
|
+
assert isinstance(context, Context), (type, context)
|
|
93
|
+
if not isinstance(config, dict):
|
|
94
|
+
raise ValueError(f"Invalid input config {config}")
|
|
95
|
+
if len(config) != 1:
|
|
96
|
+
raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}")
|
|
97
|
+
|
|
98
|
+
config = deepcopy(config)
|
|
99
|
+
key = list(config.keys())[0]
|
|
100
|
+
|
|
101
|
+
if isinstance(config[key], list):
|
|
102
|
+
args, kwargs = config[key], {}
|
|
103
|
+
elif isinstance(config[key], dict):
|
|
104
|
+
args, kwargs = [], config[key]
|
|
105
|
+
else:
|
|
106
|
+
raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}")
|
|
107
|
+
|
|
108
|
+
cls = {
|
|
109
|
+
"data_sources": DataSourcesAction,
|
|
110
|
+
"concat": ConcatAction,
|
|
111
|
+
"join": JoinAction,
|
|
112
|
+
"pipe": PipeAction,
|
|
113
|
+
"function": FunctionAction,
|
|
114
|
+
"repeated_dates": RepeatedDatesAction,
|
|
115
|
+
}.get(key)
|
|
116
|
+
|
|
117
|
+
if cls is None:
|
|
118
|
+
if not is_function(key, "sources"):
|
|
119
|
+
raise ValueError(f"Unknown action '{key}' in {config}")
|
|
120
|
+
cls = FunctionAction
|
|
121
|
+
args = [key] + args
|
|
122
|
+
|
|
123
|
+
return cls(context, action_path + [key], *args, **kwargs)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
# (C) Copyright 2024 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
#
|
|
9
|
+
import logging
|
|
10
|
+
from copy import deepcopy
|
|
11
|
+
from functools import cached_property
|
|
12
|
+
|
|
13
|
+
from anemoi.datasets.dates import DatesProvider
|
|
14
|
+
|
|
15
|
+
from .action import Action
|
|
16
|
+
from .action import action_factory
|
|
17
|
+
from .empty import EmptyResult
|
|
18
|
+
from .misc import _tidy
|
|
19
|
+
from .misc import assert_fieldlist
|
|
20
|
+
from .result import Result
|
|
21
|
+
from .template import notify_result
|
|
22
|
+
from .trace import trace_datasource
|
|
23
|
+
from .trace import trace_select
|
|
24
|
+
|
|
25
|
+
LOG = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ConcatResult(Result):
|
|
29
|
+
def __init__(self, context, action_path, group_of_dates, results, **kwargs):
|
|
30
|
+
super().__init__(context, action_path, group_of_dates)
|
|
31
|
+
self.results = [r for r in results if not r.empty]
|
|
32
|
+
|
|
33
|
+
@cached_property
|
|
34
|
+
@assert_fieldlist
|
|
35
|
+
@notify_result
|
|
36
|
+
@trace_datasource
|
|
37
|
+
def datasource(self):
|
|
38
|
+
ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource
|
|
39
|
+
for i in self.results:
|
|
40
|
+
ds += i.datasource
|
|
41
|
+
return _tidy(ds)
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def variables(self):
|
|
45
|
+
"""Check that all the results objects have the same variables."""
|
|
46
|
+
variables = None
|
|
47
|
+
for f in self.results:
|
|
48
|
+
if f.empty:
|
|
49
|
+
continue
|
|
50
|
+
if variables is None:
|
|
51
|
+
variables = f.variables
|
|
52
|
+
assert variables == f.variables, (variables, f.variables)
|
|
53
|
+
assert variables is not None, self.results
|
|
54
|
+
return variables
|
|
55
|
+
|
|
56
|
+
def __repr__(self):
|
|
57
|
+
content = "\n".join([str(i) for i in self.results])
|
|
58
|
+
return super().__repr__(content)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ConcatAction(Action):
|
|
62
|
+
def __init__(self, context, action_path, *configs):
|
|
63
|
+
super().__init__(context, action_path, *configs)
|
|
64
|
+
parts = []
|
|
65
|
+
for i, cfg in enumerate(configs):
|
|
66
|
+
if "dates" not in cfg:
|
|
67
|
+
raise ValueError(f"Missing 'dates' in {cfg}")
|
|
68
|
+
cfg = deepcopy(cfg)
|
|
69
|
+
dates_cfg = cfg.pop("dates")
|
|
70
|
+
assert isinstance(dates_cfg, dict), dates_cfg
|
|
71
|
+
filtering_dates = DatesProvider.from_config(**dates_cfg)
|
|
72
|
+
action = action_factory(cfg, context, action_path + [str(i)])
|
|
73
|
+
parts.append((filtering_dates, action))
|
|
74
|
+
self.parts = parts
|
|
75
|
+
|
|
76
|
+
def __repr__(self):
|
|
77
|
+
content = "\n".join([str(i) for i in self.parts])
|
|
78
|
+
return super().__repr__(content)
|
|
79
|
+
|
|
80
|
+
@trace_select
|
|
81
|
+
def select(self, group_of_dates):
|
|
82
|
+
from anemoi.datasets.dates.groups import GroupOfDates
|
|
83
|
+
|
|
84
|
+
results = []
|
|
85
|
+
for filtering_dates, action in self.parts:
|
|
86
|
+
newdates = GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider)
|
|
87
|
+
if newdates:
|
|
88
|
+
results.append(action.select(newdates))
|
|
89
|
+
if not results:
|
|
90
|
+
return EmptyResult(self.context, self.action_path, group_of_dates)
|
|
91
|
+
|
|
92
|
+
return ConcatResult(self.context, self.action_path, group_of_dates, results)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# (C) Copyright 2024 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
#
|
|
9
|
+
import logging
|
|
10
|
+
import textwrap
|
|
11
|
+
|
|
12
|
+
from anemoi.utils.dates import as_datetime as as_datetime
|
|
13
|
+
from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta
|
|
14
|
+
from anemoi.utils.humanize import plural
|
|
15
|
+
|
|
16
|
+
from anemoi.datasets.dates import DatesProvider as DatesProvider
|
|
17
|
+
from anemoi.datasets.fields import FieldArray as FieldArray
|
|
18
|
+
from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField
|
|
19
|
+
|
|
20
|
+
from .trace import step
|
|
21
|
+
from .trace import trace
|
|
22
|
+
|
|
23
|
+
LOG = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Context:
|
|
27
|
+
def __init__(self):
|
|
28
|
+
# used_references is a set of reference paths that will be needed
|
|
29
|
+
self.used_references = set()
|
|
30
|
+
# results is a dictionary of reference path -> obj
|
|
31
|
+
self.results = {}
|
|
32
|
+
|
|
33
|
+
def will_need_reference(self, key):
|
|
34
|
+
assert isinstance(key, (list, tuple)), key
|
|
35
|
+
key = tuple(key)
|
|
36
|
+
self.used_references.add(key)
|
|
37
|
+
|
|
38
|
+
def notify_result(self, key, result):
|
|
39
|
+
trace(
|
|
40
|
+
"🎯",
|
|
41
|
+
step(key),
|
|
42
|
+
"notify result",
|
|
43
|
+
textwrap.shorten(repr(result).replace(",", ", "), width=40),
|
|
44
|
+
plural(len(result), "field"),
|
|
45
|
+
)
|
|
46
|
+
assert isinstance(key, (list, tuple)), key
|
|
47
|
+
key = tuple(key)
|
|
48
|
+
if key in self.used_references:
|
|
49
|
+
if key in self.results:
|
|
50
|
+
raise ValueError(f"Duplicate result {key}")
|
|
51
|
+
self.results[key] = result
|
|
52
|
+
|
|
53
|
+
def get_result(self, key):
|
|
54
|
+
assert isinstance(key, (list, tuple)), key
|
|
55
|
+
key = tuple(key)
|
|
56
|
+
if key in self.results:
|
|
57
|
+
return self.results[key]
|
|
58
|
+
all_keys = sorted(list(self.results.keys()))
|
|
59
|
+
raise ValueError(f"Cannot find result {key} in {all_keys}")
|