anemoi-datasets 0.5.18__py3-none-any.whl → 0.5.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/compare-lam.py +401 -0
  3. anemoi/datasets/commands/grib-index.py +114 -0
  4. anemoi/datasets/create/__init__.py +8 -1
  5. anemoi/datasets/create/config.py +5 -0
  6. anemoi/datasets/create/filters/pressure_level_relative_humidity_to_specific_humidity.py +3 -1
  7. anemoi/datasets/create/filters/pressure_level_specific_humidity_to_relative_humidity.py +3 -1
  8. anemoi/datasets/create/filters/transform.py +1 -3
  9. anemoi/datasets/create/filters/wz_to_w.py +3 -2
  10. anemoi/datasets/create/input/action.py +4 -1
  11. anemoi/datasets/create/input/function.py +5 -5
  12. anemoi/datasets/create/input/result.py +1 -1
  13. anemoi/datasets/create/sources/accumulations.py +13 -0
  14. anemoi/datasets/create/sources/accumulations2.py +652 -0
  15. anemoi/datasets/create/sources/anemoi_dataset.py +73 -0
  16. anemoi/datasets/create/sources/grib.py +7 -0
  17. anemoi/datasets/create/sources/grib_index.py +614 -0
  18. anemoi/datasets/create/sources/legacy.py +7 -2
  19. anemoi/datasets/create/sources/xarray_support/__init__.py +1 -1
  20. anemoi/datasets/create/sources/xarray_support/fieldlist.py +2 -2
  21. anemoi/datasets/create/sources/xarray_support/flavour.py +6 -0
  22. anemoi/datasets/create/sources/xarray_support/grid.py +15 -4
  23. anemoi/datasets/data/__init__.py +16 -0
  24. anemoi/datasets/data/complement.py +4 -1
  25. anemoi/datasets/data/dataset.py +14 -0
  26. anemoi/datasets/data/interpolate.py +76 -0
  27. anemoi/datasets/data/masked.py +77 -0
  28. anemoi/datasets/data/misc.py +159 -0
  29. anemoi/datasets/grids.py +8 -2
  30. {anemoi_datasets-0.5.18.dist-info → anemoi_datasets-0.5.20.dist-info}/METADATA +10 -4
  31. {anemoi_datasets-0.5.18.dist-info → anemoi_datasets-0.5.20.dist-info}/RECORD +35 -30
  32. {anemoi_datasets-0.5.18.dist-info → anemoi_datasets-0.5.20.dist-info}/WHEEL +1 -1
  33. {anemoi_datasets-0.5.18.dist-info → anemoi_datasets-0.5.20.dist-info}/entry_points.txt +0 -0
  34. {anemoi_datasets-0.5.18.dist-info → anemoi_datasets-0.5.20.dist-info}/licenses/LICENSE +0 -0
  35. {anemoi_datasets-0.5.18.dist-info → anemoi_datasets-0.5.20.dist-info}/top_level.txt +0 -0
@@ -636,6 +636,19 @@ def accumulations(
636
636
  Any
637
637
  The computed accumulations.
638
638
  """
639
+
640
+ if (
641
+ request.get("class") == "ea"
642
+ and request.get("stream", "oper") == "oper"
643
+ and request.get("accumulation_period") == 24
644
+ ):
645
+ from .accumulations2 import accumulations as accumulations2
646
+
647
+ LOG.warning(
648
+ "🧪️ Experimental features: Using accumulations2, because class=ea stream=oper and accumulation_period=24"
649
+ )
650
+ return accumulations2(context, dates, **request)
651
+
639
652
  _to_list(request["param"])
640
653
  class_ = request.get("class", "od")
641
654
  stream = request.get("stream", "oper")
@@ -0,0 +1,652 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+ import datetime
11
+ import logging
12
+ import warnings
13
+ from abc import abstractmethod
14
+ from copy import deepcopy
15
+ from typing import Any
16
+ from typing import Dict
17
+ from typing import List
18
+ from typing import Tuple
19
+ from typing import Union
20
+
21
+ import earthkit.data as ekd
22
+ import numpy as np
23
+ from earthkit.data.core.temporary import temp_file
24
+ from earthkit.data.readers.grib.output import new_grib_output
25
+
26
+ from anemoi.datasets.create.sources.mars import mars
27
+ from anemoi.datasets.create.utils import to_datetime_list
28
+
29
+ from .legacy import legacy_source
30
+
31
+ LOG = logging.getLogger(__name__)
32
+
33
+
34
+ xprint = print
35
+
36
+
37
+ def _member(field: Any) -> int:
38
+ """Retrieves the member number from the field metadata.
39
+
40
+ Parameters
41
+ ----------
42
+ field : Any
43
+ The field from which to retrieve the member number.
44
+
45
+ Returns
46
+ -------
47
+ int
48
+ The member number.
49
+ """
50
+ # Bug in eccodes has number=0 randomly
51
+ number = field.metadata("number", default=0)
52
+ if number is None:
53
+ number = 0
54
+ return number
55
+
56
+
57
+ class Period:
58
+ value = None
59
+
60
+ def __init__(self, start_datetime, end_datetime, base_datetime):
61
+ assert isinstance(start_datetime, datetime.datetime)
62
+ assert isinstance(end_datetime, datetime.datetime)
63
+ assert isinstance(base_datetime, datetime.datetime)
64
+
65
+ self.start_datetime = start_datetime
66
+ self.end_datetime = end_datetime
67
+
68
+ self.base_datetime = base_datetime
69
+
70
+ @property
71
+ def time_request(self):
72
+ date = int(self.base_datetime.strftime("%Y%m%d"))
73
+ time = int(self.base_datetime.strftime("%H%M"))
74
+
75
+ end_step = self.end_datetime - self.base_datetime
76
+ assert end_step.total_seconds() % 3600 == 0, end_step # only full hours supported
77
+ end_step = int(end_step.total_seconds() // 3600)
78
+
79
+ return (("date", date), ("time", time), ("step", end_step))
80
+
81
+ def field_to_key(self, field):
82
+ return (
83
+ ("date", field.metadata("date")),
84
+ ("time", field.metadata("time")),
85
+ ("step", field.metadata("step")),
86
+ )
87
+
88
+ def check(self, field):
89
+ stepType = field.metadata("stepType")
90
+ startStep = field.metadata("startStep")
91
+ endStep = field.metadata("endStep")
92
+ date = field.metadata("date")
93
+ time = field.metadata("time")
94
+
95
+ assert stepType == "accum", stepType
96
+
97
+ base_datetime = datetime.datetime.strptime(str(date) + str(time).zfill(4), "%Y%m%d%H%M")
98
+
99
+ start = base_datetime + datetime.timedelta(hours=startStep)
100
+ assert start == self.start_datetime, (start, self.start_datetime)
101
+
102
+ end = base_datetime + datetime.timedelta(hours=endStep)
103
+ assert end == self.end_datetime, (end, self.end_datetime)
104
+
105
+ def is_matching_field(self, field):
106
+ return self.field_to_key(field) == self.time_request
107
+
108
+ def __repr__(self):
109
+ return f"Period({self.start_datetime} to {self.end_datetime} -> {self.time_request})"
110
+
111
+ def length(self):
112
+ return self.end_datetime - self.start_datetime
113
+
114
+ def apply(self, accumulated, values):
115
+
116
+ if accumulated is None:
117
+ accumulated = np.zeros_like(values)
118
+
119
+ assert accumulated.shape == values.shape, (accumulated.shape, values.shape)
120
+
121
+ if not np.all(values >= 0):
122
+ warnings.warn(f"Negative values for {values}: {np.amin(values)} {np.amax(values)}")
123
+
124
+ return accumulated + self.sign * values
125
+
126
+
127
+ class TodoList:
128
+ def __init__(self, keys):
129
+ self._todo = set(keys)
130
+ self._len = len(keys)
131
+ self._done = set()
132
+ assert self._len == len(self._todo), (self._len, len(self._todo))
133
+
134
+ def is_todo(self, key):
135
+ return key in self._todo
136
+
137
+ def is_done(self, key):
138
+ return key in self._done
139
+
140
+ def set_done(self, key):
141
+ self._done.add(key)
142
+ self._todo.remove(key)
143
+
144
+ def all_done(self):
145
+ if not self._todo:
146
+ assert len(self._done) == self._len, (len(self._done), self._len)
147
+ return True
148
+ return False
149
+
150
+
151
+ class Periods:
152
+ _todo = None
153
+
154
+ def __init__(self, valid_date, accumulation_period, **kwargs):
155
+ # one Periods object for each accumulated field in the output
156
+
157
+ assert isinstance(valid_date, datetime.datetime), (valid_date, type(valid_date))
158
+ assert isinstance(accumulation_period, datetime.timedelta), (accumulation_period, type(accumulation_period))
159
+ self.valid_date = valid_date
160
+ self.accumulation_period = accumulation_period
161
+ self.kwargs = kwargs
162
+
163
+ self._periods = self.build_periods()
164
+ self.check_merged_interval()
165
+
166
+ def check_merged_interval(self):
167
+ global_start = self.valid_date - self.accumulation_period
168
+ global_end = self.valid_date
169
+ resolution = datetime.timedelta(hours=1)
170
+
171
+ timeline = np.arange(
172
+ np.datetime64(global_start, "s"), np.datetime64(global_end, "s"), np.timedelta64(resolution)
173
+ )
174
+
175
+ flags = np.zeros_like(timeline, dtype=int)
176
+ for p in self._periods:
177
+ segment = np.where((timeline >= p.start_datetime) & (timeline < p.end_datetime))
178
+ xprint(segment)
179
+ flags[segment] += p.sign
180
+ assert np.all(flags == 1), flags
181
+
182
+ def find_matching_period(self, field):
183
+ # Find a period that matches the field, or return None
184
+ found = [p for p in self._periods if p.is_matching_field(field)]
185
+ if len(found) == 1:
186
+ return found[0]
187
+ if len(found) > 1:
188
+ raise ValueError(f"Found more than one period for {field}")
189
+ return None
190
+
191
+ @property
192
+ def todo(self):
193
+ if self._todo is None:
194
+ self._todo = TodoList([p.time_request for p in self._periods])
195
+ return self._todo
196
+
197
+ def is_todo(self, period):
198
+ return self.todo.is_todo(period.time_request)
199
+
200
+ def is_done(self, period):
201
+ return self.todo.is_done(period.time_request)
202
+
203
+ def set_done(self, period):
204
+ self.todo.set_done(period.time_request)
205
+
206
+ def all_done(self):
207
+ return self.todo.all_done()
208
+
209
+ def __iter__(self):
210
+ return iter(self._periods)
211
+
212
+ @abstractmethod
213
+ def build_periods(self):
214
+ pass
215
+
216
+
217
+ class EraPeriods(Periods):
218
+ def search_periods(self, start, end, debug=False):
219
+ # find candidate periods that can be used to accumulate the data
220
+ # to get the accumulation between the two dates 'start' and 'end'
221
+ found = []
222
+ if not end - start == datetime.timedelta(hours=1):
223
+ raise NotImplementedError("Only 1 hour period is supported")
224
+
225
+ for base_time, steps in self.available_steps(start, end).items():
226
+ for step1, step2 in steps:
227
+ if debug:
228
+ xprint(f"❌ tring: {base_time=} {step1=} {step2=}")
229
+
230
+ if ((base_time + step1) % 24) != start.hour:
231
+ continue
232
+
233
+ if ((base_time + step2) % 24) != end.hour:
234
+ continue
235
+
236
+ base_datetime = start - datetime.timedelta(hours=step1)
237
+
238
+ period = Period(start, end, base_datetime)
239
+ found.append(period)
240
+
241
+ assert base_datetime.hour == base_time, (base_datetime, base_time)
242
+
243
+ assert period.start_datetime - period.base_datetime == datetime.timedelta(hours=step1), (
244
+ period.start_datetime,
245
+ period.base_datetime,
246
+ step1,
247
+ )
248
+ assert period.end_datetime - period.base_datetime == datetime.timedelta(hours=step2), (
249
+ period.end_datetime,
250
+ period.base_datetime,
251
+ step2,
252
+ )
253
+
254
+ return found
255
+
256
+ def build_periods(self):
257
+ # build the list of periods to accumulate the data
258
+
259
+ hours = self.accumulation_period.total_seconds() / 3600
260
+ assert int(hours) == hours, f"Only full hours accumulation is supported {hours}"
261
+ hours = int(hours)
262
+
263
+ lst = []
264
+ for wanted in [[i, i + 1] for i in range(0, hours, 1)]:
265
+
266
+ start = self.valid_date - datetime.timedelta(hours=wanted[1])
267
+ end = self.valid_date - datetime.timedelta(hours=wanted[0])
268
+
269
+ found = self.search_periods(start, end)
270
+ if not found:
271
+ xprint(f"❌❌❌ Cannot find accumulation for {start} {end}")
272
+ self.search_periods(start, end, debug=True)
273
+ raise ValueError(f"Cannot find accumulation for {start} {end}")
274
+
275
+ found = sorted(found, key=lambda x: x.base_datetime, reverse=True)
276
+ chosen = found[0]
277
+
278
+ if len(found) > 1:
279
+ xprint(f" Found more than one period for {start} {end}")
280
+ for f in found:
281
+ xprint(f" {f}")
282
+ xprint(f" Chosing {chosen}")
283
+
284
+ chosen.sign = 1
285
+
286
+ lst.append(chosen)
287
+ return lst
288
+
289
+
290
+ class EaOperPeriods(EraPeriods):
291
+ def available_steps(self, start, end):
292
+ return {
293
+ 6: [[i, i + 1] for i in range(0, 18, 1)],
294
+ 18: [[i, i + 1] for i in range(0, 18, 1)],
295
+ }
296
+
297
+
298
+ class L5OperPeriods(EraPeriods):
299
+ def available_steps(self, start, end):
300
+ print("❌❌❌ untested")
301
+ x = 24 # need to check if 24 is the right value
302
+ return {
303
+ 0: [[i, i + 1] for i in range(0, x, 1)],
304
+ }
305
+
306
+
307
+ class EaEndaPeriods(EraPeriods):
308
+ def available_steps(self, start, end):
309
+ print("❌❌❌ untested")
310
+ return {
311
+ 6: [[i, i + 3] for i in range(0, 18, 1)],
312
+ 18: [[i, i + 3] for i in range(0, 18, 1)],
313
+ }
314
+
315
+
316
+ class RrOperPeriods(Periods):
317
+ def available_steps(self, start, end):
318
+ raise NotImplementedError("need to implement diff")
319
+ x = 24 # todo: check if 24 is the right value
320
+ return {
321
+ 0: [[0, i] for i in range(0, x, 1)],
322
+ 3: [[0, i] for i in range(0, x, 1)],
323
+ 6: [[0, i] for i in range(0, x, 1)],
324
+ 9: [[0, i] for i in range(0, x, 1)],
325
+ 12: [[0, i] for i in range(0, x, 1)],
326
+ 15: [[0, i] for i in range(0, x, 1)],
327
+ 18: [[0, i] for i in range(0, x, 1)],
328
+ 21: [[0, i] for i in range(0, x, 1)],
329
+ }
330
+
331
+
332
+ class OdEldaPeriods(EraPeriods):
333
+ def available_steps(self, start, end):
334
+ print("❌❌❌ untested")
335
+ x = 24 # need to check if 24 is the right value
336
+ return {
337
+ 6: [[i, i + 1] for i in range(0, x, 1)],
338
+ 18: [[i, i + 1] for i in range(0, x, 1)],
339
+ }
340
+
341
+
342
+ class DiffPeriods(Periods):
343
+ pass
344
+
345
+
346
+ class OdOperPeriods(DiffPeriods):
347
+ def available_steps(self, start, end):
348
+ raise NotImplementedError("need to implement diff and _scda patch")
349
+
350
+
351
+ class OdEnfoPeriods(DiffPeriods):
352
+ def available_steps(self, start, end):
353
+ raise NotImplementedError("need to implement diff")
354
+
355
+
356
+ def find_accumulator_class(class_: str, stream: str) -> Periods:
357
+ return {
358
+ ("ea", "oper"): EaOperPeriods, # runs ok
359
+ ("ea", "enda"): EaEndaPeriods,
360
+ ("rr", "oper"): RrOperPeriods,
361
+ ("l5", "oper"): L5OperPeriods,
362
+ ("od", "oper"): OdOperPeriods,
363
+ ("od", "enfo"): OdEnfoPeriods,
364
+ ("od", "elda"): OdEldaPeriods,
365
+ }[class_, stream]
366
+
367
+
368
+ class Accumulator:
369
+ values = None
370
+
371
+ def __init__(self, period_class, out, valid_date, user_accumulation_period, **kwargs):
372
+ self.valid_date = valid_date
373
+
374
+ # keep the reference to the output file to be able to write the result using an input field as template
375
+ self.out = out
376
+
377
+ # key contains the mars request parameters except the one related to the time
378
+ # A mars request is a dictionary with three categories of keys:
379
+ # - the ones related to the time (date, time, step)
380
+ # - the ones related to the data (param, stream, levtype, expver, number, ...)
381
+ # - the ones related to the processing to be done (grid, area, ...)
382
+ self.kwargs = kwargs
383
+ for k in ["date", "time", "step"]:
384
+ if k in kwargs:
385
+ raise ValueError(f"Cannot use {k} in kwargs for accumulations")
386
+
387
+ self.key = {k: v for k, v in kwargs.items() if k in ["param", "level", "levelist", "number"]}
388
+
389
+ self.periods = period_class(self.valid_date, user_accumulation_period, **kwargs)
390
+
391
+ @property
392
+ def requests(self):
393
+ for period in self.periods:
394
+ # build the full data requests, merging the time requests with the key
395
+ yield {**self.kwargs.copy(), **dict(period.time_request)}
396
+
397
+ def is_field_needed(self, field):
398
+ for k, v in self.key.items():
399
+ if field.metadata(k) != v:
400
+ LOG.debug(f"{self} does not need field {field} because of {k}={field.metadata(k)} not {v}")
401
+ return False
402
+ return True
403
+
404
+ def compute(self, field, values):
405
+ if not self.is_field_needed(field):
406
+ return
407
+
408
+ period = self.periods.find_matching_period(field)
409
+ if not period:
410
+ return
411
+ assert self.periods.is_todo(period), (self.periods, period)
412
+ assert not self.periods.is_done(period), f"Field {field} for period {period} already done"
413
+
414
+ period.check(field)
415
+
416
+ xprint(f"{self} field ✅ ({period.sign}){field} for {period}")
417
+
418
+ self.values = period.apply(self.values, values)
419
+ self.periods.set_done(period)
420
+
421
+ if self.periods.all_done():
422
+ self.write(field)
423
+ xprint("accumulator", self, " : data written ✅ ")
424
+
425
+ def check(self, field: Any) -> None:
426
+ if self._check is None:
427
+ self._check = field.metadata(namespace="mars")
428
+
429
+ assert self.param == field.metadata("param"), (self.param, field.metadata("param"))
430
+ assert self.date == field.metadata("date"), (self.date, field.metadata("date"))
431
+ assert self.time == field.metadata("time"), (self.time, field.metadata("time"))
432
+ assert self.step == field.metadata("step"), (self.step, field.metadata("step"))
433
+ assert self.number == _member(field), (self.number, _member(field))
434
+ return
435
+
436
+ mars = field.metadata(namespace="mars")
437
+ keys1 = sorted(self._check.keys())
438
+ keys2 = sorted(mars.keys())
439
+
440
+ assert keys1 == keys2, (keys1, keys2)
441
+
442
+ for k in keys1:
443
+ if k not in ("step",):
444
+ assert self._check[k] == mars[k], (k, self._check[k], mars[k])
445
+
446
+ def write(self, template: Any) -> None:
447
+ assert self.periods.all_done(), self.periods
448
+
449
+ if np.all(self.values < 0):
450
+ LOG.warning(
451
+ f"Negative values when computing accumutation for {self}): min={np.amin(self.values)} max={np.amax(self.values)}"
452
+ )
453
+
454
+ startStep = 0
455
+ endStep = self.periods.accumulation_period.total_seconds() // 3600
456
+ assert int(endStep) == endStep, "only full hours accumulation is supported"
457
+ endStep = int(endStep)
458
+ fake_base_date = self.valid_date - self.periods.accumulation_period
459
+ date = int(fake_base_date.strftime("%Y%m%d"))
460
+ time = int(fake_base_date.strftime("%H%M"))
461
+
462
+ self.out.write(
463
+ self.values,
464
+ template=template,
465
+ stepType="accum",
466
+ startStep=startStep,
467
+ endStep=endStep,
468
+ date=date,
469
+ time=time,
470
+ check_nans=True,
471
+ )
472
+ self.values = None
473
+
474
+ def __repr__(self):
475
+ key = ", ".join(f"{k}={v}" for k, v in self.key.items())
476
+ return f"{self.__class__.__name__}({self.valid_date}, {key})"
477
+
478
+
479
+ def _compute_accumulations(
480
+ context: Any,
481
+ dates: List[datetime.datetime],
482
+ request: Dict[str, Any],
483
+ user_accumulation_period: datetime.timedelta,
484
+ # data_accumulation_period: Optional[int] = None,
485
+ # patch: Any = _identity,
486
+ ) -> Any:
487
+
488
+ request = deepcopy(request)
489
+
490
+ param = request.pop("param")
491
+ assert isinstance(param, (list, tuple))
492
+
493
+ number = request.pop("number", [0])
494
+ if not isinstance(number, (list, tuple)):
495
+ number = [number]
496
+ assert isinstance(number, (list, tuple))
497
+
498
+ request["stream"] = request.get("stream", "oper")
499
+
500
+ type_ = request.get("type", "an")
501
+ if type_ == "an":
502
+ type_ = "fc"
503
+ request["type"] = type_
504
+
505
+ request["levtype"] = request.get("levtype", "sfc")
506
+ if request["levtype"] != "sfc":
507
+ # LOG.warning("'type' should be 'sfc', found %s", request['type'])
508
+ raise NotImplementedError("Only sfc leveltype is supported")
509
+
510
+ period_class = find_accumulator_class(request["class"], request["stream"])
511
+
512
+ tmp = temp_file()
513
+ path = tmp.path
514
+ out = new_grib_output(path)
515
+
516
+ # build one accumulator per output field
517
+ accumulators = []
518
+ for valid_date in dates:
519
+ for p in param:
520
+ for n in number:
521
+ accumulators.append(
522
+ Accumulator(
523
+ period_class,
524
+ out,
525
+ valid_date,
526
+ user_accumulation_period=user_accumulation_period,
527
+ param=p,
528
+ number=n,
529
+ **request,
530
+ )
531
+ )
532
+
533
+ xprint("accumulators", len(accumulators))
534
+
535
+ # get all needed data requests (mars)
536
+ requests = []
537
+ for a in accumulators:
538
+ xprint("accumulator", a)
539
+ for r in a.requests:
540
+ xprint(" ", r)
541
+ requests.append(r)
542
+
543
+ # get the data (this will pack the requests to avoid duplicates and make a minimal number of requests)
544
+ ds = mars(context, dates, request_already_using_valid_datetime=True, *requests)
545
+
546
+ # send each field to the each accumulator, the accumulatore will use the field to the accumulation
547
+ # if the accumulator has requested it
548
+ for field in ds:
549
+ values = field.values # optimisation
550
+ for a in accumulators:
551
+ a.compute(field, values)
552
+
553
+ out.close()
554
+
555
+ ds = ekd.from_source("file", path)
556
+
557
+ assert len(ds) / len(param) / len(number) == len(dates), (
558
+ len(ds),
559
+ len(param),
560
+ len(dates),
561
+ )
562
+
563
+ # keep a reference to the tmp file, or it gets deleted when the function returns
564
+ ds._tmp = tmp
565
+
566
+ return ds
567
+
568
+
569
+ def _to_list(x: Union[List[Any], Tuple[Any], Any]) -> List[Any]:
570
+ """Converts the input to a list if it is not already a list or tuple.
571
+
572
+ Parameters
573
+ ----------
574
+ x : Union[List[Any], Tuple[Any], Any]
575
+ Input value.
576
+
577
+ Returns
578
+ -------
579
+ List[Any]
580
+ The input value as a list.
581
+ """
582
+ if isinstance(x, (list, tuple)):
583
+ return x
584
+ return [x]
585
+
586
+
587
+ def _scda(request: Dict[str, Any]) -> Dict[str, Any]:
588
+ """Modifies the request stream based on the time.
589
+
590
+ Parameters
591
+ ----------
592
+ request : Dict[str, Any]
593
+ Request parameters.
594
+
595
+ Returns
596
+ -------
597
+ Dict[str, Any]
598
+ The modified request parameters.
599
+ """
600
+ if request["time"] in (6, 18, 600, 1800):
601
+ request["stream"] = "scda"
602
+ else:
603
+ request["stream"] = "oper"
604
+ return request
605
+
606
+
607
+ @legacy_source(__file__)
608
+ def accumulations(context, dates, **request):
609
+ _to_list(request["param"])
610
+ user_accumulation_period = request.pop("accumulation_period", 6)
611
+ user_accumulation_period = datetime.timedelta(hours=user_accumulation_period)
612
+
613
+ context.trace("🌧️", f"accumulations {request} {user_accumulation_period}")
614
+
615
+ return _compute_accumulations(
616
+ context,
617
+ dates,
618
+ request,
619
+ user_accumulation_period=user_accumulation_period,
620
+ )
621
+
622
+
623
+ execute = accumulations
624
+
625
+ if __name__ == "__main__":
626
+ import yaml
627
+
628
+ config = yaml.safe_load(
629
+ """
630
+ class: ea
631
+ expver: '0001'
632
+ grid: 20./20.
633
+ levtype: sfc
634
+ # number: [0, 1]
635
+ # stream: enda
636
+ param: [cp, tp]
637
+ # accumulation_period: 6h
638
+ accumulation_period: 2
639
+ """
640
+ )
641
+ dates = yaml.safe_load("[2022-12-31 00:00, 2022-12-31 06:00]")
642
+ # dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]")
643
+ dates = to_datetime_list(dates)
644
+
645
+ class Context:
646
+ use_grib_paramid = True
647
+
648
+ def trace(self, *args):
649
+ print(*args)
650
+
651
+ for f in accumulations(Context, dates, **config):
652
+ print(f, f.to_numpy().mean())