anemoi-datasets 0.5.28__py3-none-any.whl → 0.5.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/create/__init__.py +4 -12
  3. anemoi/datasets/create/config.py +50 -53
  4. anemoi/datasets/create/input/result/field.py +1 -3
  5. anemoi/datasets/create/sources/accumulate.py +517 -0
  6. anemoi/datasets/create/sources/accumulate_utils/__init__.py +8 -0
  7. anemoi/datasets/create/sources/accumulate_utils/covering_intervals.py +221 -0
  8. anemoi/datasets/create/sources/accumulate_utils/field_to_interval.py +153 -0
  9. anemoi/datasets/create/sources/accumulate_utils/interval_generators.py +321 -0
  10. anemoi/datasets/create/sources/grib_index.py +79 -51
  11. anemoi/datasets/create/sources/mars.py +56 -27
  12. anemoi/datasets/create/sources/xarray_support/__init__.py +1 -0
  13. anemoi/datasets/create/sources/xarray_support/coordinates.py +1 -4
  14. anemoi/datasets/create/sources/xarray_support/flavour.py +2 -2
  15. anemoi/datasets/create/sources/xarray_support/patch.py +178 -5
  16. anemoi/datasets/data/complement.py +26 -17
  17. anemoi/datasets/data/dataset.py +6 -0
  18. anemoi/datasets/data/masked.py +74 -13
  19. anemoi/datasets/data/missing.py +5 -0
  20. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/METADATA +8 -7
  21. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/RECORD +25 -23
  22. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/WHEEL +1 -1
  23. anemoi/datasets/create/sources/accumulations.py +0 -1042
  24. anemoi/datasets/create/sources/accumulations2.py +0 -618
  25. anemoi/datasets/create/sources/tendencies.py +0 -171
  26. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/entry_points.txt +0 -0
  27. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/licenses/LICENSE +0 -0
  28. {anemoi_datasets-0.5.28.dist-info → anemoi_datasets-0.5.30.dist-info}/top_level.txt +0 -0
@@ -7,9 +7,12 @@
7
7
  # granted to it by virtue of its status as an intergovernmental organisation
8
8
  # nor does it submit to any jurisdiction.
9
9
 
10
+ import hashlib
11
+ import json
10
12
  import logging
11
13
  import os
12
14
  import sqlite3
15
+ from collections import defaultdict
13
16
  from collections.abc import Iterator
14
17
  from typing import Any
15
18
 
@@ -46,8 +49,8 @@ class GribIndex:
46
49
  ----------
47
50
  database : str
48
51
  Path to the SQLite database file.
49
- keys : Optional[List[str] | str], optional
50
- List of keys or a string of keys to use for indexing, by default None.
52
+ keys : Optional[list[str] | str], optional
53
+ list of keys or a string of keys to use for indexing, by default None.
51
54
  flavour : Optional[str], optional
52
55
  Flavour configuration for mapping fields, by default None.
53
56
  update : bool, optional
@@ -103,21 +106,18 @@ class GribIndex:
103
106
  """Create the necessary tables in the database."""
104
107
  assert self.update
105
108
 
106
- self.cursor.execute(
107
- """
109
+ self.cursor.execute("""
108
110
  CREATE TABLE IF NOT EXISTS paths (
109
111
  id INTEGER PRIMARY KEY,
110
112
  path TEXT not null
111
113
  )
112
- """
113
- )
114
+ """)
114
115
 
115
116
  columns = ("valid_datetime",)
116
117
  # We don't use NULL as a default because NULL is considered a different value
117
118
  # in UNIQUE INDEX constraints (https://www.sqlite.org/lang_createindex.html)
118
119
 
119
- self.cursor.execute(
120
- f"""
120
+ self.cursor.execute(f"""
121
121
  CREATE TABLE IF NOT EXISTS grib_index (
122
122
  _id INTEGER PRIMARY KEY,
123
123
  _path_id INTEGER not null,
@@ -125,30 +125,23 @@ class GribIndex:
125
125
  _length INTEGER not null,
126
126
  {', '.join(f"{key} TEXT not null default ''" for key in columns)},
127
127
  FOREIGN KEY(_path_id) REFERENCES paths(id))
128
- """
129
- ) # ,
128
+ """) # ,
130
129
 
131
- self.cursor.execute(
132
- """
130
+ self.cursor.execute("""
133
131
  CREATE UNIQUE INDEX IF NOT EXISTS idx_grib_index_path_offset
134
132
  ON grib_index (_path_id, _offset)
135
- """
136
- )
133
+ """)
137
134
 
138
- self.cursor.execute(
139
- f"""
135
+ self.cursor.execute(f"""
140
136
  CREATE UNIQUE INDEX IF NOT EXISTS idx_grib_index_all_keys
141
137
  ON grib_index ({', '.join(columns)})
142
- """
143
- )
138
+ """)
144
139
 
145
140
  for key in columns:
146
- self.cursor.execute(
147
- f"""
141
+ self.cursor.execute(f"""
148
142
  CREATE INDEX IF NOT EXISTS idx_grib_index_{key}
149
143
  ON grib_index ({key})
150
- """
151
- )
144
+ """)
152
145
 
153
146
  self._commit()
154
147
 
@@ -161,7 +154,7 @@ class GribIndex:
161
154
 
162
155
  Returns
163
156
  -------
164
- List[str]
157
+ list[str]
165
158
  A list of metadata keys stored in the database.
166
159
  """
167
160
  self.cursor.execute("SELECT key FROM metadata_keys")
@@ -229,7 +222,7 @@ class GribIndex:
229
222
 
230
223
  Returns
231
224
  -------
232
- List[str]
225
+ list[str]
233
226
  A list of column names.
234
227
  """
235
228
  if self._columns is not None:
@@ -245,8 +238,8 @@ class GribIndex:
245
238
 
246
239
  Parameters
247
240
  ----------
248
- columns : List[str]
249
- List of column names to ensure in the table.
241
+ columns : list[str]
242
+ list of column names to ensure in the table.
250
243
  """
251
244
  assert self.update
252
245
 
@@ -264,20 +257,16 @@ class GribIndex:
264
257
  self.cursor.execute("""DROP INDEX IF EXISTS idx_grib_index_all_keys""")
265
258
  all_columns = self._all_columns()
266
259
 
267
- self.cursor.execute(
268
- f"""
260
+ self.cursor.execute(f"""
269
261
  CREATE UNIQUE INDEX IF NOT EXISTS idx_grib_index_all_keys
270
262
  ON grib_index ({', '.join(all_columns)})
271
- """
272
- )
263
+ """)
273
264
 
274
265
  for key in all_columns:
275
- self.cursor.execute(
276
- f"""
266
+ self.cursor.execute(f"""
277
267
  CREATE INDEX IF NOT EXISTS idx_grib_index_{key}
278
268
  ON grib_index ({key})
279
- """
280
- )
269
+ """)
281
270
 
282
271
  def add_grib_file(self, path: str) -> None:
283
272
  """Add a GRIB file to the database.
@@ -364,7 +353,7 @@ class GribIndex:
364
353
 
365
354
  Returns
366
355
  -------
367
- List[dict]
356
+ list[dict]
368
357
  A list of GRIB2 parameter information.
369
358
  """
370
359
  if ("grib2", paramId) in self.cache:
@@ -524,8 +513,8 @@ class GribIndex:
524
513
 
525
514
  Parameters
526
515
  ----------
527
- dates : List[Any]
528
- List of dates to retrieve data for.
516
+ dates : list[Any]
517
+ list of dates to retrieve data for.
529
518
  **kwargs : Any
530
519
  Additional filtering criteria.
531
520
 
@@ -539,12 +528,13 @@ class GribIndex:
539
528
  dates = [d.isoformat() for d in dates]
540
529
 
541
530
  query = """SELECT _path_id, _offset, _length
542
- FROM grib_index WHERE valid_datetime IN ({})""".format(
543
- ", ".join("?" for _ in dates)
544
- )
531
+ FROM grib_index WHERE valid_datetime IN ({})""".format(", ".join("?" for _ in dates))
545
532
  params = dates
546
533
 
547
534
  for k, v in kwargs.items():
535
+ if k not in self._columns:
536
+ LOG.warning(f"Warning : {k} not in database columns, key discarded")
537
+ continue
548
538
  if isinstance(v, list):
549
539
  query += f" AND {k} IN ({', '.join('?' for _ in v)})"
550
540
  params.extend([str(_) for _ in v])
@@ -552,11 +542,14 @@ class GribIndex:
552
542
  query += f" AND {k} = ?"
553
543
  params.append(str(v))
554
544
 
555
- print("SELECT", query)
556
- print("SELECT", params)
545
+ print("SELECT (query)", query)
546
+ print("SELECT (params)", params)
557
547
 
558
548
  self.cursor.execute(query, params)
559
- for path_id, offset, length in self.cursor.fetchall():
549
+
550
+ fetch = self.cursor.fetchall()
551
+
552
+ for path_id, offset, length in fetch:
560
553
  if path_id in self.cache:
561
554
  file = self.cache[path_id]
562
555
  else:
@@ -570,9 +563,8 @@ class GribIndex:
570
563
  yield data
571
564
 
572
565
 
573
- @source_registry.register("grib_index")
566
+ @source_registry.register("grib-index")
574
567
  class GribIndexSource(LegacySource):
575
-
576
568
  @staticmethod
577
569
  def _execute(
578
570
  context: Any,
@@ -602,15 +594,51 @@ class GribIndexSource(LegacySource):
602
594
  An array of retrieved GRIB fields.
603
595
  """
604
596
  index = GribIndex(indexdb)
605
- result = []
606
597
 
607
598
  if flavour is not None:
608
599
  flavour = RuleBasedFlavour(flavour)
609
600
 
610
- for grib in index.retrieve(dates, **kwargs):
611
- field = ekd.from_source("memory", grib)[0]
612
- if flavour:
613
- field = flavour.apply(field)
614
- result.append(field)
601
+ if hasattr(dates, "date_to_intervals"):
602
+ # When using accumulate source
603
+ full_requests = []
604
+ for d, interval in dates.intervals:
605
+ context.trace("🌧️", "interval:", interval)
606
+ valid_date, request, _ = dates._adjust_request_to_interval(interval, kwargs)
607
+ context.trace("🌧️", " request =", request)
608
+ full_requests.append(([valid_date], request))
609
+ else:
610
+ # Normal case, without accumulate source
611
+ full_requests = [(dates, kwargs)]
612
+
613
+ full_requests = factorise(full_requests)
614
+ context.trace("🌧️", f"number of (factorised) requests: {len(full_requests)}")
615
+ for valid_dates, request in full_requests:
616
+ context.trace("🌧️", f" dates: {valid_dates}, request: {request}")
617
+
618
+ result = []
619
+ for valid_dates, request in full_requests:
620
+ for grib in index.retrieve(valid_dates, **request):
621
+ field = ekd.from_source("memory", grib)[0]
622
+ if flavour:
623
+ field = flavour.apply(field)
624
+ result.append(field)
615
625
 
616
626
  return FieldArray(result)
627
+
628
+
629
+ def factorise(lst):
630
+ """Factorise a list of (dates, request) tuples by merging dates with identical requests."""
631
+ content = dict()
632
+
633
+ d = defaultdict(list)
634
+ for dates, request in lst:
635
+ assert isinstance(request, dict), type(request)
636
+ key = hashlib.md5(json.dumps(request, sort_keys=True).encode()).hexdigest()
637
+ content[key] = request
638
+ d[key] += dates
639
+
640
+ res = []
641
+ for key, dates in d.items():
642
+ dates = list(sorted(set(dates)))
643
+ res.append((dates, content[key]))
644
+ return res
@@ -17,6 +17,7 @@ from earthkit.data import from_source
17
17
  from earthkit.data.utils.availability import Availability
18
18
 
19
19
  from anemoi.datasets.create.sources import source_registry
20
+ from anemoi.datasets.create.sources.accumulate import IntervalsDatesProvider
20
21
 
21
22
  from .legacy import LegacySource
22
23
 
@@ -145,7 +146,7 @@ def _expand_mars_request(
145
146
 
146
147
  Parameters
147
148
  ----------
148
- request : Dict[str, Any]
149
+ request : dict[str, Any]
149
150
  The input MARS request.
150
151
  date : datetime.datetime
151
152
  The date to be used in the request.
@@ -156,7 +157,7 @@ def _expand_mars_request(
156
157
 
157
158
  Returns
158
159
  -------
159
- List[Dict[str, Any]]
160
+ List[dict[str, Any]]
160
161
  A list of expanded MARS requests.
161
162
  """
162
163
  requests = []
@@ -164,23 +165,26 @@ def _expand_mars_request(
164
165
  user_step = to_list(expand_to_by(request.get("step", [0])))
165
166
  user_time = None
166
167
  user_date = None
167
-
168
168
  if not request_already_using_valid_datetime:
169
- user_time = request.get("time")
169
+ user_time = request.get("user_time")
170
170
  if user_time is not None:
171
171
  user_time = to_list(user_time)
172
172
  user_time = [_normalise_time(t) for t in user_time]
173
173
 
174
174
  user_date = request.get(date_key)
175
175
  if user_date is not None:
176
- assert isinstance(user_date, str), user_date
176
+ if isinstance(user_date, int):
177
+ user_date = str(user_date)
178
+ elif isinstance(user_date, datetime.datetime):
179
+ user_date = user_date.strftime("%Y%m%d")
180
+ else:
181
+ raise ValueError(f"Invalid type for {user_date}")
177
182
  user_date = re.compile("^{}$".format(user_date.replace("-", "").replace("?", ".")))
178
183
 
179
184
  for step in user_step:
180
185
  r = request.copy()
181
186
 
182
187
  if not request_already_using_valid_datetime:
183
-
184
188
  if isinstance(step, str) and "-" in step:
185
189
  assert step.count("-") == 1, step
186
190
 
@@ -190,30 +194,27 @@ def _expand_mars_request(
190
194
  base = date - datetime.timedelta(hours=hours)
191
195
  r.update(
192
196
  {
193
- date_key: base.strftime("%Y%m%d"),
197
+ "date": base.strftime("%Y%m%d"),
194
198
  "time": base.strftime("%H%M"),
195
199
  "step": step,
196
200
  }
197
201
  )
198
-
199
202
  for pproc in ("grid", "rotation", "frame", "area", "bitmap", "resol"):
200
203
  if pproc in r:
201
204
  if isinstance(r[pproc], (list, tuple)):
202
205
  r[pproc] = "/".join(str(x) for x in r[pproc])
203
206
 
204
207
  if user_date is not None:
205
- if not user_date.match(r[date_key]):
208
+ if not user_date.match(r["date"]):
206
209
  continue
207
210
 
208
211
  if user_time is not None:
209
- # It time is provided by the user, we only keep the requests that match the time
212
+ # If time is provided by the user, we only keep the requests that match the time
210
213
  if r["time"] not in user_time:
211
214
  continue
212
215
 
213
216
  requests.append(r)
214
217
 
215
- # assert requests, requests
216
-
217
218
  return requests
218
219
 
219
220
 
@@ -222,6 +223,7 @@ def factorise_requests(
222
223
  *requests: dict[str, Any],
223
224
  request_already_using_valid_datetime: bool = False,
224
225
  date_key: str = "date",
226
+ no_date_here: bool = False,
225
227
  ) -> Generator[dict[str, Any], None, None]:
226
228
  """Factorizes the requests based on the given dates.
227
229
 
@@ -229,33 +231,42 @@ def factorise_requests(
229
231
  ----------
230
232
  dates : List[datetime.datetime]
231
233
  The list of dates to be used in the requests.
232
- requests : Dict[str, Any]
234
+ requests : List[dict[str, Any]]
233
235
  The input requests to be factorized.
234
236
  request_already_using_valid_datetime : bool, optional
235
237
  Flag indicating if the requests already use valid datetime.
236
238
  date_key : str, optional
237
239
  The key for the date in the requests.
240
+ no_date_here : bool, optional
241
+ Flag indicating if there is no date in the "dates" list.
238
242
 
239
243
  Returns
240
244
  -------
241
- Generator[Dict[str, Any], None, None]
245
+ Generator[dict[str, Any], None, None]
242
246
  Factorized requests.
243
247
  """
244
- updates = []
245
- for req in requests:
246
- # req = normalise_request(req)
248
+ if isinstance(requests, tuple) and len(requests) == 1 and "requests" in requests[0]:
249
+ requests = requests[0]["requests"]
247
250
 
248
- for d in dates:
249
- updates += _expand_mars_request(
251
+ updates = []
252
+ for d in sorted(dates):
253
+ for req in requests:
254
+ if not no_date_here and (
255
+ ("date" in req)
256
+ and ("time" in req)
257
+ and d.strftime("%Y%m%d%H%M") != (str(req["date"]) + str(req["time"]).zfill(4))
258
+ ):
259
+ continue
260
+ new_req = _expand_mars_request(
250
261
  req,
251
262
  date=d,
252
263
  request_already_using_valid_datetime=request_already_using_valid_datetime,
253
- date_key=date_key,
264
+ date_key="user_date",
254
265
  )
266
+ updates += new_req
255
267
 
256
268
  if not updates:
257
269
  return
258
-
259
270
  compressed = Availability(updates)
260
271
  for r in compressed.iterate():
261
272
  for k, v in r.items():
@@ -269,12 +280,12 @@ def use_grib_paramid(r: dict[str, Any]) -> dict[str, Any]:
269
280
 
270
281
  Parameters
271
282
  ----------
272
- r : Dict[str, Any]
283
+ r : dict[str, Any]
273
284
  The input request containing parameter short names.
274
285
 
275
286
  Returns
276
287
  -------
277
- Dict[str, Any]
288
+ dict[str, Any]
278
289
  The request with parameter IDs.
279
290
  """
280
291
  from anemoi.utils.grib import shortname_to_paramid
@@ -379,7 +390,7 @@ class MarsSource(LegacySource):
379
390
  The context for the requests.
380
391
  dates : List[datetime.datetime]
381
392
  The list of dates to be used in the requests.
382
- requests : Dict[str, Any]
393
+ requests : dict[str, Any]
383
394
  The input requests to be executed.
384
395
  request_already_using_valid_datetime : bool, optional
385
396
  Flag indicating if the requests already use valid datetime.
@@ -395,7 +406,6 @@ class MarsSource(LegacySource):
395
406
  Any
396
407
  The resulting dataset.
397
408
  """
398
-
399
409
  if not requests:
400
410
  requests = [kwargs]
401
411
 
@@ -418,7 +428,26 @@ class MarsSource(LegacySource):
418
428
  "'param' cannot be 'True'. If you wrote 'param: on' in yaml, you may want to use quotes?"
419
429
  )
420
430
 
421
- if len(dates) == 0: # When using `repeated_dates`
431
+ if isinstance(dates, IntervalsDatesProvider):
432
+ # When using accumulate source
433
+ requests_ = []
434
+ for request in requests:
435
+ for d, interval in dates.intervals:
436
+ context.trace("🌧️", "interval:", interval)
437
+ _, r, _ = dates._adjust_request_to_interval(interval, request)
438
+ context.trace("🌧️", " adjusted request =", r)
439
+ requests_.append(r)
440
+ requests = requests_
441
+ context.trace("🌧️", f"Total requests: {len(requests)}")
442
+ requests = factorise_requests(
443
+ ["no_date_here"],
444
+ *requests,
445
+ request_already_using_valid_datetime=True,
446
+ date_key=date_key,
447
+ no_date_here=True,
448
+ )
449
+
450
+ elif len(dates) == 0: # When using `repeated_dates`
422
451
  assert len(requests) == 1, requests
423
452
  assert "date" in requests[0], requests[0]
424
453
  if isinstance(requests[0]["date"], datetime.date):
@@ -434,7 +463,7 @@ class MarsSource(LegacySource):
434
463
  requests = list(requests)
435
464
 
436
465
  ds = from_source("empty")
437
- context.trace("✅", f"{[str(d) for d in dates]}")
466
+ context.trace("✅", f"{[str(d) for d in dates]}, {len(dates)}")
438
467
  context.trace("✅", f"Will run {len(requests)} requests")
439
468
  for r in requests:
440
469
  r = {k: v for k, v in r.items() if v != ("-",)}
@@ -97,6 +97,7 @@ def load_one(
97
97
  if isinstance(dataset, xr.Dataset):
98
98
  data = dataset
99
99
  else:
100
+ print(f"Opening dataset {dataset} with options {options}")
100
101
  data = xr.open_dataset(dataset, **options)
101
102
 
102
103
  fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
@@ -223,13 +223,10 @@ class Coordinate:
223
223
  # Assume the array is sorted
224
224
 
225
225
  index = np.searchsorted(values, value)
226
- index = index[index < len(values)]
227
-
228
- if np.all(values[index] == value):
226
+ if np.all(index < len(values)) and np.all(values[index] == value):
229
227
  return index
230
228
 
231
229
  # If not found, we need to check if the value is in the array
232
-
233
230
  index = np.where(np.isin(values, value))[0]
234
231
 
235
232
  # We could also return incomplete matches
@@ -557,10 +557,10 @@ class DefaultCoordinateGuesser(CoordinateGuesser):
557
557
  super().__init__(ds)
558
558
 
559
559
  def _is_point(self, c: xr.DataArray, attributes: CoordinateAttributes) -> PointCoordinate | None:
560
- if attributes.standard_name in ["cell", "station", "poi", "point"]:
560
+ if attributes.standard_name in ["location", "cell", "id", "station", "poi", "point"]:
561
561
  return PointCoordinate(c)
562
562
 
563
- if attributes.name in ["cell", "station", "poi", "point"]: # WeatherBench
563
+ if attributes.name in ["location", "cell", "id", "station", "poi", "point"]: # WeatherBench
564
564
  return PointCoordinate(c)
565
565
 
566
566
  return None