anemoi-datasets 0.5.0__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/inspect.py +1 -1
- anemoi/datasets/commands/publish.py +30 -0
- anemoi/datasets/create/__init__.py +42 -3
- anemoi/datasets/create/check.py +6 -0
- anemoi/datasets/create/functions/filters/rename.py +2 -3
- anemoi/datasets/create/functions/sources/__init__.py +7 -1
- anemoi/datasets/create/functions/sources/accumulations.py +2 -0
- anemoi/datasets/create/functions/sources/grib.py +1 -1
- anemoi/datasets/create/functions/sources/xarray/__init__.py +7 -2
- anemoi/datasets/create/functions/sources/xarray/coordinates.py +12 -1
- anemoi/datasets/create/functions/sources/xarray/field.py +13 -4
- anemoi/datasets/create/functions/sources/xarray/fieldlist.py +16 -16
- anemoi/datasets/create/functions/sources/xarray/flavour.py +130 -13
- anemoi/datasets/create/functions/sources/xarray/grid.py +106 -17
- anemoi/datasets/create/functions/sources/xarray/metadata.py +3 -11
- anemoi/datasets/create/functions/sources/xarray/time.py +1 -5
- anemoi/datasets/create/functions/sources/xarray/variable.py +10 -10
- anemoi/datasets/create/input/__init__.py +69 -0
- anemoi/datasets/create/input/action.py +123 -0
- anemoi/datasets/create/input/concat.py +92 -0
- anemoi/datasets/create/input/context.py +59 -0
- anemoi/datasets/create/input/data_sources.py +71 -0
- anemoi/datasets/create/input/empty.py +42 -0
- anemoi/datasets/create/input/filter.py +76 -0
- anemoi/datasets/create/input/function.py +122 -0
- anemoi/datasets/create/input/join.py +57 -0
- anemoi/datasets/create/input/misc.py +85 -0
- anemoi/datasets/create/input/pipe.py +33 -0
- anemoi/datasets/create/input/repeated_dates.py +217 -0
- anemoi/datasets/create/input/result.py +413 -0
- anemoi/datasets/create/input/step.py +99 -0
- anemoi/datasets/create/{template.py → input/template.py} +0 -42
- anemoi/datasets/create/statistics/__init__.py +1 -1
- anemoi/datasets/create/zarr.py +4 -2
- anemoi/datasets/dates/__init__.py +1 -0
- anemoi/datasets/dates/groups.py +12 -4
- anemoi/datasets/fields.py +66 -0
- anemoi/datasets/utils/fields.py +47 -0
- {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/METADATA +1 -1
- {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/RECORD +46 -30
- anemoi/datasets/create/input.py +0 -1087
- /anemoi/datasets/create/{trace.py → input/trace.py} +0 -0
- {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/LICENSE +0 -0
- {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/WHEEL +0 -0
- {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# (C) Copyright 2024 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
#
|
|
9
|
+
import logging
|
|
10
|
+
from functools import cached_property
|
|
11
|
+
|
|
12
|
+
from anemoi.utils.dates import as_datetime as as_datetime
|
|
13
|
+
from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta
|
|
14
|
+
|
|
15
|
+
from anemoi.datasets.dates import DatesProvider as DatesProvider
|
|
16
|
+
from anemoi.datasets.fields import FieldArray as FieldArray
|
|
17
|
+
from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField
|
|
18
|
+
|
|
19
|
+
from .misc import assert_fieldlist
|
|
20
|
+
from .result import Result
|
|
21
|
+
from .trace import trace_datasource
|
|
22
|
+
|
|
23
|
+
LOG = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class EmptyResult(Result):
|
|
27
|
+
empty = True
|
|
28
|
+
|
|
29
|
+
def __init__(self, context, action_path, dates):
|
|
30
|
+
super().__init__(context, action_path + ["empty"], dates)
|
|
31
|
+
|
|
32
|
+
@cached_property
|
|
33
|
+
@assert_fieldlist
|
|
34
|
+
@trace_datasource
|
|
35
|
+
def datasource(self):
|
|
36
|
+
from earthkit.data import from_source
|
|
37
|
+
|
|
38
|
+
return from_source("empty")
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def variables(self):
|
|
42
|
+
return []
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# (C) Copyright 2024 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
#
|
|
9
|
+
import logging
|
|
10
|
+
from functools import cached_property
|
|
11
|
+
|
|
12
|
+
from anemoi.utils.dates import as_datetime as as_datetime
|
|
13
|
+
from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta
|
|
14
|
+
|
|
15
|
+
from anemoi.datasets.dates import DatesProvider as DatesProvider
|
|
16
|
+
from anemoi.datasets.fields import FieldArray as FieldArray
|
|
17
|
+
from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField
|
|
18
|
+
|
|
19
|
+
from ..functions import import_function
|
|
20
|
+
from .function import FunctionContext
|
|
21
|
+
from .misc import _tidy
|
|
22
|
+
from .misc import assert_fieldlist
|
|
23
|
+
from .step import StepAction
|
|
24
|
+
from .step import StepResult
|
|
25
|
+
from .template import notify_result
|
|
26
|
+
from .trace import trace_datasource
|
|
27
|
+
|
|
28
|
+
LOG = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class FilterStepResult(StepResult):
|
|
32
|
+
@property
|
|
33
|
+
@notify_result
|
|
34
|
+
@assert_fieldlist
|
|
35
|
+
@trace_datasource
|
|
36
|
+
def datasource(self):
|
|
37
|
+
ds = self.upstream_result.datasource
|
|
38
|
+
ds = ds.sel(**self.action.kwargs)
|
|
39
|
+
return _tidy(ds)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class FilterStepAction(StepAction):
|
|
43
|
+
result_class = FilterStepResult
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class StepFunctionResult(StepResult):
|
|
47
|
+
@cached_property
|
|
48
|
+
@assert_fieldlist
|
|
49
|
+
@notify_result
|
|
50
|
+
@trace_datasource
|
|
51
|
+
def datasource(self):
|
|
52
|
+
try:
|
|
53
|
+
return _tidy(
|
|
54
|
+
self.action.function(
|
|
55
|
+
FunctionContext(self),
|
|
56
|
+
self.upstream_result.datasource,
|
|
57
|
+
*self.action.args[1:],
|
|
58
|
+
**self.action.kwargs,
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
except Exception:
|
|
63
|
+
LOG.error(f"Error in {self.action.name}", exc_info=True)
|
|
64
|
+
raise
|
|
65
|
+
|
|
66
|
+
def _trace_datasource(self, *args, **kwargs):
|
|
67
|
+
return f"{self.action.name}({self.group_of_dates})"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class FunctionStepAction(StepAction):
|
|
71
|
+
result_class = StepFunctionResult
|
|
72
|
+
|
|
73
|
+
def __init__(self, context, action_path, previous_step, *args, **kwargs):
|
|
74
|
+
super().__init__(context, action_path, previous_step, *args, **kwargs)
|
|
75
|
+
self.name = args[0]
|
|
76
|
+
self.function = import_function(self.name, "filters")
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# (C) Copyright 2024 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
#
|
|
9
|
+
import logging
|
|
10
|
+
from functools import cached_property
|
|
11
|
+
|
|
12
|
+
from anemoi.utils.dates import as_datetime as as_datetime
|
|
13
|
+
from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta
|
|
14
|
+
|
|
15
|
+
from anemoi.datasets.dates import DatesProvider as DatesProvider
|
|
16
|
+
from anemoi.datasets.fields import FieldArray as FieldArray
|
|
17
|
+
from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField
|
|
18
|
+
|
|
19
|
+
from ..functions import import_function
|
|
20
|
+
from .action import Action
|
|
21
|
+
from .misc import _tidy
|
|
22
|
+
from .misc import assert_fieldlist
|
|
23
|
+
from .result import Result
|
|
24
|
+
from .template import notify_result
|
|
25
|
+
from .template import resolve
|
|
26
|
+
from .template import substitute
|
|
27
|
+
from .trace import trace
|
|
28
|
+
from .trace import trace_datasource
|
|
29
|
+
from .trace import trace_select
|
|
30
|
+
|
|
31
|
+
LOG = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class FunctionContext:
|
|
35
|
+
"""A FunctionContext is passed to all functions, it will be used to pass information
|
|
36
|
+
to the functions from the other actions and filters and results.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, owner):
|
|
40
|
+
self.owner = owner
|
|
41
|
+
self.use_grib_paramid = owner.context.use_grib_paramid
|
|
42
|
+
|
|
43
|
+
def trace(self, emoji, *args):
|
|
44
|
+
trace(emoji, *args)
|
|
45
|
+
|
|
46
|
+
def info(self, *args, **kwargs):
|
|
47
|
+
LOG.info(*args, **kwargs)
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def dates_provider(self):
|
|
51
|
+
return self.owner.group_of_dates.provider
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def partial_ok(self):
|
|
55
|
+
return self.owner.group_of_dates.partial_ok
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class FunctionAction(Action):
|
|
59
|
+
def __init__(self, context, action_path, _name, **kwargs):
|
|
60
|
+
super().__init__(context, action_path, **kwargs)
|
|
61
|
+
self.name = _name
|
|
62
|
+
|
|
63
|
+
@trace_select
|
|
64
|
+
def select(self, group_of_dates):
|
|
65
|
+
return FunctionResult(self.context, self.action_path, group_of_dates, action=self)
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def function(self):
|
|
69
|
+
# name, delta = parse_function_name(self.name)
|
|
70
|
+
return import_function(self.name, "sources")
|
|
71
|
+
|
|
72
|
+
def __repr__(self):
|
|
73
|
+
content = ""
|
|
74
|
+
content += ",".join([self._short_str(a) for a in self.args])
|
|
75
|
+
content += " ".join([self._short_str(f"{k}={v}") for k, v in self.kwargs.items()])
|
|
76
|
+
content = self._short_str(content)
|
|
77
|
+
return super().__repr__(_inline_=content, _indent_=" ")
|
|
78
|
+
|
|
79
|
+
def _trace_select(self, group_of_dates):
|
|
80
|
+
return f"{self.name}({group_of_dates})"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class FunctionResult(Result):
|
|
84
|
+
def __init__(self, context, action_path, group_of_dates, action):
|
|
85
|
+
super().__init__(context, action_path, group_of_dates)
|
|
86
|
+
assert isinstance(action, Action), type(action)
|
|
87
|
+
self.action = action
|
|
88
|
+
|
|
89
|
+
self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs))
|
|
90
|
+
|
|
91
|
+
def _trace_datasource(self, *args, **kwargs):
|
|
92
|
+
return f"{self.action.name}({self.group_of_dates})"
|
|
93
|
+
|
|
94
|
+
@cached_property
|
|
95
|
+
@assert_fieldlist
|
|
96
|
+
@notify_result
|
|
97
|
+
@trace_datasource
|
|
98
|
+
def datasource(self):
|
|
99
|
+
args, kwargs = resolve(self.context, (self.args, self.kwargs))
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
return _tidy(
|
|
103
|
+
self.action.function(
|
|
104
|
+
FunctionContext(self),
|
|
105
|
+
list(self.group_of_dates), # Will provide a list of datetime objects
|
|
106
|
+
*args,
|
|
107
|
+
**kwargs,
|
|
108
|
+
)
|
|
109
|
+
)
|
|
110
|
+
except Exception:
|
|
111
|
+
LOG.error(f"Error in {self.action.function.__name__}", exc_info=True)
|
|
112
|
+
raise
|
|
113
|
+
|
|
114
|
+
def __repr__(self):
|
|
115
|
+
try:
|
|
116
|
+
return f"{self.action.name}({self.group_of_dates})"
|
|
117
|
+
except Exception:
|
|
118
|
+
return f"{self.__class__.__name__}(unitialised)"
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def function(self):
|
|
122
|
+
raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# (C) Copyright 2024 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
#
|
|
9
|
+
import logging
|
|
10
|
+
from functools import cached_property
|
|
11
|
+
|
|
12
|
+
from .action import Action
|
|
13
|
+
from .action import action_factory
|
|
14
|
+
from .empty import EmptyResult
|
|
15
|
+
from .misc import _tidy
|
|
16
|
+
from .misc import assert_fieldlist
|
|
17
|
+
from .result import Result
|
|
18
|
+
from .template import notify_result
|
|
19
|
+
from .trace import trace_datasource
|
|
20
|
+
from .trace import trace_select
|
|
21
|
+
|
|
22
|
+
LOG = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class JoinResult(Result):
|
|
26
|
+
def __init__(self, context, action_path, group_of_dates, results, **kwargs):
|
|
27
|
+
super().__init__(context, action_path, group_of_dates)
|
|
28
|
+
self.results = [r for r in results if not r.empty]
|
|
29
|
+
|
|
30
|
+
@cached_property
|
|
31
|
+
@assert_fieldlist
|
|
32
|
+
@notify_result
|
|
33
|
+
@trace_datasource
|
|
34
|
+
def datasource(self):
|
|
35
|
+
ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource
|
|
36
|
+
for i in self.results:
|
|
37
|
+
ds += i.datasource
|
|
38
|
+
return _tidy(ds)
|
|
39
|
+
|
|
40
|
+
def __repr__(self):
|
|
41
|
+
content = "\n".join([str(i) for i in self.results])
|
|
42
|
+
return super().__repr__(content)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class JoinAction(Action):
|
|
46
|
+
def __init__(self, context, action_path, *configs):
|
|
47
|
+
super().__init__(context, action_path, *configs)
|
|
48
|
+
self.actions = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)]
|
|
49
|
+
|
|
50
|
+
def __repr__(self):
|
|
51
|
+
content = "\n".join([str(i) for i in self.actions])
|
|
52
|
+
return super().__repr__(content)
|
|
53
|
+
|
|
54
|
+
@trace_select
|
|
55
|
+
def select(self, group_of_dates):
|
|
56
|
+
results = [a.select(group_of_dates) for a in self.actions]
|
|
57
|
+
return JoinResult(self.context, self.action_path, group_of_dates, results)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# (C) Copyright 2024 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
#
|
|
9
|
+
import logging
|
|
10
|
+
from functools import wraps
|
|
11
|
+
|
|
12
|
+
from anemoi.utils.dates import as_datetime as as_datetime
|
|
13
|
+
from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta
|
|
14
|
+
from earthkit.data.core.fieldlist import MultiFieldList
|
|
15
|
+
from earthkit.data.indexing.fieldlist import FieldList
|
|
16
|
+
|
|
17
|
+
from anemoi.datasets.dates import DatesProvider as DatesProvider
|
|
18
|
+
from anemoi.datasets.fields import FieldArray as FieldArray
|
|
19
|
+
from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField
|
|
20
|
+
|
|
21
|
+
from ..functions import import_function
|
|
22
|
+
|
|
23
|
+
LOG = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def parse_function_name(name):
|
|
27
|
+
|
|
28
|
+
if name.endswith("h") and name[:-1].isdigit():
|
|
29
|
+
|
|
30
|
+
if "-" in name:
|
|
31
|
+
name, delta = name.split("-")
|
|
32
|
+
sign = -1
|
|
33
|
+
|
|
34
|
+
elif "+" in name:
|
|
35
|
+
name, delta = name.split("+")
|
|
36
|
+
sign = 1
|
|
37
|
+
|
|
38
|
+
else:
|
|
39
|
+
return name, None
|
|
40
|
+
|
|
41
|
+
assert delta[-1] == "h", (name, delta)
|
|
42
|
+
delta = sign * int(delta[:-1])
|
|
43
|
+
return name, delta
|
|
44
|
+
|
|
45
|
+
return name, None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def is_function(name, kind):
|
|
49
|
+
name, _ = parse_function_name(name)
|
|
50
|
+
try:
|
|
51
|
+
import_function(name, kind)
|
|
52
|
+
return True
|
|
53
|
+
except ImportError as e:
|
|
54
|
+
print(e)
|
|
55
|
+
return False
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def assert_fieldlist(method):
|
|
59
|
+
@wraps(method)
|
|
60
|
+
def wrapper(self, *args, **kwargs):
|
|
61
|
+
result = method(self, *args, **kwargs)
|
|
62
|
+
assert isinstance(result, FieldList), type(result)
|
|
63
|
+
return result
|
|
64
|
+
|
|
65
|
+
return wrapper
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def assert_is_fieldlist(obj):
|
|
69
|
+
assert isinstance(obj, FieldList), type(obj)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _flatten(ds):
|
|
73
|
+
if isinstance(ds, MultiFieldList):
|
|
74
|
+
return [_tidy(f) for s in ds._indexes for f in _flatten(s)]
|
|
75
|
+
return [ds]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _tidy(ds, indent=0):
|
|
79
|
+
if isinstance(ds, MultiFieldList):
|
|
80
|
+
|
|
81
|
+
sources = [s for s in _flatten(ds) if len(s) > 0]
|
|
82
|
+
if len(sources) == 1:
|
|
83
|
+
return sources[0]
|
|
84
|
+
return MultiFieldList(sources)
|
|
85
|
+
return ds
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# (C) Copyright 2024 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
#
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
from .action import Action
|
|
12
|
+
from .action import action_factory
|
|
13
|
+
from .step import step_factory
|
|
14
|
+
from .trace import trace_select
|
|
15
|
+
|
|
16
|
+
LOG = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PipeAction(Action):
|
|
20
|
+
def __init__(self, context, action_path, *configs):
|
|
21
|
+
super().__init__(context, action_path, *configs)
|
|
22
|
+
assert len(configs) > 1, configs
|
|
23
|
+
current = action_factory(configs[0], context, action_path + ["0"])
|
|
24
|
+
for i, c in enumerate(configs[1:]):
|
|
25
|
+
current = step_factory(c, context, action_path + [str(i + 1)], previous_step=current)
|
|
26
|
+
self.last_step = current
|
|
27
|
+
|
|
28
|
+
@trace_select
|
|
29
|
+
def select(self, group_of_dates):
|
|
30
|
+
return self.last_step.select(group_of_dates)
|
|
31
|
+
|
|
32
|
+
def __repr__(self):
|
|
33
|
+
return super().__repr__(self.last_step)
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
# (C) Copyright 2023 ECMWF.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
6
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
|
+
# nor does it submit to any jurisdiction.
|
|
8
|
+
#
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from anemoi.utils.dates import as_datetime
|
|
15
|
+
from anemoi.utils.dates import frequency_to_timedelta
|
|
16
|
+
|
|
17
|
+
from anemoi.datasets.fields import FieldArray
|
|
18
|
+
from anemoi.datasets.fields import NewValidDateTimeField
|
|
19
|
+
|
|
20
|
+
from .action import Action
|
|
21
|
+
from .action import action_factory
|
|
22
|
+
from .join import JoinResult
|
|
23
|
+
from .result import Result
|
|
24
|
+
from .trace import trace_select
|
|
25
|
+
|
|
26
|
+
LOG = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class DateMapper:
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def from_mode(mode, source, config):
|
|
33
|
+
|
|
34
|
+
MODES = dict(
|
|
35
|
+
closest=DateMapperClosest,
|
|
36
|
+
climatology=DateMapperClimatology,
|
|
37
|
+
constant=DateMapperConstant,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
if mode not in MODES:
|
|
41
|
+
raise ValueError(f"Invalid mode for DateMapper: {mode}")
|
|
42
|
+
|
|
43
|
+
return MODES[mode](source, **config)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class DateMapperClosest(DateMapper):
|
|
47
|
+
def __init__(self, source, frequency="1h", maximum="30d", skip_all_nans=False):
|
|
48
|
+
self.source = source
|
|
49
|
+
self.maximum = frequency_to_timedelta(maximum)
|
|
50
|
+
self.frequency = frequency_to_timedelta(frequency)
|
|
51
|
+
self.skip_all_nans = skip_all_nans
|
|
52
|
+
self.tried = set()
|
|
53
|
+
self.found = set()
|
|
54
|
+
|
|
55
|
+
def transform(self, group_of_dates):
|
|
56
|
+
from anemoi.datasets.dates.groups import GroupOfDates
|
|
57
|
+
|
|
58
|
+
asked_dates = list(group_of_dates)
|
|
59
|
+
if not asked_dates:
|
|
60
|
+
return []
|
|
61
|
+
|
|
62
|
+
to_try = set()
|
|
63
|
+
for date in asked_dates:
|
|
64
|
+
start = date
|
|
65
|
+
while start >= date - self.maximum:
|
|
66
|
+
to_try.add(start)
|
|
67
|
+
start -= self.frequency
|
|
68
|
+
|
|
69
|
+
end = date
|
|
70
|
+
while end <= date + self.maximum:
|
|
71
|
+
to_try.add(end)
|
|
72
|
+
end += self.frequency
|
|
73
|
+
|
|
74
|
+
to_try = sorted(to_try - self.tried)
|
|
75
|
+
|
|
76
|
+
if to_try:
|
|
77
|
+
result = self.source.select(
|
|
78
|
+
GroupOfDates(
|
|
79
|
+
sorted(to_try),
|
|
80
|
+
group_of_dates.provider,
|
|
81
|
+
partial_ok=True,
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
for f in result.datasource:
|
|
86
|
+
# We could keep the fields in a dictionary, but we don't want to keep the fields in memory
|
|
87
|
+
date = as_datetime(f.metadata("valid_datetime"))
|
|
88
|
+
|
|
89
|
+
if self.skip_all_nans:
|
|
90
|
+
if np.isnan(f.to_numpy()).all():
|
|
91
|
+
LOG.warning(f"Skipping {date} because all values are NaN")
|
|
92
|
+
continue
|
|
93
|
+
|
|
94
|
+
self.found.add(date)
|
|
95
|
+
|
|
96
|
+
self.tried.update(to_try)
|
|
97
|
+
|
|
98
|
+
new_dates = defaultdict(list)
|
|
99
|
+
|
|
100
|
+
for date in asked_dates:
|
|
101
|
+
best = None
|
|
102
|
+
for found_date in sorted(self.found):
|
|
103
|
+
delta = abs(date - found_date)
|
|
104
|
+
# With < we prefer the first date
|
|
105
|
+
# With <= we prefer the last date
|
|
106
|
+
if best is None or delta <= best[0]:
|
|
107
|
+
best = delta, found_date
|
|
108
|
+
new_dates[best[1]].append(date)
|
|
109
|
+
|
|
110
|
+
for date, dates in new_dates.items():
|
|
111
|
+
yield (
|
|
112
|
+
GroupOfDates([date], group_of_dates.provider),
|
|
113
|
+
GroupOfDates(dates, group_of_dates.provider),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class DateMapperClimatology(DateMapper):
|
|
118
|
+
def __init__(self, source, year, day):
|
|
119
|
+
self.year = year
|
|
120
|
+
self.day = day
|
|
121
|
+
|
|
122
|
+
def transform(self, group_of_dates):
|
|
123
|
+
from anemoi.datasets.dates.groups import GroupOfDates
|
|
124
|
+
|
|
125
|
+
dates = list(group_of_dates)
|
|
126
|
+
if not dates:
|
|
127
|
+
return []
|
|
128
|
+
|
|
129
|
+
new_dates = defaultdict(list)
|
|
130
|
+
for date in dates:
|
|
131
|
+
new_date = date.replace(year=self.year, day=self.day)
|
|
132
|
+
new_dates[new_date].append(date)
|
|
133
|
+
|
|
134
|
+
for date, dates in new_dates.items():
|
|
135
|
+
yield (
|
|
136
|
+
GroupOfDates([date], group_of_dates.provider),
|
|
137
|
+
GroupOfDates(dates, group_of_dates.provider),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class DateMapperConstant(DateMapper):
|
|
142
|
+
def __init__(self, source, date=None):
|
|
143
|
+
self.source = source
|
|
144
|
+
self.date = date
|
|
145
|
+
|
|
146
|
+
def transform(self, group_of_dates):
|
|
147
|
+
from anemoi.datasets.dates.groups import GroupOfDates
|
|
148
|
+
|
|
149
|
+
if self.date is None:
|
|
150
|
+
return [
|
|
151
|
+
(
|
|
152
|
+
GroupOfDates([], group_of_dates.provider),
|
|
153
|
+
group_of_dates,
|
|
154
|
+
)
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
return [
|
|
158
|
+
(
|
|
159
|
+
GroupOfDates([self.date], group_of_dates.provider),
|
|
160
|
+
group_of_dates,
|
|
161
|
+
)
|
|
162
|
+
]
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class DateMapperResult(Result):
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
context,
|
|
169
|
+
action_path,
|
|
170
|
+
group_of_dates,
|
|
171
|
+
source_result,
|
|
172
|
+
mapper,
|
|
173
|
+
original_group_of_dates,
|
|
174
|
+
):
|
|
175
|
+
super().__init__(context, action_path, group_of_dates)
|
|
176
|
+
|
|
177
|
+
self.source_results = source_result
|
|
178
|
+
self.mapper = mapper
|
|
179
|
+
self.original_group_of_dates = original_group_of_dates
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def datasource(self):
|
|
183
|
+
result = []
|
|
184
|
+
|
|
185
|
+
for field in self.source_results.datasource:
|
|
186
|
+
for date in self.original_group_of_dates:
|
|
187
|
+
result.append(NewValidDateTimeField(field, date))
|
|
188
|
+
|
|
189
|
+
return FieldArray(result)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class RepeatedDatesAction(Action):
|
|
193
|
+
def __init__(self, context, action_path, source, mode, **kwargs):
|
|
194
|
+
super().__init__(context, action_path, source, mode, **kwargs)
|
|
195
|
+
|
|
196
|
+
self.source = action_factory(source, context, action_path + ["source"])
|
|
197
|
+
self.mapper = DateMapper.from_mode(mode, self.source, kwargs)
|
|
198
|
+
|
|
199
|
+
@trace_select
|
|
200
|
+
def select(self, group_of_dates):
|
|
201
|
+
results = []
|
|
202
|
+
for one_date_group, many_dates_group in self.mapper.transform(group_of_dates):
|
|
203
|
+
results.append(
|
|
204
|
+
DateMapperResult(
|
|
205
|
+
self.context,
|
|
206
|
+
self.action_path,
|
|
207
|
+
one_date_group,
|
|
208
|
+
self.source.select(one_date_group),
|
|
209
|
+
self.mapper,
|
|
210
|
+
many_dates_group,
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return JoinResult(self.context, self.action_path, group_of_dates, results)
|
|
215
|
+
|
|
216
|
+
def __repr__(self):
|
|
217
|
+
return f"MultiDateMatchAction({self.source}, {self.mapper})"
|