anemoi-datasets 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/cleanup.py +44 -0
  3. anemoi/datasets/commands/create.py +52 -21
  4. anemoi/datasets/commands/finalise-additions.py +45 -0
  5. anemoi/datasets/commands/finalise.py +39 -0
  6. anemoi/datasets/commands/init-additions.py +45 -0
  7. anemoi/datasets/commands/init.py +67 -0
  8. anemoi/datasets/commands/inspect.py +1 -1
  9. anemoi/datasets/commands/load-additions.py +47 -0
  10. anemoi/datasets/commands/load.py +47 -0
  11. anemoi/datasets/commands/patch.py +39 -0
  12. anemoi/datasets/create/__init__.py +959 -146
  13. anemoi/datasets/create/check.py +5 -3
  14. anemoi/datasets/create/config.py +54 -2
  15. anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py +57 -0
  16. anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py +57 -0
  17. anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py +54 -0
  18. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py +59 -0
  19. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py +115 -0
  20. anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py +390 -0
  21. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +77 -0
  22. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +55 -0
  23. anemoi/datasets/create/functions/sources/grib.py +86 -1
  24. anemoi/datasets/create/functions/sources/hindcasts.py +14 -73
  25. anemoi/datasets/create/functions/sources/mars.py +9 -3
  26. anemoi/datasets/create/functions/sources/xarray/__init__.py +12 -2
  27. anemoi/datasets/create/functions/sources/xarray/coordinates.py +7 -0
  28. anemoi/datasets/create/functions/sources/xarray/field.py +8 -2
  29. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +0 -2
  30. anemoi/datasets/create/functions/sources/xarray/flavour.py +21 -1
  31. anemoi/datasets/create/functions/sources/xarray/metadata.py +40 -40
  32. anemoi/datasets/create/functions/sources/xarray/time.py +63 -30
  33. anemoi/datasets/create/functions/sources/xarray/variable.py +15 -38
  34. anemoi/datasets/create/input.py +62 -39
  35. anemoi/datasets/create/persistent.py +1 -1
  36. anemoi/datasets/create/statistics/__init__.py +39 -23
  37. anemoi/datasets/create/utils.py +6 -2
  38. anemoi/datasets/data/__init__.py +1 -0
  39. anemoi/datasets/data/concat.py +46 -2
  40. anemoi/datasets/data/dataset.py +119 -34
  41. anemoi/datasets/data/debug.py +5 -1
  42. anemoi/datasets/data/forwards.py +17 -8
  43. anemoi/datasets/data/grids.py +17 -3
  44. anemoi/datasets/data/interpolate.py +133 -0
  45. anemoi/datasets/data/masked.py +2 -2
  46. anemoi/datasets/data/misc.py +56 -66
  47. anemoi/datasets/data/missing.py +240 -0
  48. anemoi/datasets/data/rescale.py +147 -0
  49. anemoi/datasets/data/select.py +7 -1
  50. anemoi/datasets/data/stores.py +23 -10
  51. anemoi/datasets/data/subset.py +47 -5
  52. anemoi/datasets/data/unchecked.py +20 -22
  53. anemoi/datasets/data/xy.py +125 -0
  54. anemoi/datasets/dates/__init__.py +124 -95
  55. anemoi/datasets/dates/groups.py +85 -20
  56. anemoi/datasets/grids.py +66 -48
  57. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/METADATA +8 -17
  58. anemoi_datasets-0.5.0.dist-info/RECORD +105 -0
  59. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/WHEEL +1 -1
  60. anemoi/datasets/create/loaders.py +0 -936
  61. anemoi_datasets-0.4.4.dist-info/RECORD +0 -86
  62. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/LICENSE +0 -0
  63. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/entry_points.txt +0 -0
  64. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.5.0.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ from earthkit.data.core.fieldlist import FieldList
23
23
  from earthkit.data.core.fieldlist import MultiFieldList
24
24
  from earthkit.data.core.order import build_remapping
25
25
 
26
- from anemoi.datasets.dates import Dates
26
+ from anemoi.datasets.dates import DatesProvider
27
27
 
28
28
  from .functions import import_function
29
29
  from .template import Context
@@ -75,7 +75,7 @@ def time_delta_to_string(delta):
75
75
 
76
76
 
77
77
  def is_function(name, kind):
78
- name, delta = parse_function_name(name) # noqa
78
+ name, _ = parse_function_name(name)
79
79
  try:
80
80
  import_function(name, kind)
81
81
  return True
@@ -106,30 +106,32 @@ def _data_request(data):
106
106
  area = grid = None
107
107
 
108
108
  for field in data:
109
- if not hasattr(field, "as_mars"):
110
- continue
111
-
112
- if date is None:
113
- date = field.datetime()["valid_time"]
114
-
115
- if field.datetime()["valid_time"] != date:
116
- continue
109
+ try:
110
+ if date is None:
111
+ date = field.datetime()["valid_time"]
117
112
 
118
- as_mars = field.metadata(namespace="mars")
119
- step = as_mars.get("step")
120
- levtype = as_mars.get("levtype", "sfc")
121
- param = as_mars["param"]
122
- levelist = as_mars.get("levelist", None)
123
- area = field.mars_area
124
- grid = field.mars_grid
113
+ if field.datetime()["valid_time"] != date:
114
+ continue
125
115
 
126
- if levelist is None:
127
- params_levels[levtype].add(param)
128
- else:
129
- params_levels[levtype].add((param, levelist))
116
+ as_mars = field.metadata(namespace="mars")
117
+ if not as_mars:
118
+ continue
119
+ step = as_mars.get("step")
120
+ levtype = as_mars.get("levtype", "sfc")
121
+ param = as_mars["param"]
122
+ levelist = as_mars.get("levelist", None)
123
+ area = field.mars_area
124
+ grid = field.mars_grid
125
+
126
+ if levelist is None:
127
+ params_levels[levtype].add(param)
128
+ else:
129
+ params_levels[levtype].add((param, levelist))
130
130
 
131
- if step:
132
- params_steps[levtype].add((param, step))
131
+ if step:
132
+ params_steps[levtype].add((param, step))
133
+ except Exception:
134
+ LOG.error(f"Error in retrieving metadata (cannot build data request info) for {field}", exc_info=True)
133
135
 
134
136
  def sort(old_dic):
135
137
  new_dic = {}
@@ -202,11 +204,15 @@ class Result:
202
204
  _coords_already_built = False
203
205
 
204
206
  def __init__(self, context, action_path, dates):
207
+ from anemoi.datasets.dates.groups import GroupOfDates
208
+
209
+ assert isinstance(dates, GroupOfDates), dates
210
+
205
211
  assert isinstance(context, ActionContext), type(context)
206
212
  assert isinstance(action_path, list), action_path
207
213
 
208
214
  self.context = context
209
- self.dates = dates
215
+ self.group_of_dates = dates
210
216
  self.action_path = action_path
211
217
 
212
218
  @property
@@ -288,7 +294,6 @@ class Result:
288
294
  names += list(a.keys())
289
295
 
290
296
  print(f"Building a {len(names)}D hypercube using", names)
291
-
292
297
  ds = ds.order_by(*args, remapping=remapping, patches=patches)
293
298
  user_coords = ds.unique_values(*names, remapping=remapping, patches=patches, progress_bar=False)
294
299
 
@@ -404,10 +409,10 @@ class Result:
404
409
  more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()])
405
410
 
406
411
  dates = " no-dates"
407
- if self.dates is not None:
408
- dates = f" {len(self.dates)} dates"
412
+ if self.group_of_dates is not None:
413
+ dates = f" {len(self.group_of_dates)} dates"
409
414
  dates += " ("
410
- dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.dates)
415
+ dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.group_of_dates)
411
416
  if len(dates) > 100:
412
417
  dates = dates[:100] + "..."
413
418
  dates += ")"
@@ -422,7 +427,7 @@ class Result:
422
427
  raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
423
428
 
424
429
  def _trace_datasource(self, *args, **kwargs):
425
- return f"{self.__class__.__name__}({shorten(self.dates)})"
430
+ return f"{self.__class__.__name__}({self.group_of_dates})"
426
431
 
427
432
  def build_coords(self):
428
433
  if self._coords_already_built:
@@ -512,7 +517,7 @@ class Result:
512
517
  @cached_property
513
518
  def shape(self):
514
519
  return [
515
- len(self.dates),
520
+ len(self.group_of_dates),
516
521
  len(self.variables),
517
522
  len(self.ensembles),
518
523
  len(self.grid_values),
@@ -521,7 +526,7 @@ class Result:
521
526
  @cached_property
522
527
  def coords(self):
523
528
  return {
524
- "dates": self.dates,
529
+ "dates": list(self.group_of_dates),
525
530
  "variables": self.variables,
526
531
  "ensembles": self.ensembles,
527
532
  "values": self.grid_values,
@@ -572,7 +577,7 @@ class FunctionResult(Result):
572
577
  self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs))
573
578
 
574
579
  def _trace_datasource(self, *args, **kwargs):
575
- return f"{self.action.name}({shorten(self.dates)})"
580
+ return f"{self.action.name}({self.group_of_dates})"
576
581
 
577
582
  @cached_property
578
583
  @assert_fieldlist
@@ -582,14 +587,21 @@ class FunctionResult(Result):
582
587
  args, kwargs = resolve(self.context, (self.args, self.kwargs))
583
588
 
584
589
  try:
585
- return _tidy(self.action.function(FunctionContext(self), self.dates, *args, **kwargs))
590
+ return _tidy(
591
+ self.action.function(
592
+ FunctionContext(self),
593
+ list(self.group_of_dates), # Will provide a list of datetime objects
594
+ *args,
595
+ **kwargs,
596
+ )
597
+ )
586
598
  except Exception:
587
599
  LOG.error(f"Error in {self.action.function.__name__}", exc_info=True)
588
600
  raise
589
601
 
590
602
  def __repr__(self):
591
603
  try:
592
- return f"{self.action.name}({shorten(self.dates)})"
604
+ return f"{self.action.name}({self.group_of_dates})"
593
605
  except Exception:
594
606
  return f"{self.__class__.__name__}(unitialised)"
595
607
 
@@ -608,7 +620,7 @@ class JoinResult(Result):
608
620
  @notify_result
609
621
  @trace_datasource
610
622
  def datasource(self):
611
- ds = EmptyResult(self.context, self.action_path, self.dates).datasource
623
+ ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource
612
624
  for i in self.results:
613
625
  ds += i.datasource
614
626
  return _tidy(ds)
@@ -823,7 +835,7 @@ class ConcatResult(Result):
823
835
  @notify_result
824
836
  @trace_datasource
825
837
  def datasource(self):
826
- ds = EmptyResult(self.context, self.action_path, self.dates).datasource
838
+ ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource
827
839
  for i in self.results:
828
840
  ds += i.datasource
829
841
  return _tidy(ds)
@@ -903,7 +915,7 @@ class ConcatAction(Action):
903
915
  cfg = deepcopy(cfg)
904
916
  dates_cfg = cfg.pop("dates")
905
917
  assert isinstance(dates_cfg, dict), dates_cfg
906
- filtering_dates = Dates.from_config(**dates_cfg)
918
+ filtering_dates = DatesProvider.from_config(**dates_cfg)
907
919
  action = action_factory(cfg, context, action_path + [str(i)])
908
920
  parts.append((filtering_dates, action))
909
921
  self.parts = parts
@@ -914,9 +926,11 @@ class ConcatAction(Action):
914
926
 
915
927
  @trace_select
916
928
  def select(self, dates):
929
+ from anemoi.datasets.dates.groups import GroupOfDates
930
+
917
931
  results = []
918
932
  for filtering_dates, action in self.parts:
919
- newdates = sorted(set(dates) & set(filtering_dates))
933
+ newdates = GroupOfDates(sorted(set(dates) & set(filtering_dates)), dates.provider)
920
934
  if newdates:
921
935
  results.append(action.select(newdates))
922
936
  if not results:
@@ -952,8 +966,10 @@ def action_factory(config, context, action_path):
952
966
 
953
967
  if isinstance(config[key], list):
954
968
  args, kwargs = config[key], {}
955
- if isinstance(config[key], dict):
969
+ elif isinstance(config[key], dict):
956
970
  args, kwargs = [], config[key]
971
+ else:
972
+ raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}")
957
973
 
958
974
  cls = {
959
975
  # "date_shift": DateShiftAction,
@@ -1020,6 +1036,13 @@ class FunctionContext:
1020
1036
  def trace(self, emoji, *args):
1021
1037
  trace(emoji, *args)
1022
1038
 
1039
+ def info(self, *args, **kwargs):
1040
+ LOG.info(*args, **kwargs)
1041
+
1042
+ @property
1043
+ def dates_provider(self):
1044
+ return self.owner.group_of_dates.provider
1045
+
1023
1046
 
1024
1047
  class ActionContext(Context):
1025
1048
  def __init__(self, /, order_by, flatten_grid, remapping, use_grib_paramid):
@@ -68,7 +68,7 @@ class PersistentDict:
68
68
  path = os.path.join(self.dirname, f"{h}.pickle")
69
69
 
70
70
  if os.path.exists(path):
71
- LOG.warn(f"{path} already exists")
71
+ LOG.warning(f"{path} already exists")
72
72
 
73
73
  tmp_path = path + f".tmp-{os.getpid()}-on-{socket.gethostname()}"
74
74
  with open(tmp_path, "wb") as f:
@@ -79,6 +79,37 @@ def to_datetimes(dates):
79
79
  return [to_datetime(d) for d in dates]
80
80
 
81
81
 
82
+ def fix_variance(x, name, count, sums, squares):
83
+ assert count.shape == sums.shape == squares.shape
84
+ assert isinstance(x, float)
85
+
86
+ mean = sums / count
87
+ assert mean.shape == count.shape
88
+
89
+ if x >= 0:
90
+ return x
91
+
92
+ LOG.warning(f"Negative variance for {name=}, variance={x}")
93
+ magnitude = np.sqrt((squares / count + mean * mean) / 2)
94
+ LOG.warning(f"square / count - mean * mean = {squares/count} - {mean*mean} = {squares/count - mean*mean}")
95
+ LOG.warning(f"Variable span order of magnitude is {magnitude}.")
96
+ LOG.warning(f"Count is {count}.")
97
+
98
+ variances = squares / count - mean * mean
99
+ assert variances.shape == squares.shape == mean.shape
100
+ if all(variances >= 0):
101
+ LOG.warning(f"All individual variances for {name} are positive, setting variance to 0.")
102
+ return 0
103
+
104
+ # if abs(x) < magnitude * 1e-6 and abs(x) < range * 1e-6:
105
+ # LOG.warning("Variance is negative but very small.")
106
+ # variances = squares / count - mean * mean
107
+ # return 0
108
+
109
+ LOG.warning(f"ERROR at least one individual variance is negative ({np.nanmin(variances)}).")
110
+ return x
111
+
112
+
82
113
  def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squares):
83
114
  if (x >= 0).all():
84
115
  return
@@ -292,39 +323,24 @@ class StatAggregator:
292
323
  def aggregate(self):
293
324
  minimum = np.nanmin(self.minimum, axis=0)
294
325
  maximum = np.nanmax(self.maximum, axis=0)
326
+
295
327
  sums = np.nansum(self.sums, axis=0)
296
328
  squares = np.nansum(self.squares, axis=0)
297
329
  count = np.nansum(self.count, axis=0)
298
330
  has_nans = np.any(self.has_nans, axis=0)
299
- mean = sums / count
331
+ assert sums.shape == count.shape == squares.shape == minimum.shape == maximum.shape
300
332
 
301
- assert sums.shape == count.shape == squares.shape == mean.shape == minimum.shape == maximum.shape
333
+ mean = sums / count
334
+ assert mean.shape == minimum.shape
302
335
 
303
336
  x = squares / count - mean * mean
304
-
305
- # def fix_variance(x, name, minimum, maximum, mean, count, sums, squares):
306
- # assert x.shape == minimum.shape == maximum.shape == mean.shape == count.shape == sums.shape == squares.shape
307
- # assert x.shape == (1,)
308
- # x, minimum, maximum, mean, count, sums, squares = x[0], minimum[0], maximum[0], mean[0], count[0], sums[0], squares[0]
309
- # if x >= 0:
310
- # return x
311
- #
312
- # order = np.sqrt((squares / count + mean * mean)/2)
313
- # range = maximum - minimum
314
- # LOG.warning(f"Negative variance for {name=}, variance={x}")
315
- # LOG.warning(f"square / count - mean * mean = {squares / count} - {mean * mean} = {squares / count - mean * mean}")
316
- # LOG.warning(f"Variable order of magnitude is {order}.")
317
- # LOG.warning(f"Range is {range} ({maximum=} - {minimum=}).")
318
- # LOG.warning(f"Count is {count}.")
319
- # if abs(x) < order * 1e-6 and abs(x) < range * 1e-6:
320
- # LOG.warning(f"Variance is negative but very small, setting to 0.")
321
- # return x*0
322
- # return x
337
+ assert x.shape == minimum.shape
323
338
 
324
339
  for i, name in enumerate(self.variables_names):
325
340
  # remove negative variance due to numerical errors
326
- # Not needed for now, fix_variance is disabled
327
- # x[i] = fix_variance(x[i:i+1], name, minimum[i:i+1], maximum[i:i+1], mean[i:i+1], count[i:i+1], sums[i:i+1], squares[i:i+1])
341
+ x[i] = fix_variance(x[i], name, self.count[i : i + 1], self.sums[i : i + 1], self.squares[i : i + 1])
342
+
343
+ for i, name in enumerate(self.variables_names):
328
344
  check_variance(
329
345
  x[i : i + 1],
330
346
  [name],
@@ -7,6 +7,7 @@
7
7
  # nor does it submit to any jurisdiction.
8
8
  #
9
9
 
10
+ import datetime
10
11
  import os
11
12
  from contextlib import contextmanager
12
13
 
@@ -61,10 +62,13 @@ def make_list_int(value):
61
62
 
62
63
 
63
64
  def normalize_and_check_dates(dates, start, end, frequency, dtype="datetime64[s]"):
64
- assert isinstance(frequency, int), frequency
65
+
66
+ dates = [d.hdate if hasattr(d, "hdate") else d for d in dates]
67
+
68
+ assert isinstance(frequency, datetime.timedelta), frequency
65
69
  start = np.datetime64(start)
66
70
  end = np.datetime64(end)
67
- delta = np.timedelta64(frequency, "h")
71
+ delta = np.timedelta64(frequency)
68
72
 
69
73
  res = []
70
74
  while start <= end:
@@ -27,6 +27,7 @@ class MissingDateError(Exception):
27
27
 
28
28
  def open_dataset(*args, **kwargs):
29
29
  ds = _open_dataset(*args, **kwargs)
30
+ ds = ds.mutate()
30
31
  ds.arguments = {"args": args, "kwargs": kwargs}
31
32
  ds._check()
32
33
  return ds
@@ -9,6 +9,7 @@ import logging
9
9
  from functools import cached_property
10
10
 
11
11
  import numpy as np
12
+ from anemoi.utils.dates import frequency_to_timedelta
12
13
 
13
14
  from .debug import Node
14
15
  from .debug import debug_indexing
@@ -102,20 +103,63 @@ class Concat(ConcatMixin, Combined):
102
103
  def tree(self):
103
104
  return Node(self, [d.tree() for d in self.datasets])
104
105
 
106
+ @classmethod
107
+ def check_dataset_compatibility(cls, datasets, fill_missing_gaps=False):
108
+ # Study the dates
109
+ ranges = [(d.dates[0].astype(object), d.dates[-1].astype(object)) for d in datasets]
105
110
 
106
- def concat_factory(args, kwargs, zarr_root):
111
+ # Make sure the dates are disjoint
112
+ for i in range(len(ranges)):
113
+ r = ranges[i]
114
+ for j in range(i + 1, len(ranges)):
115
+ s = ranges[j]
116
+ if r[0] <= s[0] <= r[1] or r[0] <= s[1] <= r[1]:
117
+ raise ValueError(f"Overlapping dates: {r} and {s} ({datasets[i]} {datasets[j]})")
118
+
119
+ # For now we should have the datasets in order with no gaps
120
+
121
+ frequency = frequency_to_timedelta(datasets[0].frequency)
122
+ result = []
123
+
124
+ for i in range(len(ranges) - 1):
125
+ result.append(datasets[i])
126
+ r = ranges[i]
127
+ s = ranges[i + 1]
128
+ if r[1] + frequency != s[0]:
129
+ if fill_missing_gaps:
130
+ from .missing import MissingDataset
131
+
132
+ result.append(MissingDataset(datasets[i], r[1] + frequency, s[0] - frequency))
133
+ else:
134
+ r = [str(e) for e in r]
135
+ s = [str(e) for e in s]
136
+ raise ValueError(
137
+ "Datasets must be sorted by dates, with no gaps: "
138
+ f"{r} and {s} ({datasets[i]} {datasets[i+1]})"
139
+ )
140
+
141
+ result.append(datasets[-1])
142
+ assert len(result) >= len(datasets), (len(result), len(datasets))
143
+
144
+ return result
145
+
146
+
147
+ def concat_factory(args, kwargs):
107
148
 
108
149
  datasets = kwargs.pop("concat")
150
+ fill_missing_gaps = kwargs.pop("fill_missing_gaps", False)
109
151
  assert isinstance(datasets, (list, tuple))
110
152
  assert len(args) == 0
111
153
 
112
154
  assert isinstance(datasets, (list, tuple))
113
155
 
114
- datasets = [_open(e, zarr_root) for e in datasets]
156
+ datasets = [_open(e) for e in datasets]
115
157
 
116
158
  if len(datasets) == 1:
117
159
  return datasets[0]._subset(**kwargs)
118
160
 
119
161
  datasets, kwargs = _auto_adjust(datasets, kwargs)
120
162
 
163
+ datasets = Concat.check_dataset_compatibility(datasets, fill_missing_gaps)
164
+
121
165
  return Concat(datasets)._subset(**kwargs)