anemoi-datasets 0.4.5__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- anemoi/datasets/_version.py +2 -2
- anemoi/datasets/commands/create.py +3 -2
- anemoi/datasets/create/__init__.py +30 -32
- anemoi/datasets/create/config.py +4 -3
- anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py +57 -0
- anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py +57 -0
- anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py +54 -0
- anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py +59 -0
- anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py +115 -0
- anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py +390 -0
- anemoi/datasets/create/functions/filters/speeddir_to_uv.py +77 -0
- anemoi/datasets/create/functions/filters/uv_to_speeddir.py +55 -0
- anemoi/datasets/create/functions/sources/grib.py +86 -1
- anemoi/datasets/create/functions/sources/hindcasts.py +14 -73
- anemoi/datasets/create/functions/sources/mars.py +9 -3
- anemoi/datasets/create/functions/sources/xarray/field.py +7 -1
- anemoi/datasets/create/functions/sources/xarray/metadata.py +13 -11
- anemoi/datasets/create/input.py +39 -17
- anemoi/datasets/create/persistent.py +1 -1
- anemoi/datasets/create/utils.py +3 -0
- anemoi/datasets/data/dataset.py +11 -1
- anemoi/datasets/data/debug.py +5 -1
- anemoi/datasets/data/masked.py +2 -2
- anemoi/datasets/data/rescale.py +147 -0
- anemoi/datasets/data/stores.py +20 -7
- anemoi/datasets/dates/__init__.py +112 -30
- anemoi/datasets/dates/groups.py +84 -19
- {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.0.dist-info}/METADATA +10 -19
- {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.0.dist-info}/RECORD +33 -24
- {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.0.dist-info}/WHEEL +1 -1
- {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.0.dist-info}/LICENSE +0 -0
- {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.0.dist-info}/entry_points.txt +0 -0
- {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -6,7 +6,6 @@
|
|
|
6
6
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
7
7
|
# nor does it submit to any jurisdiction.
|
|
8
8
|
#
|
|
9
|
-
import datetime
|
|
10
9
|
import logging
|
|
11
10
|
|
|
12
11
|
from earthkit.data.core.fieldlist import MultiFieldList
|
|
@@ -14,7 +13,6 @@ from earthkit.data.core.fieldlist import MultiFieldList
|
|
|
14
13
|
from anemoi.datasets.create.functions.sources.mars import mars
|
|
15
14
|
|
|
16
15
|
LOGGER = logging.getLogger(__name__)
|
|
17
|
-
DEBUG = True
|
|
18
16
|
|
|
19
17
|
|
|
20
18
|
def _to_list(x):
|
|
@@ -23,91 +21,34 @@ def _to_list(x):
|
|
|
23
21
|
return [x]
|
|
24
22
|
|
|
25
23
|
|
|
26
|
-
|
|
27
|
-
def __init__(self, base_times, available_steps, request):
|
|
28
|
-
self.base_times = base_times
|
|
29
|
-
self.available_steps = available_steps
|
|
30
|
-
self.request = request
|
|
31
|
-
|
|
32
|
-
def compute_hindcast(self, date):
|
|
33
|
-
result = []
|
|
34
|
-
for step in sorted(self.available_steps): # Use the shortest step
|
|
35
|
-
start_date = date - datetime.timedelta(hours=step)
|
|
36
|
-
hours = start_date.hour
|
|
37
|
-
if hours in self.base_times:
|
|
38
|
-
r = self.request.copy()
|
|
39
|
-
r["date"] = start_date
|
|
40
|
-
r["time"] = f"{start_date.hour:02d}00"
|
|
41
|
-
r["step"] = step
|
|
42
|
-
result.append(r)
|
|
43
|
-
|
|
44
|
-
if not result:
|
|
45
|
-
raise ValueError(
|
|
46
|
-
f"Cannot find data for {self.request} for {date} (base_times={self.base_times}, "
|
|
47
|
-
f"available_steps={self.available_steps})"
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
if len(result) > 1:
|
|
51
|
-
raise ValueError(
|
|
52
|
-
f"Multiple requests for {self.request} for {date} (base_times={self.base_times}, "
|
|
53
|
-
f"available_steps={self.available_steps})"
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
return result[0]
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def use_reference_year(reference_year, request):
|
|
60
|
-
request = request.copy()
|
|
61
|
-
hdate = request.pop("date")
|
|
62
|
-
|
|
63
|
-
if hdate.year >= reference_year:
|
|
64
|
-
return None, False
|
|
24
|
+
def hindcasts(context, dates, **request):
|
|
65
25
|
|
|
66
|
-
|
|
67
|
-
date = datetime.datetime(reference_year, hdate.month, hdate.day)
|
|
68
|
-
except ValueError:
|
|
69
|
-
if hdate.month == 2 and hdate.day == 29:
|
|
70
|
-
return None, False
|
|
71
|
-
raise
|
|
26
|
+
from anemoi.datasets.dates import HindcastsDates
|
|
72
27
|
|
|
73
|
-
|
|
74
|
-
|
|
28
|
+
provider = context.dates_provider
|
|
29
|
+
assert isinstance(provider, HindcastsDates)
|
|
75
30
|
|
|
31
|
+
context.trace("H️", f"hindcasts {len(dates)=}")
|
|
76
32
|
|
|
77
|
-
def hindcasts(context, dates, **request):
|
|
78
33
|
request["param"] = _to_list(request["param"])
|
|
79
|
-
request["step"] = _to_list(request
|
|
34
|
+
request["step"] = _to_list(request.get("step", 0))
|
|
80
35
|
request["step"] = [int(_) for _ in request["step"]]
|
|
81
36
|
|
|
82
|
-
|
|
83
|
-
request["base_times"] = [0]
|
|
84
|
-
|
|
85
|
-
available_steps = request.pop("step")
|
|
86
|
-
available_steps = _to_list(available_steps)
|
|
87
|
-
|
|
88
|
-
base_times = request.pop("base_times")
|
|
89
|
-
|
|
90
|
-
reference_year = request.pop("reference_year")
|
|
37
|
+
context.trace("H️", f"hindcast {request}")
|
|
91
38
|
|
|
92
|
-
context.trace("H️", f"hindcast {request} {base_times} {available_steps} {reference_year}")
|
|
93
|
-
|
|
94
|
-
c = HindcastCompute(base_times, available_steps, request)
|
|
95
39
|
requests = []
|
|
96
40
|
for d in dates:
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
41
|
+
r = request.copy()
|
|
42
|
+
hindcast = provider.mapping[d]
|
|
43
|
+
r["hdate"] = hindcast.hdate.strftime("%Y-%m-%d")
|
|
44
|
+
r["date"] = hindcast.refdate.strftime("%Y-%m-%d")
|
|
45
|
+
r["time"] = hindcast.refdate.strftime("%H")
|
|
46
|
+
r["step"] = hindcast.step
|
|
47
|
+
requests.append(r)
|
|
104
48
|
|
|
105
49
|
if len(requests) == 0:
|
|
106
|
-
# print("HINDCASTS no requests")
|
|
107
50
|
return MultiFieldList([])
|
|
108
51
|
|
|
109
|
-
# print("HINDCASTS requests", requests)
|
|
110
|
-
|
|
111
52
|
return mars(
|
|
112
53
|
context,
|
|
113
54
|
dates,
|
|
@@ -203,16 +203,22 @@ def mars(context, dates, *requests, request_already_using_valid_datetime=False,
|
|
|
203
203
|
request_already_using_valid_datetime=request_already_using_valid_datetime,
|
|
204
204
|
date_key=date_key,
|
|
205
205
|
)
|
|
206
|
+
|
|
207
|
+
requests = list(requests)
|
|
208
|
+
|
|
206
209
|
ds = from_source("empty")
|
|
210
|
+
context.trace("✅", f"{[str(d) for d in dates]}")
|
|
211
|
+
context.trace("✅", f"Will run {len(requests)} requests")
|
|
212
|
+
for r in requests:
|
|
213
|
+
r = {k: v for k, v in r.items() if v != ("-",)}
|
|
214
|
+
context.trace("✅", f"mars {r}")
|
|
215
|
+
|
|
207
216
|
for r in requests:
|
|
208
217
|
r = {k: v for k, v in r.items() if v != ("-",)}
|
|
209
218
|
|
|
210
219
|
if context.use_grib_paramid and "param" in r:
|
|
211
220
|
r = use_grib_paramid(r)
|
|
212
221
|
|
|
213
|
-
if DEBUG:
|
|
214
|
-
context.trace("✅", f"from_source(mars, {r}")
|
|
215
|
-
|
|
216
222
|
for k, v in r.items():
|
|
217
223
|
if k not in MARS_KEYS:
|
|
218
224
|
raise ValueError(
|
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
# nor does it submit to any jurisdiction.
|
|
8
8
|
#
|
|
9
9
|
|
|
10
|
+
import datetime
|
|
10
11
|
import logging
|
|
11
12
|
|
|
12
13
|
from earthkit.data.core.fieldlist import Field
|
|
@@ -103,7 +104,12 @@ class XArrayField(Field):
|
|
|
103
104
|
|
|
104
105
|
@property
|
|
105
106
|
def forecast_reference_time(self):
|
|
106
|
-
|
|
107
|
+
date, time = self.metadata("date", "time")
|
|
108
|
+
assert len(time) == 4, time
|
|
109
|
+
assert len(date) == 8, date
|
|
110
|
+
yyyymmdd = int(date)
|
|
111
|
+
time = int(time) // 100
|
|
112
|
+
return datetime.datetime(yyyymmdd // 10000, yyyymmdd // 100 % 100, yyyymmdd % 100, time)
|
|
107
113
|
|
|
108
114
|
def __repr__(self):
|
|
109
115
|
return repr(self._metadata)
|
|
@@ -70,15 +70,17 @@ class XArrayMetadata(RawMetadata):
|
|
|
70
70
|
return self._as_mars()
|
|
71
71
|
|
|
72
72
|
def _as_mars(self):
|
|
73
|
-
return
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
)
|
|
73
|
+
return {}
|
|
74
|
+
# p = dict(
|
|
75
|
+
# param=self.get("variable", self.get("param")),
|
|
76
|
+
# step=self.get("step"),
|
|
77
|
+
# levelist=self.get("levelist", self.get("level")),
|
|
78
|
+
# levtype=self.get("levtype"),
|
|
79
|
+
# number=self.get("number"),
|
|
80
|
+
# date=self.get("date"),
|
|
81
|
+
# time=self.get("time"),
|
|
82
|
+
# )
|
|
83
|
+
# return {k: v for k, v in p.items() if v is not None}
|
|
82
84
|
|
|
83
85
|
def _base_datetime(self):
|
|
84
86
|
return self._field.forecast_reference_time
|
|
@@ -135,12 +137,12 @@ class XArrayFieldGeography(Geography):
|
|
|
135
137
|
# TODO: implement resolution
|
|
136
138
|
return None
|
|
137
139
|
|
|
138
|
-
@property
|
|
140
|
+
# @property
|
|
139
141
|
def mars_grid(self):
|
|
140
142
|
# TODO: implement mars_grid
|
|
141
143
|
return None
|
|
142
144
|
|
|
143
|
-
@property
|
|
145
|
+
# @property
|
|
144
146
|
def mars_area(self):
|
|
145
147
|
# TODO: code me
|
|
146
148
|
# return [self.north, self.west, self.south, self.east]
|
anemoi/datasets/create/input.py
CHANGED
|
@@ -23,7 +23,7 @@ from earthkit.data.core.fieldlist import FieldList
|
|
|
23
23
|
from earthkit.data.core.fieldlist import MultiFieldList
|
|
24
24
|
from earthkit.data.core.order import build_remapping
|
|
25
25
|
|
|
26
|
-
from anemoi.datasets.dates import
|
|
26
|
+
from anemoi.datasets.dates import DatesProvider
|
|
27
27
|
|
|
28
28
|
from .functions import import_function
|
|
29
29
|
from .template import Context
|
|
@@ -75,7 +75,7 @@ def time_delta_to_string(delta):
|
|
|
75
75
|
|
|
76
76
|
|
|
77
77
|
def is_function(name, kind):
|
|
78
|
-
name,
|
|
78
|
+
name, _ = parse_function_name(name)
|
|
79
79
|
try:
|
|
80
80
|
import_function(name, kind)
|
|
81
81
|
return True
|
|
@@ -204,11 +204,15 @@ class Result:
|
|
|
204
204
|
_coords_already_built = False
|
|
205
205
|
|
|
206
206
|
def __init__(self, context, action_path, dates):
|
|
207
|
+
from anemoi.datasets.dates.groups import GroupOfDates
|
|
208
|
+
|
|
209
|
+
assert isinstance(dates, GroupOfDates), dates
|
|
210
|
+
|
|
207
211
|
assert isinstance(context, ActionContext), type(context)
|
|
208
212
|
assert isinstance(action_path, list), action_path
|
|
209
213
|
|
|
210
214
|
self.context = context
|
|
211
|
-
self.
|
|
215
|
+
self.group_of_dates = dates
|
|
212
216
|
self.action_path = action_path
|
|
213
217
|
|
|
214
218
|
@property
|
|
@@ -405,10 +409,10 @@ class Result:
|
|
|
405
409
|
more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()])
|
|
406
410
|
|
|
407
411
|
dates = " no-dates"
|
|
408
|
-
if self.
|
|
409
|
-
dates = f" {len(self.
|
|
412
|
+
if self.group_of_dates is not None:
|
|
413
|
+
dates = f" {len(self.group_of_dates)} dates"
|
|
410
414
|
dates += " ("
|
|
411
|
-
dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.
|
|
415
|
+
dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.group_of_dates)
|
|
412
416
|
if len(dates) > 100:
|
|
413
417
|
dates = dates[:100] + "..."
|
|
414
418
|
dates += ")"
|
|
@@ -423,7 +427,7 @@ class Result:
|
|
|
423
427
|
raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
|
|
424
428
|
|
|
425
429
|
def _trace_datasource(self, *args, **kwargs):
|
|
426
|
-
return f"{self.__class__.__name__}({
|
|
430
|
+
return f"{self.__class__.__name__}({self.group_of_dates})"
|
|
427
431
|
|
|
428
432
|
def build_coords(self):
|
|
429
433
|
if self._coords_already_built:
|
|
@@ -513,7 +517,7 @@ class Result:
|
|
|
513
517
|
@cached_property
|
|
514
518
|
def shape(self):
|
|
515
519
|
return [
|
|
516
|
-
len(self.
|
|
520
|
+
len(self.group_of_dates),
|
|
517
521
|
len(self.variables),
|
|
518
522
|
len(self.ensembles),
|
|
519
523
|
len(self.grid_values),
|
|
@@ -522,7 +526,7 @@ class Result:
|
|
|
522
526
|
@cached_property
|
|
523
527
|
def coords(self):
|
|
524
528
|
return {
|
|
525
|
-
"dates": self.
|
|
529
|
+
"dates": list(self.group_of_dates),
|
|
526
530
|
"variables": self.variables,
|
|
527
531
|
"ensembles": self.ensembles,
|
|
528
532
|
"values": self.grid_values,
|
|
@@ -573,7 +577,7 @@ class FunctionResult(Result):
|
|
|
573
577
|
self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs))
|
|
574
578
|
|
|
575
579
|
def _trace_datasource(self, *args, **kwargs):
|
|
576
|
-
return f"{self.action.name}({
|
|
580
|
+
return f"{self.action.name}({self.group_of_dates})"
|
|
577
581
|
|
|
578
582
|
@cached_property
|
|
579
583
|
@assert_fieldlist
|
|
@@ -583,14 +587,21 @@ class FunctionResult(Result):
|
|
|
583
587
|
args, kwargs = resolve(self.context, (self.args, self.kwargs))
|
|
584
588
|
|
|
585
589
|
try:
|
|
586
|
-
return _tidy(
|
|
590
|
+
return _tidy(
|
|
591
|
+
self.action.function(
|
|
592
|
+
FunctionContext(self),
|
|
593
|
+
list(self.group_of_dates), # Will provide a list of datetime objects
|
|
594
|
+
*args,
|
|
595
|
+
**kwargs,
|
|
596
|
+
)
|
|
597
|
+
)
|
|
587
598
|
except Exception:
|
|
588
599
|
LOG.error(f"Error in {self.action.function.__name__}", exc_info=True)
|
|
589
600
|
raise
|
|
590
601
|
|
|
591
602
|
def __repr__(self):
|
|
592
603
|
try:
|
|
593
|
-
return f"{self.action.name}({
|
|
604
|
+
return f"{self.action.name}({self.group_of_dates})"
|
|
594
605
|
except Exception:
|
|
595
606
|
return f"{self.__class__.__name__}(unitialised)"
|
|
596
607
|
|
|
@@ -609,7 +620,7 @@ class JoinResult(Result):
|
|
|
609
620
|
@notify_result
|
|
610
621
|
@trace_datasource
|
|
611
622
|
def datasource(self):
|
|
612
|
-
ds = EmptyResult(self.context, self.action_path, self.
|
|
623
|
+
ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource
|
|
613
624
|
for i in self.results:
|
|
614
625
|
ds += i.datasource
|
|
615
626
|
return _tidy(ds)
|
|
@@ -824,7 +835,7 @@ class ConcatResult(Result):
|
|
|
824
835
|
@notify_result
|
|
825
836
|
@trace_datasource
|
|
826
837
|
def datasource(self):
|
|
827
|
-
ds = EmptyResult(self.context, self.action_path, self.
|
|
838
|
+
ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource
|
|
828
839
|
for i in self.results:
|
|
829
840
|
ds += i.datasource
|
|
830
841
|
return _tidy(ds)
|
|
@@ -904,7 +915,7 @@ class ConcatAction(Action):
|
|
|
904
915
|
cfg = deepcopy(cfg)
|
|
905
916
|
dates_cfg = cfg.pop("dates")
|
|
906
917
|
assert isinstance(dates_cfg, dict), dates_cfg
|
|
907
|
-
filtering_dates =
|
|
918
|
+
filtering_dates = DatesProvider.from_config(**dates_cfg)
|
|
908
919
|
action = action_factory(cfg, context, action_path + [str(i)])
|
|
909
920
|
parts.append((filtering_dates, action))
|
|
910
921
|
self.parts = parts
|
|
@@ -915,9 +926,11 @@ class ConcatAction(Action):
|
|
|
915
926
|
|
|
916
927
|
@trace_select
|
|
917
928
|
def select(self, dates):
|
|
929
|
+
from anemoi.datasets.dates.groups import GroupOfDates
|
|
930
|
+
|
|
918
931
|
results = []
|
|
919
932
|
for filtering_dates, action in self.parts:
|
|
920
|
-
newdates = sorted(set(dates) & set(filtering_dates))
|
|
933
|
+
newdates = GroupOfDates(sorted(set(dates) & set(filtering_dates)), dates.provider)
|
|
921
934
|
if newdates:
|
|
922
935
|
results.append(action.select(newdates))
|
|
923
936
|
if not results:
|
|
@@ -953,8 +966,10 @@ def action_factory(config, context, action_path):
|
|
|
953
966
|
|
|
954
967
|
if isinstance(config[key], list):
|
|
955
968
|
args, kwargs = config[key], {}
|
|
956
|
-
|
|
969
|
+
elif isinstance(config[key], dict):
|
|
957
970
|
args, kwargs = [], config[key]
|
|
971
|
+
else:
|
|
972
|
+
raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}")
|
|
958
973
|
|
|
959
974
|
cls = {
|
|
960
975
|
# "date_shift": DateShiftAction,
|
|
@@ -1021,6 +1036,13 @@ class FunctionContext:
|
|
|
1021
1036
|
def trace(self, emoji, *args):
|
|
1022
1037
|
trace(emoji, *args)
|
|
1023
1038
|
|
|
1039
|
+
def info(self, *args, **kwargs):
|
|
1040
|
+
LOG.info(*args, **kwargs)
|
|
1041
|
+
|
|
1042
|
+
@property
|
|
1043
|
+
def dates_provider(self):
|
|
1044
|
+
return self.owner.group_of_dates.provider
|
|
1045
|
+
|
|
1024
1046
|
|
|
1025
1047
|
class ActionContext(Context):
|
|
1026
1048
|
def __init__(self, /, order_by, flatten_grid, remapping, use_grib_paramid):
|
|
@@ -68,7 +68,7 @@ class PersistentDict:
|
|
|
68
68
|
path = os.path.join(self.dirname, f"{h}.pickle")
|
|
69
69
|
|
|
70
70
|
if os.path.exists(path):
|
|
71
|
-
LOG.
|
|
71
|
+
LOG.warning(f"{path} already exists")
|
|
72
72
|
|
|
73
73
|
tmp_path = path + f".tmp-{os.getpid()}-on-{socket.gethostname()}"
|
|
74
74
|
with open(tmp_path, "wb") as f:
|
anemoi/datasets/create/utils.py
CHANGED
|
@@ -62,6 +62,9 @@ def make_list_int(value):
|
|
|
62
62
|
|
|
63
63
|
|
|
64
64
|
def normalize_and_check_dates(dates, start, end, frequency, dtype="datetime64[s]"):
|
|
65
|
+
|
|
66
|
+
dates = [d.hdate if hasattr(d, "hdate") else d for d in dates]
|
|
67
|
+
|
|
65
68
|
assert isinstance(frequency, datetime.timedelta), frequency
|
|
66
69
|
start = np.datetime64(start)
|
|
67
70
|
end = np.datetime64(end)
|
anemoi/datasets/data/dataset.py
CHANGED
|
@@ -23,7 +23,11 @@ LOG = logging.getLogger(__name__)
|
|
|
23
23
|
class Dataset:
|
|
24
24
|
arguments = {}
|
|
25
25
|
|
|
26
|
-
def mutate(self):
|
|
26
|
+
def mutate(self) -> "Dataset":
|
|
27
|
+
"""
|
|
28
|
+
Give an opportunity to a subclass to return a new Dataset
|
|
29
|
+
object of a different class, if needed.
|
|
30
|
+
"""
|
|
27
31
|
return self
|
|
28
32
|
|
|
29
33
|
def swap_with_parent(self, parent):
|
|
@@ -90,6 +94,12 @@ class Dataset:
|
|
|
90
94
|
rename = kwargs.pop("rename")
|
|
91
95
|
return Rename(self, rename)._subset(**kwargs).mutate()
|
|
92
96
|
|
|
97
|
+
if "rescale" in kwargs:
|
|
98
|
+
from .rescale import Rescale
|
|
99
|
+
|
|
100
|
+
rescale = kwargs.pop("rescale")
|
|
101
|
+
return Rescale(self, rescale)._subset(**kwargs).mutate()
|
|
102
|
+
|
|
93
103
|
if "statistics" in kwargs:
|
|
94
104
|
from ..data import open_dataset
|
|
95
105
|
from .statistics import Statistics
|
anemoi/datasets/data/debug.py
CHANGED
|
@@ -209,10 +209,14 @@ def _debug_indexing(method):
|
|
|
209
209
|
return wrapper
|
|
210
210
|
|
|
211
211
|
|
|
212
|
+
def _identity(x):
|
|
213
|
+
return x
|
|
214
|
+
|
|
215
|
+
|
|
212
216
|
if DEBUG_ZARR_INDEXING:
|
|
213
217
|
debug_indexing = _debug_indexing
|
|
214
218
|
else:
|
|
215
|
-
debug_indexing =
|
|
219
|
+
debug_indexing = _identity
|
|
216
220
|
|
|
217
221
|
|
|
218
222
|
def debug_zarr_loading(on_off):
|
anemoi/datasets/data/masked.py
CHANGED
|
@@ -112,5 +112,5 @@ class Cropping(Masked):
|
|
|
112
112
|
def tree(self):
|
|
113
113
|
return Node(self, [self.forward.tree()], area=self.area)
|
|
114
114
|
|
|
115
|
-
def
|
|
116
|
-
return
|
|
115
|
+
def subclass_metadata_specific(self):
|
|
116
|
+
return dict(area=self.area)
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
|
|
2
|
+
# This software is licensed under the terms of the Apache Licence Version 2.0
|
|
3
|
+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
|
|
4
|
+
# In applying this licence, ECMWF does not waive the privileges and immunities
|
|
5
|
+
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
|
+
# nor does it submit to any jurisdiction.
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from functools import cached_property
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from .debug import Node
|
|
14
|
+
from .debug import debug_indexing
|
|
15
|
+
from .forwards import Forwards
|
|
16
|
+
from .indexing import apply_index_to_slices_changes
|
|
17
|
+
from .indexing import expand_list_indexing
|
|
18
|
+
from .indexing import index_to_slices
|
|
19
|
+
from .indexing import update_tuple
|
|
20
|
+
|
|
21
|
+
LOG = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def make_rescale(variable, rescale):
|
|
25
|
+
|
|
26
|
+
if isinstance(rescale, (tuple, list)):
|
|
27
|
+
|
|
28
|
+
assert len(rescale) == 2, rescale
|
|
29
|
+
|
|
30
|
+
if isinstance(rescale[0], (int, float)):
|
|
31
|
+
return rescale
|
|
32
|
+
|
|
33
|
+
from cfunits import Units
|
|
34
|
+
|
|
35
|
+
u0 = Units(rescale[0])
|
|
36
|
+
u1 = Units(rescale[1])
|
|
37
|
+
|
|
38
|
+
x1, x2 = 0.0, 1.0
|
|
39
|
+
y1, y2 = Units.conform([x1, x2], u0, u1)
|
|
40
|
+
|
|
41
|
+
a = (y2 - y1) / (x2 - x1)
|
|
42
|
+
b = y1 - a * x1
|
|
43
|
+
|
|
44
|
+
return a, b
|
|
45
|
+
|
|
46
|
+
return rescale
|
|
47
|
+
|
|
48
|
+
if isinstance(rescale, dict):
|
|
49
|
+
assert "scale" in rescale, rescale
|
|
50
|
+
assert "offset" in rescale, rescale
|
|
51
|
+
return rescale["scale"], rescale["offset"]
|
|
52
|
+
|
|
53
|
+
assert False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Rescale(Forwards):
|
|
57
|
+
def __init__(self, dataset, rescale):
|
|
58
|
+
super().__init__(dataset)
|
|
59
|
+
for n in rescale:
|
|
60
|
+
assert n in dataset.variables, n
|
|
61
|
+
|
|
62
|
+
variables = dataset.variables
|
|
63
|
+
|
|
64
|
+
self._a = np.ones(len(variables))
|
|
65
|
+
self._b = np.zeros(len(variables))
|
|
66
|
+
|
|
67
|
+
self.rescale = {}
|
|
68
|
+
for i, v in enumerate(variables):
|
|
69
|
+
if v in rescale:
|
|
70
|
+
a, b = make_rescale(v, rescale[v])
|
|
71
|
+
self.rescale[v] = a, b
|
|
72
|
+
self._a[i], self._b[i] = a, b
|
|
73
|
+
|
|
74
|
+
self._a = self._a[np.newaxis, :, np.newaxis, np.newaxis]
|
|
75
|
+
self._b = self._b[np.newaxis, :, np.newaxis, np.newaxis]
|
|
76
|
+
|
|
77
|
+
self._a = self._a.astype(self.forward.dtype)
|
|
78
|
+
self._b = self._b.astype(self.forward.dtype)
|
|
79
|
+
|
|
80
|
+
def tree(self):
|
|
81
|
+
return Node(self, [self.forward.tree()], rescale=self.rescale)
|
|
82
|
+
|
|
83
|
+
def subclass_metadata_specific(self):
|
|
84
|
+
return dict(rescale=self.rescale)
|
|
85
|
+
|
|
86
|
+
@debug_indexing
|
|
87
|
+
@expand_list_indexing
|
|
88
|
+
def _get_tuple(self, index):
|
|
89
|
+
index, changes = index_to_slices(index, self.shape)
|
|
90
|
+
index, previous = update_tuple(index, 1, slice(None))
|
|
91
|
+
result = self.forward[index]
|
|
92
|
+
result = result * self._a + self._b
|
|
93
|
+
result = result[:, previous]
|
|
94
|
+
result = apply_index_to_slices_changes(result, changes)
|
|
95
|
+
return result
|
|
96
|
+
|
|
97
|
+
@debug_indexing
|
|
98
|
+
def __get_slice_(self, n):
|
|
99
|
+
data = self.forward[n]
|
|
100
|
+
return data * self._a + self._b
|
|
101
|
+
|
|
102
|
+
@debug_indexing
|
|
103
|
+
def __getitem__(self, n):
|
|
104
|
+
|
|
105
|
+
if isinstance(n, tuple):
|
|
106
|
+
return self._get_tuple(n)
|
|
107
|
+
|
|
108
|
+
if isinstance(n, slice):
|
|
109
|
+
return self.__get_slice_(n)
|
|
110
|
+
|
|
111
|
+
data = self.forward[n]
|
|
112
|
+
|
|
113
|
+
return data * self._a[0] + self._b[0]
|
|
114
|
+
|
|
115
|
+
@cached_property
|
|
116
|
+
def statistics(self):
|
|
117
|
+
result = {}
|
|
118
|
+
a = self._a.squeeze()
|
|
119
|
+
assert np.all(a >= 0)
|
|
120
|
+
|
|
121
|
+
b = self._b.squeeze()
|
|
122
|
+
for k, v in self.forward.statistics.items():
|
|
123
|
+
if k in ("maximum", "minimum", "mean"):
|
|
124
|
+
result[k] = v * a + b
|
|
125
|
+
continue
|
|
126
|
+
|
|
127
|
+
if k in ("stdev",):
|
|
128
|
+
result[k] = v * a
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
raise NotImplementedError("rescale statistics", k)
|
|
132
|
+
|
|
133
|
+
return result
|
|
134
|
+
|
|
135
|
+
def statistics_tendencies(self, delta=None):
|
|
136
|
+
result = {}
|
|
137
|
+
a = self._a.squeeze()
|
|
138
|
+
assert np.all(a >= 0)
|
|
139
|
+
|
|
140
|
+
for k, v in self.forward.statistics_tendencies(delta).items():
|
|
141
|
+
if k in ("maximum", "minimum", "mean", "stdev"):
|
|
142
|
+
result[k] = v * a
|
|
143
|
+
continue
|
|
144
|
+
|
|
145
|
+
raise NotImplementedError("rescale tendencies statistics", k)
|
|
146
|
+
|
|
147
|
+
return result
|
anemoi/datasets/data/stores.py
CHANGED
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
# granted to it by virtue of its status as an intergovernmental organisation
|
|
6
6
|
# nor does it submit to any jurisdiction.
|
|
7
7
|
|
|
8
|
+
|
|
8
9
|
import logging
|
|
9
10
|
import os
|
|
10
11
|
import warnings
|
|
@@ -83,6 +84,8 @@ class S3Store(ReadOnlyStore):
|
|
|
83
84
|
|
|
84
85
|
|
|
85
86
|
class DebugStore(ReadOnlyStore):
|
|
87
|
+
"""A store to debug the zarr loading."""
|
|
88
|
+
|
|
86
89
|
def __init__(self, store):
|
|
87
90
|
assert not isinstance(store, DebugStore)
|
|
88
91
|
self.store = store
|
|
@@ -148,6 +151,8 @@ def open_zarr(path, dont_fail=False, cache=None):
|
|
|
148
151
|
|
|
149
152
|
|
|
150
153
|
class Zarr(Dataset):
|
|
154
|
+
"""A zarr dataset."""
|
|
155
|
+
|
|
151
156
|
def __init__(self, path):
|
|
152
157
|
if isinstance(path, zarr.hierarchy.Group):
|
|
153
158
|
self.was_zarr = True
|
|
@@ -244,14 +249,20 @@ class Zarr(Dataset):
|
|
|
244
249
|
delta = self.frequency
|
|
245
250
|
if isinstance(delta, int):
|
|
246
251
|
delta = f"{delta}h"
|
|
247
|
-
from anemoi.
|
|
252
|
+
from anemoi.utils.dates import frequency_to_string
|
|
253
|
+
from anemoi.utils.dates import frequency_to_timedelta
|
|
254
|
+
|
|
255
|
+
delta = frequency_to_timedelta(delta)
|
|
256
|
+
delta = frequency_to_string(delta)
|
|
257
|
+
|
|
258
|
+
def func(k):
|
|
259
|
+
return f"statistics_tendencies_{delta}_{k}"
|
|
248
260
|
|
|
249
|
-
func = TendenciesStatisticsAddition.final_storage_name_from_delta
|
|
250
261
|
return dict(
|
|
251
|
-
mean=self.z[func("mean"
|
|
252
|
-
stdev=self.z[func("stdev"
|
|
253
|
-
maximum=self.z[func("maximum"
|
|
254
|
-
minimum=self.z[func("minimum"
|
|
262
|
+
mean=self.z[func("mean")][:],
|
|
263
|
+
stdev=self.z[func("stdev")][:],
|
|
264
|
+
maximum=self.z[func("maximum")][:],
|
|
265
|
+
minimum=self.z[func("minimum")][:],
|
|
255
266
|
)
|
|
256
267
|
|
|
257
268
|
@property
|
|
@@ -322,11 +333,13 @@ class Zarr(Dataset):
|
|
|
322
333
|
|
|
323
334
|
|
|
324
335
|
class ZarrWithMissingDates(Zarr):
|
|
336
|
+
"""A zarr dataset with missing dates."""
|
|
337
|
+
|
|
325
338
|
def __init__(self, path):
|
|
326
339
|
super().__init__(path)
|
|
327
340
|
|
|
328
341
|
missing_dates = self.z.attrs.get("missing_dates", [])
|
|
329
|
-
missing_dates = [np.datetime64(x) for x in missing_dates]
|
|
342
|
+
missing_dates = set([np.datetime64(x) for x in missing_dates])
|
|
330
343
|
self.missing_to_dates = {i: d for i, d in enumerate(self.dates) if d in missing_dates}
|
|
331
344
|
self.missing = set(self.missing_to_dates)
|
|
332
345
|
|