anemoi-datasets 0.5.28__py3-none-any.whl → 0.5.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/create/__init__.py +4 -12
  3. anemoi/datasets/create/config.py +50 -53
  4. anemoi/datasets/create/input/result/field.py +1 -3
  5. anemoi/datasets/create/sources/accumulate.py +517 -0
  6. anemoi/datasets/create/sources/accumulate_utils/__init__.py +8 -0
  7. anemoi/datasets/create/sources/accumulate_utils/covering_intervals.py +221 -0
  8. anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py +149 -0
  9. anemoi/datasets/create/sources/accumulate_utils/interval_generators.py +321 -0
  10. anemoi/datasets/create/sources/grib_index.py +64 -20
  11. anemoi/datasets/create/sources/mars.py +56 -27
  12. anemoi/datasets/create/sources/xarray_support/__init__.py +1 -0
  13. anemoi/datasets/create/sources/xarray_support/coordinates.py +1 -4
  14. anemoi/datasets/create/sources/xarray_support/flavour.py +2 -2
  15. anemoi/datasets/create/sources/xarray_support/patch.py +178 -5
  16. anemoi/datasets/data/complement.py +26 -17
  17. anemoi/datasets/data/dataset.py +6 -0
  18. anemoi/datasets/data/masked.py +74 -13
  19. anemoi/datasets/data/missing.py +5 -0
  20. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/METADATA +7 -7
  21. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/RECORD +25 -23
  22. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/WHEEL +1 -1
  23. anemoi/datasets/create/sources/accumulations.py +0 -1042
  24. anemoi/datasets/create/sources/accumulations2.py +0 -618
  25. anemoi/datasets/create/sources/tendencies.py +0 -171
  26. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/entry_points.txt +0 -0
  27. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/licenses/LICENSE +0 -0
  28. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/top_level.txt +0 -0
@@ -7,9 +7,12 @@
7
7
  # granted to it by virtue of its status as an intergovernmental organisation
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
+ import hashlib
11
+ import json
10
12
  import logging
11
13
  import os
12
14
  import sqlite3
15
+ from collections import defaultdict
13
16
  from collections.abc import Iterator
14
17
  from typing import Any
15
18
 
@@ -46,8 +49,8 @@ class GribIndex:
46
49
  ----------
47
50
  database : str
48
51
  Path to the SQLite database file.
49
- keys : Optional[List[str] | str], optional
50
- List of keys or a string of keys to use for indexing, by default None.
52
+ keys : Optional[list[str] | str], optional
53
+ list of keys or a string of keys to use for indexing, by default None.
51
54
  flavour : Optional[str], optional
52
55
  Flavour configuration for mapping fields, by default None.
53
56
  update : bool, optional
@@ -161,7 +164,7 @@ class GribIndex:
161
164
 
162
165
  Returns
163
166
  -------
164
- List[str]
167
+ list[str]
165
168
  A list of metadata keys stored in the database.
166
169
  """
167
170
  self.cursor.execute("SELECT key FROM metadata_keys")
@@ -229,7 +232,7 @@ class GribIndex:
229
232
 
230
233
  Returns
231
234
  -------
232
- List[str]
235
+ list[str]
233
236
  A list of column names.
234
237
  """
235
238
  if self._columns is not None:
@@ -245,8 +248,8 @@ class GribIndex:
245
248
 
246
249
  Parameters
247
250
  ----------
248
- columns : List[str]
249
- List of column names to ensure in the table.
251
+ columns : list[str]
252
+ list of column names to ensure in the table.
250
253
  """
251
254
  assert self.update
252
255
 
@@ -364,7 +367,7 @@ class GribIndex:
364
367
 
365
368
  Returns
366
369
  -------
367
- List[dict]
370
+ list[dict]
368
371
  A list of GRIB2 parameter information.
369
372
  """
370
373
  if ("grib2", paramId) in self.cache:
@@ -524,8 +527,8 @@ class GribIndex:
524
527
 
525
528
  Parameters
526
529
  ----------
527
- dates : List[Any]
528
- List of dates to retrieve data for.
530
+ dates : list[Any]
531
+ list of dates to retrieve data for.
529
532
  **kwargs : Any
530
533
  Additional filtering criteria.
531
534
 
@@ -545,6 +548,9 @@ class GribIndex:
545
548
  params = dates
546
549
 
547
550
  for k, v in kwargs.items():
551
+ if k not in self._columns:
552
+ LOG.warning(f"Warning : {k} not in database columns, key discarded")
553
+ continue
548
554
  if isinstance(v, list):
549
555
  query += f" AND {k} IN ({', '.join('?' for _ in v)})"
550
556
  params.extend([str(_) for _ in v])
@@ -552,11 +558,14 @@ class GribIndex:
552
558
  query += f" AND {k} = ?"
553
559
  params.append(str(v))
554
560
 
555
- print("SELECT", query)
556
- print("SELECT", params)
561
+ print("SELECT (query)", query)
562
+ print("SELECT (params)", params)
557
563
 
558
564
  self.cursor.execute(query, params)
559
- for path_id, offset, length in self.cursor.fetchall():
565
+
566
+ fetch = self.cursor.fetchall()
567
+
568
+ for path_id, offset, length in fetch:
560
569
  if path_id in self.cache:
561
570
  file = self.cache[path_id]
562
571
  else:
@@ -570,9 +579,8 @@ class GribIndex:
570
579
  yield data
571
580
 
572
581
 
573
- @source_registry.register("grib_index")
582
+ @source_registry.register("grib-index")
574
583
  class GribIndexSource(LegacySource):
575
-
576
584
  @staticmethod
577
585
  def _execute(
578
586
  context: Any,
@@ -602,15 +610,51 @@ class GribIndexSource(LegacySource):
602
610
  An array of retrieved GRIB fields.
603
611
  """
604
612
  index = GribIndex(indexdb)
605
- result = []
606
613
 
607
614
  if flavour is not None:
608
615
  flavour = RuleBasedFlavour(flavour)
609
616
 
610
- for grib in index.retrieve(dates, **kwargs):
611
- field = ekd.from_source("memory", grib)[0]
612
- if flavour:
613
- field = flavour.apply(field)
614
- result.append(field)
617
+ if hasattr(dates, "date_to_intervals"):
618
+ # When using accumulate source
619
+ full_requests = []
620
+ for d, interval in dates.intervals:
621
+ context.trace("🌧️", "interval:", interval)
622
+ valid_date, request, _ = dates._adjust_request_to_interval(interval, kwargs)
623
+ context.trace("🌧️", " request =", request)
624
+ full_requests.append(([valid_date], request))
625
+ else:
626
+ # Normal case, without accumulate source
627
+ full_requests = [(dates, kwargs)]
628
+
629
+ full_requests = factorise(full_requests)
630
+ context.trace("🌧️", f"number of (factorised) requests: {len(full_requests)}")
631
+ for valid_dates, request in full_requests:
632
+ context.trace("🌧️", f" dates: {valid_dates}, request: {request}")
633
+
634
+ result = []
635
+ for valid_dates, request in full_requests:
636
+ for grib in index.retrieve(valid_dates, **request):
637
+ field = ekd.from_source("memory", grib)[0]
638
+ if flavour:
639
+ field = flavour.apply(field)
640
+ result.append(field)
615
641
 
616
642
  return FieldArray(result)
643
+
644
+
645
+ def factorise(lst):
646
+ """Factorise a list of (dates, request) tuples by merging dates with identical requests."""
647
+ content = dict()
648
+
649
+ d = defaultdict(list)
650
+ for dates, request in lst:
651
+ assert isinstance(request, dict), type(request)
652
+ key = hashlib.md5(json.dumps(request, sort_keys=True).encode()).hexdigest()
653
+ content[key] = request
654
+ d[key] += dates
655
+
656
+ res = []
657
+ for key, dates in d.items():
658
+ dates = list(sorted(set(dates)))
659
+ res.append((dates, content[key]))
660
+ return res
@@ -17,6 +17,7 @@ from earthkit.data import from_source
17
17
  from earthkit.data.utils.availability import Availability
18
18
 
19
19
  from anemoi.datasets.create.sources import source_registry
20
+ from anemoi.datasets.create.sources.accumulate import IntervalsDatesProvider
20
21
 
21
22
  from .legacy import LegacySource
22
23
 
@@ -145,7 +146,7 @@ def _expand_mars_request(
145
146
 
146
147
  Parameters
147
148
  ----------
148
- request : Dict[str, Any]
149
+ request : dict[str, Any]
149
150
  The input MARS request.
150
151
  date : datetime.datetime
151
152
  The date to be used in the request.
@@ -156,7 +157,7 @@ def _expand_mars_request(
156
157
 
157
158
  Returns
158
159
  -------
159
- List[Dict[str, Any]]
160
+ List[dict[str, Any]]
160
161
  A list of expanded MARS requests.
161
162
  """
162
163
  requests = []
@@ -164,23 +165,26 @@ def _expand_mars_request(
164
165
  user_step = to_list(expand_to_by(request.get("step", [0])))
165
166
  user_time = None
166
167
  user_date = None
167
-
168
168
  if not request_already_using_valid_datetime:
169
- user_time = request.get("time")
169
+ user_time = request.get("user_time")
170
170
  if user_time is not None:
171
171
  user_time = to_list(user_time)
172
172
  user_time = [_normalise_time(t) for t in user_time]
173
173
 
174
174
  user_date = request.get(date_key)
175
175
  if user_date is not None:
176
- assert isinstance(user_date, str), user_date
176
+ if isinstance(user_date, int):
177
+ user_date = str(user_date)
178
+ elif isinstance(user_date, datetime.datetime):
179
+ user_date = user_date.strftime("%Y%m%d")
180
+ else:
181
+ raise ValueError(f"Invalid type for {user_date}")
177
182
  user_date = re.compile("^{}$".format(user_date.replace("-", "").replace("?", ".")))
178
183
 
179
184
  for step in user_step:
180
185
  r = request.copy()
181
186
 
182
187
  if not request_already_using_valid_datetime:
183
-
184
188
  if isinstance(step, str) and "-" in step:
185
189
  assert step.count("-") == 1, step
186
190
 
@@ -190,30 +194,27 @@ def _expand_mars_request(
190
194
  base = date - datetime.timedelta(hours=hours)
191
195
  r.update(
192
196
  {
193
- date_key: base.strftime("%Y%m%d"),
197
+ "date": base.strftime("%Y%m%d"),
194
198
  "time": base.strftime("%H%M"),
195
199
  "step": step,
196
200
  }
197
201
  )
198
-
199
202
  for pproc in ("grid", "rotation", "frame", "area", "bitmap", "resol"):
200
203
  if pproc in r:
201
204
  if isinstance(r[pproc], (list, tuple)):
202
205
  r[pproc] = "/".join(str(x) for x in r[pproc])
203
206
 
204
207
  if user_date is not None:
205
- if not user_date.match(r[date_key]):
208
+ if not user_date.match(r["date"]):
206
209
  continue
207
210
 
208
211
  if user_time is not None:
209
- # It time is provided by the user, we only keep the requests that match the time
212
+ # If time is provided by the user, we only keep the requests that match the time
210
213
  if r["time"] not in user_time:
211
214
  continue
212
215
 
213
216
  requests.append(r)
214
217
 
215
- # assert requests, requests
216
-
217
218
  return requests
218
219
 
219
220
 
@@ -222,6 +223,7 @@ def factorise_requests(
222
223
  *requests: dict[str, Any],
223
224
  request_already_using_valid_datetime: bool = False,
224
225
  date_key: str = "date",
226
+ no_date_here: bool = False,
225
227
  ) -> Generator[dict[str, Any], None, None]:
226
228
  """Factorizes the requests based on the given dates.
227
229
 
@@ -229,33 +231,42 @@ def factorise_requests(
229
231
  ----------
230
232
  dates : List[datetime.datetime]
231
233
  The list of dates to be used in the requests.
232
- requests : Dict[str, Any]
234
+ requests : List[dict[str, Any]]
233
235
  The input requests to be factorized.
234
236
  request_already_using_valid_datetime : bool, optional
235
237
  Flag indicating if the requests already use valid datetime.
236
238
  date_key : str, optional
237
239
  The key for the date in the requests.
240
+ no_date_here : bool, optional
241
+ Flag indicating if there is no date in the "dates" list.
238
242
 
239
243
  Returns
240
244
  -------
241
- Generator[Dict[str, Any], None, None]
245
+ Generator[dict[str, Any], None, None]
242
246
  Factorized requests.
243
247
  """
244
- updates = []
245
- for req in requests:
246
- # req = normalise_request(req)
248
+ if isinstance(requests, tuple) and len(requests) == 1 and "requests" in requests[0]:
249
+ requests = requests[0]["requests"]
247
250
 
248
- for d in dates:
249
- updates += _expand_mars_request(
251
+ updates = []
252
+ for d in sorted(dates):
253
+ for req in requests:
254
+ if not no_date_here and (
255
+ ("date" in req)
256
+ and ("time" in req)
257
+ and d.strftime("%Y%m%d%H%M") != (str(req["date"]) + str(req["time"]).zfill(4))
258
+ ):
259
+ continue
260
+ new_req = _expand_mars_request(
250
261
  req,
251
262
  date=d,
252
263
  request_already_using_valid_datetime=request_already_using_valid_datetime,
253
- date_key=date_key,
264
+ date_key="user_date",
254
265
  )
266
+ updates += new_req
255
267
 
256
268
  if not updates:
257
269
  return
258
-
259
270
  compressed = Availability(updates)
260
271
  for r in compressed.iterate():
261
272
  for k, v in r.items():
@@ -269,12 +280,12 @@ def use_grib_paramid(r: dict[str, Any]) -> dict[str, Any]:
269
280
 
270
281
  Parameters
271
282
  ----------
272
- r : Dict[str, Any]
283
+ r : dict[str, Any]
273
284
  The input request containing parameter short names.
274
285
 
275
286
  Returns
276
287
  -------
277
- Dict[str, Any]
288
+ dict[str, Any]
278
289
  The request with parameter IDs.
279
290
  """
280
291
  from anemoi.utils.grib import shortname_to_paramid
@@ -379,7 +390,7 @@ class MarsSource(LegacySource):
379
390
  The context for the requests.
380
391
  dates : List[datetime.datetime]
381
392
  The list of dates to be used in the requests.
382
- requests : Dict[str, Any]
393
+ requests : dict[str, Any]
383
394
  The input requests to be executed.
384
395
  request_already_using_valid_datetime : bool, optional
385
396
  Flag indicating if the requests already use valid datetime.
@@ -395,7 +406,6 @@ class MarsSource(LegacySource):
395
406
  Any
396
407
  The resulting dataset.
397
408
  """
398
-
399
409
  if not requests:
400
410
  requests = [kwargs]
401
411
 
@@ -418,7 +428,26 @@ class MarsSource(LegacySource):
418
428
  "'param' cannot be 'True'. If you wrote 'param: on' in yaml, you may want to use quotes?"
419
429
  )
420
430
 
421
- if len(dates) == 0: # When using `repeated_dates`
431
+ if isinstance(dates, IntervalsDatesProvider):
432
+ # When using accumulate source
433
+ requests_ = []
434
+ for request in requests:
435
+ for d, interval in dates.intervals:
436
+ context.trace("🌧️", "interval:", interval)
437
+ _, r, _ = dates._adjust_request_to_interval(interval, request)
438
+ context.trace("🌧️", " adjusted request =", r)
439
+ requests_.append(r)
440
+ requests = requests_
441
+ context.trace("🌧️", f"Total requests: {len(requests)}")
442
+ requests = factorise_requests(
443
+ ["no_date_here"],
444
+ *requests,
445
+ request_already_using_valid_datetime=True,
446
+ date_key=date_key,
447
+ no_date_here=True,
448
+ )
449
+
450
+ elif len(dates) == 0: # When using `repeated_dates`
422
451
  assert len(requests) == 1, requests
423
452
  assert "date" in requests[0], requests[0]
424
453
  if isinstance(requests[0]["date"], datetime.date):
@@ -434,7 +463,7 @@ class MarsSource(LegacySource):
434
463
  requests = list(requests)
435
464
 
436
465
  ds = from_source("empty")
437
- context.trace("✅", f"{[str(d) for d in dates]}")
466
+ context.trace("✅", f"{[str(d) for d in dates]}, {len(dates)}")
438
467
  context.trace("✅", f"Will run {len(requests)} requests")
439
468
  for r in requests:
440
469
  r = {k: v for k, v in r.items() if v != ("-",)}
@@ -97,6 +97,7 @@ def load_one(
97
97
  if isinstance(dataset, xr.Dataset):
98
98
  data = dataset
99
99
  else:
100
+ print(f"Opening dataset {dataset} with options {options}")
100
101
  data = xr.open_dataset(dataset, **options)
101
102
 
102
103
  fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
@@ -223,13 +223,10 @@ class Coordinate:
223
223
  # Assume the array is sorted
224
224
 
225
225
  index = np.searchsorted(values, value)
226
- index = index[index < len(values)]
227
-
228
- if np.all(values[index] == value):
226
+ if np.all(index < len(values)) and np.all(values[index] == value):
229
227
  return index
230
228
 
231
229
  # If not found, we need to check if the value is in the array
232
-
233
230
  index = np.where(np.isin(values, value))[0]
234
231
 
235
232
  # We could also return incomplete matches
@@ -557,10 +557,10 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
557
557
  super().__init__(ds)
558
558
 
559
559
  def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> PointCoordinate | None:
560
- if attributes.standard_name in ["cell", "station", "poi", "point"]:
560
+ if attributes.standard_name in ["location", "cell", "id", "station", "poi", "point"]:
561
561
  return PointCoordinate(c)
562
562
 
563
- if attributes.name in ["cell", "station", "poi", "point"]: # WeatherBench
563
+ if attributes.name in ["location", "cell", "id", "station", "poi", "point"]: # WeatherBench
564
564
  return PointCoordinate(c)
565
565
 
566
566
  return None
@@ -10,13 +10,14 @@
10
10
 
11
11
  import logging
12
12
  from typing import Any
13
+ from typing import Literal
13
14
 
14
15
  import xarray as xr
15
16
 
16
17
  LOG = logging.getLogger(__name__)
17
18
 
18
19
 
19
- def patch_attributes(ds: xr.Dataset, attributes: dict[str, dict[str, Any]]) -> Any:
20
+ def patch_attributes(ds: xr.Dataset, attributes: dict[str, dict[str, Any]]) -> xr.Dataset:
20
21
  """Patch the attributes of the dataset.
21
22
 
22
23
  Parameters
@@ -38,7 +39,7 @@ def patch_attributes(ds: xr.Dataset, attributes: dict[str, dict[str, Any]]) -> A
38
39
  return ds
39
40
 
40
41
 
41
- def patch_coordinates(ds: xr.Dataset, coordinates: list[str]) -> Any:
42
+ def patch_coordinates(ds: xr.Dataset, coordinates: list[str]) -> xr.Dataset:
42
43
  """Patch the coordinates of the dataset.
43
44
 
44
45
  Parameters
@@ -59,7 +60,7 @@ def patch_coordinates(ds: xr.Dataset, coordinates: list[str]) -> Any:
59
60
  return ds
60
61
 
61
62
 
62
- def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any:
63
+ def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> xr.Dataset:
63
64
  """Rename variables in the dataset.
64
65
 
65
66
  Parameters
@@ -77,7 +78,7 @@ def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any:
77
78
  return ds.rename(renames)
78
79
 
79
80
 
80
- def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: list[str]) -> Any:
81
+ def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: list[str]) -> xr.Dataset:
81
82
  """Sort the coordinates of the dataset.
82
83
 
83
84
  Parameters
@@ -98,11 +99,175 @@ def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: list[str]) -> Any:
98
99
  return ds
99
100
 
100
101
 
102
+ def patch_subset_dataset(ds: xr.Dataset, selection: dict[str, Any]) -> xr.Dataset:
103
+ """Select a subset of the dataset using xarray's sel method.
104
+
105
+ Parameters
106
+ ----------
107
+ ds : xr.Dataset
108
+ The dataset to patch.
109
+ selection : dict[str, Any]
110
+ Dictionary mapping dimension names to selection criteria.
111
+ Keys must be existing dimension names in the dataset.
112
+ Values can be any type accepted by xarray's sel method, including:
113
+ - Single values (int, float, str, datetime)
114
+ - Lists or arrays of values
115
+ - Slices (using slice() objects)
116
+ - Boolean arrays
117
+
118
+ Returns
119
+ -------
120
+ xr.Dataset
121
+ The patched dataset containing only the selected subset.
122
+
123
+ Examples
124
+ --------
125
+ >>> # Select specific time and pressure level
126
+ >>> patch_subset_dataset(ds, {
127
+ ... 'time': '2020-01-01',
128
+ ... 'pressure': 500
129
+ ... })
130
+
131
+ >>> # Select a range using slice
132
+ >>> patch_subset_dataset(ds, {
133
+ ... 'lat': slice(-90, 90),
134
+ ... 'lon': slice(0, 180)
135
+ ... })
136
+ """
137
+
138
+ ds = ds.sel(selection)
139
+
140
+ return ds
141
+
142
+
143
+ def patch_analysis_lead_to_valid_time(
144
+ ds: xr.Dataset,
145
+ time_coord_names: dict[Literal["analysis_time_coordinate", "lead_time_coordinate", "valid_time_coordinate"], str],
146
+ ) -> xr.Dataset:
147
+ """Convert analysis time and lead time coordinates to valid time.
148
+
149
+ This function creates a new valid time coordinate by adding the analysis time
150
+ and lead time coordinates, then stacks and reorganizes the dataset to use
151
+ valid time as the primary time dimension.
152
+
153
+ Parameters
154
+ ----------
155
+ ds : xr.Dataset
156
+ The dataset to patch.
157
+ time_coord_names : dict[str, str]
158
+ Dictionary mapping required keys to coordinate names in the dataset:
159
+
160
+ - 'analysis_time_coordinate' : str
161
+ Name of the analysis/initialization time coordinate.
162
+ - 'lead_time_coordinate' : str
163
+ Name of the forecast lead time coordinate.
164
+ - 'valid_time_coordinate' : str
165
+ Name for the new valid time coordinate to create.
166
+
167
+ Returns
168
+ -------
169
+ xr.Dataset
170
+ The patched dataset with valid time as the primary time coordinate.
171
+ The analysis and lead time coordinates are removed.
172
+
173
+ Examples
174
+ --------
175
+ >>> patch_analysis_lead_to_valid_time(ds, {
176
+ ... 'analysis_time_coordinate': 'forecast_reference_time',
177
+ ... 'lead_time_coordinate': 'step',
178
+ ... 'valid_time_coordinate': 'time'
179
+ ... })
180
+ """
181
+
182
+ assert time_coord_names.keys() == {
183
+ "analysis_time_coordinate",
184
+ "lead_time_coordinate",
185
+ "valid_time_coordinate",
186
+ }, "time_coord_names must contain exactly keys 'analysis_time_coordinate', 'lead_time_coordinate', and 'valid_time_coordinate'"
187
+
188
+ analysis_time_coordinate = time_coord_names["analysis_time_coordinate"]
189
+ lead_time_coordinate = time_coord_names["lead_time_coordinate"]
190
+ valid_time_coordinate = time_coord_names["valid_time_coordinate"]
191
+
192
+ valid_time = ds[analysis_time_coordinate] + ds[lead_time_coordinate]
193
+
194
+ ds = (
195
+ ds.assign_coords({valid_time_coordinate: valid_time})
196
+ .stack(time_index=[analysis_time_coordinate, lead_time_coordinate])
197
+ .set_index(time_index=valid_time_coordinate)
198
+ .rename(time_index=valid_time_coordinate)
199
+ .drop_vars([analysis_time_coordinate, lead_time_coordinate])
200
+ )
201
+
202
+ return ds
203
+
204
+
205
+ def patch_rolling_operation(
206
+ ds: xr.Dataset, vars_operation_config: dict[Literal["dim", "steps", "vars", "operation"], str | int | list[str]]
207
+ ) -> xr.Dataset:
208
+ """Apply a rolling operation to specified variables in the dataset.
209
+
210
+ This function calculates a rolling operation over a specified dimension for selected
211
+ variables. The rolling window requires all periods to be present (min_periods=steps).
212
+
213
+ Parameters
214
+ ----------
215
+ ds : xr.Dataset
216
+ The dataset to patch.
217
+ vars_operation_config: dict
218
+ Configuration for the rolling operation with the following keys:
219
+
220
+ - 'dim' : str
221
+ The dimension along which to apply the rolling operation (e.g., 'time').
222
+ - 'steps' : int
223
+ The number of steps in the rolling window.
224
+ - 'vars' : list[str]
225
+ List of variable names to apply the rolling operation to.
226
+ - 'operation' : str
227
+ The operation to apply ('sum', 'mean', 'min', 'max', 'std', etc.).
228
+
229
+ Returns
230
+ -------
231
+ xr.Dataset
232
+ The patched dataset with rolling operations applied to the specified variables.
233
+
234
+ Examples
235
+ --------
236
+ >>> patch_rolling_operation(ds, {
237
+ ... 'dim': 'time',
238
+ ... 'steps': 3,
239
+ ... 'vars': ['precipitation', 'radiation'],
240
+ ... 'operation': 'sum'
241
+ ... })
242
+ """
243
+
244
+ assert vars_operation_config.keys() == {
245
+ "dim",
246
+ "steps",
247
+ "vars",
248
+ "operation",
249
+ }, "vars_operation_config must contain exactly keys 'dim', 'steps', 'vars', and 'operation'"
250
+
251
+ dim = vars_operation_config["dim"]
252
+ steps = vars_operation_config["steps"]
253
+ vars = vars_operation_config["vars"]
254
+ operation = vars_operation_config["operation"]
255
+
256
+ for var in vars:
257
+ rolling = ds[var].rolling(dim={dim: steps}, min_periods=steps)
258
+ ds[var] = getattr(rolling, operation)()
259
+
260
+ return ds
261
+
262
+
101
263
  PATCHES = {
102
264
  "attributes": patch_attributes,
103
265
  "coordinates": patch_coordinates,
104
266
  "rename": patch_rename,
105
267
  "sort_coordinates": patch_sort_coordinate,
268
+ "analysis_lead_to_valid_time": patch_analysis_lead_to_valid_time,
269
+ "rolling_operation": patch_rolling_operation,
270
+ "subset_dataset": patch_subset_dataset,
106
271
  }
107
272
 
108
273
 
@@ -122,7 +287,15 @@ def patch_dataset(ds: xr.Dataset, patch: dict[str, dict[str, Any]]) -> Any:
122
287
  The patched dataset.
123
288
  """
124
289
 
125
- ORDER = ["coordinates", "attributes", "rename", "sort_coordinates"]
290
+ ORDER = [
291
+ "coordinates",
292
+ "attributes",
293
+ "rename",
294
+ "sort_coordinates",
295
+ "subset_dataset",
296
+ "analysis_lead_to_valid_time",
297
+ "rolling_operation",
298
+ ]
126
299
  for what, values in sorted(patch.items(), key=lambda x: ORDER.index(x[0])):
127
300
  if what not in PATCHES:
128
301
  raise ValueError(f"Unknown patch type {what!r}")