anemoi-datasets 0.5.28__py3-none-any.whl → 0.5.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/create/__init__.py +4 -12
  3. anemoi/datasets/create/config.py +50 -53
  4. anemoi/datasets/create/input/result/field.py +1 -3
  5. anemoi/datasets/create/sources/accumulate.py +517 -0
  6. anemoi/datasets/create/sources/accumulate_utils/__init__.py +8 -0
  7. anemoi/datasets/create/sources/accumulate_utils/covering_intervals.py +221 -0
  8. anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py +153 -0
  9. anemoi/datasets/create/sources/accumulate_utils/interval_generators.py +321 -0
  10. anemoi/datasets/create/sources/grib_index.py +79 -51
  11. anemoi/datasets/create/sources/mars.py +56 -27
  12. anemoi/datasets/create/sources/xarray_support/__init__.py +1 -0
  13. anemoi/datasets/create/sources/xarray_support/coordinates.py +1 -4
  14. anemoi/datasets/create/sources/xarray_support/flavour.py +2 -2
  15. anemoi/datasets/create/sources/xarray_support/patch.py +178 -5
  16. anemoi/datasets/data/complement.py +26 -17
  17. anemoi/datasets/data/dataset.py +6 -0
  18. anemoi/datasets/data/masked.py +74 -13
  19. anemoi/datasets/data/missing.py +5 -0
  20. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/METADATA +8 -7
  21. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/RECORD +25 -23
  22. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/WHEEL +1 -1
  23. anemoi/datasets/create/sources/accumulations.py +0 -1042
  24. anemoi/datasets/create/sources/accumulations2.py +0 -618
  25. anemoi/datasets/create/sources/tendencies.py +0 -171
  26. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/entry_points.txt +0 -0
  27. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/licenses/LICENSE +0 -0
  28. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,517 @@
1
+ # (C) Copyright 2025 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+ import datetime
11
+ import hashlib
12
+ import json
13
+ import logging
14
+ from typing import Any
15
+
16
+ import earthkit.data
17
+ import numpy as np
18
+ from anemoi.utils.dates import frequency_to_string
19
+ from anemoi.utils.dates import frequency_to_timedelta
20
+ from earthkit.data.core.temporary import temp_file
21
+ from earthkit.data.readers.grib.output import new_grib_output
22
+ from numpy.typing import NDArray
23
+
24
+ from anemoi.datasets.create.sources import source_registry
25
+ from anemoi.datasets.create.sources.accumulate_utils.interval_generators import interval_generator_factory
26
+
27
+ from .accumulate_utils.covering_intervals import SignedInterval
28
+ from .accumulate_utils.field_to_interval import FieldToInterval
29
+ from .legacy import LegacySource
30
+
31
+ LOG = logging.getLogger(__name__)
32
+
33
+
34
+ def _adjust_request_to_interval(interval: Any, request: list[dict]) -> tuple[Any]:
35
+ # TODO:
36
+ # for od-oper: need to do this adjustment, should be in mars source itself?
37
+ # Modifies the request stream based on the time (so, not here).
38
+ # if request["time"] in (6, 18, 600, 1800):
39
+ # request["stream"] = "scda"
40
+ # else:
41
+ # request["stream"] = "oper"
42
+ r = request.copy()
43
+ if interval.base is None:
44
+ # for some sources, we may not have a base time (grib-index)
45
+ step = int((interval.end - interval.start).total_seconds() / 3600)
46
+ r["step"] = step
47
+ return interval.max, request, step
48
+ else:
49
+ step = int((interval.max - interval.base).total_seconds() / 3600)
50
+ r["date"] = interval.base.strftime("%Y%m%d")
51
+ r["time"] = interval.base.strftime("%H%M")
52
+ r["step"] = step
53
+ return interval.max, r, step
54
+
55
+
56
+ class IntervalsDatesProvider:
57
+ def __init__(self, dates, coverages):
58
+ self._dates = dates
59
+ self.date_to_intervals = coverages
60
+
61
+ def _adjust_request_to_interval(self, interval: Any, request: list[dict]) -> tuple[Any]:
62
+ return _adjust_request_to_interval(interval, request)
63
+
64
+ @property
65
+ def intervals(self):
66
+ for d in self._dates:
67
+ for interval in self.date_to_intervals[d]:
68
+ yield d, interval
69
+
70
+ def __len__(self):
71
+ return len(self._dates)
72
+
73
+ def __iter__(self):
74
+ yield from self._dates
75
+
76
+ def __getitem__(self, index):
77
+ return self._dates[index]
78
+
79
+
80
+ class Accumulator:
81
+ values: NDArray | None = None
82
+ locked: bool = False
83
+
84
+ def __init__(self, valid_date: datetime.datetime, period: datetime.timedelta, key: dict[str, Any], coverage):
85
+ # The accumulator only accumulates fields and does not know about the rest
86
+ # Accumulator object for a given param/member/valid_date
87
+
88
+ self.valid_date = valid_date
89
+ self.period = period
90
+ self.key = key
91
+
92
+ self.coverage = coverage
93
+
94
+ self.todo = [v for v in coverage]
95
+ self.done = []
96
+
97
+ self.values = None # will hold accumulated values array
98
+
99
+ def is_complete(self, **kwargs) -> bool:
100
+ """Check whether the accumulation is complete (all intervals have been processed)"""
101
+ return not self.todo
102
+
103
+ def compute(self, values: NDArray, interval: SignedInterval) -> None:
104
+ """Perform accumulation with the values array on this interval and record the operation.
105
+ Note: values have been extracted from field before the call to `compute`,
106
+ so values are read from field only once.
107
+
108
+ Parameters:
109
+ ----------
110
+ field: Any
111
+ An earthkit-data-like field
112
+ values: NDArray
113
+ Values from the field, will be added to the held values array
114
+
115
+ Return
116
+ ------
117
+ None
118
+ """
119
+
120
+ def match_interval(interval: SignedInterval, lst: list[SignedInterval]) -> bool:
121
+ for i in lst:
122
+ if i.min == interval.min and i.max == interval.max and i.base == interval.base:
123
+ return i
124
+ if i.start == interval.start and i.end == interval.end and i.base is None:
125
+ return i
126
+ return None
127
+
128
+ matching = match_interval(interval, self.todo)
129
+
130
+ if not matching:
131
+ # interval not needed for this accumulator
132
+ # this happens when multiple accumulators have the same key but different valid_date
133
+ return False
134
+
135
+ def raise_error(msg):
136
+ LOG.error(f"Accumulator {self.__repr__(verbose=True)} state:")
137
+ LOG.error(f"Received interval: {interval}")
138
+ LOG.error(f"Matching interval: {matching}")
139
+ raise ValueError(msg)
140
+
141
+ if matching in self.done:
142
+ # this should not happen normally
143
+ raise_error(f"SignedInterval {matching} already done for accumulator")
144
+
145
+ if self.locked:
146
+ raise_error(f"Accumulator already used, cannot process interval {interval}")
147
+
148
+ assert isinstance(values, np.ndarray), type(values)
149
+
150
+ # actual accumulation computation
151
+ # negative accumulation if interval is reversed
152
+ # copy is mandatory since value is shared between accumulators
153
+ local_values = matching.sign * values.copy()
154
+ if self.values is None:
155
+ self.values = local_values
156
+ else:
157
+ self.values += local_values
158
+
159
+ self.todo.remove(matching)
160
+ self.done.append(matching)
161
+ return True
162
+
163
+ def write_to_output(self, output, template) -> None:
164
+ assert self.is_complete(), (self.todo, self.done, self)
165
+ assert not self.locked # prevent double writing
166
+
167
+ # negative values may be an anomaly (e.g precipitation), but this is user's choice
168
+ for k, v in self.key:
169
+ if k == "param" and v == "tp":
170
+ if np.any(self.values < 0):
171
+ LOG.warning(
172
+ f"Negative values when computing accumutation for {self}): min={np.nanmin(self.values)} max={np.nanmax(self.values)}"
173
+ )
174
+ write_accumulated_field_with_valid_time(
175
+ template=template,
176
+ values=self.values,
177
+ valid_date=self.valid_date,
178
+ period=self.period,
179
+ output=output,
180
+ )
181
+ # lock the accumulator to prevent further use
182
+ self.locked = True
183
+
184
+ def __repr__(self, verbose: bool = False) -> str:
185
+ key = ", ".join(f"{k}={v}" for k, v in self.key)
186
+ period = frequency_to_string(self.period)
187
+ default = f"{self.__class__.__name__}(valid_date={self.valid_date}, {period}, key={{ {key} }})"
188
+ if verbose:
189
+ extra = []
190
+ if self.locked:
191
+ extra.append("(locked)")
192
+ for i in self.done:
193
+ extra.append(f" done: {i}")
194
+ for i in self.todo:
195
+ extra.append(f" todo: {i}")
196
+ default += "\n" + "\n".join(extra)
197
+ return default
198
+
199
+
200
+ def write_accumulated_field_with_valid_time(
201
+ template, values, valid_date: datetime.datetime, period: datetime.timedelta, output
202
+ ) -> Any:
203
+ MISSING_VALUE = 1e-38
204
+ assert np.all(values != MISSING_VALUE)
205
+
206
+ date = (valid_date - period).strftime("%Y%m%d")
207
+ time = (valid_date - period).strftime("%H%M")
208
+ endStep = period
209
+
210
+ hours = endStep.total_seconds() / 3600
211
+ if hours != int(hours):
212
+ raise ValueError(f"Accumulation period must be integer hours, got {hours}")
213
+ hours = int(hours)
214
+
215
+ if template.metadata("edition") == 1:
216
+ # this is a special case for GRIB edition 1 which only supports integer hours up to 254
217
+ assert hours <= 254, f"edition 1 accumulation period must be <=254 hours, got {hours}"
218
+ output.write(
219
+ values,
220
+ template=template,
221
+ date=int(date),
222
+ time=int(time),
223
+ stepType="instant",
224
+ step=hours,
225
+ check_nans=True,
226
+ missing_value=MISSING_VALUE,
227
+ )
228
+ else:
229
+ # this is the normal case for GRIB edition 2. And with edition 1 when hours are integer and <=254
230
+ output.write(
231
+ values,
232
+ template=template,
233
+ date=int(date),
234
+ time=int(time),
235
+ stepType="accum",
236
+ startStep=0,
237
+ endStep=hours,
238
+ check_nans=True,
239
+ missing_value=MISSING_VALUE,
240
+ )
241
+
242
+
243
+ class Logs(list):
244
+ def __init__(self, *args, accumulators, source, source_object, field_to_interval, **kwargs):
245
+ super().__init__(*args, **kwargs)
246
+ self.accumulators = accumulators
247
+ self.source = source
248
+ self.source_object = source_object
249
+ self.field_to_interval = field_to_interval
250
+
251
+ def raise_error(self, msg, field=None, field_interval=None) -> str:
252
+ INTERVAL_COLOR = "\033[93m"
253
+ FIELD_COLOR = "\033[92m"
254
+ KEY_COLOR = "\033[95m"
255
+ RESET_COLOR = "\033[0m"
256
+
257
+ res = [""]
258
+ res.append(f"❌ {msg}")
259
+ res.append(f"💬 Patches applied: {self.field_to_interval.patches}")
260
+ res.append("💬 Current field:")
261
+ res.append(f" {FIELD_COLOR}{field}{RESET_COLOR}")
262
+ res.append(f" {INTERVAL_COLOR}{field_interval}{RESET_COLOR}")
263
+ if self.accumulators:
264
+ res.append(f"💬 Existing accumulators ({len(self.accumulators)}) :")
265
+ for a in self.accumulators.values():
266
+ res.append(f" {a.__repr__(verbose=True)}")
267
+ res.append(f"💬 Received fields ({len(self)}):")
268
+ for log in self:
269
+ res.append(f" {KEY_COLOR}{log[0]}{RESET_COLOR} {INTERVAL_COLOR}{log[2]}{RESET_COLOR}")
270
+ res.append(f" {KEY_COLOR}{log[1]}{RESET_COLOR}")
271
+ for d, acc_repr in zip(log[3], log[4]):
272
+ res.append(f" used for date {d}: {acc_repr}")
273
+
274
+ LOG.error("\n".join(res))
275
+ res = ["More details below:"]
276
+
277
+ res.append(f"💬 Fields returned to be accumulated ({len(self.source_object)}):")
278
+ for field in self.source_object:
279
+ res.append(
280
+ f" {field}, startStep={field.metadata('startStep')}, endStep={field.metadata('endStep')} mean={np.nanmean(field.values, axis=0)}"
281
+ )
282
+
283
+ LOG.error("\n".join(res))
284
+ res = ["Even more details below:"]
285
+
286
+ if "mars" in self.source:
287
+ res.append("💬 Example of code fetching some available fields and inspect them:")
288
+ res.append("# --------------------------------------------------")
289
+ code = []
290
+ code.append("from earthkit.data import from_source")
291
+ code.append("import numpy as np")
292
+ code.append('ds = from_source("mars", **{')
293
+ for k, v in self.source["mars"].items():
294
+ code.append(f" {k!r}: {v!r},")
295
+ code.append(f' "date": {field.metadata("date")!r},')
296
+ code.append(f' "time": {field.metadata("time")!r}, # "ALL"')
297
+ code.append(f' "step": "ALL", # {field.metadata("step")!r},')
298
+ code.append("})")
299
+ code.append('print(f"Got {len(ds)} fields:")')
300
+ code.append("prev_m = None")
301
+ code.append("for field in ds[:50]: # limit to first 50 for brevity")
302
+ code.append(
303
+ ' print(f"{field} startStep={field.metadata("startStep")}, endStep={field.metadata("endStep")} mean={np.nanmean(field.values)}")'
304
+ )
305
+ res.append("# --------------------------------------------------")
306
+ code.append("")
307
+ res += code
308
+
309
+ # now execute the code to show actual field values
310
+ LOG.error("\n".join(res))
311
+
312
+ raise ValueError(msg)
313
+
314
+
315
+ def _compute_accumulations(
316
+ context: Any,
317
+ dates: list[datetime.datetime],
318
+ period: datetime.timedelta,
319
+ source: dict,
320
+ availability: dict[str, Any] | None = None,
321
+ patch: dict | None = None,
322
+ **kwargs,
323
+ ) -> Any:
324
+ """Concrete accumulation logic.
325
+
326
+ - identify the needed intervals for each date/parameter/member defined in recipe
327
+ - fetch the source data via a database (mars or grib-index)
328
+ - create Accumulator objects and fill them will accumulated values from source data
329
+ - return the datasource with accumulated values
330
+
331
+ Parameters:
332
+ ----------
333
+ context: Any,
334
+ The dataset building context (will be updated with trace of accumulation)
335
+ dates: list[datetime.datetime]
336
+ The list of valid dates on which to perform accumulations.
337
+ source: dict
338
+ The source configuration to request fields from
339
+ period: datetime.timedelta,
340
+ The interval over which to accumulate (user-defined)
341
+ availability: Any, optional
342
+ A description of the available periods in the data source. See documentation.
343
+ patch: list[dict] | None, optional
344
+ A description of patches to apply to fields returned by the source to fix metadata issues.
345
+
346
+ Return
347
+ ------
348
+ The accumulated datasource for all dates, parameters, members.
349
+
350
+ """
351
+
352
+ LOG.debug("💬 source for accumulations: %s", source)
353
+ field_to_interval = FieldToInterval(patch)
354
+
355
+ # building the source objects
356
+ assert isinstance(source, dict)
357
+ assert len(source) == 1, f"Source must have exactly one key, got {list(source.keys())}"
358
+ source_name, _ = next(iter(source.items()))
359
+ if source_name == "mars":
360
+ if "type" not in source[source_name]:
361
+ source[source_name]["type"] = "fc"
362
+ LOG.warning("Assuming 'type: fc' for mars source as it was not specified in the recipe")
363
+ if "levtype" not in source[source_name]:
364
+ source[source_name]["levtype"] = "sfc"
365
+ LOG.warning("Assuming 'levtype: sfc' for mars source as it was not specified in the recipe")
366
+
367
+ h = hashlib.md5(json.dumps((str(period), source), sort_keys=True).encode()).hexdigest()
368
+ source_object = context.create_source(source, "data_sources", h)
369
+
370
+ interval_generator = interval_generator_factory(availability, source_name, source[source_name])
371
+
372
+ # generate the interval coverage for every date
373
+ coverages = {}
374
+ for d in dates:
375
+ if not isinstance(d, datetime.datetime):
376
+ raise TypeError("valid_date must be a datetime.datetime instance")
377
+ coverages[d] = interval_generator.covering_intervals(d - period, d)
378
+ LOG.debug(f" Found covering intervals: for {d - period} to {d}:")
379
+ for c in coverages[d]:
380
+ LOG.debug(f" {c}")
381
+
382
+ intervals = IntervalsDatesProvider(dates, coverages)
383
+
384
+ # need a temporary file to store the accumulated fields for now, because earthkit-data
385
+ # does not completely support in-memory fieldlists yet (metadata consistency is not fully ensured)
386
+ tmp = temp_file()
387
+ path = tmp.path
388
+ output = new_grib_output(path)
389
+
390
+ accumulators = {}
391
+ logs = Logs(
392
+ accumulators=accumulators,
393
+ source=source,
394
+ source_object=source_object(context, intervals),
395
+ field_to_interval=field_to_interval,
396
+ )
397
+ for field in source_object(context, intervals):
398
+ # for each field provided by the catalogue, find which accumulators need it and perform accumulation
399
+
400
+ values = field.values.copy()
401
+
402
+ key = field.metadata(namespace="mars")
403
+ key = {k: v for k, v in key.items() if k not in ["date", "time", "step"]}
404
+ key = tuple(sorted(key.items()))
405
+ log = " ".join(f"{k}={v}" for k, v in field.metadata(namespace="mars").items())
406
+
407
+ field_interval = field_to_interval(field)
408
+
409
+ logs.append([str(field), log, field_interval, [], []])
410
+
411
+ if field_interval.end <= field_interval.start:
412
+ logs.raise_error("Invalid field interval with end <= start", field=field, field_interval=field_interval)
413
+
414
+ field_used = False
415
+ for date in dates:
416
+ # build accumulator if it does not exist yet
417
+ if (date, key) not in accumulators:
418
+ accumulators[(date, key)] = Accumulator(date, period=period, key=key, coverage=coverages[date])
419
+
420
+ # find the accumulator for this valid date and key
421
+ acc = accumulators[(date, key)]
422
+
423
+ # perform accumulation if needed
424
+ if acc.compute(values, field_interval):
425
+ # .compute() returned True, meaning the field was used for accumulation
426
+ field_used = True
427
+ logs[-1][3].append(date)
428
+ logs[-1][4].append(acc.__repr__(verbose=True))
429
+
430
+ if acc.is_complete():
431
+ # all intervals for accumulation have been processed, write the accumulated field to output
432
+ acc.write_to_output(output, template=field)
433
+
434
+ if not field_used:
435
+ logs.raise_error("Field not used for any accumulation", field=field, field_interval=field_interval)
436
+
437
+ # Final checks
438
+ def check_missing_accumulators():
439
+ for date in dates:
440
+ count = sum(1 for (d, k) in accumulators.keys() if d == date)
441
+ LOG.debug(f"Date {date} has {count} accumulators")
442
+ if count != len(accumulators) // len(dates):
443
+ LOG.error(f"All requested dates: {dates}")
444
+ LOG.error(f"Date {date} has {count} accumulators, expected {len(accumulators) // len(dates)}")
445
+ for d, k in accumulators.keys():
446
+ if d == date:
447
+ LOG.error(f" Accumulator for date {d}, key {k}")
448
+ raise ValueError(f"Date {date} has {count} accumulators, expected {len(accumulators) // len(dates)}")
449
+
450
+ check_missing_accumulators()
451
+
452
+ for acc in accumulators.values():
453
+ if not acc.is_complete():
454
+ raise ValueError(f"Accumulator not complete: {acc.__repr__(verbose=True)}")
455
+
456
+ LOG.info(f"Created {len(accumulators)} accumulated fields")
457
+
458
+ if not accumulators:
459
+ raise ValueError("No accumulators were created, cannot produce accumulated datasource")
460
+
461
+ output.close()
462
+ ds = earthkit.data.from_source("file", path)
463
+ ds._keep_file = tmp # prevent deletion of temp file until ds is deleted
464
+
465
+ LOG.debug(f"Created {len(ds)} accumulated fields:")
466
+ for i in ds:
467
+ LOG.debug(" %s", i)
468
+ return ds
469
+
470
+
471
+ @source_registry.register("accumulate")
472
+ class AccumulateSource(LegacySource):
473
+
474
+ @staticmethod
475
+ def _execute(
476
+ context: Any,
477
+ dates: list[datetime.datetime],
478
+ source: Any,
479
+ period: str | int | datetime.timedelta,
480
+ availability=None,
481
+ patch: Any = None,
482
+ ) -> Any:
483
+ """Accumulation source callable function.
484
+ Read the recipe for accumulation in the request dictionary, check main arguments and call computation.
485
+
486
+ Parameters:
487
+ ----------
488
+ context: Any,
489
+ The dataset building context (will be updated with trace of accumulation)
490
+ dates: list[datetime.datetime]
491
+ The list of valid dates on which to perform accumulations.
492
+ source: Any,
493
+ The accumulation source
494
+ period: str | int | datetime.timedelta,
495
+ The interval over which to accumulate (user-defined)
496
+ availability: Any, optional
497
+ A description of the available periods in the data source. See documentation.
498
+ patch: Any, optional
499
+ A description of patches to apply to fields returned by the source to fix metadata issues.
500
+
501
+ Return
502
+ ------
503
+ The accumulated data source.
504
+
505
+ """
506
+ if availability is None:
507
+ raise ValueError(
508
+ "Argument 'availability' must be specified for accumulate source. See https://anemoi.readthedocs.io/projects/datasets/en/latest/building/sources/accumulate.html"
509
+ )
510
+
511
+ if "accumulation_period" in source:
512
+ raise ValueError("'accumulation_period' should be define outside source for accumulate action as 'period'")
513
+
514
+ period = frequency_to_timedelta(period)
515
+ return _compute_accumulations(
516
+ context, dates, source=source, period=period, availability=availability, patch=patch
517
+ )
@@ -0,0 +1,8 @@
1
+ # (C) Copyright 2025 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.