anemoi-datasets 0.4.4__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/cleanup.py +44 -0
  3. anemoi/datasets/commands/create.py +50 -20
  4. anemoi/datasets/commands/finalise-additions.py +45 -0
  5. anemoi/datasets/commands/finalise.py +39 -0
  6. anemoi/datasets/commands/init-additions.py +45 -0
  7. anemoi/datasets/commands/init.py +67 -0
  8. anemoi/datasets/commands/inspect.py +1 -1
  9. anemoi/datasets/commands/load-additions.py +47 -0
  10. anemoi/datasets/commands/load.py +47 -0
  11. anemoi/datasets/commands/patch.py +39 -0
  12. anemoi/datasets/create/__init__.py +961 -146
  13. anemoi/datasets/create/check.py +5 -3
  14. anemoi/datasets/create/config.py +53 -2
  15. anemoi/datasets/create/functions/sources/xarray/__init__.py +12 -2
  16. anemoi/datasets/create/functions/sources/xarray/coordinates.py +7 -0
  17. anemoi/datasets/create/functions/sources/xarray/field.py +1 -1
  18. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +0 -2
  19. anemoi/datasets/create/functions/sources/xarray/flavour.py +21 -1
  20. anemoi/datasets/create/functions/sources/xarray/metadata.py +27 -29
  21. anemoi/datasets/create/functions/sources/xarray/time.py +63 -30
  22. anemoi/datasets/create/functions/sources/xarray/variable.py +15 -38
  23. anemoi/datasets/create/input.py +23 -22
  24. anemoi/datasets/create/statistics/__init__.py +39 -23
  25. anemoi/datasets/create/utils.py +3 -2
  26. anemoi/datasets/data/__init__.py +1 -0
  27. anemoi/datasets/data/concat.py +46 -2
  28. anemoi/datasets/data/dataset.py +109 -34
  29. anemoi/datasets/data/forwards.py +17 -8
  30. anemoi/datasets/data/grids.py +17 -3
  31. anemoi/datasets/data/interpolate.py +133 -0
  32. anemoi/datasets/data/misc.py +56 -66
  33. anemoi/datasets/data/missing.py +240 -0
  34. anemoi/datasets/data/select.py +7 -1
  35. anemoi/datasets/data/stores.py +3 -3
  36. anemoi/datasets/data/subset.py +47 -5
  37. anemoi/datasets/data/unchecked.py +20 -22
  38. anemoi/datasets/data/xy.py +125 -0
  39. anemoi/datasets/dates/__init__.py +13 -66
  40. anemoi/datasets/dates/groups.py +2 -2
  41. anemoi/datasets/grids.py +66 -48
  42. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/METADATA +5 -5
  43. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/RECORD +47 -37
  44. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/WHEEL +1 -1
  45. anemoi/datasets/create/loaders.py +0 -936
  46. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/LICENSE +0 -0
  47. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/entry_points.txt +0 -0
  48. {anemoi_datasets-0.4.4.dist-info → anemoi_datasets-0.4.5.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,6 @@
8
8
  import calendar
9
9
  import datetime
10
10
  import logging
11
- import re
12
11
  from pathlib import PurePath
13
12
 
14
13
  import numpy as np
@@ -39,26 +38,21 @@ def add_dataset_path(path):
39
38
  config["datasets"]["path"].append(path)
40
39
 
41
40
 
42
- def _frequency_to_hours(frequency):
43
- if isinstance(frequency, int):
44
- return frequency
45
-
46
- if isinstance(frequency, float):
47
- assert int(frequency) == frequency
48
- return int(frequency)
49
-
50
- m = re.match(r"(\d+)([dh])?", frequency)
51
- if m is None:
52
- raise ValueError("Invalid frequency: " + frequency)
53
-
54
- frequency = int(m.group(1))
55
- if m.group(2) == "h":
56
- return frequency
57
-
58
- if m.group(2) == "d":
59
- return frequency * 24
41
+ def round_datetime(d, dates, up):
42
+ """Round up (or down) a datetime to the nearest date in a list of dates"""
43
+ if dates is None or len(dates) == 0:
44
+ return d
60
45
 
61
- raise NotImplementedError()
46
+ for i, date in enumerate(dates):
47
+ if date == d:
48
+ return date
49
+ if date > d:
50
+ if up:
51
+ return date
52
+ if i > 0:
53
+ return dates[i - 1]
54
+ return date
55
+ return dates[-1]
62
56
 
63
57
 
64
58
  def _as_date(d, dates, last):
@@ -67,7 +61,8 @@ def _as_date(d, dates, last):
67
61
  # so we need to check for datetime.datetime first
68
62
 
69
63
  if isinstance(d, (np.datetime64, datetime.datetime)):
70
- return d
64
+ d = round_datetime(d, dates, up=not last)
65
+ return np.datetime64(d)
71
66
 
72
67
  if isinstance(d, datetime.date):
73
68
  d = d.year * 10_000 + d.month * 100 + d.day
@@ -81,27 +76,27 @@ def _as_date(d, dates, last):
81
76
  if len(str(d)) == 4:
82
77
  year = d
83
78
  if last:
84
- return np.datetime64(f"{year:04}-12-31T23:59:59")
79
+ return _as_date(np.datetime64(f"{year:04}-12-31T23:59:59"), dates, last)
85
80
  else:
86
- return np.datetime64(f"{year:04}-01-01T00:00:00")
81
+ return _as_date(np.datetime64(f"{year:04}-01-01T00:00:00"), dates, last)
87
82
 
88
83
  if len(str(d)) == 6:
89
84
  year = d // 100
90
85
  month = d % 100
91
86
  if last:
92
87
  _, last_day = calendar.monthrange(year, month)
93
- return np.datetime64(f"{year:04}-{month:02}-{last_day:02}T23:59:59")
88
+ return _as_date(np.datetime64(f"{year:04}-{month:02}-{last_day:02}T23:59:59"), dates, last)
94
89
  else:
95
- return np.datetime64(f"{year:04}-{month:02}-01T00:00:00")
90
+ return _as_date(np.datetime64(f"{year:04}-{month:02}-01T00:00:00"), dates, last)
96
91
 
97
92
  if len(str(d)) == 8:
98
93
  year = d // 10000
99
94
  month = (d % 10000) // 100
100
95
  day = d % 100
101
96
  if last:
102
- return np.datetime64(f"{year:04}-{month:02}-{day:02}T23:59:59")
97
+ return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T23:59:59"), dates, last)
103
98
  else:
104
- return np.datetime64(f"{year:04}-{month:02}-{day:02}T00:00:00")
99
+ return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T00:00:00"), dates, last)
105
100
 
106
101
  if isinstance(d, str):
107
102
 
@@ -109,7 +104,11 @@ def _as_date(d, dates, last):
109
104
  date, time = d.replace(" ", "T").split("T")
110
105
  year, month, day = [int(_) for _ in date.split("-")]
111
106
  hour, minute, second = [int(_) for _ in time.split(":")]
112
- return np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour:02}:{minute:02}:{second:02}")
107
+ return _as_date(
108
+ np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour:02}:{minute:02}:{second:02}"),
109
+ dates,
110
+ last,
111
+ )
113
112
 
114
113
  if "-" in d:
115
114
  assert ":" not in d
@@ -121,11 +120,8 @@ def _as_date(d, dates, last):
121
120
  return _as_date(int(bits[0]) * 100 + int(bits[1]), dates, last)
122
121
 
123
122
  if len(bits) == 3:
124
- return _as_date(
125
- int(bits[0]) * 10000 + int(bits[1]) * 100 + int(bits[2]),
126
- dates,
127
- last,
128
- )
123
+ return _as_date(int(bits[0]) * 10000 + int(bits[1]) * 100 + int(bits[2]), dates, last)
124
+
129
125
  if ":" in d:
130
126
  assert len(d) == 5
131
127
  hour, minute = d.split(":")
@@ -136,7 +132,7 @@ def _as_date(d, dates, last):
136
132
  month = first.month
137
133
  day = first.day
138
134
 
139
- return np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour}:00:00")
135
+ return _as_date(np.datetime64(f"{year:04}-{month:02}-{day:02}T{hour}:00:00"), dates, last)
140
136
 
141
137
  raise NotImplementedError(f"Unsupported date: {d} ({type(d)})")
142
138
 
@@ -163,28 +159,10 @@ def _concat_or_join(datasets, kwargs):
163
159
 
164
160
  return Join(datasets)._overlay(), kwargs
165
161
 
166
- # Make sure the dates are disjoint
167
- for i in range(len(ranges)):
168
- r = ranges[i]
169
- for j in range(i + 1, len(ranges)):
170
- s = ranges[j]
171
- if r[0] <= s[0] <= r[1] or r[0] <= s[1] <= r[1]:
172
- raise ValueError(f"Overlapping dates: {r} and {s} ({datasets[i]} {datasets[j]})")
173
-
174
- # For now we should have the datasets in order with no gaps
175
-
176
- frequency = _frequency_to_hours(datasets[0].frequency)
177
-
178
- for i in range(len(ranges) - 1):
179
- r = ranges[i]
180
- s = ranges[i + 1]
181
- if r[1] + datetime.timedelta(hours=frequency) != s[0]:
182
- raise ValueError(
183
- "Datasets must be sorted by dates, with no gaps: " f"{r} and {s} ({datasets[i]} {datasets[i+1]})"
184
- )
185
-
186
162
  from .concat import Concat
187
163
 
164
+ Concat.check_dataset_compatibility(datasets)
165
+
188
166
  return Concat(datasets), kwargs
189
167
 
190
168
 
@@ -193,7 +171,7 @@ def _open(a):
193
171
  from .stores import zarr_lookup
194
172
 
195
173
  if isinstance(a, Dataset):
196
- return a
174
+ return a.mutate()
197
175
 
198
176
  if isinstance(a, zarr.hierarchy.Group):
199
177
  return Zarr(a).mutate()
@@ -202,13 +180,13 @@ def _open(a):
202
180
  return Zarr(zarr_lookup(a)).mutate()
203
181
 
204
182
  if isinstance(a, PurePath):
205
- return _open(str(a))
183
+ return _open(str(a)).mutate()
206
184
 
207
185
  if isinstance(a, dict):
208
- return _open_dataset(**a)
186
+ return _open_dataset(**a).mutate()
209
187
 
210
188
  if isinstance(a, (list, tuple)):
211
- return _open_dataset(*a)
189
+ return _open_dataset(*a).mutate()
212
190
 
213
191
  raise NotImplementedError(f"Unsupported argument: {type(a)}")
214
192
 
@@ -288,47 +266,59 @@ def _open_dataset(*args, **kwargs):
288
266
  for a in args:
289
267
  sets.append(_open(a))
290
268
 
269
+ if "xy" in kwargs:
270
+ from .xy import xy_factory
271
+
272
+ assert not sets, sets
273
+ return xy_factory(args, kwargs).mutate()
274
+
275
+ if "x" in kwargs and "y" in kwargs:
276
+ from .xy import xy_factory
277
+
278
+ assert not sets, sets
279
+ return xy_factory(args, kwargs).mutate()
280
+
291
281
  if "zip" in kwargs:
292
- from .unchecked import zip_factory
282
+ from .xy import zip_factory
293
283
 
294
284
  assert not sets, sets
295
- return zip_factory(args, kwargs)
285
+ return zip_factory(args, kwargs).mutate()
296
286
 
297
287
  if "chain" in kwargs:
298
288
  from .unchecked import chain_factory
299
289
 
300
290
  assert not sets, sets
301
- return chain_factory(args, kwargs)
291
+ return chain_factory(args, kwargs).mutate()
302
292
 
303
293
  if "join" in kwargs:
304
294
  from .join import join_factory
305
295
 
306
296
  assert not sets, sets
307
- return join_factory(args, kwargs)
297
+ return join_factory(args, kwargs).mutate()
308
298
 
309
299
  if "concat" in kwargs:
310
300
  from .concat import concat_factory
311
301
 
312
302
  assert not sets, sets
313
- return concat_factory(args, kwargs)
303
+ return concat_factory(args, kwargs).mutate()
314
304
 
315
305
  if "ensemble" in kwargs:
316
306
  from .ensemble import ensemble_factory
317
307
 
318
308
  assert not sets, sets
319
- return ensemble_factory(args, kwargs)
309
+ return ensemble_factory(args, kwargs).mutate()
320
310
 
321
311
  if "grids" in kwargs:
322
312
  from .grids import grids_factory
323
313
 
324
314
  assert not sets, sets
325
- return grids_factory(args, kwargs)
315
+ return grids_factory(args, kwargs).mutate()
326
316
 
327
317
  if "cutout" in kwargs:
328
318
  from .grids import cutout_factory
329
319
 
330
320
  assert not sets, sets
331
- return cutout_factory(args, kwargs)
321
+ return cutout_factory(args, kwargs).mutate()
332
322
 
333
323
  for name in ("datasets", "dataset"):
334
324
  if name in kwargs:
@@ -0,0 +1,240 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ import logging
9
+ from functools import cached_property
10
+
11
+ import numpy as np
12
+
13
+ from anemoi.datasets.create.utils import to_datetime
14
+ from anemoi.datasets.data import MissingDateError
15
+
16
+ from .debug import Node
17
+ from .debug import debug_indexing
18
+ from .forwards import Forwards
19
+ from .indexing import expand_list_indexing
20
+ from .indexing import update_tuple
21
+
22
+ LOG = logging.getLogger(__name__)
23
+
24
+
25
+ class MissingDates(Forwards):
26
+ # TODO: Use that class instead of ZarrMissing
27
+
28
+ def __init__(self, dataset, missing_dates):
29
+ super().__init__(dataset)
30
+ self.missing_dates = []
31
+
32
+ self._missing = set()
33
+
34
+ other = []
35
+ for date in missing_dates:
36
+ if isinstance(date, int):
37
+ self._missing.add(date)
38
+ self.missing_dates.append(dataset.dates[date])
39
+ else:
40
+ date = to_datetime(date)
41
+ other.append(date)
42
+
43
+ if other:
44
+ for i, date in enumerate(dataset.dates):
45
+ if date in other:
46
+ self._missing.add(i)
47
+ self.missing_dates.append(date)
48
+
49
+ n = self.forward._len
50
+ self._missing = set(i for i in self._missing if 0 <= i < n)
51
+ self.missing_dates = sorted(to_datetime(x) for x in self.missing_dates)
52
+
53
+ assert len(self._missing), "No dates to force missing"
54
+
55
+ @cached_property
56
+ def missing(self):
57
+ return self._missing.union(self.forward.missing)
58
+
59
+ @debug_indexing
60
+ @expand_list_indexing
61
+ def __getitem__(self, n):
62
+ if isinstance(n, int):
63
+ if n in self.missing:
64
+ self._report_missing(n)
65
+ return self.forward[n]
66
+
67
+ if isinstance(n, slice):
68
+ common = set(range(*n.indices(len(self)))) & self.missing
69
+ if common:
70
+ self._report_missing(list(common)[0])
71
+ return self.forward[n]
72
+
73
+ if isinstance(n, tuple):
74
+ first = n[0]
75
+ if isinstance(first, int):
76
+ if first in self.missing:
77
+ self._report_missing(first)
78
+ return self.forward[n]
79
+
80
+ if isinstance(first, slice):
81
+ common = set(range(*first.indices(len(self)))) & self.missing
82
+ if common:
83
+ self._report_missing(list(common)[0])
84
+ return self.forward[n]
85
+
86
+ if isinstance(first, (list, tuple)):
87
+ common = set(first) & self.missing
88
+ if common:
89
+ self._report_missing(list(common)[0])
90
+ return self.forward[n]
91
+
92
+ raise TypeError(f"Unsupported index {n} {type(n)}")
93
+
94
+ def _report_missing(self, n):
95
+ raise MissingDateError(f"Date {self.forward.dates[n]} is missing (index={n})")
96
+
97
+ @property
98
+ def reason(self):
99
+ return {"missing_dates": self.missing_dates}
100
+
101
+ def tree(self):
102
+ return Node(self, [self.forward.tree()], **self.reason)
103
+
104
+ def subclass_metadata_specific(self):
105
+ return {"missing_dates": self.missing_dates}
106
+
107
+
108
+ class SkipMissingDates(Forwards):
109
+
110
+ def __init__(self, dataset, expected_access):
111
+ super().__init__(dataset)
112
+
113
+ # if isinstance(expected_access, (tuple, list)):
114
+ # expected_access = slice(*expected_access)
115
+
116
+ if isinstance(expected_access, int):
117
+ expected_access = slice(0, expected_access)
118
+
119
+ assert isinstance(expected_access, slice), f"Expected access must be a slice, got {expected_access}"
120
+
121
+ expected_access = slice(*expected_access.indices(dataset._len))
122
+ missing = dataset.missing.copy()
123
+
124
+ size = (expected_access.stop - expected_access.start) // expected_access.step
125
+ indices = []
126
+
127
+ for i in range(dataset._len):
128
+ s = slice(expected_access.start + i, expected_access.stop + i, expected_access.step)
129
+ p = set(range(*s.indices(dataset._len)))
130
+ if p.intersection(missing):
131
+ continue
132
+
133
+ if len(p) != size:
134
+ continue
135
+
136
+ indices.append(tuple(sorted(p)))
137
+
138
+ self.expected_access = expected_access
139
+ self.indices = indices
140
+
141
+ def __len__(self):
142
+ return len(self.indices)
143
+
144
+ @property
145
+ def start_date(self):
146
+ return self.forward.start_date
147
+
148
+ @property
149
+ def end_date(self):
150
+ return self.forward.end_date
151
+
152
+ @property
153
+ def dates(self):
154
+ raise NotImplementedError("SkipMissingDates.dates")
155
+
156
+ @debug_indexing
157
+ @expand_list_indexing
158
+ def _get_tuple(self, index):
159
+
160
+ def _get_one(n):
161
+ result = []
162
+ for i in self.indices[n]:
163
+ s, _ = update_tuple(index, 0, i)
164
+ result.append(self.forward[s])
165
+
166
+ return tuple(result)
167
+
168
+ first = index[0]
169
+ if isinstance(first, int):
170
+ return _get_one(first)
171
+
172
+ assert isinstance(first, slice), f"SkipMissingDates._get_tuple {index}"
173
+
174
+ values = [_get_one(i) for i in range(*first.indices(self._len))]
175
+
176
+ result = [_ for _ in zip(*values)]
177
+ return tuple(np.stack(_) for _ in result)
178
+
179
+ @debug_indexing
180
+ def _get_slice(self, s):
181
+ values = [self[i] for i in range(*s.indices(self._len))]
182
+ result = [_ for _ in zip(*values)]
183
+ return tuple(np.stack(_) for _ in result)
184
+
185
+ @debug_indexing
186
+ def __getitem__(self, n):
187
+ if isinstance(n, tuple):
188
+ return self._get_tuple(n)
189
+
190
+ if isinstance(n, slice):
191
+ return self._get_slice(n)
192
+
193
+ return tuple(self.forward[i] for i in self.indices[n])
194
+
195
+ @property
196
+ def frequency(self):
197
+ return self.forward.frequency
198
+
199
+ def tree(self):
200
+ return Node(self, [self.forward.tree()], expected_access=self.expected_access)
201
+
202
+ def subclass_metadata_specific(self):
203
+ return {"expected_access": self.expected_access}
204
+
205
+
206
+ class MissingDataset(Forwards):
207
+
208
+ def __init__(self, dataset, start, end):
209
+ super().__init__(dataset)
210
+ self.start = start
211
+ self.end = end
212
+
213
+ dates = []
214
+ date = start
215
+ while date <= end:
216
+ dates.append(date)
217
+ date += dataset.frequency
218
+
219
+ self._dates = np.array(dates, dtype="datetime64")
220
+ self._missing = set(range(len(dates)))
221
+
222
+ def __len__(self):
223
+ return len(self._dates)
224
+
225
+ @property
226
+ def dates(self):
227
+ return self._dates
228
+
229
+ @property
230
+ def missing(self):
231
+ return self._missing
232
+
233
+ def __getitem__(self, n):
234
+ raise MissingDateError(f"Date {self.dates[n]} is missing (index={n})")
235
+
236
+ def tree(self):
237
+ return Node(self, [self.forward.tree()], start=self.start, end=self.end)
238
+
239
+ def subclass_metadata_specific(self):
240
+ return {"start": self.start, "end": self.end}
@@ -40,6 +40,12 @@ class Select(Forwards):
40
40
  # Forward other properties to the main dataset
41
41
  super().__init__(dataset)
42
42
 
43
+ def clone(self, dataset):
44
+ return self.__class__(dataset, self.indices, self.reason).mutate()
45
+
46
+ def mutate(self):
47
+ return self.forward.swap_with_parent(parent=self)
48
+
43
49
  @debug_indexing
44
50
  @expand_list_indexing
45
51
  def _get_tuple(self, index):
@@ -101,7 +107,7 @@ class Rename(Forwards):
101
107
  def __init__(self, dataset, rename):
102
108
  super().__init__(dataset)
103
109
  for n in rename:
104
- assert n in dataset.variables
110
+ assert n in dataset.variables, n
105
111
  self._variables = [rename.get(v, v) for v in dataset.variables]
106
112
  self.rename = rename
107
113
 
@@ -13,6 +13,7 @@ from urllib.parse import urlparse
13
13
 
14
14
  import numpy as np
15
15
  import zarr
16
+ from anemoi.utils.dates import frequency_to_timedelta
16
17
 
17
18
  from . import MissingDateError
18
19
  from .dataset import Dataset
@@ -268,12 +269,11 @@ class Zarr(Dataset):
268
269
  @property
269
270
  def frequency(self):
270
271
  try:
271
- return self.z.attrs["frequency"]
272
+ return frequency_to_timedelta(self.z.attrs["frequency"])
272
273
  except KeyError:
273
274
  LOG.warning("No 'frequency' in %r, computing from 'dates'", self)
274
275
  dates = self.dates
275
- delta = dates[1].astype(object) - dates[0].astype(object)
276
- return int(delta.total_seconds() / 3600)
276
+ return dates[1].astype(object) - dates[0].astype(object)
277
277
 
278
278
  @property
279
279
  def name_to_index(self):
@@ -9,6 +9,7 @@ import logging
9
9
  from functools import cached_property
10
10
 
11
11
  import numpy as np
12
+ from anemoi.utils.dates import frequency_to_timedelta
12
13
 
13
14
  from .debug import Node
14
15
  from .debug import Source
@@ -23,13 +24,51 @@ from .indexing import update_tuple
23
24
  LOG = logging.getLogger(__name__)
24
25
 
25
26
 
27
+ def _default(a, b, dates):
28
+ return [a, b]
29
+
30
+
31
+ def _start(a, b, dates):
32
+ from .misc import as_first_date
33
+
34
+ c = as_first_date(a, dates)
35
+ d = as_first_date(b, dates)
36
+ if c < d:
37
+ return b
38
+ else:
39
+ return a
40
+
41
+
42
+ def _end(a, b, dates):
43
+ from .misc import as_last_date
44
+
45
+ c = as_last_date(a, dates)
46
+ d = as_last_date(b, dates)
47
+ if c < d:
48
+ return a
49
+ else:
50
+ return b
51
+
52
+
53
+ def _combine_reasons(reason1, reason2, dates):
54
+
55
+ reason = reason1.copy()
56
+ for k, v in reason2.items():
57
+ if k not in reason:
58
+ reason[k] = v
59
+ else:
60
+ func = globals().get(f"_{k}", _default)
61
+ reason[k] = func(reason[k], v, dates)
62
+ return reason
63
+
64
+
26
65
  class Subset(Forwards):
27
66
  """Select a subset of the dates."""
28
67
 
29
68
  def __init__(self, dataset, indices, reason):
30
69
  while isinstance(dataset, Subset):
31
70
  indices = [dataset.indices[i] for i in indices]
32
- reason = {**reason, **dataset.reason}
71
+ reason = _combine_reasons(reason, dataset.reason, dataset.dates)
33
72
  dataset = dataset.dataset
34
73
 
35
74
  self.dataset = dataset
@@ -39,6 +78,12 @@ class Subset(Forwards):
39
78
  # Forward other properties to the super dataset
40
79
  super().__init__(dataset)
41
80
 
81
+ def clone(self, dataset):
82
+ return self.__class__(dataset, self.indices, self.reason).mutate()
83
+
84
+ def mutate(self):
85
+ return self.forward.swap_with_parent(parent=self)
86
+
42
87
  @debug_indexing
43
88
  def __getitem__(self, n):
44
89
  if isinstance(n, tuple):
@@ -66,10 +111,8 @@ class Subset(Forwards):
66
111
  @expand_list_indexing
67
112
  def _get_tuple(self, n):
68
113
  index, changes = index_to_slices(n, self.shape)
69
- # print('INDEX', index, changes)
70
114
  indices = [self.indices[i] for i in range(*index[0].indices(self._len))]
71
115
  indices = make_slice_or_index_from_list_or_tuple(indices)
72
- # print('INDICES', indices)
73
116
  index, _ = update_tuple(index, 0, indices)
74
117
  result = self.dataset[index]
75
118
  result = apply_index_to_slices_changes(result, changes)
@@ -89,8 +132,7 @@ class Subset(Forwards):
89
132
  @cached_property
90
133
  def frequency(self):
91
134
  dates = self.dates
92
- delta = dates[1].astype(object) - dates[0].astype(object)
93
- return int(delta.total_seconds() / 3600)
135
+ return frequency_to_timedelta(dates[1].astype(object) - dates[0].astype(object))
94
136
 
95
137
  def source(self, index):
96
138
  return Source(self, index, self.forward.source(index))
@@ -104,22 +104,29 @@ class Unchecked(Combined):
104
104
  def shape(self):
105
105
  raise NotImplementedError()
106
106
 
107
- @property
108
- def dtype(self):
109
- raise NotImplementedError()
107
+ # @property
108
+ # def field_shape(self):
109
+ # return tuple(d.shape for d in self.datasets)
110
110
 
111
- @property
112
- def grids(self):
113
- raise NotImplementedError()
111
+ # @property
112
+ # def latitudes(self):
113
+ # return tuple(d.latitudes for d in self.datasets)
114
114
 
115
+ # @property
116
+ # def longitudes(self):
117
+ # return tuple(d.longitudes for d in self.datasets)
115
118
 
116
- class Zip(Unchecked):
119
+ # @property
120
+ # def statistics(self):
121
+ # return tuple(d.statistics for d in self.datasets)
117
122
 
118
- def __len__(self):
119
- return min(len(d) for d in self.datasets)
123
+ # @property
124
+ # def resolution(self):
125
+ # return tuple(d.resolution for d in self.datasets)
120
126
 
121
- def __getitem__(self, n):
122
- return tuple(d[n] for d in self.datasets)
127
+ # @property
128
+ # def name_to_index(self):
129
+ # return tuple(d.name_to_index for d in self.datasets)
123
130
 
124
131
  @cached_property
125
132
  def missing(self):
@@ -142,17 +149,8 @@ class Chain(ConcatMixin, Unchecked):
142
149
  def dates(self):
143
150
  raise NotImplementedError()
144
151
 
145
-
146
- def zip_factory(args, kwargs):
147
-
148
- zip = kwargs.pop("zip")
149
- assert len(args) == 0
150
- assert isinstance(zip, (list, tuple))
151
-
152
- datasets = [_open(e) for e in zip]
153
- datasets, kwargs = _auto_adjust(datasets, kwargs)
154
-
155
- return Zip(datasets)._subset(**kwargs)
152
+ def dataset_metadata(self):
153
+ return {"multiple": [d.dataset_metadata() for d in self.datasets]}
156
154
 
157
155
 
158
156
  def chain_factory(args, kwargs):