anemoi-datasets 0.5.17__py3-none-any.whl → 0.5.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.5.17'
21
- __version_tuple__ = version_tuple = (0, 5, 17)
20
+ __version__ = version = '0.5.19'
21
+ __version_tuple__ = version_tuple = (0, 5, 19)
@@ -294,7 +294,14 @@ class Dataset:
294
294
  import zarr
295
295
 
296
296
  z = zarr.open(self.path, mode="r")
297
- return loader_config(z.attrs.get("_create_yaml_config"))
297
+ config = loader_config(z.attrs.get("_create_yaml_config"))
298
+
299
+ if "env" in config:
300
+ for k, v in config["env"].items():
301
+ LOG.info(f"Setting env variable {k}={v}")
302
+ os.environ[k] = str(v)
303
+
304
+ return config
298
305
 
299
306
 
300
307
  class WritableDataset(Dataset):
@@ -420,6 +420,11 @@ def loader_config(config: dict, is_test: bool = False) -> LoadersConfig:
420
420
  print(b)
421
421
  raise ValueError("Serialisation failed")
422
422
 
423
+ if "env" in copy:
424
+ for k, v in copy["env"].items():
425
+ LOG.info(f"Setting env variable {k}={v}")
426
+ os.environ[k] = str(v)
427
+
423
428
  return copy
424
429
 
425
430
 
@@ -33,15 +33,13 @@ class TransformFilter(Filter):
33
33
  from anemoi.transform.filters import create_filter
34
34
 
35
35
  self.name = name
36
- self.transform_filter = create_filter(self, config)
36
+ self.transform_filter = create_filter(context, config)
37
37
 
38
- def execute(self, context: Any, input: ekd.FieldList) -> ekd.FieldList:
38
+ def execute(self, input: ekd.FieldList) -> ekd.FieldList:
39
39
  """Execute the transformation filter.
40
40
 
41
41
  Parameters
42
42
  ----------
43
- context : Any
44
- The context in which the execution occurs.
45
43
  input : ekd.FieldList
46
44
  The input data to be transformed.
47
45
 
@@ -17,6 +17,7 @@ from earthkit.data.core.order import build_remapping
17
17
 
18
18
  from ...dates.groups import GroupOfDates
19
19
  from .context import Context
20
+ from .template import substitute
20
21
 
21
22
  LOG = logging.getLogger(__name__)
22
23
 
@@ -248,7 +249,7 @@ def action_factory(config: Dict[str, Any], context: ActionContext, action_path:
248
249
  if cls is None:
249
250
  from ..sources import create_source
250
251
 
251
- source = create_source(None, config)
252
+ source = create_source(None, substitute(context, config))
252
253
  return FunctionAction(context, action_path + [key], key, source)
253
254
 
254
255
  return cls(context, action_path + [key], *args, **kwargs)
@@ -20,7 +20,6 @@ from .misc import _tidy
20
20
  from .misc import assert_fieldlist
21
21
  from .result import Result
22
22
  from .template import notify_result
23
- from .template import resolve
24
23
  from .template import substitute
25
24
  from .trace import trace
26
25
  from .trace import trace_datasource
@@ -79,6 +78,9 @@ class FunctionContext:
79
78
  """Returns whether partial results are acceptable."""
80
79
  return self.owner.group_of_dates.partial_ok
81
80
 
81
+ def get_result(self, *args, **kwargs) -> Any:
82
+ return self.owner.context.get_result(*args, **kwargs)
83
+
82
84
 
83
85
  class FunctionAction(Action):
84
86
  """Represents an action that executes a function.
@@ -203,14 +205,12 @@ class FunctionResult(Result):
203
205
  @trace_datasource
204
206
  def datasource(self) -> FieldList:
205
207
  """Returns the datasource for the function result."""
206
- args, kwargs = resolve(self.context, (self.args, self.kwargs))
208
+ # args, kwargs = resolve(self.context, (self.args, self.kwargs))
207
209
  self.action.source.context = FunctionContext(self)
208
210
 
209
211
  return _tidy(
210
212
  self.action.source.execute(
211
- self.group_of_dates, # Will provide a list of datetime objects
212
- *args,
213
- **kwargs,
213
+ list(self.group_of_dates), # Will provide a list of datetime objects
214
214
  )
215
215
  )
216
216
 
@@ -636,6 +636,19 @@ def accumulations(
636
636
  Any
637
637
  The computed accumulations.
638
638
  """
639
+
640
+ if (
641
+ request.get("class") == "ea"
642
+ and request.get("stream", "oper") == "oper"
643
+ and request.get("accumulation_period") == 24
644
+ ):
645
+ from .accumulations2 import accumulations as accumulations2
646
+
647
+ LOG.warning(
648
+ "🧪️ Experimental features: Using accumulations2, because class=ea stream=oper and accumulation_period=24"
649
+ )
650
+ return accumulations2(context, dates, **request)
651
+
639
652
  _to_list(request["param"])
640
653
  class_ = request.get("class", "od")
641
654
  stream = request.get("stream", "oper")
@@ -0,0 +1,652 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+ import datetime
11
+ import logging
12
+ import warnings
13
+ from abc import abstractmethod
14
+ from copy import deepcopy
15
+ from typing import Any
16
+ from typing import Dict
17
+ from typing import List
18
+ from typing import Tuple
19
+ from typing import Union
20
+
21
+ import earthkit.data as ekd
22
+ import numpy as np
23
+ from earthkit.data.core.temporary import temp_file
24
+ from earthkit.data.readers.grib.output import new_grib_output
25
+
26
+ from anemoi.datasets.create.sources.mars import mars
27
+ from anemoi.datasets.create.utils import to_datetime_list
28
+
29
+ from .legacy import legacy_source
30
+
31
+ LOG = logging.getLogger(__name__)
32
+
33
+
34
+ xprint = print
35
+
36
+
37
+ def _member(field: Any) -> int:
38
+ """Retrieves the member number from the field metadata.
39
+
40
+ Parameters
41
+ ----------
42
+ field : Any
43
+ The field from which to retrieve the member number.
44
+
45
+ Returns
46
+ -------
47
+ int
48
+ The member number.
49
+ """
50
+ # Bug in eccodes has number=0 randomly
51
+ number = field.metadata("number", default=0)
52
+ if number is None:
53
+ number = 0
54
+ return number
55
+
56
+
57
+ class Period:
58
+ value = None
59
+
60
+ def __init__(self, start_datetime, end_datetime, base_datetime):
61
+ assert isinstance(start_datetime, datetime.datetime)
62
+ assert isinstance(end_datetime, datetime.datetime)
63
+ assert isinstance(base_datetime, datetime.datetime)
64
+
65
+ self.start_datetime = start_datetime
66
+ self.end_datetime = end_datetime
67
+
68
+ self.base_datetime = base_datetime
69
+
70
+ @property
71
+ def time_request(self):
72
+ date = int(self.base_datetime.strftime("%Y%m%d"))
73
+ time = int(self.base_datetime.strftime("%H%M"))
74
+
75
+ end_step = self.end_datetime - self.base_datetime
76
+ assert end_step.total_seconds() % 3600 == 0, end_step # only full hours supported
77
+ end_step = int(end_step.total_seconds() // 3600)
78
+
79
+ return (("date", date), ("time", time), ("step", end_step))
80
+
81
+ def field_to_key(self, field):
82
+ return (
83
+ ("date", field.metadata("date")),
84
+ ("time", field.metadata("time")),
85
+ ("step", field.metadata("step")),
86
+ )
87
+
88
+ def check(self, field):
89
+ stepType = field.metadata("stepType")
90
+ startStep = field.metadata("startStep")
91
+ endStep = field.metadata("endStep")
92
+ date = field.metadata("date")
93
+ time = field.metadata("time")
94
+
95
+ assert stepType == "accum", stepType
96
+
97
+ base_datetime = datetime.datetime.strptime(str(date) + str(time).zfill(4), "%Y%m%d%H%M")
98
+
99
+ start = base_datetime + datetime.timedelta(hours=startStep)
100
+ assert start == self.start_datetime, (start, self.start_datetime)
101
+
102
+ end = base_datetime + datetime.timedelta(hours=endStep)
103
+ assert end == self.end_datetime, (end, self.end_datetime)
104
+
105
+ def is_matching_field(self, field):
106
+ return self.field_to_key(field) == self.time_request
107
+
108
+ def __repr__(self):
109
+ return f"Period({self.start_datetime} to {self.end_datetime} -> {self.time_request})"
110
+
111
+ def length(self):
112
+ return self.end_datetime - self.start_datetime
113
+
114
+ def apply(self, accumulated, values):
115
+
116
+ if accumulated is None:
117
+ accumulated = np.zeros_like(values)
118
+
119
+ assert accumulated.shape == values.shape, (accumulated.shape, values.shape)
120
+
121
+ if not np.all(values >= 0):
122
+ warnings.warn(f"Negative values for {values}: {np.amin(values)} {np.amax(values)}")
123
+
124
+ return accumulated + self.sign * values
125
+
126
+
127
+ class TodoList:
128
+ def __init__(self, keys):
129
+ self._todo = set(keys)
130
+ self._len = len(keys)
131
+ self._done = set()
132
+ assert self._len == len(self._todo), (self._len, len(self._todo))
133
+
134
+ def is_todo(self, key):
135
+ return key in self._todo
136
+
137
+ def is_done(self, key):
138
+ return key in self._done
139
+
140
+ def set_done(self, key):
141
+ self._done.add(key)
142
+ self._todo.remove(key)
143
+
144
+ def all_done(self):
145
+ if not self._todo:
146
+ assert len(self._done) == self._len, (len(self._done), self._len)
147
+ return True
148
+ return False
149
+
150
+
151
+ class Periods:
152
+ _todo = None
153
+
154
+ def __init__(self, valid_date, accumulation_period, **kwargs):
155
+ # one Periods object for each accumulated field in the output
156
+
157
+ assert isinstance(valid_date, datetime.datetime), (valid_date, type(valid_date))
158
+ assert isinstance(accumulation_period, datetime.timedelta), (accumulation_period, type(accumulation_period))
159
+ self.valid_date = valid_date
160
+ self.accumulation_period = accumulation_period
161
+ self.kwargs = kwargs
162
+
163
+ self._periods = self.build_periods()
164
+ self.check_merged_interval()
165
+
166
+ def check_merged_interval(self):
167
+ global_start = self.valid_date - self.accumulation_period
168
+ global_end = self.valid_date
169
+ resolution = datetime.timedelta(hours=1)
170
+
171
+ timeline = np.arange(
172
+ np.datetime64(global_start, "s"), np.datetime64(global_end, "s"), np.timedelta64(resolution)
173
+ )
174
+
175
+ flags = np.zeros_like(timeline, dtype=int)
176
+ for p in self._periods:
177
+ segment = np.where((timeline >= p.start_datetime) & (timeline < p.end_datetime))
178
+ xprint(segment)
179
+ flags[segment] += p.sign
180
+ assert np.all(flags == 1), flags
181
+
182
+ def find_matching_period(self, field):
183
+ # Find a period that matches the field, or return None
184
+ found = [p for p in self._periods if p.is_matching_field(field)]
185
+ if len(found) == 1:
186
+ return found[0]
187
+ if len(found) > 1:
188
+ raise ValueError(f"Found more than one period for {field}")
189
+ return None
190
+
191
+ @property
192
+ def todo(self):
193
+ if self._todo is None:
194
+ self._todo = TodoList([p.time_request for p in self._periods])
195
+ return self._todo
196
+
197
+ def is_todo(self, period):
198
+ return self.todo.is_todo(period.time_request)
199
+
200
+ def is_done(self, period):
201
+ return self.todo.is_done(period.time_request)
202
+
203
+ def set_done(self, period):
204
+ self.todo.set_done(period.time_request)
205
+
206
+ def all_done(self):
207
+ return self.todo.all_done()
208
+
209
+ def __iter__(self):
210
+ return iter(self._periods)
211
+
212
+ @abstractmethod
213
+ def build_periods(self):
214
+ pass
215
+
216
+
217
+ class EraPeriods(Periods):
218
+ def search_periods(self, start, end, debug=False):
219
+ # find candidate periods that can be used to accumulate the data
220
+ # to get the accumulation between the two dates 'start' and 'end'
221
+ found = []
222
+ if not end - start == datetime.timedelta(hours=1):
223
+ raise NotImplementedError("Only 1 hour period is supported")
224
+
225
+ for base_time, steps in self.available_steps(start, end).items():
226
+ for step1, step2 in steps:
227
+ if debug:
228
+ xprint(f"❌ tring: {base_time=} {step1=} {step2=}")
229
+
230
+ if ((base_time + step1) % 24) != start.hour:
231
+ continue
232
+
233
+ if ((base_time + step2) % 24) != end.hour:
234
+ continue
235
+
236
+ base_datetime = start - datetime.timedelta(hours=step1)
237
+
238
+ period = Period(start, end, base_datetime)
239
+ found.append(period)
240
+
241
+ assert base_datetime.hour == base_time, (base_datetime, base_time)
242
+
243
+ assert period.start_datetime - period.base_datetime == datetime.timedelta(hours=step1), (
244
+ period.start_datetime,
245
+ period.base_datetime,
246
+ step1,
247
+ )
248
+ assert period.end_datetime - period.base_datetime == datetime.timedelta(hours=step2), (
249
+ period.end_datetime,
250
+ period.base_datetime,
251
+ step2,
252
+ )
253
+
254
+ return found
255
+
256
+ def build_periods(self):
257
+ # build the list of periods to accumulate the data
258
+
259
+ hours = self.accumulation_period.total_seconds() / 3600
260
+ assert int(hours) == hours, f"Only full hours accumulation is supported {hours}"
261
+ hours = int(hours)
262
+
263
+ lst = []
264
+ for wanted in [[i, i + 1] for i in range(0, hours, 1)]:
265
+
266
+ start = self.valid_date - datetime.timedelta(hours=wanted[1])
267
+ end = self.valid_date - datetime.timedelta(hours=wanted[0])
268
+
269
+ found = self.search_periods(start, end)
270
+ if not found:
271
+ xprint(f"❌❌❌ Cannot find accumulation for {start} {end}")
272
+ self.search_periods(start, end, debug=True)
273
+ raise ValueError(f"Cannot find accumulation for {start} {end}")
274
+
275
+ found = sorted(found, key=lambda x: x.base_datetime, reverse=True)
276
+ chosen = found[0]
277
+
278
+ if len(found) > 1:
279
+ xprint(f" Found more than one period for {start} {end}")
280
+ for f in found:
281
+ xprint(f" {f}")
282
+ xprint(f" Chosing {chosen}")
283
+
284
+ chosen.sign = 1
285
+
286
+ lst.append(chosen)
287
+ return lst
288
+
289
+
290
+ class EaOperPeriods(EraPeriods):
291
+ def available_steps(self, start, end):
292
+ return {
293
+ 6: [[i, i + 1] for i in range(0, 18, 1)],
294
+ 18: [[i, i + 1] for i in range(0, 18, 1)],
295
+ }
296
+
297
+
298
+ class L5OperPeriods(EraPeriods):
299
+ def available_steps(self, start, end):
300
+ print("❌❌❌ untested")
301
+ x = 24 # need to check if 24 is the right value
302
+ return {
303
+ 0: [[i, i + 1] for i in range(0, x, 1)],
304
+ }
305
+
306
+
307
+ class EaEndaPeriods(EraPeriods):
308
+ def available_steps(self, start, end):
309
+ print("❌❌❌ untested")
310
+ return {
311
+ 6: [[i, i + 3] for i in range(0, 18, 1)],
312
+ 18: [[i, i + 3] for i in range(0, 18, 1)],
313
+ }
314
+
315
+
316
+ class RrOperPeriods(Periods):
317
+ def available_steps(self, start, end):
318
+ raise NotImplementedError("need to implement diff")
319
+ x = 24 # todo: check if 24 is the right value
320
+ return {
321
+ 0: [[0, i] for i in range(0, x, 1)],
322
+ 3: [[0, i] for i in range(0, x, 1)],
323
+ 6: [[0, i] for i in range(0, x, 1)],
324
+ 9: [[0, i] for i in range(0, x, 1)],
325
+ 12: [[0, i] for i in range(0, x, 1)],
326
+ 15: [[0, i] for i in range(0, x, 1)],
327
+ 18: [[0, i] for i in range(0, x, 1)],
328
+ 21: [[0, i] for i in range(0, x, 1)],
329
+ }
330
+
331
+
332
+ class OdEldaPeriods(EraPeriods):
333
+ def available_steps(self, start, end):
334
+ print("❌❌❌ untested")
335
+ x = 24 # need to check if 24 is the right value
336
+ return {
337
+ 6: [[i, i + 1] for i in range(0, x, 1)],
338
+ 18: [[i, i + 1] for i in range(0, x, 1)],
339
+ }
340
+
341
+
342
+ class DiffPeriods(Periods):
343
+ pass
344
+
345
+
346
+ class OdOperPeriods(DiffPeriods):
347
+ def available_steps(self, start, end):
348
+ raise NotImplementedError("need to implement diff and _scda patch")
349
+
350
+
351
+ class OdEnfoPeriods(DiffPeriods):
352
+ def available_steps(self, start, end):
353
+ raise NotImplementedError("need to implement diff")
354
+
355
+
356
+ def find_accumulator_class(class_: str, stream: str) -> Periods:
357
+ return {
358
+ ("ea", "oper"): EaOperPeriods, # runs ok
359
+ ("ea", "enda"): EaEndaPeriods,
360
+ ("rr", "oper"): RrOperPeriods,
361
+ ("l5", "oper"): L5OperPeriods,
362
+ ("od", "oper"): OdOperPeriods,
363
+ ("od", "enfo"): OdEnfoPeriods,
364
+ ("od", "elda"): OdEldaPeriods,
365
+ }[class_, stream]
366
+
367
+
368
+ class Accumulator:
369
+ values = None
370
+
371
+ def __init__(self, period_class, out, valid_date, user_accumulation_period, **kwargs):
372
+ self.valid_date = valid_date
373
+
374
+ # keep the reference to the output file to be able to write the result using an input field as template
375
+ self.out = out
376
+
377
+ # key contains the mars request parameters except the one related to the time
378
+ # A mars request is a dictionary with three categories of keys:
379
+ # - the ones related to the time (date, time, step)
380
+ # - the ones related to the data (param, stream, levtype, expver, number, ...)
381
+ # - the ones related to the processing to be done (grid, area, ...)
382
+ self.kwargs = kwargs
383
+ for k in ["date", "time", "step"]:
384
+ if k in kwargs:
385
+ raise ValueError(f"Cannot use {k} in kwargs for accumulations")
386
+
387
+ self.key = {k: v for k, v in kwargs.items() if k in ["param", "level", "levelist", "number"]}
388
+
389
+ self.periods = period_class(self.valid_date, user_accumulation_period, **kwargs)
390
+
391
+ @property
392
+ def requests(self):
393
+ for period in self.periods:
394
+ # build the full data requests, merging the time requests with the key
395
+ yield {**self.kwargs.copy(), **dict(period.time_request)}
396
+
397
+ def is_field_needed(self, field):
398
+ for k, v in self.key.items():
399
+ if field.metadata(k) != v:
400
+ LOG.debug(f"{self} does not need field {field} because of {k}={field.metadata(k)} not {v}")
401
+ return False
402
+ return True
403
+
404
+ def compute(self, field, values):
405
+ if not self.is_field_needed(field):
406
+ return
407
+
408
+ period = self.periods.find_matching_period(field)
409
+ if not period:
410
+ return
411
+ assert self.periods.is_todo(period), (self.periods, period)
412
+ assert not self.periods.is_done(period), f"Field {field} for period {period} already done"
413
+
414
+ period.check(field)
415
+
416
+ xprint(f"{self} field ✅ ({period.sign}){field} for {period}")
417
+
418
+ self.values = period.apply(self.values, values)
419
+ self.periods.set_done(period)
420
+
421
+ if self.periods.all_done():
422
+ self.write(field)
423
+ xprint("accumulator", self, " : data written ✅ ")
424
+
425
+ def check(self, field: Any) -> None:
426
+ if self._check is None:
427
+ self._check = field.metadata(namespace="mars")
428
+
429
+ assert self.param == field.metadata("param"), (self.param, field.metadata("param"))
430
+ assert self.date == field.metadata("date"), (self.date, field.metadata("date"))
431
+ assert self.time == field.metadata("time"), (self.time, field.metadata("time"))
432
+ assert self.step == field.metadata("step"), (self.step, field.metadata("step"))
433
+ assert self.number == _member(field), (self.number, _member(field))
434
+ return
435
+
436
+ mars = field.metadata(namespace="mars")
437
+ keys1 = sorted(self._check.keys())
438
+ keys2 = sorted(mars.keys())
439
+
440
+ assert keys1 == keys2, (keys1, keys2)
441
+
442
+ for k in keys1:
443
+ if k not in ("step",):
444
+ assert self._check[k] == mars[k], (k, self._check[k], mars[k])
445
+
446
+ def write(self, template: Any) -> None:
447
+ assert self.periods.all_done(), self.periods
448
+
449
+ if np.all(self.values < 0):
450
+ LOG.warning(
451
+ f"Negative values when computing accumutation for {self}): min={np.amin(self.values)} max={np.amax(self.values)}"
452
+ )
453
+
454
+ startStep = 0
455
+ endStep = self.periods.accumulation_period.total_seconds() // 3600
456
+ assert int(endStep) == endStep, "only full hours accumulation is supported"
457
+ endStep = int(endStep)
458
+ fake_base_date = self.valid_date - self.periods.accumulation_period
459
+ date = int(fake_base_date.strftime("%Y%m%d"))
460
+ time = int(fake_base_date.strftime("%H%M"))
461
+
462
+ self.out.write(
463
+ self.values,
464
+ template=template,
465
+ stepType="accum",
466
+ startStep=startStep,
467
+ endStep=endStep,
468
+ date=date,
469
+ time=time,
470
+ check_nans=True,
471
+ )
472
+ self.values = None
473
+
474
+ def __repr__(self):
475
+ key = ", ".join(f"{k}={v}" for k, v in self.key.items())
476
+ return f"{self.__class__.__name__}({self.valid_date}, {key})"
477
+
478
+
479
+ def _compute_accumulations(
480
+ context: Any,
481
+ dates: List[datetime.datetime],
482
+ request: Dict[str, Any],
483
+ user_accumulation_period: datetime.timedelta,
484
+ # data_accumulation_period: Optional[int] = None,
485
+ # patch: Any = _identity,
486
+ ) -> Any:
487
+
488
+ request = deepcopy(request)
489
+
490
+ param = request.pop("param")
491
+ assert isinstance(param, (list, tuple))
492
+
493
+ number = request.pop("number", [0])
494
+ if not isinstance(number, (list, tuple)):
495
+ number = [number]
496
+ assert isinstance(number, (list, tuple))
497
+
498
+ request["stream"] = request.get("stream", "oper")
499
+
500
+ type_ = request.get("type", "an")
501
+ if type_ == "an":
502
+ type_ = "fc"
503
+ request["type"] = type_
504
+
505
+ request["levtype"] = request.get("levtype", "sfc")
506
+ if request["levtype"] != "sfc":
507
+ # LOG.warning("'type' should be 'sfc', found %s", request['type'])
508
+ raise NotImplementedError("Only sfc leveltype is supported")
509
+
510
+ period_class = find_accumulator_class(request["class"], request["stream"])
511
+
512
+ tmp = temp_file()
513
+ path = tmp.path
514
+ out = new_grib_output(path)
515
+
516
+ # build one accumulator per output field
517
+ accumulators = []
518
+ for valid_date in dates:
519
+ for p in param:
520
+ for n in number:
521
+ accumulators.append(
522
+ Accumulator(
523
+ period_class,
524
+ out,
525
+ valid_date,
526
+ user_accumulation_period=user_accumulation_period,
527
+ param=p,
528
+ number=n,
529
+ **request,
530
+ )
531
+ )
532
+
533
+ xprint("accumulators", len(accumulators))
534
+
535
+ # get all needed data requests (mars)
536
+ requests = []
537
+ for a in accumulators:
538
+ xprint("accumulator", a)
539
+ for r in a.requests:
540
+ xprint(" ", r)
541
+ requests.append(r)
542
+
543
+ # get the data (this will pack the requests to avoid duplicates and make a minimal number of requests)
544
+ ds = mars(context, dates, request_already_using_valid_datetime=True, *requests)
545
+
546
+ # send each field to the each accumulator, the accumulatore will use the field to the accumulation
547
+ # if the accumulator has requested it
548
+ for field in ds:
549
+ values = field.values # optimisation
550
+ for a in accumulators:
551
+ a.compute(field, values)
552
+
553
+ out.close()
554
+
555
+ ds = ekd.from_source("file", path)
556
+
557
+ assert len(ds) / len(param) / len(number) == len(dates), (
558
+ len(ds),
559
+ len(param),
560
+ len(dates),
561
+ )
562
+
563
+ # keep a reference to the tmp file, or it gets deleted when the function returns
564
+ ds._tmp = tmp
565
+
566
+ return ds
567
+
568
+
569
+ def _to_list(x: Union[List[Any], Tuple[Any], Any]) -> List[Any]:
570
+ """Converts the input to a list if it is not already a list or tuple.
571
+
572
+ Parameters
573
+ ----------
574
+ x : Union[List[Any], Tuple[Any], Any]
575
+ Input value.
576
+
577
+ Returns
578
+ -------
579
+ List[Any]
580
+ The input value as a list.
581
+ """
582
+ if isinstance(x, (list, tuple)):
583
+ return x
584
+ return [x]
585
+
586
+
587
+ def _scda(request: Dict[str, Any]) -> Dict[str, Any]:
588
+ """Modifies the request stream based on the time.
589
+
590
+ Parameters
591
+ ----------
592
+ request : Dict[str, Any]
593
+ Request parameters.
594
+
595
+ Returns
596
+ -------
597
+ Dict[str, Any]
598
+ The modified request parameters.
599
+ """
600
+ if request["time"] in (6, 18, 600, 1800):
601
+ request["stream"] = "scda"
602
+ else:
603
+ request["stream"] = "oper"
604
+ return request
605
+
606
+
607
+ @legacy_source(__file__)
608
+ def accumulations(context, dates, **request):
609
+ _to_list(request["param"])
610
+ user_accumulation_period = request.pop("accumulation_period", 6)
611
+ user_accumulation_period = datetime.timedelta(hours=user_accumulation_period)
612
+
613
+ context.trace("🌧️", f"accumulations {request} {user_accumulation_period}")
614
+
615
+ return _compute_accumulations(
616
+ context,
617
+ dates,
618
+ request,
619
+ user_accumulation_period=user_accumulation_period,
620
+ )
621
+
622
+
623
+ execute = accumulations
624
+
625
+ if __name__ == "__main__":
626
+ import yaml
627
+
628
+ config = yaml.safe_load(
629
+ """
630
+ class: ea
631
+ expver: '0001'
632
+ grid: 20./20.
633
+ levtype: sfc
634
+ # number: [0, 1]
635
+ # stream: enda
636
+ param: [cp, tp]
637
+ # accumulation_period: 6h
638
+ accumulation_period: 2
639
+ """
640
+ )
641
+ dates = yaml.safe_load("[2022-12-31 00:00, 2022-12-31 06:00]")
642
+ # dates = yaml.safe_load("[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]")
643
+ dates = to_datetime_list(dates)
644
+
645
+ class Context:
646
+ use_grib_paramid = True
647
+
648
+ def trace(self, *args):
649
+ print(*args)
650
+
651
+ for f in accumulations(Context, dates, **config):
652
+ print(f, f.to_numpy().mean())
@@ -14,6 +14,8 @@ import os
14
14
  from typing import Any
15
15
  from typing import Callable
16
16
 
17
+ from anemoi.datasets.create.input.template import resolve
18
+
17
19
  from ..source import Source
18
20
  from . import source_registry
19
21
 
@@ -71,12 +73,15 @@ class legacy_source:
71
73
 
72
74
  def execute_wrapper(self, dates) -> Any:
73
75
  """Wrapper method to call the execute function."""
76
+
77
+ args, kwargs = resolve(self.context, (self.args, self.kwargs))
78
+
74
79
  try:
75
- return execute(self.context, dates, *self.args, **self.kwargs)
80
+ return execute(self.context, dates, *args, **kwargs)
76
81
  except TypeError:
77
82
  LOG.error(f"Error executing source {this.name} from {source}")
78
83
  LOG.error(f"Function signature is: {inspect.signature(execute)}")
79
- LOG.error(f"Arguments are: {self.args=}, {self.kwargs=}")
84
+ LOG.error(f"Arguments are: {args=}, {kwargs=}")
80
85
  raise
81
86
 
82
87
  klass = type(
@@ -61,6 +61,7 @@ class LatLonGrid(Grid):
61
61
  super().__init__()
62
62
  self.lat = lat
63
63
  self.lon = lon
64
+ self.variable_dims = variable_dims
64
65
 
65
66
 
66
67
  class XYGrid(Grid):
@@ -86,10 +87,20 @@ class MeshedGrid(LatLonGrid):
86
87
  @cached_property
87
88
  def grid_points(self) -> Tuple[Any, Any]:
88
89
  """Get the grid points for the meshed grid."""
89
- lat, lon = np.meshgrid(
90
- self.lat.variable.values,
91
- self.lon.variable.values,
92
- )
90
+
91
+ if self.variable_dims == (self.lon.variable.name, self.lat.variable.name):
92
+ lat, lon = np.meshgrid(
93
+ self.lat.variable.values,
94
+ self.lon.variable.values,
95
+ )
96
+ elif self.variable_dims == (self.lat.variable.name, self.lon.variable.name):
97
+ lon, lat = np.meshgrid(
98
+ self.lon.variable.values,
99
+ self.lat.variable.values,
100
+ )
101
+
102
+ else:
103
+ raise NotImplementedError(f"MeshedGrid.grid_points: unrecognized variable_dims {self.variable_dims}")
93
104
 
94
105
  return lat.flatten(), lon.flatten()
95
106
 
@@ -310,7 +310,12 @@ class Dataset(ABC, Sized):
310
310
  """
311
311
  requested_frequency = frequency_to_seconds(frequency)
312
312
  dataset_frequency = frequency_to_seconds(self.frequency)
313
- assert requested_frequency % dataset_frequency == 0
313
+
314
+ if requested_frequency % dataset_frequency != 0:
315
+ raise ValueError(
316
+ f"Requested frequency {frequency} is not a multiple of the dataset frequency {self.frequency}. Did you mean to use `interpolate_frequency`?"
317
+ )
318
+
314
319
  # Question: where do we start? first date, or first date that is a multiple of the frequency?
315
320
  step = requested_frequency // dataset_frequency
316
321
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: anemoi-datasets
3
- Version: 0.5.17
3
+ Version: 0.5.19
4
4
  Summary: A package to hold various functions to support training of ML models on ECMWF data.
5
5
  Author-email: "European Centre for Medium-Range Weather Forecasts (ECMWF)" <software.support@ecmwf.int>
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  anemoi/datasets/__init__.py,sha256=i_wsAT3ezEYF7o5dpqGrpoG4wmLS-QIBug18uJbSYMs,1065
2
2
  anemoi/datasets/__main__.py,sha256=ErwAqE3rBc7OaNO2JRsEOhWpB8ldjAt7BFSuRhbnlqQ,936
3
- anemoi/datasets/_version.py,sha256=G3jgIvqAOb3RgYTFryPuF7LF2dSXviBKmqCnS1RQzaU,513
3
+ anemoi/datasets/_version.py,sha256=c9HpQ99YdGTjwGSu9DP-RLCHGpdS8jfmCirVxrPm0i4,513
4
4
  anemoi/datasets/grids.py,sha256=ALvRRMvu0GaDCnNlOO-cRCfbpywA-1w_wzSylqpqgNY,17795
5
5
  anemoi/datasets/testing.py,sha256=fy_JzavUwLlK_2rtXAT-UGUyo5gjyQW2y826zf334Wg,2645
6
6
  anemoi/datasets/commands/__init__.py,sha256=O5W3yHZywRoAqmRUioAr3zMCh0hGVV18wZYGvc00ioM,698
@@ -20,10 +20,10 @@ anemoi/datasets/commands/publish.py,sha256=7YusLCWYdVLuexZzvyh8ztYoBOBzVmve3uJs-
20
20
  anemoi/datasets/commands/scan.py,sha256=e5t_oxSi-II38TVQiMlWMJ8AZhDEBk5PcPD22DDbHfU,4008
21
21
  anemoi/datasets/compute/__init__.py,sha256=hCW0QcLHJmE-C1r38P27_ZOvCLNewex5iQEtZqx2ckI,393
22
22
  anemoi/datasets/compute/recentre.py,sha256=kwxDB8qpgOCFZSQJvjAmVcpH5zWsfk5FSoIureqNHd4,5915
23
- anemoi/datasets/create/__init__.py,sha256=1Y7IcMPXjCZkNbezhSDiurG-qvGFGgPqvgOMpQH_e0k,50501
23
+ anemoi/datasets/create/__init__.py,sha256=D0a5Q-xv5_mtBTzEO_6IaWHCQUbZyH0NjHwEtixMCXs,50699
24
24
  anemoi/datasets/create/check.py,sha256=FrgyZP3Xyx4qXHl8_ZfM31fgNhcxMqxlE5oLweMDGU0,10003
25
25
  anemoi/datasets/create/chunks.py,sha256=kZV3dWoCuv3Bttc0wysJB7OPbXsD99exKyrrj4HGFwQ,4025
26
- anemoi/datasets/create/config.py,sha256=ZF7tEPT6U4ILYVekryFd612tQeMDQK6riaTYtSJrUcM,13295
26
+ anemoi/datasets/create/config.py,sha256=xrSlaY2p5zssfLIt8A1CP9WwJReSXVWBMQM7bT1aFbU,13448
27
27
  anemoi/datasets/create/filter.py,sha256=Hu4o3Z2omIdcu5ycJqmBkY_ZSKTG5JkjbIuxXM8ADfs,1254
28
28
  anemoi/datasets/create/patch.py,sha256=u4CeIuo3Ncrbhu9CTyaUbcmaJfBfMrrFVpgEikM9pE4,5398
29
29
  anemoi/datasets/create/persistent.py,sha256=XkEBjymXrR-y9KPVLtz9xdd0IB14wSEhcANUhUUzGVw,7832
@@ -49,18 +49,18 @@ anemoi/datasets/create/filters/single_level_relative_humidity_to_specific_humidi
49
49
  anemoi/datasets/create/filters/single_level_specific_humidity_to_relative_humidity.py,sha256=bXgm5nKgBZaP1E4tcjSLqJsEl6BlJaNLr3MsR8V9sJ4,14682
50
50
  anemoi/datasets/create/filters/speeddir_to_uv.py,sha256=8NXsus1LaYOzAAr7XCHKCh8HAz8BI0A1ZZz_RNDB0-w,2762
51
51
  anemoi/datasets/create/filters/sum.py,sha256=aGT6JkdHJ3i2SKzklqiyJ4ZFV3bVMYhHOSoxkdYuzp8,2151
52
- anemoi/datasets/create/filters/transform.py,sha256=C8tizuYtO1Bp28dTB9mEHeADAO8zHlDFXh8XR1IO1Os,1506
52
+ anemoi/datasets/create/filters/transform.py,sha256=gIDLvaJlnn3Nc6P29aPOvNYM6yBWcIGrR2e_1bM6_Nw,1418
53
53
  anemoi/datasets/create/filters/unrotate_winds.py,sha256=3AJf0crnVVySLlXLIdfEUxRRlQeKgheUuD-UCrSrgo8,2798
54
54
  anemoi/datasets/create/filters/uv_to_speeddir.py,sha256=Zdc34AG5Bsz-Z7JGuznyRJr6F-BnWKXPiI3mjmOpbek,2883
55
55
  anemoi/datasets/create/filters/wz_to_w.py,sha256=42AQvTHk-ISyHlwvXfU3yiTGiDkfrs5kiKNkyqqtQpg,2725
56
56
  anemoi/datasets/create/input/__init__.py,sha256=XeURpmbReQvpELltGFKzg3oZFXWRdUxW9SK3K662SBQ,3364
57
- anemoi/datasets/create/input/action.py,sha256=0P1aSutrzdDDtUU78YMLfdsUEOeJcLvLiH2KDR5kOxM,7565
57
+ anemoi/datasets/create/input/action.py,sha256=xXLqVsoygxyaROiXc7TW9DCEOzVh1YgPDAqUpcOb9fs,7619
58
58
  anemoi/datasets/create/input/concat.py,sha256=bU8SWfBVfK8bRAmmN4UO9zpIGxwQvRUk9_vwrKPOTE4,5355
59
59
  anemoi/datasets/create/input/context.py,sha256=qrLccxMe9UkyQxsNuR6JSK7oLzZq21dt38AxZ9kYzsY,2714
60
60
  anemoi/datasets/create/input/data_sources.py,sha256=4xUUShM0pCXIZVPJW_cSNMUwCO_wLx996MLFpTLChm0,4385
61
61
  anemoi/datasets/create/input/empty.py,sha256=tOxe3LykoGkEAFuf4yggMpAcvFzMw3E6hCz5pyeQ8Q0,1534
62
62
  anemoi/datasets/create/input/filter.py,sha256=R19IUwTdWBueeTKAMxyYKiP-JXOFJQu2vUoEiPYK0rA,3313
63
- anemoi/datasets/create/input/function.py,sha256=FJ2W5DJBLmpkQ6QFo0-yfUE9iIZyimBn_cZ1b2nRu-Q,6874
63
+ anemoi/datasets/create/input/function.py,sha256=Q15IVNJqHm_9Pf0pWnDedyJcRoz0fxbKt8d1f2IMqQA,6916
64
64
  anemoi/datasets/create/input/join.py,sha256=RAdgE4lVcC71_J47dNa1weJuWdTXSQIvo06G2J6dfdg,4016
65
65
  anemoi/datasets/create/input/misc.py,sha256=FVaH_ym52RZI_fnLSMM_dKTQmWTrInucP780E3gGqvw,3357
66
66
  anemoi/datasets/create/input/pipe.py,sha256=-tCz161IwXoI8pl1hilA9T_j5eHSr-sgbijFLp9HHNc,2083
@@ -70,14 +70,15 @@ anemoi/datasets/create/input/step.py,sha256=WcR9NgRvUKF60Fo5veLvRCAQMrOd55x1gOEA
70
70
  anemoi/datasets/create/input/template.py,sha256=Iycw9VmfA0WEIDP_Of8bp-8HsV0EUfwbnm0WjxiO4GA,4092
71
71
  anemoi/datasets/create/input/trace.py,sha256=dakPYMmwKq6s17Scww1CN-xYBD3btJTGeDknOhAcnEM,3320
72
72
  anemoi/datasets/create/sources/__init__.py,sha256=XNiiGaC6NbxnGfl6glPw-gTJASi3vsGKwVlfkMqYGk4,950
73
- anemoi/datasets/create/sources/accumulations.py,sha256=Fh4LJi7XptsOZ9CBv4Nxw8CPJpp_-ugRAWg3mtNcmKU,19855
73
+ anemoi/datasets/create/sources/accumulations.py,sha256=ZA8F8RJPMHok5RpIHH4x-txwiSll8zuWwqJ3rn95JHk,20295
74
+ anemoi/datasets/create/sources/accumulations2.py,sha256=iBORRrH0N7r3gMWm3mCkJ6XmB-dO_lEckHPwvmk9fu0,20673
74
75
  anemoi/datasets/create/sources/constants.py,sha256=5O6d9tEuAmVjl5vNkNfmkaAjKXFlw1UjeueTsF1GZCI,1528
75
76
  anemoi/datasets/create/sources/eccc_fstd.py,sha256=8HK38f444HcWMvBhooP0XqTfMXYoCbN_8G9RI_Ne5rc,659
76
77
  anemoi/datasets/create/sources/empty.py,sha256=5mVIVRUwnBfE3zp-bvNA_imXCSpyds-4mewcI8HXAiY,1020
77
78
  anemoi/datasets/create/sources/forcings.py,sha256=877OZoXUoJncQ2_AAGSijwWqM-4kJJdxdIa6SFvZBUw,1216
78
79
  anemoi/datasets/create/sources/grib.py,sha256=zFBFWNFDVPCMSDRheNuaLZ7EaInjDt9OTJwVOPj9j-U,8371
79
80
  anemoi/datasets/create/sources/hindcasts.py,sha256=_4880rgd4AsRxlDXVi6dkh8mlKXrz2i27btVlmlMFjY,2611
80
- anemoi/datasets/create/sources/legacy.py,sha256=O6sTbI4QBlUiuGwaUwO2kpmfJYCAs_gTid0YOnkm37I,2536
81
+ anemoi/datasets/create/sources/legacy.py,sha256=RJce-9TwmUUCFbgC8A3Dp61nSBfB8_lWti8WNoOMIcU,2652
81
82
  anemoi/datasets/create/sources/mars.py,sha256=tesQz7Ne6SLBChE_cNJU6Sxr6e0LXFlUKQ8gCdRiCMw,13155
82
83
  anemoi/datasets/create/sources/netcdf.py,sha256=UnehMwEMJquqaOeU33zNyFUYfzqQx4Rg-GRmUcgMcbE,1222
83
84
  anemoi/datasets/create/sources/opendap.py,sha256=sTm0wXE_BHk9q8vaNNE_Y6BhTOmhxPweS8RTjP4HYjU,1254
@@ -95,7 +96,7 @@ anemoi/datasets/create/sources/xarray_support/coordinates.py,sha256=rPEuijS77mQ9
95
96
  anemoi/datasets/create/sources/xarray_support/field.py,sha256=YRxx6kh1qO2qQ6I_VyR51h3dwNiiFM7CNwQNfpp-p-E,6375
96
97
  anemoi/datasets/create/sources/xarray_support/fieldlist.py,sha256=CG8ecTXCr37pNiykoxR6Sb21Xxsz6AS5K5-KE4qMEmo,8243
97
98
  anemoi/datasets/create/sources/xarray_support/flavour.py,sha256=GYodfpKfTBBWiyXytRrin6NK07ltlyz0UF7x4gQ3Fok,31836
98
- anemoi/datasets/create/sources/xarray_support/grid.py,sha256=P-NPDYU0eZg_mWcEbeNL9ZhtoJHGNw0eWaC1jxYfK5o,5690
99
+ anemoi/datasets/create/sources/xarray_support/grid.py,sha256=lsE8bQwBH9pflzvsJ89Z6ExYPdHJd54xorMNzL2gTd0,6181
99
100
  anemoi/datasets/create/sources/xarray_support/metadata.py,sha256=WRO86l-ZB7iJ7pG5Vz9kVv5h1MokfF0fuy0bNSNBRIc,10687
100
101
  anemoi/datasets/create/sources/xarray_support/patch.py,sha256=Snk8bz7gp0HrG0MrY5hrXu7VC0tKgtoiWXByi2sBYJc,2037
101
102
  anemoi/datasets/create/sources/xarray_support/time.py,sha256=Y_lZTUOXWJH4jcSgyL4WTDwrtPXi7MUiumaXfRoqqAY,12486
@@ -105,7 +106,7 @@ anemoi/datasets/create/statistics/summary.py,sha256=JdtChTmsr1Y958_nka36HltTbeZk
105
106
  anemoi/datasets/data/__init__.py,sha256=dLzKYFX0eCi7urHA9t530SqZ_GYxTUyQeEcXYV8lZho,2521
106
107
  anemoi/datasets/data/complement.py,sha256=C55ZyWO8uM-bGbZkpuh80z95XtQjIr_NBnsxiKDWWtE,9643
107
108
  anemoi/datasets/data/concat.py,sha256=eY5rujcdal00BJCv00mKSlxp0FKVvPQd7uqrBnL9fj4,8996
108
- anemoi/datasets/data/dataset.py,sha256=Z1P1bkscPChGNcjjkxonbw9XylixJoM0UIUjqDDvxl8,30494
109
+ anemoi/datasets/data/dataset.py,sha256=Dz74L_RihBzHJyHqlCKcXHBa0J_PkW3YYFofhv-Rh-4,30694
109
110
  anemoi/datasets/data/debug.css,sha256=z2X_ZDSnZ9C3pyZPWnQiEyAxuMxUaxJxET4oaCImTAQ,211
110
111
  anemoi/datasets/data/debug.py,sha256=hVa1jAQ-TK7CoKJNyyUC0eZPobFG-FpkVXEaO_3B-MA,10796
111
112
  anemoi/datasets/data/ensemble.py,sha256=-36kMjuT2y5jUeSnjCRTCyE4um6DLAADBVSKSTkHZZg,5352
@@ -129,9 +130,9 @@ anemoi/datasets/data/xy.py,sha256=-jWzYismrK3eI3YCKIBpU1BCmraRncmVn0_2IUY--lk,75
129
130
  anemoi/datasets/dates/__init__.py,sha256=pEArHDQ7w5E0WC8Vvf9ypyKSdm6gnhoN9TmooITB7C4,13617
130
131
  anemoi/datasets/dates/groups.py,sha256=IOveL6IyTXZwEdXZEnRAnpu9pINY95VN7LzcpLfJ09E,10105
131
132
  anemoi/datasets/utils/__init__.py,sha256=hCW0QcLHJmE-C1r38P27_ZOvCLNewex5iQEtZqx2ckI,393
132
- anemoi_datasets-0.5.17.dist-info/licenses/LICENSE,sha256=8HznKF1Vi2IvfLsKNE5A2iVyiri3pRjRPvPC9kxs6qk,11354
133
- anemoi_datasets-0.5.17.dist-info/METADATA,sha256=jIFWwwr0VWMKPo-lTpC82vLpvagt5n2HD99Otf7CKW4,15727
134
- anemoi_datasets-0.5.17.dist-info/WHEEL,sha256=L0N565qmK-3nM2eBoMNFszYJ_MTx03_tQ0CQu1bHLYo,91
135
- anemoi_datasets-0.5.17.dist-info/entry_points.txt,sha256=yR-o-4uiPEA_GLBL81SkMYnUoxq3CAV3hHulQiRtGG0,66
136
- anemoi_datasets-0.5.17.dist-info/top_level.txt,sha256=DYn8VPs-fNwr7fNH9XIBqeXIwiYYd2E2k5-dUFFqUz0,7
137
- anemoi_datasets-0.5.17.dist-info/RECORD,,
133
+ anemoi_datasets-0.5.19.dist-info/licenses/LICENSE,sha256=8HznKF1Vi2IvfLsKNE5A2iVyiri3pRjRPvPC9kxs6qk,11354
134
+ anemoi_datasets-0.5.19.dist-info/METADATA,sha256=-iNWmeuYT_FuUpsKNhTMb79nXFJhh7sERHF2fW1XJGM,15727
135
+ anemoi_datasets-0.5.19.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
136
+ anemoi_datasets-0.5.19.dist-info/entry_points.txt,sha256=yR-o-4uiPEA_GLBL81SkMYnUoxq3CAV3hHulQiRtGG0,66
137
+ anemoi_datasets-0.5.19.dist-info/top_level.txt,sha256=DYn8VPs-fNwr7fNH9XIBqeXIwiYYd2E2k5-dUFFqUz0,7
138
+ anemoi_datasets-0.5.19.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.0.1)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5