anemoi-datasets 0.4.3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/cleanup.py +44 -0
  3. anemoi/datasets/commands/create.py +50 -20
  4. anemoi/datasets/commands/finalise-additions.py +45 -0
  5. anemoi/datasets/commands/finalise.py +39 -0
  6. anemoi/datasets/commands/init-additions.py +45 -0
  7. anemoi/datasets/commands/init.py +67 -0
  8. anemoi/datasets/commands/inspect.py +1 -1
  9. anemoi/datasets/commands/load-additions.py +47 -0
  10. anemoi/datasets/commands/load.py +47 -0
  11. anemoi/datasets/commands/patch.py +39 -0
  12. anemoi/datasets/compute/recentre.py +1 -1
  13. anemoi/datasets/create/__init__.py +961 -146
  14. anemoi/datasets/create/check.py +5 -3
  15. anemoi/datasets/create/config.py +53 -2
  16. anemoi/datasets/create/functions/sources/accumulations.py +6 -22
  17. anemoi/datasets/create/functions/sources/hindcasts.py +27 -12
  18. anemoi/datasets/create/functions/sources/tendencies.py +1 -1
  19. anemoi/datasets/create/functions/sources/xarray/__init__.py +12 -2
  20. anemoi/datasets/create/functions/sources/xarray/coordinates.py +7 -0
  21. anemoi/datasets/create/functions/sources/xarray/field.py +1 -1
  22. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +0 -2
  23. anemoi/datasets/create/functions/sources/xarray/flavour.py +21 -1
  24. anemoi/datasets/create/functions/sources/xarray/metadata.py +27 -29
  25. anemoi/datasets/create/functions/sources/xarray/time.py +63 -30
  26. anemoi/datasets/create/functions/sources/xarray/variable.py +15 -38
  27. anemoi/datasets/create/input.py +62 -25
  28. anemoi/datasets/create/statistics/__init__.py +39 -23
  29. anemoi/datasets/create/utils.py +3 -2
  30. anemoi/datasets/data/__init__.py +1 -0
  31. anemoi/datasets/data/concat.py +46 -2
  32. anemoi/datasets/data/dataset.py +109 -34
  33. anemoi/datasets/data/forwards.py +17 -8
  34. anemoi/datasets/data/grids.py +17 -3
  35. anemoi/datasets/data/interpolate.py +133 -0
  36. anemoi/datasets/data/misc.py +56 -66
  37. anemoi/datasets/data/missing.py +240 -0
  38. anemoi/datasets/data/select.py +7 -1
  39. anemoi/datasets/data/stores.py +3 -3
  40. anemoi/datasets/data/subset.py +47 -5
  41. anemoi/datasets/data/unchecked.py +20 -22
  42. anemoi/datasets/data/xy.py +125 -0
  43. anemoi/datasets/dates/__init__.py +33 -20
  44. anemoi/datasets/dates/groups.py +2 -2
  45. anemoi/datasets/grids.py +66 -48
  46. {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/METADATA +5 -5
  47. {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/RECORD +51 -41
  48. {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/WHEEL +1 -1
  49. anemoi/datasets/create/loaders.py +0 -924
  50. {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/LICENSE +0 -0
  51. {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/entry_points.txt +0 -0
  52. {anemoi_datasets-0.4.3.dist-info → anemoi_datasets-0.4.5.dist-info}/top_level.txt +0 -0
@@ -14,34 +14,32 @@ from functools import cached_property
14
14
  import numpy as np
15
15
  from earthkit.data.utils.array import ensure_backend
16
16
 
17
- from anemoi.datasets.create.functions.sources.xarray.metadata import MDMapping
18
-
19
17
  from .field import XArrayField
20
18
 
21
19
  LOG = logging.getLogger(__name__)
22
20
 
23
21
 
24
22
  class Variable:
25
- def __init__(self, *, ds, var, coordinates, grid, time, metadata, mapping=None, array_backend=None):
23
+ def __init__(
24
+ self,
25
+ *,
26
+ ds,
27
+ var,
28
+ coordinates,
29
+ grid,
30
+ time,
31
+ metadata,
32
+ array_backend=None,
33
+ ):
26
34
  self.ds = ds
27
35
  self.var = var
28
36
 
29
37
  self.grid = grid
30
38
  self.coordinates = coordinates
31
39
 
32
- # print("Variable", var.name)
33
- # for c in coordinates:
34
- # print(" ", c)
35
-
36
40
  self._metadata = metadata.copy()
37
- # self._metadata.update(var.attrs)
38
41
  self._metadata.update({"variable": var.name})
39
42
 
40
- # self._metadata.setdefault("level", None)
41
- # self._metadata.setdefault("number", 0)
42
- # self._metadata.setdefault("levtype", "sfc")
43
- self._mapping = mapping
44
-
45
43
  self.time = time
46
44
 
47
45
  self.shape = tuple(len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid)
@@ -51,23 +49,6 @@ class Variable:
51
49
  self.length = math.prod(self.shape)
52
50
  self.array_backend = ensure_backend(array_backend)
53
51
 
54
- def update_metadata_mapping(self, kwargs):
55
-
56
- result = {}
57
-
58
- for k, v in kwargs.items():
59
- if k == "param":
60
- result[k] = "variable"
61
- continue
62
-
63
- for c in self.coordinates:
64
- if k in c.mars_names:
65
- for v in c.mars_names:
66
- result[v] = c.variable.name
67
- break
68
-
69
- self._mapping = MDMapping(result)
70
-
71
52
  @property
72
53
  def name(self):
73
54
  return self.var.name
@@ -111,17 +92,11 @@ class Variable:
111
92
  kwargs = {k: v for k, v in zip(self.names, coords)}
112
93
  return XArrayField(self, self.var.isel(kwargs))
113
94
 
114
- @property
115
- def mapping(self):
116
- return self._mapping
117
-
118
95
  def sel(self, missing, **kwargs):
119
96
 
120
97
  if not kwargs:
121
98
  return self
122
99
 
123
- kwargs = self._mapping.from_user(kwargs)
124
-
125
100
  k, v = kwargs.popitem()
126
101
 
127
102
  c = self.by_name.get(k)
@@ -147,13 +122,15 @@ class Variable:
147
122
  grid=self.grid,
148
123
  time=self.time,
149
124
  metadata=metadata,
150
- mapping=self.mapping,
151
125
  )
152
126
 
153
127
  return variable.sel(missing, **kwargs)
154
128
 
155
129
  def match(self, **kwargs):
156
- kwargs = self._mapping.from_user(kwargs)
130
+
131
+ if "param" in kwargs:
132
+ assert "variable" not in kwargs
133
+ kwargs["variable"] = kwargs.pop("param")
157
134
 
158
135
  if "variable" in kwargs:
159
136
  name = kwargs.pop("variable")
@@ -106,30 +106,32 @@ def _data_request(data):
106
106
  area = grid = None
107
107
 
108
108
  for field in data:
109
- if not hasattr(field, "as_mars"):
110
- continue
111
-
112
- if date is None:
113
- date = field.datetime()["valid_time"]
114
-
115
- if field.datetime()["valid_time"] != date:
116
- continue
109
+ try:
110
+ if date is None:
111
+ date = field.datetime()["valid_time"]
117
112
 
118
- as_mars = field.metadata(namespace="mars")
119
- step = as_mars.get("step")
120
- levtype = as_mars.get("levtype", "sfc")
121
- param = as_mars["param"]
122
- levelist = as_mars.get("levelist", None)
123
- area = field.mars_area
124
- grid = field.mars_grid
113
+ if field.datetime()["valid_time"] != date:
114
+ continue
125
115
 
126
- if levelist is None:
127
- params_levels[levtype].add(param)
128
- else:
129
- params_levels[levtype].add((param, levelist))
116
+ as_mars = field.metadata(namespace="mars")
117
+ if not as_mars:
118
+ continue
119
+ step = as_mars.get("step")
120
+ levtype = as_mars.get("levtype", "sfc")
121
+ param = as_mars["param"]
122
+ levelist = as_mars.get("levelist", None)
123
+ area = field.mars_area
124
+ grid = field.mars_grid
125
+
126
+ if levelist is None:
127
+ params_levels[levtype].add(param)
128
+ else:
129
+ params_levels[levtype].add((param, levelist))
130
130
 
131
- if step:
132
- params_steps[levtype].add((param, step))
131
+ if step:
132
+ params_steps[levtype].add((param, step))
133
+ except Exception:
134
+ LOG.error(f"Error in retrieving metadata (cannot build data request info) for {field}", exc_info=True)
133
135
 
134
136
  def sort(old_dic):
135
137
  new_dic = {}
@@ -277,6 +279,9 @@ class Result:
277
279
  if len(args) == 1 and isinstance(args[0], (list, tuple)):
278
280
  args = args[0]
279
281
 
282
+ # print("Executing", self.action_path)
283
+ # print("Dates:", compress_dates(self.dates))
284
+
280
285
  names = []
281
286
  for a in args:
282
287
  if isinstance(a, str):
@@ -285,14 +290,13 @@ class Result:
285
290
  names += list(a.keys())
286
291
 
287
292
  print(f"Building a {len(names)}D hypercube using", names)
288
-
289
293
  ds = ds.order_by(*args, remapping=remapping, patches=patches)
290
- user_coords = ds.unique_values(*names, remapping=remapping, patches=patches)
294
+ user_coords = ds.unique_values(*names, remapping=remapping, patches=patches, progress_bar=False)
291
295
 
292
296
  print()
293
297
  print("Number of unique values found for each coordinate:")
294
298
  for k, v in user_coords.items():
295
- print(f" {k:20}:", len(v))
299
+ print(f" {k:20}:", len(v), shorten_list(v, max_length=10))
296
300
  print()
297
301
  user_shape = tuple(len(v) for k, v in user_coords.items())
298
302
  print("Shape of the hypercube :", user_shape)
@@ -305,13 +309,18 @@ class Result:
305
309
 
306
310
  remapping = build_remapping(remapping, patches)
307
311
  expected = set(itertools.product(*user_coords.values()))
312
+ extra = set()
308
313
 
309
314
  if math.prod(user_shape) > len(ds):
310
315
  print(f"This means that all the fields in the datasets do not exists for all combinations of {names}.")
311
316
 
312
317
  for f in ds:
313
318
  metadata = remapping(f.metadata)
314
- expected.remove(tuple(metadata(n) for n in names))
319
+ key = tuple(metadata(n, default=None) for n in names)
320
+ if key in expected:
321
+ expected.remove(key)
322
+ else:
323
+ extra.add(key)
315
324
 
316
325
  print("Missing fields:")
317
326
  print()
@@ -321,7 +330,35 @@ class Result:
321
330
  print("...", len(expected) - i - 1, "more")
322
331
  break
323
332
 
333
+ print("Extra fields:")
334
+ print()
335
+ for i, f in enumerate(sorted(extra)):
336
+ print(" ", f)
337
+ if i >= 9 and len(extra) > 10:
338
+ print("...", len(extra) - i - 1, "more")
339
+ break
340
+
324
341
  print()
342
+ print("Missing values:")
343
+ per_name = defaultdict(set)
344
+ for e in expected:
345
+ for n, v in zip(names, e):
346
+ per_name[n].add(v)
347
+
348
+ for n, v in per_name.items():
349
+ print(" ", n, len(v), shorten_list(sorted(v), max_length=10))
350
+ print()
351
+
352
+ print("Extra values:")
353
+ per_name = defaultdict(set)
354
+ for e in extra:
355
+ for n, v in zip(names, e):
356
+ per_name[n].add(v)
357
+
358
+ for n, v in per_name.items():
359
+ print(" ", n, len(v), shorten_list(sorted(v), max_length=10))
360
+ print()
361
+
325
362
  print("To solve this issue, you can:")
326
363
  print(
327
364
  " - Provide a better selection, like 'step: 0' or 'level: 1000' to "
@@ -79,6 +79,37 @@ def to_datetimes(dates):
79
79
  return [to_datetime(d) for d in dates]
80
80
 
81
81
 
82
+ def fix_variance(x, name, count, sums, squares):
83
+ assert count.shape == sums.shape == squares.shape
84
+ assert isinstance(x, float)
85
+
86
+ mean = sums / count
87
+ assert mean.shape == count.shape
88
+
89
+ if x >= 0:
90
+ return x
91
+
92
+ LOG.warning(f"Negative variance for {name=}, variance={x}")
93
+ magnitude = np.sqrt((squares / count + mean * mean) / 2)
94
+ LOG.warning(f"square / count - mean * mean = {squares/count} - {mean*mean} = {squares/count - mean*mean}")
95
+ LOG.warning(f"Variable span order of magnitude is {magnitude}.")
96
+ LOG.warning(f"Count is {count}.")
97
+
98
+ variances = squares / count - mean * mean
99
+ assert variances.shape == squares.shape == mean.shape
100
+ if all(variances >= 0):
101
+ LOG.warning(f"All individual variances for {name} are positive, setting variance to 0.")
102
+ return 0
103
+
104
+ # if abs(x) < magnitude * 1e-6 and abs(x) < range * 1e-6:
105
+ # LOG.warning("Variance is negative but very small.")
106
+ # variances = squares / count - mean * mean
107
+ # return 0
108
+
109
+ LOG.warning(f"ERROR at least one individual variance is negative ({np.nanmin(variances)}).")
110
+ return x
111
+
112
+
82
113
  def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squares):
83
114
  if (x >= 0).all():
84
115
  return
@@ -292,39 +323,24 @@ class StatAggregator:
292
323
  def aggregate(self):
293
324
  minimum = np.nanmin(self.minimum, axis=0)
294
325
  maximum = np.nanmax(self.maximum, axis=0)
326
+
295
327
  sums = np.nansum(self.sums, axis=0)
296
328
  squares = np.nansum(self.squares, axis=0)
297
329
  count = np.nansum(self.count, axis=0)
298
330
  has_nans = np.any(self.has_nans, axis=0)
299
- mean = sums / count
331
+ assert sums.shape == count.shape == squares.shape == minimum.shape == maximum.shape
300
332
 
301
- assert sums.shape == count.shape == squares.shape == mean.shape == minimum.shape == maximum.shape
333
+ mean = sums / count
334
+ assert mean.shape == minimum.shape
302
335
 
303
336
  x = squares / count - mean * mean
304
-
305
- # def fix_variance(x, name, minimum, maximum, mean, count, sums, squares):
306
- # assert x.shape == minimum.shape == maximum.shape == mean.shape == count.shape == sums.shape == squares.shape
307
- # assert x.shape == (1,)
308
- # x, minimum, maximum, mean, count, sums, squares = x[0], minimum[0], maximum[0], mean[0], count[0], sums[0], squares[0]
309
- # if x >= 0:
310
- # return x
311
- #
312
- # order = np.sqrt((squares / count + mean * mean)/2)
313
- # range = maximum - minimum
314
- # LOG.warning(f"Negative variance for {name=}, variance={x}")
315
- # LOG.warning(f"square / count - mean * mean = {squares / count} - {mean * mean} = {squares / count - mean * mean}")
316
- # LOG.warning(f"Variable order of magnitude is {order}.")
317
- # LOG.warning(f"Range is {range} ({maximum=} - {minimum=}).")
318
- # LOG.warning(f"Count is {count}.")
319
- # if abs(x) < order * 1e-6 and abs(x) < range * 1e-6:
320
- # LOG.warning(f"Variance is negative but very small, setting to 0.")
321
- # return x*0
322
- # return x
337
+ assert x.shape == minimum.shape
323
338
 
324
339
  for i, name in enumerate(self.variables_names):
325
340
  # remove negative variance due to numerical errors
326
- # Not needed for now, fix_variance is disabled
327
- # x[i] = fix_variance(x[i:i+1], name, minimum[i:i+1], maximum[i:i+1], mean[i:i+1], count[i:i+1], sums[i:i+1], squares[i:i+1])
341
+ x[i] = fix_variance(x[i], name, self.count[i : i + 1], self.sums[i : i + 1], self.squares[i : i + 1])
342
+
343
+ for i, name in enumerate(self.variables_names):
328
344
  check_variance(
329
345
  x[i : i + 1],
330
346
  [name],
@@ -7,6 +7,7 @@
7
7
  # nor does it submit to any jurisdiction.
8
8
  #
9
9
 
10
+ import datetime
10
11
  import os
11
12
  from contextlib import contextmanager
12
13
 
@@ -61,10 +62,10 @@ def make_list_int(value):
61
62
 
62
63
 
63
64
  def normalize_and_check_dates(dates, start, end, frequency, dtype="datetime64[s]"):
64
- assert isinstance(frequency, int), frequency
65
+ assert isinstance(frequency, datetime.timedelta), frequency
65
66
  start = np.datetime64(start)
66
67
  end = np.datetime64(end)
67
- delta = np.timedelta64(frequency, "h")
68
+ delta = np.timedelta64(frequency)
68
69
 
69
70
  res = []
70
71
  while start <= end:
@@ -27,6 +27,7 @@ class MissingDateError(Exception):
27
27
 
28
28
  def open_dataset(*args, **kwargs):
29
29
  ds = _open_dataset(*args, **kwargs)
30
+ ds = ds.mutate()
30
31
  ds.arguments = {"args": args, "kwargs": kwargs}
31
32
  ds._check()
32
33
  return ds
@@ -9,6 +9,7 @@ import logging
9
9
  from functools import cached_property
10
10
 
11
11
  import numpy as np
12
+ from anemoi.utils.dates import frequency_to_timedelta
12
13
 
13
14
  from .debug import Node
14
15
  from .debug import debug_indexing
@@ -102,20 +103,63 @@ class Concat(ConcatMixin, Combined):
102
103
  def tree(self):
103
104
  return Node(self, [d.tree() for d in self.datasets])
104
105
 
106
+ @classmethod
107
+ def check_dataset_compatibility(cls, datasets, fill_missing_gaps=False):
108
+ # Study the dates
109
+ ranges = [(d.dates[0].astype(object), d.dates[-1].astype(object)) for d in datasets]
105
110
 
106
- def concat_factory(args, kwargs, zarr_root):
111
+ # Make sure the dates are disjoint
112
+ for i in range(len(ranges)):
113
+ r = ranges[i]
114
+ for j in range(i + 1, len(ranges)):
115
+ s = ranges[j]
116
+ if r[0] <= s[0] <= r[1] or r[0] <= s[1] <= r[1]:
117
+ raise ValueError(f"Overlapping dates: {r} and {s} ({datasets[i]} {datasets[j]})")
118
+
119
+ # For now we should have the datasets in order with no gaps
120
+
121
+ frequency = frequency_to_timedelta(datasets[0].frequency)
122
+ result = []
123
+
124
+ for i in range(len(ranges) - 1):
125
+ result.append(datasets[i])
126
+ r = ranges[i]
127
+ s = ranges[i + 1]
128
+ if r[1] + frequency != s[0]:
129
+ if fill_missing_gaps:
130
+ from .missing import MissingDataset
131
+
132
+ result.append(MissingDataset(datasets[i], r[1] + frequency, s[0] - frequency))
133
+ else:
134
+ r = [str(e) for e in r]
135
+ s = [str(e) for e in s]
136
+ raise ValueError(
137
+ "Datasets must be sorted by dates, with no gaps: "
138
+ f"{r} and {s} ({datasets[i]} {datasets[i+1]})"
139
+ )
140
+
141
+ result.append(datasets[-1])
142
+ assert len(result) >= len(datasets), (len(result), len(datasets))
143
+
144
+ return result
145
+
146
+
147
+ def concat_factory(args, kwargs):
107
148
 
108
149
  datasets = kwargs.pop("concat")
150
+ fill_missing_gaps = kwargs.pop("fill_missing_gaps", False)
109
151
  assert isinstance(datasets, (list, tuple))
110
152
  assert len(args) == 0
111
153
 
112
154
  assert isinstance(datasets, (list, tuple))
113
155
 
114
- datasets = [_open(e, zarr_root) for e in datasets]
156
+ datasets = [_open(e) for e in datasets]
115
157
 
116
158
  if len(datasets) == 1:
117
159
  return datasets[0]._subset(**kwargs)
118
160
 
119
161
  datasets, kwargs = _auto_adjust(datasets, kwargs)
120
162
 
163
+ datasets = Concat.check_dataset_compatibility(datasets, fill_missing_gaps)
164
+
121
165
  return Concat(datasets)._subset(**kwargs)
@@ -5,24 +5,37 @@
5
5
  # granted to it by virtue of its status as an intergovernmental organisation
6
6
  # nor does it submit to any jurisdiction.
7
7
 
8
+ import datetime
9
+ import json
8
10
  import logging
9
11
  import os
12
+ import pprint
10
13
  import warnings
11
14
  from functools import cached_property
12
15
 
16
+ from anemoi.utils.dates import frequency_to_seconds
17
+ from anemoi.utils.dates import frequency_to_string
18
+ from anemoi.utils.dates import frequency_to_timedelta
19
+
13
20
  LOG = logging.getLogger(__name__)
14
21
 
15
22
 
16
23
  class Dataset:
17
24
  arguments = {}
18
25
 
26
+ def mutate(self):
27
+ return self
28
+
29
+ def swap_with_parent(self, parent):
30
+ return parent
31
+
19
32
  @cached_property
20
33
  def _len(self):
21
34
  return len(self)
22
35
 
23
36
  def _subset(self, **kwargs):
24
37
  if not kwargs:
25
- return self
38
+ return self.mutate()
26
39
 
27
40
  if "start" in kwargs or "end" in kwargs:
28
41
  start = kwargs.pop("start", None)
@@ -30,37 +43,52 @@ class Dataset:
30
43
 
31
44
  from .subset import Subset
32
45
 
33
- return Subset(self, self._dates_to_indices(start, end), dict(start=start, end=end))._subset(**kwargs)
46
+ return (
47
+ Subset(self, self._dates_to_indices(start, end), dict(start=start, end=end))._subset(**kwargs).mutate()
48
+ )
34
49
 
35
50
  if "frequency" in kwargs:
36
51
  from .subset import Subset
37
52
 
53
+ if "interpolate_frequency" in kwargs:
54
+ raise ValueError("Cannot use both `frequency` and `interpolate_frequency`")
55
+
38
56
  frequency = kwargs.pop("frequency")
39
- return Subset(self, self._frequency_to_indices(frequency), dict(frequency=frequency))._subset(**kwargs)
57
+ return (
58
+ Subset(self, self._frequency_to_indices(frequency), dict(frequency=frequency))
59
+ ._subset(**kwargs)
60
+ .mutate()
61
+ )
62
+
63
+ if "interpolate_frequency" in kwargs:
64
+ from .interpolate import InterpolateFrequency
65
+
66
+ interpolate_frequency = kwargs.pop("interpolate_frequency")
67
+ return InterpolateFrequency(self, interpolate_frequency)._subset(**kwargs).mutate()
40
68
 
41
69
  if "select" in kwargs:
42
70
  from .select import Select
43
71
 
44
72
  select = kwargs.pop("select")
45
- return Select(self, self._select_to_columns(select), {"select": select})._subset(**kwargs)
73
+ return Select(self, self._select_to_columns(select), {"select": select})._subset(**kwargs).mutate()
46
74
 
47
75
  if "drop" in kwargs:
48
76
  from .select import Select
49
77
 
50
78
  drop = kwargs.pop("drop")
51
- return Select(self, self._drop_to_columns(drop), {"drop": drop})._subset(**kwargs)
79
+ return Select(self, self._drop_to_columns(drop), {"drop": drop})._subset(**kwargs).mutate()
52
80
 
53
81
  if "reorder" in kwargs:
54
82
  from .select import Select
55
83
 
56
84
  reorder = kwargs.pop("reorder")
57
- return Select(self, self._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs)
85
+ return Select(self, self._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate()
58
86
 
59
87
  if "rename" in kwargs:
60
88
  from .select import Rename
61
89
 
62
90
  rename = kwargs.pop("rename")
63
- return Rename(self, rename)._subset(**kwargs)
91
+ return Rename(self, rename)._subset(**kwargs).mutate()
64
92
 
65
93
  if "statistics" in kwargs:
66
94
  from ..data import open_dataset
@@ -68,20 +96,38 @@ class Dataset:
68
96
 
69
97
  statistics = kwargs.pop("statistics")
70
98
 
71
- return Statistics(self, open_dataset(statistics))._subset(**kwargs)
99
+ return Statistics(self, open_dataset(statistics))._subset(**kwargs).mutate()
72
100
 
73
101
  if "thinning" in kwargs:
74
102
  from .masked import Thinning
75
103
 
76
104
  thinning = kwargs.pop("thinning")
77
105
  method = kwargs.pop("method", "every-nth")
78
- return Thinning(self, thinning, method)._subset(**kwargs)
106
+ return Thinning(self, thinning, method)._subset(**kwargs).mutate()
79
107
 
80
108
  if "area" in kwargs:
81
109
  from .masked import Cropping
82
110
 
83
111
  bbox = kwargs.pop("area")
84
- return Cropping(self, bbox)._subset(**kwargs)
112
+ return Cropping(self, bbox)._subset(**kwargs).mutate()
113
+
114
+ if "missing_dates" in kwargs:
115
+ from .missing import MissingDates
116
+
117
+ missing_dates = kwargs.pop("missing_dates")
118
+ return MissingDates(self, missing_dates)._subset(**kwargs).mutate()
119
+
120
+ if "skip_missing_dates" in kwargs:
121
+ from .missing import SkipMissingDates
122
+
123
+ if "expected_access" not in kwargs:
124
+ raise ValueError("`expected_access` is required with `skip_missing_dates`")
125
+
126
+ skip_missing_dates = kwargs.pop("skip_missing_dates")
127
+ expected_access = kwargs.pop("expected_access")
128
+
129
+ if skip_missing_dates:
130
+ return SkipMissingDates(self, expected_access)._subset(**kwargs).mutate()
85
131
 
86
132
  # Keep last
87
133
  if "shuffle" in kwargs:
@@ -90,15 +136,14 @@ class Dataset:
90
136
  shuffle = kwargs.pop("shuffle")
91
137
 
92
138
  if shuffle:
93
- return Subset(self, self._shuffle_indices(), dict(shuffle=True))._subset(**kwargs)
139
+ return Subset(self, self._shuffle_indices(), dict(shuffle=True))._subset(**kwargs).mutate()
94
140
 
95
141
  raise NotImplementedError("Unsupported arguments: " + ", ".join(kwargs))
96
142
 
97
143
  def _frequency_to_indices(self, frequency):
98
- from .misc import _frequency_to_hours
99
144
 
100
- requested_frequency = _frequency_to_hours(frequency)
101
- dataset_frequency = _frequency_to_hours(self.frequency)
145
+ requested_frequency = frequency_to_seconds(frequency)
146
+ dataset_frequency = frequency_to_seconds(self.frequency)
102
147
  assert requested_frequency % dataset_frequency == 0
103
148
  # Question: where do we start? first date, or first date that is a multiple of the frequency?
104
149
  step = requested_frequency // dataset_frequency
@@ -171,37 +216,71 @@ class Dataset:
171
216
  import anemoi
172
217
 
173
218
  def tidy(v):
174
- if isinstance(v, (list, tuple)):
219
+ if isinstance(v, (list, tuple, set)):
175
220
  return [tidy(i) for i in v]
176
221
  if isinstance(v, dict):
177
222
  return {k: tidy(v) for k, v in v.items()}
178
223
  if isinstance(v, str) and v.startswith("/"):
179
224
  return os.path.basename(v)
225
+ if isinstance(v, datetime.datetime):
226
+ return v.isoformat()
227
+ if isinstance(v, datetime.date):
228
+ return v.isoformat()
229
+ if isinstance(v, datetime.timedelta):
230
+ return frequency_to_string(v)
231
+
232
+ if isinstance(v, Dataset):
233
+ # That can happen in the `arguments`
234
+ # if a dataset is passed as an argument
235
+ return repr(v)
236
+
237
+ if isinstance(v, slice):
238
+ return (v.start, v.stop, v.step)
239
+
180
240
  return v
181
241
 
182
- return tidy(
183
- dict(
184
- version=anemoi.datasets.__version__,
185
- shape=self.shape,
186
- arguments=self.arguments,
187
- specific=self.metadata_specific(),
188
- frequency=self.frequency,
189
- variables=self.variables,
190
- start_date=self.dates[0].astype(str),
191
- end_date=self.dates[-1].astype(str),
192
- )
242
+ md = dict(
243
+ version=anemoi.datasets.__version__,
244
+ arguments=self.arguments,
245
+ **self.dataset_metadata(),
246
+ )
247
+
248
+ try:
249
+ return json.loads(json.dumps(tidy(md)))
250
+ except Exception:
251
+ LOG.exception("Failed to serialize metadata")
252
+ pprint.pprint(md)
253
+
254
+ raise
255
+
256
+ @property
257
+ def start_date(self):
258
+ return self.dates[0]
259
+
260
+ @property
261
+ def end_date(self):
262
+ return self.dates[-1]
263
+
264
+ def dataset_metadata(self):
265
+ return dict(
266
+ specific=self.metadata_specific(),
267
+ frequency=self.frequency,
268
+ variables=self.variables,
269
+ shape=self.shape,
270
+ start_date=self.start_date.astype(str),
271
+ end_date=self.end_date.astype(str),
193
272
  )
194
273
 
195
274
  def metadata_specific(self, **kwargs):
196
275
  action = self.__class__.__name__.lower()
197
- assert isinstance(self.frequency, int), (self.frequency, self, action)
276
+ # assert isinstance(self.frequency, datetime.timedelta), (self.frequency, self, action)
198
277
  return dict(
199
278
  action=action,
200
279
  variables=self.variables,
201
280
  shape=self.shape,
202
- frequency=self.frequency,
203
- start_date=self.dates[0].astype(str),
204
- end_date=self.dates[-1].astype(str),
281
+ frequency=frequency_to_string(frequency_to_timedelta(self.frequency)),
282
+ start_date=self.start_date.astype(str),
283
+ end_date=self.end_date.astype(str),
205
284
  **kwargs,
206
285
  )
207
286
 
@@ -220,10 +299,6 @@ class Dataset:
220
299
  if n.startswith("_") and not n.startswith("__"):
221
300
  warnings.warn(f"Private method {n} is overriden in {ds.__class__.__name__}")
222
301
 
223
- # for n in ('metadata_specific', 'tree', 'source'):
224
- # if n not in overriden:
225
- # warnings.warn(f"Method {n} is not overriden in {ds.__class__.__name__}")
226
-
227
302
  def _repr_html_(self):
228
303
  return self.tree().html()
229
304