anemoi-datasets 0.5.28__py3-none-any.whl → 0.5.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/create/__init__.py +4 -12
  3. anemoi/datasets/create/config.py +50 -53
  4. anemoi/datasets/create/input/result/field.py +1 -3
  5. anemoi/datasets/create/sources/accumulate.py +517 -0
  6. anemoi/datasets/create/sources/accumulate_utils/__init__.py +8 -0
  7. anemoi/datasets/create/sources/accumulate_utils/covering_intervals.py +221 -0
  8. anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py +149 -0
  9. anemoi/datasets/create/sources/accumulate_utils/interval_generators.py +321 -0
  10. anemoi/datasets/create/sources/grib_index.py +64 -20
  11. anemoi/datasets/create/sources/mars.py +56 -27
  12. anemoi/datasets/create/sources/xarray_support/__init__.py +1 -0
  13. anemoi/datasets/create/sources/xarray_support/coordinates.py +1 -4
  14. anemoi/datasets/create/sources/xarray_support/flavour.py +2 -2
  15. anemoi/datasets/create/sources/xarray_support/patch.py +178 -5
  16. anemoi/datasets/data/complement.py +26 -17
  17. anemoi/datasets/data/dataset.py +6 -0
  18. anemoi/datasets/data/masked.py +74 -13
  19. anemoi/datasets/data/missing.py +5 -0
  20. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/METADATA +7 -7
  21. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/RECORD +25 -23
  22. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/WHEEL +1 -1
  23. anemoi/datasets/create/sources/accumulations.py +0 -1042
  24. anemoi/datasets/create/sources/accumulations2.py +0 -618
  25. anemoi/datasets/create/sources/tendencies.py +0 -171
  26. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/entry_points.txt +0 -0
  27. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/licenses/LICENSE +0 -0
  28. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,221 @@
1
+ # (C) Copyright 2025 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+ import itertools
11
+ import logging
12
+ from dataclasses import dataclass
13
+ from dataclasses import field
14
+ from datetime import datetime
15
+ from datetime import timedelta
16
+ from heapq import heappop
17
+ from heapq import heappush
18
+ from typing import Callable
19
+ from typing import List
20
+ from typing import Optional
21
+
22
+ LOG = logging.getLogger(__name__)
23
+
24
+ # This module implements an algorithm to cover a target time interval with a sequence of signed intervals
25
+ # (which can be positive or negative in length)
26
+ # It is general purpose but is mainly designed to support accumulation over variable periods,
27
+ # e.g. to cover a 6h accumulation with available intervals of +/-3h,
28
+
29
+
30
+ class SignedInterval:
31
+ def __init__(self, start: datetime, end: datetime, base: Optional[datetime] = None):
32
+ self.start = start
33
+ self.end = end
34
+ self.base = base
35
+
36
+ @property
37
+ def length(self) -> float:
38
+ """Length in seconds (can be negative)."""
39
+ return (self.end - self.start).total_seconds()
40
+
41
+ @property
42
+ def sign(self) -> int:
43
+ return 1 if self.length >= 0 else -1
44
+
45
+ @property
46
+ def min(self):
47
+ return min(self.start, self.end)
48
+
49
+ @property
50
+ def max(self):
51
+ return max(self.start, self.end)
52
+
53
+ def __neg__(self):
54
+ return SignedInterval(start=self.end, end=self.start, base=self.base)
55
+
56
+ def __eq__(self, other):
57
+ if not isinstance(other, SignedInterval):
58
+ return NotImplemented
59
+ if self.start != other.start or self.end != other.end:
60
+ return False
61
+ if self.base != other.base:
62
+ return False
63
+ return True
64
+
65
+ def __hash__(self):
66
+ return hash((self.start, self.end, self.base))
67
+
68
+ def __rich__(self):
69
+ return self.__repr__(colored=True)
70
+
71
+ def __repr__(self, colored: bool = False):
72
+ try:
73
+ # use frequency_to_string only if available
74
+ # as this class should not depends on anemoi.utils
75
+ from anemoi.utils.dates import frequency_to_string
76
+ except ImportError:
77
+
78
+ def frequency_to_string(delta):
79
+ return str(delta)
80
+
81
+ start = self.start.strftime("%Y%m%d.%H%M")
82
+ end = self.end.strftime("%Y%m%d.%H%M")
83
+ if start[:9] == end[:9]:
84
+ end = " " * 9 + end[9:]
85
+
86
+ if self.base is not None:
87
+ base = self.base.strftime("%Y%m%d.%H%M")
88
+ if self.sign > 0:
89
+ steps = [
90
+ int((self.start - self.base).total_seconds() / 3600),
91
+ int((self.end - self.base).total_seconds() / 3600),
92
+ ]
93
+ else:
94
+ steps = [
95
+ -int((self.end - self.base).total_seconds() / 3600),
96
+ int((self.start - self.base).total_seconds() / 3600),
97
+ ]
98
+ base_str = f", base={base}, [{steps[0]}-{steps[1]}]"
99
+ else:
100
+ base_str = ""
101
+
102
+ if self.start < self.end:
103
+ period = f"+{frequency_to_string(self.end - self.start)}"
104
+ elif self.start == self.end:
105
+ period = "0s"
106
+ else:
107
+ period = f"-{frequency_to_string(self.start - self.end)}"
108
+ period = period.ljust(4)
109
+
110
+ if colored:
111
+ # using rich colors
112
+ start = f"[blue]{start}[/blue]"
113
+ end = f"[blue]{end}[/blue]"
114
+ if self.start < self.end:
115
+ period = f"[green]{period}[/green]"
116
+ elif self.start == self.end:
117
+ period = f"[yellow]{period}[/yellow]"
118
+ else:
119
+ period = f"[red]{period}[/red]"
120
+
121
+ return f"SignedInterval({start}{period}->{end}{base_str} )"
122
+
123
+
124
+ @dataclass(order=True)
125
+ class HeapState:
126
+ total_cost: float
127
+ covered: float
128
+ counter: int
129
+ current_time: datetime
130
+ current_base: Optional[datetime]
131
+ path: List[SignedInterval] = field(compare=False)
132
+
133
+
134
+ def covering_intervals(
135
+ start: datetime,
136
+ end: datetime,
137
+ candidates: Callable,
138
+ /,
139
+ switch_penalty: int = 24 * 3600 * 7,
140
+ max_delta: timedelta = timedelta(hours=24 * 2),
141
+ error_on_fail: bool = True,
142
+ ) -> List[SignedInterval] | None:
143
+ """Find a path of intervals covering [start, end] with minimal base switches, then minimal total absolute length.
144
+ Uses a Dijkstra-like algorithm to find the optimal path.
145
+
146
+ Args:
147
+ start: Start datetime of the target interval.
148
+
149
+ end: End datetime of the target interval.
150
+
151
+ candidates: A function(current: datetime, current_base: Optional[datetime], start: datetime, end: datetime) -> Iterable[SignedInterval]
152
+ that provides candidate intervals covering the current time.
153
+
154
+ switch_penalty: Penalty (in seconds) for switching bases between intervals.
155
+
156
+ max_delta: Maximum allowed deviation from start/end for search.
157
+
158
+ error_on_fail: Whether to raise an error if coverage cannot be found.
159
+
160
+ Returns:
161
+ A list of SignedInterval objects covering [start, end], or None if no coverage found and error_on_fail is False.
162
+
163
+ """
164
+ target_length = (end - start).total_seconds()
165
+
166
+ pq: List[HeapState] = [] # pq: priority queue
167
+ counter = itertools.count()
168
+ heappush(
169
+ pq,
170
+ HeapState(total_cost=0.0, covered=0.0, counter=next(counter), current_time=start, current_base=None, path=[]),
171
+ )
172
+
173
+ visited: dict[tuple[datetime, Optional[datetime], float], float] = {}
174
+
175
+ while pq:
176
+ state = heappop(pq)
177
+ key = (state.current_time, state.current_base, state.covered)
178
+
179
+ if key in visited and state.total_cost >= visited[key]:
180
+ continue
181
+ visited[key] = state.total_cost
182
+
183
+ # Goal: cumulative coverage matches target
184
+ if state.covered == target_length:
185
+ return state.path
186
+
187
+ if (len(visited) > 1000) and (state.current_time > end + max_delta or state.current_time < start - max_delta):
188
+ msg = f"Exceeded search limits: visited={len(visited)}, current_time={state.current_time}, target=({start} → {end}), max_delta={max_delta}"
189
+ if error_on_fail:
190
+ raise ValueError(msg)
191
+ LOG.warning(msg)
192
+ return None
193
+
194
+ for interval in candidates(state.current_time):
195
+ if interval.start != state.current_time:
196
+ raise ValueError(
197
+ f"Candidate interval {interval} does not start or end at current_time {state.current_time}"
198
+ )
199
+
200
+ # Edge cost = abs(length) + switch penalty if base changes
201
+ edge_cost = abs(interval.length)
202
+ if state.current_base is not None and state.current_base != interval.base:
203
+ edge_cost += switch_penalty
204
+
205
+ heappush(
206
+ pq,
207
+ HeapState(
208
+ total_cost=state.total_cost + edge_cost,
209
+ covered=state.covered + interval.length,
210
+ counter=next(counter), # counter only used to break ties in heapq
211
+ current_time=interval.end,
212
+ current_base=interval.base,
213
+ path=state.path + [interval],
214
+ ),
215
+ )
216
+
217
+ msg = f"Cannot find coverage of {start} → {end}"
218
+ if error_on_fail:
219
+ raise ValueError(msg)
220
+ LOG.warning(msg)
221
+ return None
@@ -0,0 +1,149 @@
1
+ # (C) Copyright 2025 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ import datetime
12
+ import logging
13
+
14
+ from anemoi.utils.dates import frequency_to_timedelta
15
+
16
+ from .covering_intervals import SignedInterval
17
+
18
+ LOG = logging.getLogger(__name__)
19
+
20
+
21
+ class FieldToInterval:
22
+ """Convert a field to its accumulation interval, applying patches if needed."""
23
+
24
+ def __init__(self, patches: dict | None = None):
25
+ if patches is None:
26
+ patches = {}
27
+ assert isinstance(patches, dict), ("patches must be a dict", patches)
28
+
29
+ self.patches = patches
30
+ for key in patches:
31
+ if key not in (
32
+ "start_step_is_zero",
33
+ "start_step_is_end_step",
34
+ "start_step_greater_than_end_step",
35
+ ):
36
+ raise ValueError(f"Unknown patch key: {key}")
37
+
38
+ def __call__(self, field) -> SignedInterval:
39
+ date_str = str(field.metadata("date")).zfill(8)
40
+ time_str = str(field.metadata("time")).zfill(4)
41
+ base_datetime = datetime.datetime.strptime(date_str + time_str, "%Y%m%d%H%M")
42
+
43
+ endStep = field.metadata("endStep")
44
+ startStep = field.metadata("startStep")
45
+
46
+ LOG.debug(f" field before patching: {startStep=}, {endStep=}")
47
+
48
+ if startStep > endStep:
49
+ startStep, endStep = self.start_step_greater_than_end_step(startStep, endStep, field=field)
50
+ elif startStep == endStep:
51
+ startStep, endStep = self.start_step_is_end_step(startStep, endStep, field=field)
52
+ elif frequency_to_timedelta(startStep).total_seconds() == 0:
53
+ startStep, endStep = self.start_step_is_zero(startStep, endStep, field=field)
54
+
55
+ LOG.debug(f" field after patching : {startStep=}, {endStep=}")
56
+
57
+ start_step = datetime.timedelta(hours=startStep)
58
+ end_step = datetime.timedelta(hours=endStep)
59
+
60
+ assert startStep >= 0, ("After patching, startStep must be >= 0", field, startStep, endStep)
61
+ assert startStep < endStep, ("After patching, startStep must be < endStep", field, startStep, endStep)
62
+
63
+ interval = SignedInterval(start=base_datetime + start_step, end=base_datetime + end_step, base=base_datetime)
64
+
65
+ date_str = str(field.metadata("validityDate")).zfill(8)
66
+ time_str = str(field.metadata("validityTime")).zfill(4)
67
+ valid_date = datetime.datetime.strptime(date_str + time_str, "%Y%m%d%H%M")
68
+ assert valid_date == interval.max, (valid_date, interval)
69
+
70
+ return interval
71
+
72
+ def start_step_is_zero(self, startStep, endStep, field=None):
73
+ # Patch to handle cases where start_step is zero
74
+ # No patch yet implemented
75
+ match self.patches.get("start_step_is_zero", None):
76
+ case False | None:
77
+ pass # do nothing
78
+ case _ as options:
79
+ raise ValueError(f"Unknown option for patch.start_step_is_zero: {options}")
80
+
81
+ return startStep, endStep
82
+
83
+ def start_step_is_end_step(self, startStep, endStep, field=None):
84
+ # Patch to handle cases where start_step equals end_step
85
+ # this should not happen in normal cases but some datasets have this issue
86
+ # The default is to set start_step to zero
87
+ # This can be disabled by setting the patch to False
88
+
89
+ match self.patches.get("start_step_is_end_step", "set_start_step_to_zero"):
90
+ case False | None:
91
+ pass # do nothing
92
+
93
+ case "set_from_end_step_ceiled_to_24_hours":
94
+ startStep, endStep = _set_start_step_from_end_step_ceiled_to_24_hours(startStep, endStep, field=field)
95
+
96
+ case "set_start_step_to_zero":
97
+ startStep, endStep = 0, endStep
98
+
99
+ case _ as options:
100
+ raise ValueError(f"Unknown option for patch.start_step_is_end_step: {options}")
101
+
102
+ return startStep, endStep
103
+
104
+ def start_step_greater_than_end_step(self, startStep, endStep, field=None):
105
+
106
+ # Patch to handle cases where start_step is greater than end_step
107
+ # this should not happen in normal cases but some datasets have this issue
108
+ # The default is to do swap the values of start_step and end_step
109
+ # This can be disabled by setting the patch to False
110
+
111
+ match self.patches.get("start_step_greater_than_end_step", None):
112
+
113
+ case False | None:
114
+ pass # do nothing
115
+
116
+ case "swap":
117
+ startStep, endStep = endStep, startStep
118
+
119
+ case _ as options:
120
+ raise ValueError(f"Unknown option for patch.start_step_greater_than_end_step: {options}")
121
+
122
+ return startStep, endStep
123
+
124
+
125
+ def _set_start_step_from_end_step_ceiled_to_24_hours(startStep, endStep, field=None):
126
+ # Because the data wrongly encode start_step, but end_step is correct
127
+ # and we know that accumulations are always reseted every multiple of 24 hours
128
+ #
129
+ # 1-1 -> 0-1
130
+ # 2-2 -> 0-2
131
+ # ...
132
+ # 23-23 -> 0-23
133
+ # 24-24 -> 0-24
134
+ # 25-25 -> 24-25
135
+ # 26-26 -> 24-26
136
+ # ...
137
+ # 47-47 -> 24-47
138
+ # 48-48 -> 24-48
139
+ # 49-49 -> 48-49
140
+ # 50-50 -> 48-50
141
+ # etc.
142
+ if endStep % 24 == 0:
143
+ # Special case: endStep is exactly 24, 48, 72, etc.
144
+ # Map to previous 24-hour boundary (24 -> 0, 48 -> 24, etc.)
145
+ return endStep - 24, endStep
146
+
147
+ # General case: floor to the nearest 24-hour boundary
148
+ # (1-23 -> 0, 25-47 -> 24, etc.)
149
+ return endStep - (endStep % 24), endStep
@@ -0,0 +1,321 @@
1
+ # (C) Copyright 2025 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+
11
+ import datetime
12
+ import logging
13
+ from abc import abstractmethod
14
+ from typing import Iterable
15
+
16
+ from anemoi.utils.dates import frequency_to_timedelta
17
+
18
+ from anemoi.datasets.create.sources.accumulate_utils.covering_intervals import SignedInterval
19
+ from anemoi.datasets.create.sources.accumulate_utils.covering_intervals import covering_intervals
20
+
21
+ LOG = logging.getLogger(__name__)
22
+
23
+
24
+ def build_interval(
25
+ current_time: datetime.datetime, start_step: int, end_step: int, base_time: str | int | None
26
+ ) -> SignedInterval:
27
+ """Build a SignedInterval object corresponding to current_time's day
28
+ This SignedInterval may not have a base datetime
29
+
30
+ """
31
+ try:
32
+ usable_base_time = int(base_time) if base_time is not None else 0
33
+ except ValueError:
34
+ raise ValueError(f"Invalid base_time: {base_time} ({type(base_time)})")
35
+ base = datetime.datetime(current_time.year, current_time.month, current_time.day, usable_base_time)
36
+ start = base + datetime.timedelta(hours=start_step)
37
+ end = base + datetime.timedelta(hours=end_step)
38
+
39
+ interval_base = base if base_time is not None else None
40
+
41
+ return SignedInterval(start=start, end=end, base=interval_base)
42
+
43
+
44
+ class IntervalGenerator:
45
+ """Abstract base class to generate intervals.
46
+ Call to IntervalGenerator will provide candidate intervals to be selected by the covering_intervals method
47
+ """
48
+
49
+ @abstractmethod
50
+ def covering_intervals(self, start: datetime, end: datetime) -> Iterable[SignedInterval]:
51
+ pass
52
+
53
+ @abstractmethod
54
+ def __call__():
55
+ pass
56
+
57
+
58
+ class Pattern:
59
+ """Common format of config arguments to build SearchableIntervalGenerator
60
+ Used to supply candidate intervals for IntervalGenerator
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ base_time: str | datetime.datetime,
66
+ steps: str | list[str],
67
+ search_range: list[int] | None = None,
68
+ base_date: dict | None = None,
69
+ ):
70
+ steps = normalise_steps(steps)
71
+ self.steps = steps
72
+
73
+ if base_time == "*":
74
+ base_time = None
75
+ self.base_time = base_time
76
+
77
+ if search_range is None:
78
+ search_range = [datetime.timedelta(days=d) for d in [-1, 0, 1]]
79
+ else:
80
+ search_range = [datetime.timedelta(days=d) for d in search_range]
81
+ self.search_range = search_range
82
+
83
+ if base_date:
84
+ assert isinstance(base_date, dict), base_date
85
+ assert "day_of_month" in base_date, base_date
86
+ assert len(base_date) == 1, base_date
87
+ self.base_date = base_date
88
+
89
+ def filter(self, interval: SignedInterval) -> bool:
90
+ if self.base_date:
91
+ if interval.base.day != self.base_date["day_of_month"]:
92
+ return False
93
+ return True
94
+
95
+
96
+ class SearchableIntervalGenerator(IntervalGenerator):
97
+ def __init__(self, config: tuple | list | dict):
98
+ if isinstance(config, (tuple, list)):
99
+ patterns = []
100
+ for base_time, steps in config:
101
+ patterns.append(Pattern(base_time=base_time, steps=steps))
102
+
103
+ if isinstance(config, dict):
104
+ patterns = [Pattern(**config)]
105
+
106
+ self.patterns: list[Pattern] = patterns
107
+
108
+ def covering_intervals(self, start: datetime.datetime, end: datetime.datetime) -> Iterable[SignedInterval]:
109
+ """Perform interval search among candidates with minimal base switches and length.
110
+ Candidates are given by self call
111
+ Return available SignedIntervals covering the period start->end (where start>end is possible)
112
+ """
113
+ return covering_intervals(start, end, self)
114
+
115
+ def __call__(
116
+ self,
117
+ current_time: datetime.datetime,
118
+ ) -> Iterable[SignedInterval]:
119
+ """This generates candidate intervals starting or ending at the given current_time
120
+ Candidates correspond to the pairs of base_time, steps stored in patterns
121
+ """
122
+ intervals = []
123
+ for p in self.patterns:
124
+ search_range = p.search_range
125
+ for delta in search_range:
126
+ base_time = p.base_time
127
+ steps = p.steps
128
+ for start_step, end_step in steps:
129
+ interval = build_interval(current_time + delta, start_step, end_step, base_time)
130
+ if not p.filter(interval):
131
+ continue
132
+
133
+ if interval not in intervals:
134
+ intervals.append(interval)
135
+
136
+ # filter only the interval starting at current_time (or ending at current_time)
137
+ filtered = []
138
+ for i in intervals:
139
+ if i.start == current_time:
140
+ filtered.append(i)
141
+ elif (-i).start == current_time:
142
+ filtered.append(-i)
143
+ intervals = filtered
144
+
145
+ # quite important to sort by reversed base to prioritise most recent base in case of ties
146
+ # in some cases, we may want to sort by other criteria
147
+ intervals = sorted(intervals, key=lambda x: -(x.base or x.start).timestamp())
148
+
149
+ return intervals
150
+
151
+
152
+ def normalise_steps(steps_list: str | list[str]) -> list[list[int]]:
153
+ """Convert the input step_list to a list of [start,end] pairs"""
154
+ res = []
155
+ if isinstance(steps_list, str):
156
+ steps_list = steps_list.split("/")
157
+ assert isinstance(steps_list, list), steps_list
158
+
159
+ for start_end_step in steps_list:
160
+ if isinstance(start_end_step, str):
161
+ assert "-" in start_end_step, start_end_step
162
+ start_end_step = start_end_step.split("-")
163
+ assert isinstance(start_end_step, (list, tuple)) and len(start_end_step) == 2, start_end_step
164
+ start_step, end_step = int(start_end_step[0]), int(start_end_step[1])
165
+ res.append([start_step, end_step])
166
+ return res
167
+
168
+
169
+ class AccumulatedFromStartIntervalGenerator(SearchableIntervalGenerator):
170
+ def __init__(self, basetime: str | datetime.datetime, frequency: int, last_step: int):
171
+ config = []
172
+ for base in basetime:
173
+ for i in range(0, last_step, frequency):
174
+ config.append([base, [f"0-{i+frequency}"]])
175
+ super().__init__(config)
176
+
177
+
178
+ class AccumulatedFromPreviousStepIntervalGenerator(SearchableIntervalGenerator):
179
+ def __init__(self, basetime: str | datetime.datetime, frequency: int, last_step: int):
180
+ config = []
181
+ for base in basetime:
182
+ for i in range(0, last_step, frequency):
183
+ config.append([base, [f"{i}-{i+frequency}"]])
184
+ super().__init__(config)
185
+
186
+
187
+ def _match_mars_config(_class: str, _stream: str | None = None, _origin: str | None = None) -> list | tuple:
188
+ """Match MARS configuration (class, stream, origin) to interval generator config.
189
+
190
+ Parameters
191
+ ----------
192
+ _class
193
+ MARS class (e.g., 'ea', 'od', 'rr', 'l5').
194
+ _stream
195
+ MARS stream (e.g., 'oper', 'enda', 'elda', 'enfo'). Defaults to 'oper'.
196
+ _origin
197
+ MARS origin (e.g., 'se-al-ec', 'fr-ms-ec'). Defaults to None.
198
+
199
+ Returns
200
+ -------
201
+ list | tuple
202
+ Interval generator configuration.
203
+
204
+ Raises
205
+ ------
206
+ NotImplementedError
207
+ If the combination is not yet implemented.
208
+ ValueError
209
+ If the combination is unknown.
210
+ """
211
+ _stream = _stream or "oper"
212
+
213
+ match (_class, _stream, _origin):
214
+ case ("ea", "oper", _):
215
+ return [
216
+ (6, "0-1/1-2/2-3/3-4/4-5/5-6/6-7/7-8/8-9/9-10/10-11/11-12/12-13/13-14/14-15/15-16/16-17/17-18"),
217
+ (18, "0-1/1-2/2-3/3-4/4-5/5-6/6-7/7-8/8-9/9-10/10-11/11-12/12-13/13-14/14-15/15-16/16-17/17-18"),
218
+ ]
219
+ case ("ea", "enda", _):
220
+ return [
221
+ (6, "0-3/3-6/6-9/9-12/12-15/15-18"),
222
+ (18, "0-3/3-6/6-9/9-12/12-15/15-18"),
223
+ ]
224
+
225
+ case ("od", "oper", _):
226
+ # https://apps.ecmwf.int/mars-catalogue/?stream=oper&levtype=sfc&time=00%3A00%3A00&expver=1&month=aug&year=2020&date=2020-08-25&type=fc&class=od
227
+ steps = [f"{0}-{i}" for i in range(1, 91)]
228
+ return ((0, steps), (12, steps))
229
+ case ("od", "elda", _):
230
+ # https://apps.ecmwf.int/mars-catalogue/?stream=elda&levtype=sfc&time=06%3A00%3A00&expver=1&month=aug&year=2020&date=2020-08-31&type=fc&class=od
231
+ # (6, "0-1/0-2/0-3/0-4/0-5/0-6/0-7/0-8/0-9/0-10/0-11/0-12"),
232
+ # (18, "0-1/0-2/0-3/0-4/0-5/0-6/0-7/0-8/0-9/0-10/0-11/0-12")
233
+ steps = [f"{0}-{i}" for i in range(1, 13)]
234
+ return ((6, steps), (18, steps))
235
+ case ("od", "enfo", _):
236
+ # https://apps.ecmwf.int/mars-catalogue/?class=od&stream=enfo&expver=1&type=fc&year=2020&month=aug&levtype=sfc&date=2020-08-31&time=06:00:00
237
+ raise NotImplementedError("od-enfo interval generator not implemented yet")
238
+
239
+ case ("rr", _, "se-al-ec"):
240
+ # https://apps.ecmwf.int/mars-catalogue/?class=rr&expver=prod&origin=se-al-ec&stream=oper&type=fc&year=2020&month=aug&levtype=sfc
241
+ return [[0, [(0, i) for i in [1, 2, 3, 4, 5, 6, 9, 12, 15, 18, 21, 24, 27, 30]]]]
242
+ case ("rr", _, "fr-ms-ec"):
243
+ # https://apps.ecmwf.int/mars-catalogue/?origin=fr-ms-ec&stream=oper&levtype=sfc&time=06%3A00%3A00&expver=prod&month=aug&year=2020&date=2020-08-31&type=fc&class=rr
244
+ return [[0, [(0, i) for i in range(1, 22, 3)]]]
245
+
246
+ case ("l5", "oper", _):
247
+ # https://apps.ecmwf.int/mars-catalogue/?class=l5&stream=oper&expver=1&type=fc&year=2020&month=aug&levtype=sfc&date=2020-08-25&time=00:00:00
248
+ return [[0, [(int(i), int(i) + 1) for i in range(0, 24)]]]
249
+
250
+ case _:
251
+ raise ValueError(f"Unknown MARS configuration: class={_class}, stream={_stream}, origin={_origin}")
252
+
253
+
254
+ def _interval_generator_factory(
255
+ config, source_name: str | None = None, source: dict | None = None
256
+ ) -> IntervalGenerator | list | dict:
257
+ match config:
258
+ case IntervalGenerator():
259
+ return config
260
+
261
+ case {"type": "accumulated-from-start", **params}:
262
+ return AccumulatedFromStartIntervalGenerator(**params)
263
+ case {"accumulated-from-start": params}:
264
+ return AccumulatedFromStartIntervalGenerator(**params)
265
+
266
+ case {"accumulated-from-previous-step": params}:
267
+ return AccumulatedFromPreviousStepIntervalGenerator(**params)
268
+ case {"type": "accumulated-from-previous-step", **params}:
269
+ return AccumulatedFromPreviousStepIntervalGenerator(**params)
270
+
271
+ case {"type": _, **params}:
272
+ raise NotImplementedError(f"Unknown availability config {config}")
273
+
274
+ case {"mars": mars_config}:
275
+ _class = mars_config.get("class")
276
+ _stream = mars_config.get("stream")
277
+ _origin = mars_config.get("origin")
278
+ assert _class is not None, "mars config must have a 'class' key"
279
+ return _match_mars_config(_class, _stream, _origin)
280
+
281
+ case dict() | list() | tuple():
282
+ return SearchableIntervalGenerator(config)
283
+
284
+ case "auto":
285
+ assert None not in (source_name, source), "Source must be specified when using 'auto' discovery"
286
+ assert source_name == "mars", "Only 'mars' source is currently supported for 'auto' availability discovery"
287
+
288
+ _class, _stream, _origin = source.get("class", None), source.get("stream", None), source.get("origin", None)
289
+
290
+ assert (
291
+ _class is not None
292
+ ), "Availability should be automatically determined from mars source, but the mars source has no 'class'"
293
+
294
+ if (_stream is None) or (_origin is None):
295
+ LOG.warning(
296
+ f"Stream and/or origin unspecified for class {_class}, "
297
+ f"stream and/or origin will be set as defaults.",
298
+ )
299
+
300
+ return {"mars": {"class": _class, "stream": _stream, "origin": _origin}}
301
+
302
+ case str():
303
+ try:
304
+ data_accumulation_period = frequency_to_timedelta(config)
305
+ except Exception as e:
306
+ raise ValueError(f"Unknown interval generator config: {config}") from e
307
+
308
+ hours = data_accumulation_period.total_seconds() / 3600
309
+ if not (hours.is_integer() and hours > 0):
310
+ raise ValueError("Only accumulation periods multiple of 1 hour are supported for now")
311
+
312
+ return [["*", [f"{i}-{i+1}" for i in range(0, 24)]]]
313
+
314
+ case _:
315
+ raise ValueError(f"Unknown interval generator config: {config}")
316
+
317
+
318
+ def interval_generator_factory(config, source_name: str | None = None, source: dict | None = None) -> IntervalGenerator:
319
+ while not isinstance(config, IntervalGenerator):
320
+ config = _interval_generator_factory(config, source_name, source)
321
+ return config