anemoi-datasets 0.5.28__py3-none-any.whl → 0.5.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/create/__init__.py +4 -12
- anemoi/datasets/create/config.py +50 -53
- anemoi/datasets/create/input/result/field.py +1 -3
- anemoi/datasets/create/sources/accumulate.py +517 -0
- anemoi/datasets/create/sources/accumulate_utils/__init__.py +8 -0
- anemoi/datasets/create/sources/accumulate_utils/covering_intervals.py +221 -0
- anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py +149 -0
- anemoi/datasets/create/sources/accumulate_utils/interval_generators.py +321 -0
- anemoi/datasets/create/sources/grib_index.py +64 -20
- anemoi/datasets/create/sources/mars.py +56 -27
- anemoi/datasets/create/sources/xarray_support/__init__.py +1 -0
- anemoi/datasets/create/sources/xarray_support/coordinates.py +1 -4
- anemoi/datasets/create/sources/xarray_support/flavour.py +2 -2
- anemoi/datasets/create/sources/xarray_support/patch.py +178 -5
- anemoi/datasets/data/complement.py +26 -17
- anemoi/datasets/data/dataset.py +6 -0
- anemoi/datasets/data/masked.py +74 -13
- anemoi/datasets/data/missing.py +5 -0
- {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/METADATA +7 -7
- {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/RECORD +25 -23
- {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/WHEEL +1 -1
- anemoi/datasets/create/sources/accumulations.py +0 -1042
- anemoi/datasets/create/sources/accumulations2.py +0 -618
- anemoi/datasets/create/sources/tendencies.py +0 -171
- {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/licenses/LICENSE +0 -0
- {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.29.dist-info}/top_level.txt +0 -0
|
@@ -7,9 +7,12 @@
|
|
|
7
7
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
8
8
|
# nor does it submit to any jurisdiction.
|
|
9
9
|
|
|
10
|
+
import hashlib
|
|
11
|
+
import json
|
|
10
12
|
import logging
|
|
11
13
|
import os
|
|
12
14
|
import sqlite3
|
|
15
|
+
from collections import defaultdict
|
|
13
16
|
from collections.abc import Iterator
|
|
14
17
|
from typing import Any
|
|
15
18
|
|
|
@@ -46,8 +49,8 @@ class GribIndex:
|
|
|
46
49
|
----------
|
|
47
50
|
database : str
|
|
48
51
|
Path to the SQLite database file.
|
|
49
|
-
keys : Optional[
|
|
50
|
-
|
|
52
|
+
keys : Optional[list[str] | str], optional
|
|
53
|
+
list of keys or a string of keys to use for indexing, by default None.
|
|
51
54
|
flavour : Optional[str], optional
|
|
52
55
|
Flavour configuration for mapping fields, by default None.
|
|
53
56
|
update : bool, optional
|
|
@@ -161,7 +164,7 @@ class GribIndex:
|
|
|
161
164
|
|
|
162
165
|
Returns
|
|
163
166
|
-------
|
|
164
|
-
|
|
167
|
+
list[str]
|
|
165
168
|
A list of metadata keys stored in the database.
|
|
166
169
|
"""
|
|
167
170
|
self.cursor.execute("SELECT key FROM metadata_keys")
|
|
@@ -229,7 +232,7 @@ class GribIndex:
|
|
|
229
232
|
|
|
230
233
|
Returns
|
|
231
234
|
-------
|
|
232
|
-
|
|
235
|
+
list[str]
|
|
233
236
|
A list of column names.
|
|
234
237
|
"""
|
|
235
238
|
if self._columns is not None:
|
|
@@ -245,8 +248,8 @@ class GribIndex:
|
|
|
245
248
|
|
|
246
249
|
Parameters
|
|
247
250
|
----------
|
|
248
|
-
columns :
|
|
249
|
-
|
|
251
|
+
columns : list[str]
|
|
252
|
+
list of column names to ensure in the table.
|
|
250
253
|
"""
|
|
251
254
|
assert self.update
|
|
252
255
|
|
|
@@ -364,7 +367,7 @@ class GribIndex:
|
|
|
364
367
|
|
|
365
368
|
Returns
|
|
366
369
|
-------
|
|
367
|
-
|
|
370
|
+
list[dict]
|
|
368
371
|
A list of GRIB2 parameter information.
|
|
369
372
|
"""
|
|
370
373
|
if ("grib2", paramId) in self.cache:
|
|
@@ -524,8 +527,8 @@ class GribIndex:
|
|
|
524
527
|
|
|
525
528
|
Parameters
|
|
526
529
|
----------
|
|
527
|
-
dates :
|
|
528
|
-
|
|
530
|
+
dates : list[Any]
|
|
531
|
+
list of dates to retrieve data for.
|
|
529
532
|
**kwargs : Any
|
|
530
533
|
Additional filtering criteria.
|
|
531
534
|
|
|
@@ -545,6 +548,9 @@ class GribIndex:
|
|
|
545
548
|
params = dates
|
|
546
549
|
|
|
547
550
|
for k, v in kwargs.items():
|
|
551
|
+
if k not in self._columns:
|
|
552
|
+
LOG.warning(f"Warning : {k} not in database columns, key discarded")
|
|
553
|
+
continue
|
|
548
554
|
if isinstance(v, list):
|
|
549
555
|
query += f" AND {k} IN ({', '.join('?' for _ in v)})"
|
|
550
556
|
params.extend([str(_) for _ in v])
|
|
@@ -552,11 +558,14 @@ class GribIndex:
|
|
|
552
558
|
query += f" AND {k} = ?"
|
|
553
559
|
params.append(str(v))
|
|
554
560
|
|
|
555
|
-
print("SELECT", query)
|
|
556
|
-
print("SELECT", params)
|
|
561
|
+
print("SELECT (query)", query)
|
|
562
|
+
print("SELECT (params)", params)
|
|
557
563
|
|
|
558
564
|
self.cursor.execute(query, params)
|
|
559
|
-
|
|
565
|
+
|
|
566
|
+
fetch = self.cursor.fetchall()
|
|
567
|
+
|
|
568
|
+
for path_id, offset, length in fetch:
|
|
560
569
|
if path_id in self.cache:
|
|
561
570
|
file = self.cache[path_id]
|
|
562
571
|
else:
|
|
@@ -570,9 +579,8 @@ class GribIndex:
|
|
|
570
579
|
yield data
|
|
571
580
|
|
|
572
581
|
|
|
573
|
-
@source_registry.register("
|
|
582
|
+
@source_registry.register("grib-index")
|
|
574
583
|
class GribIndexSource(LegacySource):
|
|
575
|
-
|
|
576
584
|
@staticmethod
|
|
577
585
|
def _execute(
|
|
578
586
|
context: Any,
|
|
@@ -602,15 +610,51 @@ class GribIndexSource(LegacySource):
|
|
|
602
610
|
An array of retrieved GRIB fields.
|
|
603
611
|
"""
|
|
604
612
|
index = GribIndex(indexdb)
|
|
605
|
-
result = []
|
|
606
613
|
|
|
607
614
|
if flavour is not None:
|
|
608
615
|
flavour = RuleBasedFlavour(flavour)
|
|
609
616
|
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
617
|
+
if hasattr(dates, "date_to_intervals"):
|
|
618
|
+
# When using accumulate source
|
|
619
|
+
full_requests = []
|
|
620
|
+
for d, interval in dates.intervals:
|
|
621
|
+
context.trace("🌧️", "interval:", interval)
|
|
622
|
+
valid_date, request, _ = dates._adjust_request_to_interval(interval, kwargs)
|
|
623
|
+
context.trace("🌧️", " request =", request)
|
|
624
|
+
full_requests.append(([valid_date], request))
|
|
625
|
+
else:
|
|
626
|
+
# Normal case, without accumulate source
|
|
627
|
+
full_requests = [(dates, kwargs)]
|
|
628
|
+
|
|
629
|
+
full_requests = factorise(full_requests)
|
|
630
|
+
context.trace("🌧️", f"number of (factorised) requests: {len(full_requests)}")
|
|
631
|
+
for valid_dates, request in full_requests:
|
|
632
|
+
context.trace("🌧️", f" dates: {valid_dates}, request: {request}")
|
|
633
|
+
|
|
634
|
+
result = []
|
|
635
|
+
for valid_dates, request in full_requests:
|
|
636
|
+
for grib in index.retrieve(valid_dates, **request):
|
|
637
|
+
field = ekd.from_source("memory", grib)[0]
|
|
638
|
+
if flavour:
|
|
639
|
+
field = flavour.apply(field)
|
|
640
|
+
result.append(field)
|
|
615
641
|
|
|
616
642
|
return FieldArray(result)
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
def factorise(lst):
|
|
646
|
+
"""Factorise a list of (dates, request) tuples by merging dates with identical requests."""
|
|
647
|
+
content = dict()
|
|
648
|
+
|
|
649
|
+
d = defaultdict(list)
|
|
650
|
+
for dates, request in lst:
|
|
651
|
+
assert isinstance(request, dict), type(request)
|
|
652
|
+
key = hashlib.md5(json.dumps(request, sort_keys=True).encode()).hexdigest()
|
|
653
|
+
content[key] = request
|
|
654
|
+
d[key] += dates
|
|
655
|
+
|
|
656
|
+
res = []
|
|
657
|
+
for key, dates in d.items():
|
|
658
|
+
dates = list(sorted(set(dates)))
|
|
659
|
+
res.append((dates, content[key]))
|
|
660
|
+
return res
|
|
@@ -17,6 +17,7 @@ from earthkit.data import from_source
|
|
|
17
17
|
from earthkit.data.utils.availability import Availability
|
|
18
18
|
|
|
19
19
|
from anemoi.datasets.create.sources import source_registry
|
|
20
|
+
from anemoi.datasets.create.sources.accumulate import IntervalsDatesProvider
|
|
20
21
|
|
|
21
22
|
from .legacy import LegacySource
|
|
22
23
|
|
|
@@ -145,7 +146,7 @@ def _expand_mars_request(
|
|
|
145
146
|
|
|
146
147
|
Parameters
|
|
147
148
|
----------
|
|
148
|
-
request :
|
|
149
|
+
request : dict[str, Any]
|
|
149
150
|
The input MARS request.
|
|
150
151
|
date : datetime.datetime
|
|
151
152
|
The date to be used in the request.
|
|
@@ -156,7 +157,7 @@ def _expand_mars_request(
|
|
|
156
157
|
|
|
157
158
|
Returns
|
|
158
159
|
-------
|
|
159
|
-
List[
|
|
160
|
+
List[dict[str, Any]]
|
|
160
161
|
A list of expanded MARS requests.
|
|
161
162
|
"""
|
|
162
163
|
requests = []
|
|
@@ -164,23 +165,26 @@ def _expand_mars_request(
|
|
|
164
165
|
user_step = to_list(expand_to_by(request.get("step", [0])))
|
|
165
166
|
user_time = None
|
|
166
167
|
user_date = None
|
|
167
|
-
|
|
168
168
|
if not request_already_using_valid_datetime:
|
|
169
|
-
user_time = request.get("
|
|
169
|
+
user_time = request.get("user_time")
|
|
170
170
|
if user_time is not None:
|
|
171
171
|
user_time = to_list(user_time)
|
|
172
172
|
user_time = [_normalise_time(t) for t in user_time]
|
|
173
173
|
|
|
174
174
|
user_date = request.get(date_key)
|
|
175
175
|
if user_date is not None:
|
|
176
|
-
|
|
176
|
+
if isinstance(user_date, int):
|
|
177
|
+
user_date = str(user_date)
|
|
178
|
+
elif isinstance(user_date, datetime.datetime):
|
|
179
|
+
user_date = user_date.strftime("%Y%m%d")
|
|
180
|
+
else:
|
|
181
|
+
raise ValueError(f"Invalid type for {user_date}")
|
|
177
182
|
user_date = re.compile("^{}$".format(user_date.replace("-", "").replace("?", ".")))
|
|
178
183
|
|
|
179
184
|
for step in user_step:
|
|
180
185
|
r = request.copy()
|
|
181
186
|
|
|
182
187
|
if not request_already_using_valid_datetime:
|
|
183
|
-
|
|
184
188
|
if isinstance(step, str) and "-" in step:
|
|
185
189
|
assert step.count("-") == 1, step
|
|
186
190
|
|
|
@@ -190,30 +194,27 @@ def _expand_mars_request(
|
|
|
190
194
|
base = date - datetime.timedelta(hours=hours)
|
|
191
195
|
r.update(
|
|
192
196
|
{
|
|
193
|
-
|
|
197
|
+
"date": base.strftime("%Y%m%d"),
|
|
194
198
|
"time": base.strftime("%H%M"),
|
|
195
199
|
"step": step,
|
|
196
200
|
}
|
|
197
201
|
)
|
|
198
|
-
|
|
199
202
|
for pproc in ("grid", "rotation", "frame", "area", "bitmap", "resol"):
|
|
200
203
|
if pproc in r:
|
|
201
204
|
if isinstance(r[pproc], (list, tuple)):
|
|
202
205
|
r[pproc] = "/".join(str(x) for x in r[pproc])
|
|
203
206
|
|
|
204
207
|
if user_date is not None:
|
|
205
|
-
if not user_date.match(r[
|
|
208
|
+
if not user_date.match(r["date"]):
|
|
206
209
|
continue
|
|
207
210
|
|
|
208
211
|
if user_time is not None:
|
|
209
|
-
#
|
|
212
|
+
# If time is provided by the user, we only keep the requests that match the time
|
|
210
213
|
if r["time"] not in user_time:
|
|
211
214
|
continue
|
|
212
215
|
|
|
213
216
|
requests.append(r)
|
|
214
217
|
|
|
215
|
-
# assert requests, requests
|
|
216
|
-
|
|
217
218
|
return requests
|
|
218
219
|
|
|
219
220
|
|
|
@@ -222,6 +223,7 @@ def factorise_requests(
|
|
|
222
223
|
*requests: dict[str, Any],
|
|
223
224
|
request_already_using_valid_datetime: bool = False,
|
|
224
225
|
date_key: str = "date",
|
|
226
|
+
no_date_here: bool = False,
|
|
225
227
|
) -> Generator[dict[str, Any], None, None]:
|
|
226
228
|
"""Factorizes the requests based on the given dates.
|
|
227
229
|
|
|
@@ -229,33 +231,42 @@ def factorise_requests(
|
|
|
229
231
|
----------
|
|
230
232
|
dates : List[datetime.datetime]
|
|
231
233
|
The list of dates to be used in the requests.
|
|
232
|
-
requests :
|
|
234
|
+
requests : List[dict[str, Any]]
|
|
233
235
|
The input requests to be factorized.
|
|
234
236
|
request_already_using_valid_datetime : bool, optional
|
|
235
237
|
Flag indicating if the requests already use valid datetime.
|
|
236
238
|
date_key : str, optional
|
|
237
239
|
The key for the date in the requests.
|
|
240
|
+
no_date_here : bool, optional
|
|
241
|
+
Flag indicating if there is no date in the "dates" list.
|
|
238
242
|
|
|
239
243
|
Returns
|
|
240
244
|
-------
|
|
241
|
-
Generator[
|
|
245
|
+
Generator[dict[str, Any], None, None]
|
|
242
246
|
Factorized requests.
|
|
243
247
|
"""
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
# req = normalise_request(req)
|
|
248
|
+
if isinstance(requests, tuple) and len(requests) == 1 and "requests" in requests[0]:
|
|
249
|
+
requests = requests[0]["requests"]
|
|
247
250
|
|
|
248
|
-
|
|
249
|
-
|
|
251
|
+
updates = []
|
|
252
|
+
for d in sorted(dates):
|
|
253
|
+
for req in requests:
|
|
254
|
+
if not no_date_here and (
|
|
255
|
+
("date" in req)
|
|
256
|
+
and ("time" in req)
|
|
257
|
+
and d.strftime("%Y%m%d%H%M") != (str(req["date"]) + str(req["time"]).zfill(4))
|
|
258
|
+
):
|
|
259
|
+
continue
|
|
260
|
+
new_req = _expand_mars_request(
|
|
250
261
|
req,
|
|
251
262
|
date=d,
|
|
252
263
|
request_already_using_valid_datetime=request_already_using_valid_datetime,
|
|
253
|
-
date_key=
|
|
264
|
+
date_key="user_date",
|
|
254
265
|
)
|
|
266
|
+
updates += new_req
|
|
255
267
|
|
|
256
268
|
if not updates:
|
|
257
269
|
return
|
|
258
|
-
|
|
259
270
|
compressed = Availability(updates)
|
|
260
271
|
for r in compressed.iterate():
|
|
261
272
|
for k, v in r.items():
|
|
@@ -269,12 +280,12 @@ def use_grib_paramid(r: dict[str, Any]) -> dict[str, Any]:
|
|
|
269
280
|
|
|
270
281
|
Parameters
|
|
271
282
|
----------
|
|
272
|
-
r :
|
|
283
|
+
r : dict[str, Any]
|
|
273
284
|
The input request containing parameter short names.
|
|
274
285
|
|
|
275
286
|
Returns
|
|
276
287
|
-------
|
|
277
|
-
|
|
288
|
+
dict[str, Any]
|
|
278
289
|
The request with parameter IDs.
|
|
279
290
|
"""
|
|
280
291
|
from anemoi.utils.grib import shortname_to_paramid
|
|
@@ -379,7 +390,7 @@ class MarsSource(LegacySource):
|
|
|
379
390
|
The context for the requests.
|
|
380
391
|
dates : List[datetime.datetime]
|
|
381
392
|
The list of dates to be used in the requests.
|
|
382
|
-
requests :
|
|
393
|
+
requests : dict[str, Any]
|
|
383
394
|
The input requests to be executed.
|
|
384
395
|
request_already_using_valid_datetime : bool, optional
|
|
385
396
|
Flag indicating if the requests already use valid datetime.
|
|
@@ -395,7 +406,6 @@ class MarsSource(LegacySource):
|
|
|
395
406
|
Any
|
|
396
407
|
The resulting dataset.
|
|
397
408
|
"""
|
|
398
|
-
|
|
399
409
|
if not requests:
|
|
400
410
|
requests = [kwargs]
|
|
401
411
|
|
|
@@ -418,7 +428,26 @@ class MarsSource(LegacySource):
|
|
|
418
428
|
"'param' cannot be 'True'. If you wrote 'param: on' in yaml, you may want to use quotes?"
|
|
419
429
|
)
|
|
420
430
|
|
|
421
|
-
if
|
|
431
|
+
if isinstance(dates, IntervalsDatesProvider):
|
|
432
|
+
# When using accumulate source
|
|
433
|
+
requests_ = []
|
|
434
|
+
for request in requests:
|
|
435
|
+
for d, interval in dates.intervals:
|
|
436
|
+
context.trace("🌧️", "interval:", interval)
|
|
437
|
+
_, r, _ = dates._adjust_request_to_interval(interval, request)
|
|
438
|
+
context.trace("🌧️", " adjusted request =", r)
|
|
439
|
+
requests_.append(r)
|
|
440
|
+
requests = requests_
|
|
441
|
+
context.trace("🌧️", f"Total requests: {len(requests)}")
|
|
442
|
+
requests = factorise_requests(
|
|
443
|
+
["no_date_here"],
|
|
444
|
+
*requests,
|
|
445
|
+
request_already_using_valid_datetime=True,
|
|
446
|
+
date_key=date_key,
|
|
447
|
+
no_date_here=True,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
elif len(dates) == 0: # When using `repeated_dates`
|
|
422
451
|
assert len(requests) == 1, requests
|
|
423
452
|
assert "date" in requests[0], requests[0]
|
|
424
453
|
if isinstance(requests[0]["date"], datetime.date):
|
|
@@ -434,7 +463,7 @@ class MarsSource(LegacySource):
|
|
|
434
463
|
requests = list(requests)
|
|
435
464
|
|
|
436
465
|
ds = from_source("empty")
|
|
437
|
-
context.trace("✅", f"{[str(d) for d in dates]}")
|
|
466
|
+
context.trace("✅", f"{[str(d) for d in dates]}, {len(dates)}")
|
|
438
467
|
context.trace("✅", f"Will run {len(requests)} requests")
|
|
439
468
|
for r in requests:
|
|
440
469
|
r = {k: v for k, v in r.items() if v != ("-",)}
|
|
@@ -97,6 +97,7 @@ def load_one(
|
|
|
97
97
|
if isinstance(dataset, xr.Dataset):
|
|
98
98
|
data = dataset
|
|
99
99
|
else:
|
|
100
|
+
print(f"Opening dataset {dataset} with options {options}")
|
|
100
101
|
data = xr.open_dataset(dataset, **options)
|
|
101
102
|
|
|
102
103
|
fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
|
|
@@ -223,13 +223,10 @@ class Coordinate:
|
|
|
223
223
|
# Assume the array is sorted
|
|
224
224
|
|
|
225
225
|
index = np.searchsorted(values, value)
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
if np.all(values[index] == value):
|
|
226
|
+
if np.all(index < len(values)) and np.all(values[index] == value):
|
|
229
227
|
return index
|
|
230
228
|
|
|
231
229
|
# If not found, we need to check if the value is in the array
|
|
232
|
-
|
|
233
230
|
index = np.where(np.isin(values, value))[0]
|
|
234
231
|
|
|
235
232
|
# We could also return incomplete matches
|
|
@@ -557,10 +557,10 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
|
|
|
557
557
|
super().__init__(ds)
|
|
558
558
|
|
|
559
559
|
def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> PointCoordinate | None:
|
|
560
|
-
if attributes.standard_name in ["cell", "station", "poi", "point"]:
|
|
560
|
+
if attributes.standard_name in ["location", "cell", "id", "station", "poi", "point"]:
|
|
561
561
|
return PointCoordinate(c)
|
|
562
562
|
|
|
563
|
-
if attributes.name in ["cell", "station", "poi", "point"]: # WeatherBench
|
|
563
|
+
if attributes.name in ["location", "cell", "id", "station", "poi", "point"]: # WeatherBench
|
|
564
564
|
return PointCoordinate(c)
|
|
565
565
|
|
|
566
566
|
return None
|
|
@@ -10,13 +10,14 @@
|
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
12
|
from typing import Any
|
|
13
|
+
from typing import Literal
|
|
13
14
|
|
|
14
15
|
import xarray as xr
|
|
15
16
|
|
|
16
17
|
LOG = logging.getLogger(__name__)
|
|
17
18
|
|
|
18
19
|
|
|
19
|
-
def patch_attributes(ds: xr.Dataset, attributes: dict[str, dict[str, Any]]) ->
|
|
20
|
+
def patch_attributes(ds: xr.Dataset, attributes: dict[str, dict[str, Any]]) -> xr.Dataset:
|
|
20
21
|
"""Patch the attributes of the dataset.
|
|
21
22
|
|
|
22
23
|
Parameters
|
|
@@ -38,7 +39,7 @@ def patch_attributes(ds: xr.Dataset, attributes: dict[str, dict[str, Any]]) -> A
|
|
|
38
39
|
return ds
|
|
39
40
|
|
|
40
41
|
|
|
41
|
-
def patch_coordinates(ds: xr.Dataset, coordinates: list[str]) ->
|
|
42
|
+
def patch_coordinates(ds: xr.Dataset, coordinates: list[str]) -> xr.Dataset:
|
|
42
43
|
"""Patch the coordinates of the dataset.
|
|
43
44
|
|
|
44
45
|
Parameters
|
|
@@ -59,7 +60,7 @@ def patch_coordinates(ds: xr.Dataset, coordinates: list[str]) -> Any:
|
|
|
59
60
|
return ds
|
|
60
61
|
|
|
61
62
|
|
|
62
|
-
def patch_rename(ds: xr.Dataset, renames: dict[str, str]) ->
|
|
63
|
+
def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> xr.Dataset:
|
|
63
64
|
"""Rename variables in the dataset.
|
|
64
65
|
|
|
65
66
|
Parameters
|
|
@@ -77,7 +78,7 @@ def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any:
|
|
|
77
78
|
return ds.rename(renames)
|
|
78
79
|
|
|
79
80
|
|
|
80
|
-
def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: list[str]) ->
|
|
81
|
+
def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: list[str]) -> xr.Dataset:
|
|
81
82
|
"""Sort the coordinates of the dataset.
|
|
82
83
|
|
|
83
84
|
Parameters
|
|
@@ -98,11 +99,175 @@ def patch_sort_coordinate(ds: xr.Dataset, sort_coordinates: list[str]) -> Any:
|
|
|
98
99
|
return ds
|
|
99
100
|
|
|
100
101
|
|
|
102
|
+
def patch_subset_dataset(ds: xr.Dataset, selection: dict[str, Any]) -> xr.Dataset:
|
|
103
|
+
"""Select a subset of the dataset using xarray's sel method.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
ds : xr.Dataset
|
|
108
|
+
The dataset to patch.
|
|
109
|
+
selection : dict[str, Any]
|
|
110
|
+
Dictionary mapping dimension names to selection criteria.
|
|
111
|
+
Keys must be existing dimension names in the dataset.
|
|
112
|
+
Values can be any type accepted by xarray's sel method, including:
|
|
113
|
+
- Single values (int, float, str, datetime)
|
|
114
|
+
- Lists or arrays of values
|
|
115
|
+
- Slices (using slice() objects)
|
|
116
|
+
- Boolean arrays
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
xr.Dataset
|
|
121
|
+
The patched dataset containing only the selected subset.
|
|
122
|
+
|
|
123
|
+
Examples
|
|
124
|
+
--------
|
|
125
|
+
>>> # Select specific time and pressure level
|
|
126
|
+
>>> patch_subset_dataset(ds, {
|
|
127
|
+
... 'time': '2020-01-01',
|
|
128
|
+
... 'pressure': 500
|
|
129
|
+
... })
|
|
130
|
+
|
|
131
|
+
>>> # Select a range using slice
|
|
132
|
+
>>> patch_subset_dataset(ds, {
|
|
133
|
+
... 'lat': slice(-90, 90),
|
|
134
|
+
... 'lon': slice(0, 180)
|
|
135
|
+
... })
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
ds = ds.sel(selection)
|
|
139
|
+
|
|
140
|
+
return ds
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def patch_analysis_lead_to_valid_time(
|
|
144
|
+
ds: xr.Dataset,
|
|
145
|
+
time_coord_names: dict[Literal["analysis_time_coordinate", "lead_time_coordinate", "valid_time_coordinate"], str],
|
|
146
|
+
) -> xr.Dataset:
|
|
147
|
+
"""Convert analysis time and lead time coordinates to valid time.
|
|
148
|
+
|
|
149
|
+
This function creates a new valid time coordinate by adding the analysis time
|
|
150
|
+
and lead time coordinates, then stacks and reorganizes the dataset to use
|
|
151
|
+
valid time as the primary time dimension.
|
|
152
|
+
|
|
153
|
+
Parameters
|
|
154
|
+
----------
|
|
155
|
+
ds : xr.Dataset
|
|
156
|
+
The dataset to patch.
|
|
157
|
+
time_coord_names : dict[str, str]
|
|
158
|
+
Dictionary mapping required keys to coordinate names in the dataset:
|
|
159
|
+
|
|
160
|
+
- 'analysis_time_coordinate' : str
|
|
161
|
+
Name of the analysis/initialization time coordinate.
|
|
162
|
+
- 'lead_time_coordinate' : str
|
|
163
|
+
Name of the forecast lead time coordinate.
|
|
164
|
+
- 'valid_time_coordinate' : str
|
|
165
|
+
Name for the new valid time coordinate to create.
|
|
166
|
+
|
|
167
|
+
Returns
|
|
168
|
+
-------
|
|
169
|
+
xr.Dataset
|
|
170
|
+
The patched dataset with valid time as the primary time coordinate.
|
|
171
|
+
The analysis and lead time coordinates are removed.
|
|
172
|
+
|
|
173
|
+
Examples
|
|
174
|
+
--------
|
|
175
|
+
>>> patch_analysis_lead_to_valid_time(ds, {
|
|
176
|
+
... 'analysis_time_coordinate': 'forecast_reference_time',
|
|
177
|
+
... 'lead_time_coordinate': 'step',
|
|
178
|
+
... 'valid_time_coordinate': 'time'
|
|
179
|
+
... })
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
assert time_coord_names.keys() == {
|
|
183
|
+
"analysis_time_coordinate",
|
|
184
|
+
"lead_time_coordinate",
|
|
185
|
+
"valid_time_coordinate",
|
|
186
|
+
}, "time_coord_names must contain exactly keys 'analysis_time_coordinate', 'lead_time_coordinate', and 'valid_time_coordinate'"
|
|
187
|
+
|
|
188
|
+
analysis_time_coordinate = time_coord_names["analysis_time_coordinate"]
|
|
189
|
+
lead_time_coordinate = time_coord_names["lead_time_coordinate"]
|
|
190
|
+
valid_time_coordinate = time_coord_names["valid_time_coordinate"]
|
|
191
|
+
|
|
192
|
+
valid_time = ds[analysis_time_coordinate] + ds[lead_time_coordinate]
|
|
193
|
+
|
|
194
|
+
ds = (
|
|
195
|
+
ds.assign_coords({valid_time_coordinate: valid_time})
|
|
196
|
+
.stack(time_index=[analysis_time_coordinate, lead_time_coordinate])
|
|
197
|
+
.set_index(time_index=valid_time_coordinate)
|
|
198
|
+
.rename(time_index=valid_time_coordinate)
|
|
199
|
+
.drop_vars([analysis_time_coordinate, lead_time_coordinate])
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
return ds
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def patch_rolling_operation(
|
|
206
|
+
ds: xr.Dataset, vars_operation_config: dict[Literal["dim", "steps", "vars", "operation"], str | int | list[str]]
|
|
207
|
+
) -> xr.Dataset:
|
|
208
|
+
"""Apply a rolling operation to specified variables in the dataset.
|
|
209
|
+
|
|
210
|
+
This function calculates a rolling operation over a specified dimension for selected
|
|
211
|
+
variables. The rolling window requires all periods to be present (min_periods=steps).
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
ds : xr.Dataset
|
|
216
|
+
The dataset to patch.
|
|
217
|
+
vars_operation_config: dict
|
|
218
|
+
Configuration for the rolling operation with the following keys:
|
|
219
|
+
|
|
220
|
+
- 'dim' : str
|
|
221
|
+
The dimension along which to apply the rolling operation (e.g., 'time').
|
|
222
|
+
- 'steps' : int
|
|
223
|
+
The number of steps in the rolling window.
|
|
224
|
+
- 'vars' : list[str]
|
|
225
|
+
List of variable names to apply the rolling operation to.
|
|
226
|
+
- 'operation' : str
|
|
227
|
+
The operation to apply ('sum', 'mean', 'min', 'max', 'std', etc.).
|
|
228
|
+
|
|
229
|
+
Returns
|
|
230
|
+
-------
|
|
231
|
+
xr.Dataset
|
|
232
|
+
The patched dataset with rolling operations applied to the specified variables.
|
|
233
|
+
|
|
234
|
+
Examples
|
|
235
|
+
--------
|
|
236
|
+
>>> patch_rolling_operation(ds, {
|
|
237
|
+
... 'dim': 'time',
|
|
238
|
+
... 'steps': 3,
|
|
239
|
+
... 'vars': ['precipitation', 'radiation'],
|
|
240
|
+
... 'operation': 'sum'
|
|
241
|
+
... })
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
assert vars_operation_config.keys() == {
|
|
245
|
+
"dim",
|
|
246
|
+
"steps",
|
|
247
|
+
"vars",
|
|
248
|
+
"operation",
|
|
249
|
+
}, "vars_operation_config must contain exactly keys 'dim', 'steps', 'vars', and 'operation'"
|
|
250
|
+
|
|
251
|
+
dim = vars_operation_config["dim"]
|
|
252
|
+
steps = vars_operation_config["steps"]
|
|
253
|
+
vars = vars_operation_config["vars"]
|
|
254
|
+
operation = vars_operation_config["operation"]
|
|
255
|
+
|
|
256
|
+
for var in vars:
|
|
257
|
+
rolling = ds[var].rolling(dim={dim: steps}, min_periods=steps)
|
|
258
|
+
ds[var] = getattr(rolling, operation)()
|
|
259
|
+
|
|
260
|
+
return ds
|
|
261
|
+
|
|
262
|
+
|
|
101
263
|
PATCHES = {
|
|
102
264
|
"attributes": patch_attributes,
|
|
103
265
|
"coordinates": patch_coordinates,
|
|
104
266
|
"rename": patch_rename,
|
|
105
267
|
"sort_coordinates": patch_sort_coordinate,
|
|
268
|
+
"analysis_lead_to_valid_time": patch_analysis_lead_to_valid_time,
|
|
269
|
+
"rolling_operation": patch_rolling_operation,
|
|
270
|
+
"subset_dataset": patch_subset_dataset,
|
|
106
271
|
}
|
|
107
272
|
|
|
108
273
|
|
|
@@ -122,7 +287,15 @@ def patch_dataset(ds: xr.Dataset, patch: dict[str, dict[str, Any]]) -> Any:
|
|
|
122
287
|
The patched dataset.
|
|
123
288
|
"""
|
|
124
289
|
|
|
125
|
-
ORDER = [
|
|
290
|
+
ORDER = [
|
|
291
|
+
"coordinates",
|
|
292
|
+
"attributes",
|
|
293
|
+
"rename",
|
|
294
|
+
"sort_coordinates",
|
|
295
|
+
"subset_dataset",
|
|
296
|
+
"analysis_lead_to_valid_time",
|
|
297
|
+
"rolling_operation",
|
|
298
|
+
]
|
|
126
299
|
for what, values in sorted(patch.items(), key=lambda x: ORDER.index(x[0])):
|
|
127
300
|
if what not in PATCHES:
|
|
128
301
|
raise ValueError(f"Unknown patch type {what!r}")
|