anemoi-datasets 0.5.28__py3-none-any.whl → 0.5.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/create/__init__.py +4 -12
- anemoi/datasets/create/config.py +50 -53
- anemoi/datasets/create/input/result/field.py +1 -3
- anemoi/datasets/create/sources/accumulate.py +517 -0
- anemoi/datasets/create/sources/accumulate_utils/__init__.py +8 -0
- anemoi/datasets/create/sources/accumulate_utils/covering_intervals.py +221 -0
- anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py +153 -0
- anemoi/datasets/create/sources/accumulate_utils/interval_generators.py +321 -0
- anemoi/datasets/create/sources/grib_index.py +79 -51
- anemoi/datasets/create/sources/mars.py +56 -27
- anemoi/datasets/create/sources/xarray_support/__init__.py +1 -0
- anemoi/datasets/create/sources/xarray_support/coordinates.py +1 -4
- anemoi/datasets/create/sources/xarray_support/flavour.py +2 -2
- anemoi/datasets/create/sources/xarray_support/patch.py +178 -5
- anemoi/datasets/data/complement.py +26 -17
- anemoi/datasets/data/dataset.py +6 -0
- anemoi/datasets/data/masked.py +74 -13
- anemoi/datasets/data/missing.py +5 -0
- {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/METADATA +8 -7
- {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/RECORD +25 -23
- {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/WHEEL +1 -1
- anemoi/datasets/create/sources/accumulations.py +0 -1042
- anemoi/datasets/create/sources/accumulations2.py +0 -618
- anemoi/datasets/create/sources/tendencies.py +0 -171
- {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/licenses/LICENSE +0 -0
- {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
# (C) Copyright 2025 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
import itertools
|
|
11
|
+
import logging
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from dataclasses import field
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
from datetime import timedelta
|
|
16
|
+
from heapq import heappop
|
|
17
|
+
from heapq import heappush
|
|
18
|
+
from typing import Callable
|
|
19
|
+
from typing import List
|
|
20
|
+
from typing import Optional
|
|
21
|
+
|
|
22
|
+
LOG = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
# This module implements an algorithm to cover a target time interval with a sequence of signed intervals
|
|
25
|
+
# (which can be positive or negative in length)
|
|
26
|
+
# It is general purpose but is mainly designed to support accumulation over variable periods,
|
|
27
|
+
# e.g. to cover a 6h accumulation with available intervals of +/-3h,
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SignedInterval:
|
|
31
|
+
def __init__(self, start: datetime, end: datetime, base: Optional[datetime] = None):
|
|
32
|
+
self.start = start
|
|
33
|
+
self.end = end
|
|
34
|
+
self.base = base
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def length(self) -> float:
|
|
38
|
+
"""Length in seconds (can be negative)."""
|
|
39
|
+
return (self.end - self.start).total_seconds()
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def sign(self) -> int:
|
|
43
|
+
return 1 if self.length >= 0 else -1
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def min(self):
|
|
47
|
+
return min(self.start, self.end)
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def max(self):
|
|
51
|
+
return max(self.start, self.end)
|
|
52
|
+
|
|
53
|
+
def __neg__(self):
|
|
54
|
+
return SignedInterval(start=self.end, end=self.start, base=self.base)
|
|
55
|
+
|
|
56
|
+
def __eq__(self, other):
|
|
57
|
+
if not isinstance(other, SignedInterval):
|
|
58
|
+
return NotImplemented
|
|
59
|
+
if self.start != other.start or self.end != other.end:
|
|
60
|
+
return False
|
|
61
|
+
if self.base != other.base:
|
|
62
|
+
return False
|
|
63
|
+
return True
|
|
64
|
+
|
|
65
|
+
def __hash__(self):
|
|
66
|
+
return hash((self.start, self.end, self.base))
|
|
67
|
+
|
|
68
|
+
def __rich__(self):
|
|
69
|
+
return self.__repr__(colored=True)
|
|
70
|
+
|
|
71
|
+
def __repr__(self, colored: bool = False):
|
|
72
|
+
try:
|
|
73
|
+
# use frequency_to_string only if available
|
|
74
|
+
# as this class should not depends on anemoi.utils
|
|
75
|
+
from anemoi.utils.dates import frequency_to_string
|
|
76
|
+
except ImportError:
|
|
77
|
+
|
|
78
|
+
def frequency_to_string(delta):
|
|
79
|
+
return str(delta)
|
|
80
|
+
|
|
81
|
+
start = self.start.strftime("%Y%m%d.%H%M")
|
|
82
|
+
end = self.end.strftime("%Y%m%d.%H%M")
|
|
83
|
+
if start[:9] == end[:9]:
|
|
84
|
+
end = " " * 9 + end[9:]
|
|
85
|
+
|
|
86
|
+
if self.base is not None:
|
|
87
|
+
base = self.base.strftime("%Y%m%d.%H%M")
|
|
88
|
+
if self.sign > 0:
|
|
89
|
+
steps = [
|
|
90
|
+
int((self.start - self.base).total_seconds() / 3600),
|
|
91
|
+
int((self.end - self.base).total_seconds() / 3600),
|
|
92
|
+
]
|
|
93
|
+
else:
|
|
94
|
+
steps = [
|
|
95
|
+
-int((self.end - self.base).total_seconds() / 3600),
|
|
96
|
+
int((self.start - self.base).total_seconds() / 3600),
|
|
97
|
+
]
|
|
98
|
+
base_str = f", base={base}, [{steps[0]}-{steps[1]}]"
|
|
99
|
+
else:
|
|
100
|
+
base_str = ""
|
|
101
|
+
|
|
102
|
+
if self.start < self.end:
|
|
103
|
+
period = f"+{frequency_to_string(self.end - self.start)}"
|
|
104
|
+
elif self.start == self.end:
|
|
105
|
+
period = "0s"
|
|
106
|
+
else:
|
|
107
|
+
period = f"-{frequency_to_string(self.start - self.end)}"
|
|
108
|
+
period = period.ljust(4)
|
|
109
|
+
|
|
110
|
+
if colored:
|
|
111
|
+
# using rich colors
|
|
112
|
+
start = f"[blue]{start}[/blue]"
|
|
113
|
+
end = f"[blue]{end}[/blue]"
|
|
114
|
+
if self.start < self.end:
|
|
115
|
+
period = f"[green]{period}[/green]"
|
|
116
|
+
elif self.start == self.end:
|
|
117
|
+
period = f"[yellow]{period}[/yellow]"
|
|
118
|
+
else:
|
|
119
|
+
period = f"[red]{period}[/red]"
|
|
120
|
+
|
|
121
|
+
return f"SignedInterval({start}{period}->{end}{base_str} )"
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass(order=True)
|
|
125
|
+
class HeapState:
|
|
126
|
+
total_cost: float
|
|
127
|
+
covered: float
|
|
128
|
+
counter: int
|
|
129
|
+
current_time: datetime
|
|
130
|
+
current_base: Optional[datetime]
|
|
131
|
+
path: List[SignedInterval] = field(compare=False)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def covering_intervals(
|
|
135
|
+
start: datetime,
|
|
136
|
+
end: datetime,
|
|
137
|
+
candidates: Callable,
|
|
138
|
+
/,
|
|
139
|
+
switch_penalty: int = 24 * 3600 * 7,
|
|
140
|
+
max_delta: timedelta = timedelta(hours=24 * 2),
|
|
141
|
+
error_on_fail: bool = True,
|
|
142
|
+
) -> List[SignedInterval] | None:
|
|
143
|
+
"""Find a path of intervals covering [start, end] with minimal base switches, then minimal total absolute length.
|
|
144
|
+
Uses a Dijkstra-like algorithm to find the optimal path.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
start: Start datetime of the target interval.
|
|
148
|
+
|
|
149
|
+
end: End datetime of the target interval.
|
|
150
|
+
|
|
151
|
+
candidates: A function(current: datetime, current_base: Optional[datetime], start: datetime, end: datetime) -> Iterable[SignedInterval]
|
|
152
|
+
that provides candidate intervals covering the current time.
|
|
153
|
+
|
|
154
|
+
switch_penalty: Penalty (in seconds) for switching bases between intervals.
|
|
155
|
+
|
|
156
|
+
max_delta: Maximum allowed deviation from start/end for search.
|
|
157
|
+
|
|
158
|
+
error_on_fail: Whether to raise an error if coverage cannot be found.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
A list of SignedInterval objects covering [start, end], or None if no coverage found and error_on_fail is False.
|
|
162
|
+
|
|
163
|
+
"""
|
|
164
|
+
target_length = (end - start).total_seconds()
|
|
165
|
+
|
|
166
|
+
pq: List[HeapState] = [] # pq: priority queue
|
|
167
|
+
counter = itertools.count()
|
|
168
|
+
heappush(
|
|
169
|
+
pq,
|
|
170
|
+
HeapState(total_cost=0.0, covered=0.0, counter=next(counter), current_time=start, current_base=None, path=[]),
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
visited: dict[tuple[datetime, Optional[datetime], float], float] = {}
|
|
174
|
+
|
|
175
|
+
while pq:
|
|
176
|
+
state = heappop(pq)
|
|
177
|
+
key = (state.current_time, state.current_base, state.covered)
|
|
178
|
+
|
|
179
|
+
if key in visited and state.total_cost >= visited[key]:
|
|
180
|
+
continue
|
|
181
|
+
visited[key] = state.total_cost
|
|
182
|
+
|
|
183
|
+
# Goal: cumulative coverage matches target
|
|
184
|
+
if state.covered == target_length:
|
|
185
|
+
return state.path
|
|
186
|
+
|
|
187
|
+
if (len(visited) > 1000) and (state.current_time > end + max_delta or state.current_time < start - max_delta):
|
|
188
|
+
msg = f"Exceeded search limits: visited={len(visited)}, current_time={state.current_time}, target=({start} → {end}), max_delta={max_delta}"
|
|
189
|
+
if error_on_fail:
|
|
190
|
+
raise ValueError(msg)
|
|
191
|
+
LOG.warning(msg)
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
for interval in candidates(state.current_time):
|
|
195
|
+
if interval.start != state.current_time:
|
|
196
|
+
raise ValueError(
|
|
197
|
+
f"Candidate interval {interval} does not start or end at current_time {state.current_time}"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Edge cost = abs(length) + switch penalty if base changes
|
|
201
|
+
edge_cost = abs(interval.length)
|
|
202
|
+
if state.current_base is not None and state.current_base != interval.base:
|
|
203
|
+
edge_cost += switch_penalty
|
|
204
|
+
|
|
205
|
+
heappush(
|
|
206
|
+
pq,
|
|
207
|
+
HeapState(
|
|
208
|
+
total_cost=state.total_cost + edge_cost,
|
|
209
|
+
covered=state.covered + interval.length,
|
|
210
|
+
counter=next(counter), # counter only used to break ties in heapq
|
|
211
|
+
current_time=interval.end,
|
|
212
|
+
current_base=interval.base,
|
|
213
|
+
path=state.path + [interval],
|
|
214
|
+
),
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
msg = f"Cannot find coverage of {start} → {end}"
|
|
218
|
+
if error_on_fail:
|
|
219
|
+
raise ValueError(msg)
|
|
220
|
+
LOG.warning(msg)
|
|
221
|
+
return None
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# (C) Copyright 2025 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import datetime
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
from anemoi.utils.dates import frequency_to_timedelta
|
|
15
|
+
|
|
16
|
+
from .covering_intervals import SignedInterval
|
|
17
|
+
|
|
18
|
+
LOG = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FieldToInterval:
|
|
22
|
+
"""Convert a field to its accumulation interval, applying patches if needed."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, patches: dict | None = None):
|
|
25
|
+
if patches is None:
|
|
26
|
+
patches = {}
|
|
27
|
+
assert isinstance(patches, dict), ("patches must be a dict", patches)
|
|
28
|
+
|
|
29
|
+
self.patches = patches
|
|
30
|
+
for key in patches:
|
|
31
|
+
if key not in (
|
|
32
|
+
"start_step_is_zero",
|
|
33
|
+
"start_step_is_end_step",
|
|
34
|
+
"start_step_greater_than_end_step",
|
|
35
|
+
"set_start_step_to_zero",
|
|
36
|
+
):
|
|
37
|
+
raise ValueError(f"Unknown patch key: {key}")
|
|
38
|
+
|
|
39
|
+
def __call__(self, field) -> SignedInterval:
|
|
40
|
+
date_str = str(field.metadata("date")).zfill(8)
|
|
41
|
+
time_str = str(field.metadata("time")).zfill(4)
|
|
42
|
+
base_datetime = datetime.datetime.strptime(date_str + time_str, "%Y%m%d%H%M")
|
|
43
|
+
|
|
44
|
+
endStep = field.metadata("endStep")
|
|
45
|
+
startStep = field.metadata("startStep")
|
|
46
|
+
|
|
47
|
+
LOG.debug(f" field before patching: {startStep=}, {endStep=}")
|
|
48
|
+
|
|
49
|
+
if self.patches.get("set_start_step_to_zero", False):
|
|
50
|
+
startStep, endStep = 0, endStep
|
|
51
|
+
|
|
52
|
+
if startStep > endStep:
|
|
53
|
+
startStep, endStep = self.start_step_greater_than_end_step(startStep, endStep, field=field)
|
|
54
|
+
elif startStep == endStep:
|
|
55
|
+
startStep, endStep = self.start_step_is_end_step(startStep, endStep, field=field)
|
|
56
|
+
elif frequency_to_timedelta(startStep).total_seconds() == 0:
|
|
57
|
+
startStep, endStep = self.start_step_is_zero(startStep, endStep, field=field)
|
|
58
|
+
|
|
59
|
+
LOG.debug(f" field after patching : {startStep=}, {endStep=}")
|
|
60
|
+
|
|
61
|
+
start_step = datetime.timedelta(hours=startStep)
|
|
62
|
+
end_step = datetime.timedelta(hours=endStep)
|
|
63
|
+
|
|
64
|
+
assert startStep >= 0, ("After patching, startStep must be >= 0", field, startStep, endStep)
|
|
65
|
+
assert startStep < endStep, ("After patching, startStep must be < endStep", field, startStep, endStep)
|
|
66
|
+
|
|
67
|
+
interval = SignedInterval(start=base_datetime + start_step, end=base_datetime + end_step, base=base_datetime)
|
|
68
|
+
|
|
69
|
+
date_str = str(field.metadata("validityDate")).zfill(8)
|
|
70
|
+
time_str = str(field.metadata("validityTime")).zfill(4)
|
|
71
|
+
valid_date = datetime.datetime.strptime(date_str + time_str, "%Y%m%d%H%M")
|
|
72
|
+
assert valid_date == interval.max, (valid_date, interval)
|
|
73
|
+
|
|
74
|
+
return interval
|
|
75
|
+
|
|
76
|
+
def start_step_is_zero(self, startStep, endStep, field=None):
|
|
77
|
+
# Patch to handle cases where start_step is zero
|
|
78
|
+
# No patch yet implemented
|
|
79
|
+
match self.patches.get("start_step_is_zero", None):
|
|
80
|
+
case False | None:
|
|
81
|
+
pass # do nothing
|
|
82
|
+
case _ as options:
|
|
83
|
+
raise ValueError(f"Unknown option for patch.start_step_is_zero: {options}")
|
|
84
|
+
|
|
85
|
+
return startStep, endStep
|
|
86
|
+
|
|
87
|
+
def start_step_is_end_step(self, startStep, endStep, field=None):
|
|
88
|
+
# Patch to handle cases where start_step equals end_step
|
|
89
|
+
# this should not happen in normal cases but some datasets have this issue
|
|
90
|
+
# The default is to set start_step to zero
|
|
91
|
+
# This can be disabled by setting the patch to False
|
|
92
|
+
|
|
93
|
+
match self.patches.get("start_step_is_end_step", "set_start_step_to_zero"):
|
|
94
|
+
case False | None:
|
|
95
|
+
pass # do nothing
|
|
96
|
+
|
|
97
|
+
case "set_from_end_step_ceiled_to_24_hours":
|
|
98
|
+
startStep, endStep = _set_start_step_from_end_step_ceiled_to_24_hours(startStep, endStep, field=field)
|
|
99
|
+
|
|
100
|
+
case "set_start_step_to_zero":
|
|
101
|
+
startStep, endStep = 0, endStep
|
|
102
|
+
|
|
103
|
+
case _ as options:
|
|
104
|
+
raise ValueError(f"Unknown option for patch.start_step_is_end_step: {options}")
|
|
105
|
+
|
|
106
|
+
return startStep, endStep
|
|
107
|
+
|
|
108
|
+
def start_step_greater_than_end_step(self, startStep, endStep, field=None):
|
|
109
|
+
|
|
110
|
+
# Patch to handle cases where start_step is greater than end_step
|
|
111
|
+
# this should not happen in normal cases but some datasets have this issue
|
|
112
|
+
# The default is to do swap the values of start_step and end_step
|
|
113
|
+
# This can be disabled by setting the patch to False
|
|
114
|
+
|
|
115
|
+
match self.patches.get("start_step_greater_than_end_step", None):
|
|
116
|
+
|
|
117
|
+
case False | None:
|
|
118
|
+
pass # do nothing
|
|
119
|
+
|
|
120
|
+
case "swap":
|
|
121
|
+
startStep, endStep = endStep, startStep
|
|
122
|
+
|
|
123
|
+
case _ as options:
|
|
124
|
+
raise ValueError(f"Unknown option for patch.start_step_greater_than_end_step: {options}")
|
|
125
|
+
|
|
126
|
+
return startStep, endStep
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _set_start_step_from_end_step_ceiled_to_24_hours(startStep, endStep, field=None):
|
|
130
|
+
# Because the data wrongly encode start_step, but end_step is correct
|
|
131
|
+
# and we know that accumulations are always reseted every multiple of 24 hours
|
|
132
|
+
#
|
|
133
|
+
# 1-1 -> 0-1
|
|
134
|
+
# 2-2 -> 0-2
|
|
135
|
+
# ...
|
|
136
|
+
# 23-23 -> 0-23
|
|
137
|
+
# 24-24 -> 0-24
|
|
138
|
+
# 25-25 -> 24-25
|
|
139
|
+
# 26-26 -> 24-26
|
|
140
|
+
# ...
|
|
141
|
+
# 47-47 -> 24-47
|
|
142
|
+
# 48-48 -> 24-48
|
|
143
|
+
# 49-49 -> 48-49
|
|
144
|
+
# 50-50 -> 48-50
|
|
145
|
+
# etc.
|
|
146
|
+
if endStep % 24 == 0:
|
|
147
|
+
# Special case: endStep is exactly 24, 48, 72, etc.
|
|
148
|
+
# Map to previous 24-hour boundary (24 -> 0, 48 -> 24, etc.)
|
|
149
|
+
return endStep - 24, endStep
|
|
150
|
+
|
|
151
|
+
# General case: floor to the nearest 24-hour boundary
|
|
152
|
+
# (1-23 -> 0, 25-47 -> 24, etc.)
|
|
153
|
+
return endStep - (endStep % 24), endStep
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
# (C) Copyright 2025 Anemoi contributors.
|
|
2
|
+
#
|
|
3
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
4
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
5
|
+
#
|
|
6
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
7
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
|
+
# nor does it submit to any jurisdiction.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import datetime
|
|
12
|
+
import logging
|
|
13
|
+
from abc import abstractmethod
|
|
14
|
+
from typing import Iterable
|
|
15
|
+
|
|
16
|
+
from anemoi.utils.dates import frequency_to_timedelta
|
|
17
|
+
|
|
18
|
+
from anemoi.datasets.create.sources.accumulate_utils.covering_intervals import SignedInterval
|
|
19
|
+
from anemoi.datasets.create.sources.accumulate_utils.covering_intervals import covering_intervals
|
|
20
|
+
|
|
21
|
+
LOG = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def build_interval(
|
|
25
|
+
current_time: datetime.datetime, start_step: int, end_step: int, base_time: str | int | None
|
|
26
|
+
) -> SignedInterval:
|
|
27
|
+
"""Build a SignedInterval object corresponding to current_time's day
|
|
28
|
+
This SignedInterval may not have a base datetime
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
try:
|
|
32
|
+
usable_base_time = int(base_time) if base_time is not None else 0
|
|
33
|
+
except ValueError:
|
|
34
|
+
raise ValueError(f"Invalid base_time: {base_time} ({type(base_time)})")
|
|
35
|
+
base = datetime.datetime(current_time.year, current_time.month, current_time.day, usable_base_time)
|
|
36
|
+
start = base + datetime.timedelta(hours=start_step)
|
|
37
|
+
end = base + datetime.timedelta(hours=end_step)
|
|
38
|
+
|
|
39
|
+
interval_base = base if base_time is not None else None
|
|
40
|
+
|
|
41
|
+
return SignedInterval(start=start, end=end, base=interval_base)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class IntervalGenerator:
|
|
45
|
+
"""Abstract base class to generate intervals.
|
|
46
|
+
Call to IntervalGenerator will provide candidate intervals to be selected by the covering_intervals method
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def covering_intervals(self, start: datetime, end: datetime) -> Iterable[SignedInterval]:
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def __call__():
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Pattern:
|
|
59
|
+
"""Common format of config arguments to build SearchableIntervalGenerator
|
|
60
|
+
Used to supply candidate intervals for IntervalGenerator
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
base_time: str | datetime.datetime,
|
|
66
|
+
steps: str | list[str],
|
|
67
|
+
search_range: list[int] | None = None,
|
|
68
|
+
base_date: dict | None = None,
|
|
69
|
+
):
|
|
70
|
+
steps = normalise_steps(steps)
|
|
71
|
+
self.steps = steps
|
|
72
|
+
|
|
73
|
+
if base_time == "*":
|
|
74
|
+
base_time = None
|
|
75
|
+
self.base_time = base_time
|
|
76
|
+
|
|
77
|
+
if search_range is None:
|
|
78
|
+
search_range = [datetime.timedelta(days=d) for d in [-1, 0, 1]]
|
|
79
|
+
else:
|
|
80
|
+
search_range = [datetime.timedelta(days=d) for d in search_range]
|
|
81
|
+
self.search_range = search_range
|
|
82
|
+
|
|
83
|
+
if base_date:
|
|
84
|
+
assert isinstance(base_date, dict), base_date
|
|
85
|
+
assert "day_of_month" in base_date, base_date
|
|
86
|
+
assert len(base_date) == 1, base_date
|
|
87
|
+
self.base_date = base_date
|
|
88
|
+
|
|
89
|
+
def filter(self, interval: SignedInterval) -> bool:
|
|
90
|
+
if self.base_date:
|
|
91
|
+
if interval.base.day != self.base_date["day_of_month"]:
|
|
92
|
+
return False
|
|
93
|
+
return True
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class SearchableIntervalGenerator(IntervalGenerator):
|
|
97
|
+
def __init__(self, config: tuple | list | dict):
|
|
98
|
+
if isinstance(config, (tuple, list)):
|
|
99
|
+
patterns = []
|
|
100
|
+
for base_time, steps in config:
|
|
101
|
+
patterns.append(Pattern(base_time=base_time, steps=steps))
|
|
102
|
+
|
|
103
|
+
if isinstance(config, dict):
|
|
104
|
+
patterns = [Pattern(**config)]
|
|
105
|
+
|
|
106
|
+
self.patterns: list[Pattern] = patterns
|
|
107
|
+
|
|
108
|
+
def covering_intervals(self, start: datetime.datetime, end: datetime.datetime) -> Iterable[SignedInterval]:
|
|
109
|
+
"""Perform interval search among candidates with minimal base switches and length.
|
|
110
|
+
Candidates are given by self call
|
|
111
|
+
Return available SignedIntervals covering the period start->end (where start>end is possible)
|
|
112
|
+
"""
|
|
113
|
+
return covering_intervals(start, end, self)
|
|
114
|
+
|
|
115
|
+
def __call__(
|
|
116
|
+
self,
|
|
117
|
+
current_time: datetime.datetime,
|
|
118
|
+
) -> Iterable[SignedInterval]:
|
|
119
|
+
"""This generates candidate intervals starting or ending at the given current_time
|
|
120
|
+
Candidates correspond to the pairs of base_time, steps stored in patterns
|
|
121
|
+
"""
|
|
122
|
+
intervals = []
|
|
123
|
+
for p in self.patterns:
|
|
124
|
+
search_range = p.search_range
|
|
125
|
+
for delta in search_range:
|
|
126
|
+
base_time = p.base_time
|
|
127
|
+
steps = p.steps
|
|
128
|
+
for start_step, end_step in steps:
|
|
129
|
+
interval = build_interval(current_time + delta, start_step, end_step, base_time)
|
|
130
|
+
if not p.filter(interval):
|
|
131
|
+
continue
|
|
132
|
+
|
|
133
|
+
if interval not in intervals:
|
|
134
|
+
intervals.append(interval)
|
|
135
|
+
|
|
136
|
+
# filter only the interval starting at current_time (or ending at current_time)
|
|
137
|
+
filtered = []
|
|
138
|
+
for i in intervals:
|
|
139
|
+
if i.start == current_time:
|
|
140
|
+
filtered.append(i)
|
|
141
|
+
elif (-i).start == current_time:
|
|
142
|
+
filtered.append(-i)
|
|
143
|
+
intervals = filtered
|
|
144
|
+
|
|
145
|
+
# quite important to sort by reversed base to prioritise most recent base in case of ties
|
|
146
|
+
# in some cases, we may want to sort by other criteria
|
|
147
|
+
intervals = sorted(intervals, key=lambda x: -(x.base or x.start).timestamp())
|
|
148
|
+
|
|
149
|
+
return intervals
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def normalise_steps(steps_list: str | list[str]) -> list[list[int]]:
|
|
153
|
+
"""Convert the input step_list to a list of [start,end] pairs"""
|
|
154
|
+
res = []
|
|
155
|
+
if isinstance(steps_list, str):
|
|
156
|
+
steps_list = steps_list.split("/")
|
|
157
|
+
assert isinstance(steps_list, list), steps_list
|
|
158
|
+
|
|
159
|
+
for start_end_step in steps_list:
|
|
160
|
+
if isinstance(start_end_step, str):
|
|
161
|
+
assert "-" in start_end_step, start_end_step
|
|
162
|
+
start_end_step = start_end_step.split("-")
|
|
163
|
+
assert isinstance(start_end_step, (list, tuple)) and len(start_end_step) == 2, start_end_step
|
|
164
|
+
start_step, end_step = int(start_end_step[0]), int(start_end_step[1])
|
|
165
|
+
res.append([start_step, end_step])
|
|
166
|
+
return res
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class AccumulatedFromStartIntervalGenerator(SearchableIntervalGenerator):
|
|
170
|
+
def __init__(self, basetime: str | datetime.datetime, frequency: int, last_step: int):
|
|
171
|
+
config = []
|
|
172
|
+
for base in basetime:
|
|
173
|
+
for i in range(0, last_step, frequency):
|
|
174
|
+
config.append([base, [f"0-{i+frequency}"]])
|
|
175
|
+
super().__init__(config)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class AccumulatedFromPreviousStepIntervalGenerator(SearchableIntervalGenerator):
|
|
179
|
+
def __init__(self, basetime: str | datetime.datetime, frequency: int, last_step: int):
|
|
180
|
+
config = []
|
|
181
|
+
for base in basetime:
|
|
182
|
+
for i in range(0, last_step, frequency):
|
|
183
|
+
config.append([base, [f"{i}-{i+frequency}"]])
|
|
184
|
+
super().__init__(config)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _match_mars_config(_class: str, _stream: str | None = None, _origin: str | None = None) -> list | tuple:
|
|
188
|
+
"""Match MARS configuration (class, stream, origin) to interval generator config.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
_class
|
|
193
|
+
MARS class (e.g., 'ea', 'od', 'rr', 'l5').
|
|
194
|
+
_stream
|
|
195
|
+
MARS stream (e.g., 'oper', 'enda', 'elda', 'enfo'). Defaults to 'oper'.
|
|
196
|
+
_origin
|
|
197
|
+
MARS origin (e.g., 'se-al-ec', 'fr-ms-ec'). Defaults to None.
|
|
198
|
+
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
list | tuple
|
|
202
|
+
Interval generator configuration.
|
|
203
|
+
|
|
204
|
+
Raises
|
|
205
|
+
------
|
|
206
|
+
NotImplementedError
|
|
207
|
+
If the combination is not yet implemented.
|
|
208
|
+
ValueError
|
|
209
|
+
If the combination is unknown.
|
|
210
|
+
"""
|
|
211
|
+
_stream = _stream or "oper"
|
|
212
|
+
|
|
213
|
+
match (_class, _stream, _origin):
|
|
214
|
+
case ("ea", "oper", _):
|
|
215
|
+
return [
|
|
216
|
+
(6, "0-1/1-2/2-3/3-4/4-5/5-6/6-7/7-8/8-9/9-10/10-11/11-12/12-13/13-14/14-15/15-16/16-17/17-18"),
|
|
217
|
+
(18, "0-1/1-2/2-3/3-4/4-5/5-6/6-7/7-8/8-9/9-10/10-11/11-12/12-13/13-14/14-15/15-16/16-17/17-18"),
|
|
218
|
+
]
|
|
219
|
+
case ("ea", "enda", _):
|
|
220
|
+
return [
|
|
221
|
+
(6, "0-3/3-6/6-9/9-12/12-15/15-18"),
|
|
222
|
+
(18, "0-3/3-6/6-9/9-12/12-15/15-18"),
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
case ("od", "oper", _):
|
|
226
|
+
# https://apps.ecmwf.int/mars-catalogue/?stream=oper&levtype=sfc&time=00%3A00%3A00&expver=1&month=aug&year=2020&date=2020-08-25&type=fc&class=od
|
|
227
|
+
steps = [f"{0}-{i}" for i in range(1, 91)]
|
|
228
|
+
return ((0, steps), (12, steps))
|
|
229
|
+
case ("od", "elda", _):
|
|
230
|
+
# https://apps.ecmwf.int/mars-catalogue/?stream=elda&levtype=sfc&time=06%3A00%3A00&expver=1&month=aug&year=2020&date=2020-08-31&type=fc&class=od
|
|
231
|
+
# (6, "0-1/0-2/0-3/0-4/0-5/0-6/0-7/0-8/0-9/0-10/0-11/0-12"),
|
|
232
|
+
# (18, "0-1/0-2/0-3/0-4/0-5/0-6/0-7/0-8/0-9/0-10/0-11/0-12")
|
|
233
|
+
steps = [f"{0}-{i}" for i in range(1, 13)]
|
|
234
|
+
return ((6, steps), (18, steps))
|
|
235
|
+
case ("od", "enfo", _):
|
|
236
|
+
# https://apps.ecmwf.int/mars-catalogue/?class=od&stream=enfo&expver=1&type=fc&year=2020&month=aug&levtype=sfc&date=2020-08-31&time=06:00:00
|
|
237
|
+
raise NotImplementedError("od-enfo interval generator not implemented yet")
|
|
238
|
+
|
|
239
|
+
case ("rr", _, "se-al-ec"):
|
|
240
|
+
# https://apps.ecmwf.int/mars-catalogue/?class=rr&expver=prod&origin=se-al-ec&stream=oper&type=fc&year=2020&month=aug&levtype=sfc
|
|
241
|
+
return [[0, [(0, i) for i in [1, 2, 3, 4, 5, 6, 9, 12, 15, 18, 21, 24, 27, 30]]]]
|
|
242
|
+
case ("rr", _, "fr-ms-ec"):
|
|
243
|
+
# https://apps.ecmwf.int/mars-catalogue/?origin=fr-ms-ec&stream=oper&levtype=sfc&time=06%3A00%3A00&expver=prod&month=aug&year=2020&date=2020-08-31&type=fc&class=rr
|
|
244
|
+
return [[0, [(0, i) for i in range(1, 22, 3)]]]
|
|
245
|
+
|
|
246
|
+
case ("l5", "oper", _):
|
|
247
|
+
# https://apps.ecmwf.int/mars-catalogue/?class=l5&stream=oper&expver=1&type=fc&year=2020&month=aug&levtype=sfc&date=2020-08-25&time=00:00:00
|
|
248
|
+
return [[0, [(int(i), int(i) + 1) for i in range(0, 24)]]]
|
|
249
|
+
|
|
250
|
+
case _:
|
|
251
|
+
raise ValueError(f"Unknown MARS configuration: class={_class}, stream={_stream}, origin={_origin}")
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def _interval_generator_factory(
|
|
255
|
+
config, source_name: str | None = None, source: dict | None = None
|
|
256
|
+
) -> IntervalGenerator | list | dict:
|
|
257
|
+
match config:
|
|
258
|
+
case IntervalGenerator():
|
|
259
|
+
return config
|
|
260
|
+
|
|
261
|
+
case {"type": "accumulated-from-start", **params}:
|
|
262
|
+
return AccumulatedFromStartIntervalGenerator(**params)
|
|
263
|
+
case {"accumulated-from-start": params}:
|
|
264
|
+
return AccumulatedFromStartIntervalGenerator(**params)
|
|
265
|
+
|
|
266
|
+
case {"accumulated-from-previous-step": params}:
|
|
267
|
+
return AccumulatedFromPreviousStepIntervalGenerator(**params)
|
|
268
|
+
case {"type": "accumulated-from-previous-step", **params}:
|
|
269
|
+
return AccumulatedFromPreviousStepIntervalGenerator(**params)
|
|
270
|
+
|
|
271
|
+
case {"type": _, **params}:
|
|
272
|
+
raise NotImplementedError(f"Unknown availability config {config}")
|
|
273
|
+
|
|
274
|
+
case {"mars": mars_config}:
|
|
275
|
+
_class = mars_config.get("class")
|
|
276
|
+
_stream = mars_config.get("stream")
|
|
277
|
+
_origin = mars_config.get("origin")
|
|
278
|
+
assert _class is not None, "mars config must have a 'class' key"
|
|
279
|
+
return _match_mars_config(_class, _stream, _origin)
|
|
280
|
+
|
|
281
|
+
case dict() | list() | tuple():
|
|
282
|
+
return SearchableIntervalGenerator(config)
|
|
283
|
+
|
|
284
|
+
case "auto":
|
|
285
|
+
assert None not in (source_name, source), "Source must be specified when using 'auto' discovery"
|
|
286
|
+
assert source_name == "mars", "Only 'mars' source is currently supported for 'auto' availability discovery"
|
|
287
|
+
|
|
288
|
+
_class, _stream, _origin = source.get("class", None), source.get("stream", None), source.get("origin", None)
|
|
289
|
+
|
|
290
|
+
assert (
|
|
291
|
+
_class is not None
|
|
292
|
+
), "Availability should be automatically determined from mars source, but the mars source has no 'class'"
|
|
293
|
+
|
|
294
|
+
if (_stream is None) or (_origin is None):
|
|
295
|
+
LOG.warning(
|
|
296
|
+
f"Stream and/or origin unspecified for class {_class}, "
|
|
297
|
+
f"stream and/or origin will be set as defaults.",
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
return {"mars": {"class": _class, "stream": _stream, "origin": _origin}}
|
|
301
|
+
|
|
302
|
+
case str():
|
|
303
|
+
try:
|
|
304
|
+
data_accumulation_period = frequency_to_timedelta(config)
|
|
305
|
+
except Exception as e:
|
|
306
|
+
raise ValueError(f"Unknown interval generator config: {config}") from e
|
|
307
|
+
|
|
308
|
+
hours = data_accumulation_period.total_seconds() / 3600
|
|
309
|
+
if not (hours.is_integer() and hours > 0):
|
|
310
|
+
raise ValueError("Only accumulation periods multiple of 1 hour are supported for now")
|
|
311
|
+
|
|
312
|
+
return [["*", [f"{i}-{i+1}" for i in range(0, 24)]]]
|
|
313
|
+
|
|
314
|
+
case _:
|
|
315
|
+
raise ValueError(f"Unknown interval generator config: {config}")
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def interval_generator_factory(config, source_name: str | None = None, source: dict | None = None) -> IntervalGenerator:
|
|
319
|
+
while not isinstance(config, IntervalGenerator):
|
|
320
|
+
config = _interval_generator_factory(config, source_name, source)
|
|
321
|
+
return config
|