anemoi-datasets 0.5.28__py3-none-any.whl → 0.5.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/create/__init__.py +4 -12
  3. anemoi/datasets/create/config.py +50 -53
  4. anemoi/datasets/create/input/result/field.py +1 -3
  5. anemoi/datasets/create/sources/accumulate.py +517 -0
  6. anemoi/datasets/create/sources/accumulate_utils/__init__.py +8 -0
  7. anemoi/datasets/create/sources/accumulate_utils/covering_intervals.py +221 -0
  8. anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py +149 -0
  9. anemoi/datasets/create/sources/accumulate_utils/interval_generators.py +321 -0
  10. anemoi/datasets/create/sources/grib_index.py +64 -20
  11. anemoi/datasets/create/sources/mars.py +56 -27
  12. anemoi/datasets/create/sources/xarray_support/__init__.py +1 -0
  13. anemoi/datasets/create/sources/xarray_support/coordinates.py +1 -4
  14. anemoi/datasets/create/sources/xarray_support/flavour.py +2 -2
  15. anemoi/datasets/create/sources/xarray_support/patch.py +178 -5
  16. anemoi/datasets/data/complement.py +26 -17
  17. anemoi/datasets/data/dataset.py +6 -0
  18. anemoi/datasets/data/masked.py +74 -13
  19. anemoi/datasets/data/missing.py +5 -0
  20. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/METADATA +7 -7
  21. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/RECORD +25 -23
  22. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/WHEEL +1 -1
  23. anemoi/datasets/create/sources/accumulations.py +0 -1042
  24. anemoi/datasets/create/sources/accumulations2.py +0 -618
  25. anemoi/datasets/create/sources/tendencies.py +0 -171
  26. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/entry_points.txt +0 -0
  27. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/licenses/LICENSE +0 -0
  28. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/top_level.txt +0 -0
@@ -1,618 +0,0 @@
1
- # (C) Copyright 2024 Anemoi contributors.
2
- #
3
- # This software is licensed under the terms of the Apache Licence Version 2.0
4
- # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
- #
6
- # In applying this licence, ECMWF does not waive the privileges and immunities
7
- # granted to it by virtue of its status as an intergovernmental organisation
8
- # nor does it submit to any jurisdiction.
9
-
10
- import datetime
11
- import logging
12
- from abc import abstractmethod
13
- from copy import deepcopy
14
- from typing import Any
15
-
16
- import earthkit.data as ekd
17
- import numpy as np
18
- from earthkit.data.core.temporary import temp_file
19
- from earthkit.data.readers.grib.output import new_grib_output
20
-
21
- from anemoi.datasets.create.sources import source_registry
22
- from anemoi.datasets.create.sources.mars import mars
23
-
24
- from .legacy import LegacySource
25
-
26
- LOG = logging.getLogger(__name__)
27
-
28
-
29
- xprint = print
30
-
31
-
32
- def _member(field: Any) -> int:
33
- """Retrieves the member number from the field metadata.
34
-
35
- Parameters
36
- ----------
37
- field : Any
38
- The field from which to retrieve the member number.
39
-
40
- Returns
41
- -------
42
- int
43
- The member number.
44
- """
45
- # Bug in eccodes has number=0 randomly
46
- number = field.metadata("number", default=0)
47
- if number is None:
48
- number = 0
49
- return number
50
-
51
-
52
- class Period:
53
- value = None
54
-
55
- def __init__(self, start_datetime, end_datetime, base_datetime):
56
- assert isinstance(start_datetime, datetime.datetime)
57
- assert isinstance(end_datetime, datetime.datetime)
58
- assert isinstance(base_datetime, datetime.datetime)
59
-
60
- self.start_datetime = start_datetime
61
- self.end_datetime = end_datetime
62
-
63
- self.base_datetime = base_datetime
64
-
65
- @property
66
- def time_request(self):
67
- date = int(self.base_datetime.strftime("%Y%m%d"))
68
- time = int(self.base_datetime.strftime("%H%M"))
69
-
70
- end_step = self.end_datetime - self.base_datetime
71
- assert end_step.total_seconds() % 3600 == 0, end_step # only full hours supported
72
- end_step = int(end_step.total_seconds() // 3600)
73
-
74
- return (("date", date), ("time", time), ("step", end_step))
75
-
76
- def field_to_key(self, field):
77
- return (
78
- ("date", field.metadata("date")),
79
- ("time", field.metadata("time")),
80
- ("step", field.metadata("step")),
81
- )
82
-
83
- def check(self, field):
84
- stepType = field.metadata("stepType")
85
- startStep = field.metadata("startStep")
86
- endStep = field.metadata("endStep")
87
- date = field.metadata("date")
88
- time = field.metadata("time")
89
-
90
- assert stepType == "accum", stepType
91
-
92
- base_datetime = datetime.datetime.strptime(str(date) + str(time).zfill(4), "%Y%m%d%H%M")
93
-
94
- start = base_datetime + datetime.timedelta(hours=startStep)
95
- assert start == self.start_datetime, (start, self.start_datetime)
96
-
97
- end = base_datetime + datetime.timedelta(hours=endStep)
98
- assert end == self.end_datetime, (end, self.end_datetime)
99
-
100
- def is_matching_field(self, field):
101
- return self.field_to_key(field) == self.time_request
102
-
103
- def __repr__(self):
104
- return f"Period({self.start_datetime} to {self.end_datetime} -> {self.time_request})"
105
-
106
- def length(self):
107
- return self.end_datetime - self.start_datetime
108
-
109
- def apply(self, accumulated, values):
110
-
111
- if accumulated is None:
112
- accumulated = np.zeros_like(values)
113
-
114
- assert accumulated.shape == values.shape, (accumulated.shape, values.shape)
115
-
116
- # if not np.all(values >= 0):
117
- # warnings.warn(f"Negative values for {values}: {np.amin(values)} {np.amax(values)}")
118
-
119
- return accumulated + self.sign * values
120
-
121
-
122
- class TodoList:
123
- def __init__(self, keys):
124
- self._todo = set(keys)
125
- self._len = len(keys)
126
- self._done = set()
127
- assert self._len == len(self._todo), (self._len, len(self._todo))
128
-
129
- def is_todo(self, key):
130
- return key in self._todo
131
-
132
- def is_done(self, key):
133
- return key in self._done
134
-
135
- def set_done(self, key):
136
- self._done.add(key)
137
- self._todo.remove(key)
138
-
139
- def all_done(self):
140
- if not self._todo:
141
- assert len(self._done) == self._len, (len(self._done), self._len)
142
- return True
143
- return False
144
-
145
-
146
- class Periods:
147
- _todo = None
148
-
149
- def __init__(self, valid_date, accumulation_period, **kwargs):
150
- # one Periods object for each accumulated field in the output
151
-
152
- assert isinstance(valid_date, datetime.datetime), (valid_date, type(valid_date))
153
- assert isinstance(accumulation_period, datetime.timedelta), (accumulation_period, type(accumulation_period))
154
- self.valid_date = valid_date
155
- self.accumulation_period = accumulation_period
156
- self.kwargs = kwargs
157
-
158
- self._periods = self.build_periods()
159
- self.check_merged_interval()
160
-
161
- def check_merged_interval(self):
162
- global_start = self.valid_date - self.accumulation_period
163
- global_end = self.valid_date
164
- resolution = datetime.timedelta(hours=1)
165
-
166
- timeline = np.arange(
167
- np.datetime64(global_start, "s"), np.datetime64(global_end, "s"), np.timedelta64(resolution)
168
- )
169
-
170
- flags = np.zeros_like(timeline, dtype=int)
171
- for p in self._periods:
172
- segment = np.where((timeline >= p.start_datetime) & (timeline < p.end_datetime))
173
- xprint(segment)
174
- flags[segment] += p.sign
175
- assert np.all(flags == 1), flags
176
-
177
- def find_matching_period(self, field):
178
- # Find a period that matches the field, or return None
179
- found = [p for p in self._periods if p.is_matching_field(field)]
180
- if len(found) == 1:
181
- return found[0]
182
- if len(found) > 1:
183
- raise ValueError(f"Found more than one period for {field}")
184
- return None
185
-
186
- @property
187
- def todo(self):
188
- if self._todo is None:
189
- self._todo = TodoList([p.time_request for p in self._periods])
190
- return self._todo
191
-
192
- def is_todo(self, period):
193
- return self.todo.is_todo(period.time_request)
194
-
195
- def is_done(self, period):
196
- return self.todo.is_done(period.time_request)
197
-
198
- def set_done(self, period):
199
- self.todo.set_done(period.time_request)
200
-
201
- def all_done(self):
202
- return self.todo.all_done()
203
-
204
- def __iter__(self):
205
- return iter(self._periods)
206
-
207
- @abstractmethod
208
- def build_periods(self):
209
- pass
210
-
211
-
212
- class EraPeriods(Periods):
213
- def search_periods(self, start, end, debug=False):
214
- # find candidate periods that can be used to accumulate the data
215
- # to get the accumulation between the two dates 'start' and 'end'
216
- found = []
217
- if not end - start == datetime.timedelta(hours=1):
218
- raise NotImplementedError("Only 1 hour period is supported")
219
-
220
- for base_time, steps in self.available_steps(start, end).items():
221
- for step1, step2 in steps:
222
- if debug:
223
- xprint(f"❌ tring: {base_time=} {step1=} {step2=}")
224
-
225
- if ((base_time + step1) % 24) != start.hour:
226
- continue
227
-
228
- if ((base_time + step2) % 24) != end.hour:
229
- continue
230
-
231
- base_datetime = start - datetime.timedelta(hours=step1)
232
-
233
- period = Period(start, end, base_datetime)
234
- found.append(period)
235
-
236
- assert base_datetime.hour == base_time, (base_datetime, base_time)
237
-
238
- assert period.start_datetime - period.base_datetime == datetime.timedelta(hours=step1), (
239
- period.start_datetime,
240
- period.base_datetime,
241
- step1,
242
- )
243
- assert period.end_datetime - period.base_datetime == datetime.timedelta(hours=step2), (
244
- period.end_datetime,
245
- period.base_datetime,
246
- step2,
247
- )
248
-
249
- return found
250
-
251
- def build_periods(self):
252
- # build the list of periods to accumulate the data
253
-
254
- hours = self.accumulation_period.total_seconds() / 3600
255
- assert int(hours) == hours, f"Only full hours accumulation is supported {hours}"
256
- hours = int(hours)
257
-
258
- lst = []
259
- for wanted in [[i, i + 1] for i in range(0, hours, 1)]:
260
-
261
- start = self.valid_date - datetime.timedelta(hours=wanted[1])
262
- end = self.valid_date - datetime.timedelta(hours=wanted[0])
263
-
264
- found = self.search_periods(start, end)
265
- if not found:
266
- xprint(f"❌❌❌ Cannot find accumulation for {start} {end}")
267
- self.search_periods(start, end, debug=True)
268
- raise ValueError(f"Cannot find accumulation for {start} {end}")
269
-
270
- found = sorted(found, key=lambda x: x.base_datetime, reverse=True)
271
- chosen = found[0]
272
-
273
- if len(found) > 1:
274
- xprint(f" Found more than one period for {start} {end}")
275
- for f in found:
276
- xprint(f" {f}")
277
- xprint(f" Chosing {chosen}")
278
-
279
- chosen.sign = 1
280
-
281
- lst.append(chosen)
282
- return lst
283
-
284
-
285
- class EaOperPeriods(EraPeriods):
286
- def available_steps(self, start, end):
287
- return {
288
- 6: [[i, i + 1] for i in range(0, 18, 1)],
289
- 18: [[i, i + 1] for i in range(0, 18, 1)],
290
- }
291
-
292
-
293
- class L5OperPeriods(EraPeriods):
294
- def available_steps(self, start, end):
295
- print("❌❌❌ untested")
296
- x = 24 # need to check if 24 is the right value
297
- return {
298
- 0: [[i, i + 1] for i in range(0, x, 1)],
299
- }
300
-
301
-
302
- class EaEndaPeriods(EraPeriods):
303
- def available_steps(self, start, end):
304
- print("❌❌❌ untested")
305
- return {
306
- 6: [[i, i + 3] for i in range(0, 18, 1)],
307
- 18: [[i, i + 3] for i in range(0, 18, 1)],
308
- }
309
-
310
-
311
- class RrOperPeriods(Periods):
312
- def available_steps(self, start, end):
313
- raise NotImplementedError("need to implement diff")
314
- x = 24 # todo: check if 24 is the right value
315
- return {
316
- 0: [[0, i] for i in range(0, x, 1)],
317
- 3: [[0, i] for i in range(0, x, 1)],
318
- 6: [[0, i] for i in range(0, x, 1)],
319
- 9: [[0, i] for i in range(0, x, 1)],
320
- 12: [[0, i] for i in range(0, x, 1)],
321
- 15: [[0, i] for i in range(0, x, 1)],
322
- 18: [[0, i] for i in range(0, x, 1)],
323
- 21: [[0, i] for i in range(0, x, 1)],
324
- }
325
-
326
-
327
- class OdEldaPeriods(EraPeriods):
328
- def available_steps(self, start, end):
329
- print("❌❌❌ untested")
330
- x = 24 # need to check if 24 is the right value
331
- return {
332
- 6: [[i, i + 1] for i in range(0, x, 1)],
333
- 18: [[i, i + 1] for i in range(0, x, 1)],
334
- }
335
-
336
-
337
- class DiffPeriods(Periods):
338
- pass
339
-
340
-
341
- class OdOperPeriods(DiffPeriods):
342
- def available_steps(self, start, end):
343
- raise NotImplementedError("need to implement diff and _scda patch")
344
-
345
-
346
- class OdEnfoPeriods(DiffPeriods):
347
- def available_steps(self, start, end):
348
- raise NotImplementedError("need to implement diff")
349
-
350
-
351
- def find_accumulator_class(class_: str, stream: str) -> Periods:
352
- return {
353
- ("ea", "oper"): EaOperPeriods, # runs ok
354
- ("ea", "enda"): EaEndaPeriods,
355
- ("rr", "oper"): RrOperPeriods,
356
- ("l5", "oper"): L5OperPeriods,
357
- ("od", "oper"): OdOperPeriods,
358
- ("od", "enfo"): OdEnfoPeriods,
359
- ("od", "elda"): OdEldaPeriods,
360
- }[class_, stream]
361
-
362
-
363
- class Accumulator:
364
- values = None
365
-
366
- def __init__(self, period_class, out, valid_date, user_accumulation_period, **kwargs):
367
- self.valid_date = valid_date
368
-
369
- # keep the reference to the output file to be able to write the result using an input field as template
370
- self.out = out
371
-
372
- # key contains the mars request parameters except the one related to the time
373
- # A mars request is a dictionary with three categories of keys:
374
- # - the ones related to the time (date, time, step)
375
- # - the ones related to the data (param, stream, levtype, expver, number, ...)
376
- # - the ones related to the processing to be done (grid, area, ...)
377
- self.kwargs = kwargs
378
- for k in ["date", "time", "step"]:
379
- if k in kwargs:
380
- raise ValueError(f"Cannot use {k} in kwargs for accumulations")
381
-
382
- self.key = {k: v for k, v in kwargs.items() if k in ["param", "level", "levelist", "number"]}
383
-
384
- self.periods = period_class(self.valid_date, user_accumulation_period, **kwargs)
385
-
386
- @property
387
- def requests(self):
388
- for period in self.periods:
389
- # build the full data requests, merging the time requests with the key
390
- yield {**self.kwargs.copy(), **dict(period.time_request)}
391
-
392
- def is_field_needed(self, field):
393
- for k, v in self.key.items():
394
- if field.metadata(k) != v:
395
- LOG.debug(f"{self} does not need field {field} because of {k}={field.metadata(k)} not {v}")
396
- return False
397
- return True
398
-
399
- def compute(self, field, values):
400
- if not self.is_field_needed(field):
401
- return
402
-
403
- period = self.periods.find_matching_period(field)
404
- if not period:
405
- return
406
- assert self.periods.is_todo(period), (self.periods, period)
407
- assert not self.periods.is_done(period), f"Field {field} for period {period} already done"
408
-
409
- period.check(field)
410
-
411
- xprint(f"{self} field ✅ ({period.sign}){field} for {period}")
412
-
413
- self.values = period.apply(self.values, values)
414
- self.periods.set_done(period)
415
-
416
- if self.periods.all_done():
417
- self.write(field)
418
- xprint("accumulator", self, " : data written ✅ ")
419
-
420
- def check(self, field: Any) -> None:
421
- if self._check is None:
422
- self._check = field.metadata(namespace="mars")
423
-
424
- assert self.param == field.metadata("param"), (self.param, field.metadata("param"))
425
- assert self.date == field.metadata("date"), (self.date, field.metadata("date"))
426
- assert self.time == field.metadata("time"), (self.time, field.metadata("time"))
427
- assert self.step == field.metadata("step"), (self.step, field.metadata("step"))
428
- assert self.number == _member(field), (self.number, _member(field))
429
- return
430
-
431
- mars = field.metadata(namespace="mars")
432
- keys1 = sorted(self._check.keys())
433
- keys2 = sorted(mars.keys())
434
-
435
- assert keys1 == keys2, (keys1, keys2)
436
-
437
- for k in keys1:
438
- if k not in ("step",):
439
- assert self._check[k] == mars[k], (k, self._check[k], mars[k])
440
-
441
- def write(self, template: Any) -> None:
442
- assert self.periods.all_done(), self.periods
443
-
444
- if np.all(self.values < 0):
445
- LOG.warning(
446
- f"Negative values when computing accumutation for {self}): min={np.amin(self.values)} max={np.amax(self.values)}"
447
- )
448
-
449
- startStep = 0
450
- endStep = self.periods.accumulation_period.total_seconds() // 3600
451
- assert int(endStep) == endStep, "only full hours accumulation is supported"
452
- endStep = int(endStep)
453
- fake_base_date = self.valid_date - self.periods.accumulation_period
454
- date = int(fake_base_date.strftime("%Y%m%d"))
455
- time = int(fake_base_date.strftime("%H%M"))
456
-
457
- self.out.write(
458
- self.values,
459
- template=template,
460
- stepType="accum",
461
- startStep=startStep,
462
- endStep=endStep,
463
- date=date,
464
- time=time,
465
- check_nans=True,
466
- )
467
- self.values = None
468
-
469
- def __repr__(self):
470
- key = ", ".join(f"{k}={v}" for k, v in self.key.items())
471
- return f"{self.__class__.__name__}({self.valid_date}, {key})"
472
-
473
-
474
- def _compute_accumulations(
475
- context: Any,
476
- dates: list[datetime.datetime],
477
- request: dict[str, Any],
478
- user_accumulation_period: datetime.timedelta,
479
- # data_accumulation_period: Optional[int] = None,
480
- # patch: Any = _identity,
481
- ) -> Any:
482
-
483
- request = deepcopy(request)
484
-
485
- param = request.pop("param")
486
- assert isinstance(param, (list, tuple))
487
-
488
- number = request.pop("number", [0])
489
- if not isinstance(number, (list, tuple)):
490
- number = [number]
491
- assert isinstance(number, (list, tuple))
492
-
493
- request["stream"] = request.get("stream", "oper")
494
-
495
- type_ = request.get("type", "an")
496
- if type_ == "an":
497
- type_ = "fc"
498
- request["type"] = type_
499
-
500
- request["levtype"] = request.get("levtype", "sfc")
501
- if request["levtype"] != "sfc":
502
- # LOG.warning("'type' should be 'sfc', found %s", request['type'])
503
- raise NotImplementedError("Only sfc leveltype is supported")
504
-
505
- period_class = find_accumulator_class(request["class"], request["stream"])
506
-
507
- tmp = temp_file()
508
- path = tmp.path
509
- out = new_grib_output(path)
510
-
511
- # build one accumulator per output field
512
- accumulators = []
513
- for valid_date in dates:
514
- for p in param:
515
- for n in number:
516
- accumulators.append(
517
- Accumulator(
518
- period_class,
519
- out,
520
- valid_date,
521
- user_accumulation_period=user_accumulation_period,
522
- param=p,
523
- number=n,
524
- **request,
525
- )
526
- )
527
-
528
- xprint("accumulators", len(accumulators))
529
-
530
- # get all needed data requests (mars)
531
- requests = []
532
- for a in accumulators:
533
- xprint("accumulator", a)
534
- for r in a.requests:
535
- xprint(" ", r)
536
- requests.append(r)
537
-
538
- # get the data (this will pack the requests to avoid duplicates and make a minimal number of requests)
539
- ds = mars(context, dates, request_already_using_valid_datetime=True, *requests)
540
-
541
- # send each field to the each accumulator, the accumulatore will use the field to the accumulation
542
- # if the accumulator has requested it
543
- for field in ds:
544
- values = field.values # optimisation
545
- for a in accumulators:
546
- a.compute(field, values)
547
-
548
- out.close()
549
-
550
- ds = ekd.from_source("file", path)
551
-
552
- assert len(ds) / len(param) / len(number) == len(dates), (
553
- len(ds),
554
- len(param),
555
- len(dates),
556
- )
557
-
558
- # keep a reference to the tmp file, or it gets deleted when the function returns
559
- ds._tmp = tmp
560
-
561
- return ds
562
-
563
-
564
- def _to_list(x: list[Any] | tuple[Any] | Any) -> list[Any]:
565
- """Converts the input to a list if it is not already a list or tuple.
566
-
567
- Parameters
568
- ----------
569
- x : Union[List[Any], Tuple[Any], Any]
570
- Input value.
571
-
572
- Returns
573
- -------
574
- List[Any]
575
- The input value as a list.
576
- """
577
- if isinstance(x, (list, tuple)):
578
- return x
579
- return [x]
580
-
581
-
582
- def _scda(request: dict[str, Any]) -> dict[str, Any]:
583
- """Modifies the request stream based on the time.
584
-
585
- Parameters
586
- ----------
587
- request : Dict[str, Any]
588
- Request parameters.
589
-
590
- Returns
591
- -------
592
- Dict[str, Any]
593
- The modified request parameters.
594
- """
595
- if request["time"] in (6, 18, 600, 1800):
596
- request["stream"] = "scda"
597
- else:
598
- request["stream"] = "oper"
599
- return request
600
-
601
-
602
- @source_registry.register("accumulations2")
603
- class Accumulations2Source(LegacySource):
604
-
605
- @staticmethod
606
- def _execute(context, dates, **request):
607
- _to_list(request["param"])
608
- user_accumulation_period = request.pop("accumulation_period", 6)
609
- user_accumulation_period = datetime.timedelta(hours=user_accumulation_period)
610
-
611
- context.trace("🌧️", f"accumulations {request} {user_accumulation_period}")
612
-
613
- return _compute_accumulations(
614
- context,
615
- dates,
616
- request,
617
- user_accumulation_period=user_accumulation_period,
618
- )