anemoi-utils 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

@@ -0,0 +1,9 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+
9
+ from ._version import __version__
@@ -0,0 +1,16 @@
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.1.6'
16
+ __version_tuple__ = version_tuple = (0, 1, 6)
@@ -0,0 +1,76 @@
1
+ # (C) Copyright 2024 ECMWF.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ # In applying this licence, ECMWF does not waive the privileges and immunities
6
+ # granted to it by virtue of its status as an intergovernmental organisation
7
+ # nor does it submit to any jurisdiction.
8
+
9
+ """
10
+ Read and write extra metadata in PyTorch checkpoints files. These files
11
+ are zip archives containing the model weights.
12
+ """
13
+
14
+ import json
15
+ import logging
16
+ import os
17
+ import zipfile
18
+
19
+ LOG = logging.getLogger(__name__)
20
+
21
+ DEFAULT_NAME = "anemoi-metadata.json"
22
+
23
+
24
+ def load_metadata(path: str, name: str = DEFAULT_NAME):
25
+ """Load metadata from a checkpoint file
26
+
27
+ Parameters
28
+ ----------
29
+ path : str
30
+ The path to the checkpoint file
31
+ name : str, optional
32
+ The name of the metadata file in the zip archive
33
+
34
+ Returns
35
+ -------
36
+ JSON
37
+ The content of the metadata file
38
+
39
+ Raises
40
+ ------
41
+ ValueError
42
+ If the metadata file is not found
43
+ """
44
+ with zipfile.ZipFile(path, "r") as f:
45
+ metadata = None
46
+ for b in f.namelist():
47
+ if os.path.basename(b) == name:
48
+ if metadata is not None:
49
+ LOG.warning(f"Found two '{name}' if {path}")
50
+ metadata = b
51
+
52
+ if metadata is not None:
53
+ with zipfile.ZipFile(path, "r") as f:
54
+ return json.load(f.open(metadata, "r"))
55
+ else:
56
+ raise ValueError(f"Could not find {name} in {path}")
57
+
58
+
59
+ def save_metadata(path, metadata, name=DEFAULT_NAME):
60
+ """Save metadata to a checkpoint file
61
+
62
+ Parameters
63
+ ----------
64
+ path : str
65
+ The path to the checkpoint file
66
+ metadata : JSON
67
+ A JSON serializable object
68
+ name : str, optional
69
+ The name of the metadata file in the zip archive
70
+ """
71
+ with zipfile.ZipFile(path, "a") as zipf:
72
+ base, _ = os.path.splitext(os.path.basename(path))
73
+ zipf.writestr(
74
+ f"{base}/{name}",
75
+ json.dumps(metadata),
76
+ )
anemoi/utils/config.py ADDED
@@ -0,0 +1,94 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+
9
+ import logging
10
+ import os
11
+
12
+ try:
13
+ import tomllib # Only available since 3.11
14
+ except ImportError:
15
+ import tomli as tomllib
16
+
17
+
18
+ LOG = logging.getLogger(__name__)
19
+
20
+
21
+ class DotDict(dict):
22
+ """A dictionary that allows access to its keys as attributes.
23
+
24
+ >>> d = DotDict({"a": 1, "b": {"c": 2}})
25
+ >>> d.a
26
+ 1
27
+ >>> d.b.c
28
+ 2
29
+ >>> d.b = 3
30
+ >>> d.b
31
+ 3
32
+
33
+ The class is recursive, so nested dictionaries are also DotDicts.
34
+
35
+ The DotDict class has the same constructor as the dict class.
36
+
37
+ >>> d = DotDict(a=1, b=2)
38
+
39
+ """
40
+
41
+ def __init__(self, *args, **kwargs):
42
+ super().__init__(*args, **kwargs)
43
+ for k, v in self.items():
44
+ if isinstance(v, dict):
45
+ self[k] = DotDict(v)
46
+
47
+ def __getattr__(self, attr):
48
+ try:
49
+ return self[attr]
50
+ except KeyError:
51
+ raise AttributeError(attr)
52
+
53
+ def __setattr__(self, attr, value):
54
+ if isinstance(value, dict):
55
+ value = DotDict(value)
56
+ self[attr] = value
57
+
58
+ def __repr__(self) -> str:
59
+ return f"DotDict({super().__repr__()})"
60
+
61
+
62
+ CONFIG = None
63
+
64
+
65
+ def load_config():
66
+ """Load the configuration from `~/.anemoi.toml`.
67
+
68
+ Returns
69
+ -------
70
+ DotDict
71
+ The configuration
72
+ """
73
+ global CONFIG
74
+ if CONFIG is not None:
75
+ return CONFIG
76
+
77
+ conf = os.path.expanduser("~/.anemoi.toml")
78
+
79
+ if os.path.exists(conf):
80
+
81
+ with open(conf, "rb") as f:
82
+ CONFIG = tomllib.load(f)
83
+ else:
84
+ CONFIG = {}
85
+
86
+ return DotDict(CONFIG)
87
+
88
+
89
+ def save_config():
90
+ """Save the configuration to `~/.anemoi.toml`."""
91
+
92
+ conf = os.path.expanduser("~/.anemoi.toml")
93
+ with open(conf, "w") as f:
94
+ tomllib.dump(CONFIG, f)
anemoi/utils/dates.py ADDED
@@ -0,0 +1,248 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+
9
+ import calendar
10
+ import datetime
11
+
12
+
13
+ def no_time_zone(date):
14
+ """Remove time zone information from a date.
15
+
16
+ Parameters
17
+ ----------
18
+ date : datetime.datetime
19
+ A datetime object.
20
+
21
+ Returns
22
+ -------
23
+ datetime.datetime
24
+ The datetime object without time zone information.
25
+ """
26
+
27
+ return date.replace(tzinfo=None)
28
+
29
+
30
+ def as_datetime(date):
31
+ """Convert a date to a datetime object, removing any time zone information.
32
+
33
+ Parameters
34
+ ----------
35
+ date : datetime.date or datetime.datetime or str
36
+ The date to convert.
37
+
38
+ Returns
39
+ -------
40
+ datetime.datetime
41
+ The datetime object.
42
+ """
43
+
44
+ if isinstance(date, datetime.datetime):
45
+ return no_time_zone(date)
46
+
47
+ if isinstance(date, datetime.date):
48
+ return no_time_zone(datetime.datetime(date.year, date.month, date.day))
49
+
50
+ if isinstance(date, str):
51
+ return no_time_zone(datetime.datetime.fromisoformat(date))
52
+
53
+ raise ValueError(f"Invalid date type: {type(date)}")
54
+
55
+
56
+ DOW = {
57
+ "monday": 0,
58
+ "tuesday": 1,
59
+ "wednesday": 2,
60
+ "thursday": 3,
61
+ "friday": 4,
62
+ "saturday": 5,
63
+ "sunday": 6,
64
+ }
65
+
66
+
67
+ MONTH = {
68
+ "january": 1,
69
+ "february": 2,
70
+ "march": 3,
71
+ "april": 4,
72
+ "may": 5,
73
+ "june": 6,
74
+ "july": 7,
75
+ "august": 8,
76
+ "september": 9,
77
+ "october": 10,
78
+ "november": 11,
79
+ "december": 12,
80
+ }
81
+
82
+
83
+ def _make_day(day):
84
+ if day is None:
85
+ return set(range(1, 32))
86
+ if not isinstance(day, list):
87
+ day = [day]
88
+ return {int(d) for d in day}
89
+
90
+
91
+ def _make_week(week):
92
+ if week is None:
93
+ return set(range(7))
94
+ if not isinstance(week, list):
95
+ week = [week]
96
+ return {DOW[w.lower()] for w in week}
97
+
98
+
99
+ def _make_months(months):
100
+ if months is None:
101
+ return set(range(1, 13))
102
+
103
+ if not isinstance(months, list):
104
+ months = [months]
105
+
106
+ return {int(MONTH.get(m, m)) for m in months}
107
+
108
+
109
+ class DateTimes:
110
+ """The DateTimes class is an iterator that generates datetime objects within a given range."""
111
+
112
+ def __init__(self, start, end, increment=24, *, day_of_month=None, day_of_week=None, calendar_months=None):
113
+ """_summary_
114
+
115
+ Parameters
116
+ ----------
117
+ start : _type_
118
+ _description_
119
+ end : _type_
120
+ _description_
121
+ increment : int, optional
122
+ _description_, by default 24
123
+ day_of_month : _type_, optional
124
+ _description_, by default None
125
+ day_of_week : _type_, optional
126
+ _description_, by default None
127
+ calendar_months : _type_, optional
128
+ _description_, by default None
129
+ """
130
+ self.start = as_datetime(start)
131
+ self.end = as_datetime(end)
132
+ self.increment = datetime.timedelta(hours=increment)
133
+ self.day_of_month = _make_day(day_of_month)
134
+ self.day_of_week = _make_week(day_of_week)
135
+ self.calendar_months = _make_months(calendar_months)
136
+
137
+ def __iter__(self):
138
+ date = self.start
139
+ while date <= self.end:
140
+ if (
141
+ (date.weekday() in self.day_of_week)
142
+ and (date.day in self.day_of_month)
143
+ and (date.month in self.calendar_months)
144
+ ):
145
+
146
+ yield date
147
+ date += self.increment
148
+
149
+
150
+ class HindcastDatesTimes:
151
+ """The HindcastDatesTimes class is an iterator that generates datetime objects within a given range."""
152
+
153
+ def __init__(self, reference_dates, years=20):
154
+ """_summary_
155
+
156
+ Parameters
157
+ ----------
158
+ reference_dates : _type_
159
+ _description_
160
+ years : int, optional
161
+ _description_, by default 20
162
+ """
163
+
164
+ self.reference_dates = reference_dates
165
+ self.years = (1, years + 1)
166
+
167
+ def __iter__(self):
168
+ for reference_date in self.reference_dates:
169
+ for year in range(*self.years):
170
+ if reference_date.month == 2 and reference_date.day == 29:
171
+ date = datetime.datetime(reference_date.year - year, 2, 28)
172
+ else:
173
+ date = datetime.datetime(reference_date.year - year, reference_date.month, reference_date.day)
174
+ yield (date, reference_date)
175
+
176
+
177
+ class Year(DateTimes):
178
+ """Year is defined as the months of January to December."""
179
+
180
+ def __init__(self, year, **kwargs):
181
+ """_summary_
182
+
183
+ Parameters
184
+ ----------
185
+ year : int
186
+ _description_
187
+ """
188
+ super().__init__(datetime.datetime(year, 1, 1), datetime.datetime(year, 12, 31), **kwargs)
189
+
190
+
191
+ class Winter(DateTimes):
192
+ """Winter is defined as the months of December, January and February."""
193
+
194
+ def __init__(self, year, **kwargs):
195
+ """_summary_
196
+
197
+ Parameters
198
+ ----------
199
+ year : int
200
+ _description_
201
+ """
202
+ super().__init__(
203
+ datetime.datetime(year, 12, 1),
204
+ datetime.datetime(year + 1, 2, calendar.monthrange(year + 1, 2)[1]),
205
+ **kwargs,
206
+ )
207
+
208
+
209
+ class Spring(DateTimes):
210
+ """Spring is defined as the months of March, April and May."""
211
+
212
+ def __init__(self, year, **kwargs):
213
+ """_summary_
214
+
215
+ Parameters
216
+ ----------
217
+ year : int
218
+ _description_
219
+ """
220
+ super().__init__(datetime.datetime(year, 3, 1), datetime.datetime(year, 5, 31), **kwargs)
221
+
222
+
223
+ class Summer(DateTimes):
224
+ """Summer is defined as the months of June, July and August."""
225
+
226
+ def __init__(self, year, **kwargs):
227
+ """_summary_
228
+
229
+ Parameters
230
+ ----------
231
+ year : int
232
+ _description_
233
+ """
234
+ super().__init__(datetime.datetime(year, 6, 1), datetime.datetime(year, 8, 31), **kwargs)
235
+
236
+
237
+ class Autumn(DateTimes):
238
+ """Autumn is defined as the months of September, October and November."""
239
+
240
+ def __init__(self, year, **kwargs):
241
+ """_summary_
242
+
243
+ Parameters
244
+ ----------
245
+ year : int
246
+ _description_
247
+ """
248
+ super().__init__(datetime.datetime(year, 9, 1), datetime.datetime(year, 11, 30), **kwargs)
anemoi/utils/grib.py ADDED
@@ -0,0 +1,73 @@
1
+ # (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2
+ # This software is licensed under the terms of the Apache Licence Version 2.0
3
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4
+ # In applying this licence, ECMWF does not waive the privileges and immunities
5
+ # granted to it by virtue of its status as an intergovernmental organisation
6
+ # nor does it submit to any jurisdiction.
7
+
8
+ """Utilities for working with GRIB parameters.
9
+
10
+ See https://codes.ecmwf.int/grib/param-db/ for more information.
11
+
12
+ """
13
+
14
+ import re
15
+
16
+ import requests
17
+
18
+
19
+ def _search(name):
20
+ name = re.escape(name)
21
+ r = requests.get(f"https://codes.ecmwf.int/parameter-database/api/v1/param/?search=^{name}$&regex=true")
22
+ r.raise_for_status()
23
+ results = r.json()
24
+ if len(results) == 0:
25
+ raise KeyError(name)
26
+
27
+ if len(results) > 1:
28
+ names = [f'{r.get("id")} ({r.get("name")})' for r in results]
29
+ raise ValueError(f"{name} is ambiguous: {', '.join(names)}")
30
+
31
+ return results[0]
32
+
33
+
34
+ def shortname_to_paramid(shortname: str) -> int:
35
+ """Return the GRIB parameter id given its shortname.
36
+
37
+ Parameters
38
+ ----------
39
+ shortname : str
40
+ Parameter shortname.
41
+
42
+ Returns
43
+ -------
44
+ int
45
+ Parameter id.
46
+
47
+
48
+ >>> shortname_to_paramid("2t")
49
+ 167
50
+
51
+ """
52
+ return _search(shortname)["id"]
53
+
54
+
55
+ def paramid_to_shortname(paramid: int) -> str:
56
+ """Return the shortname of a GRIB parameter given its id.
57
+
58
+ Parameters
59
+ ----------
60
+ paramid : int
61
+ Parameter id.
62
+
63
+ Returns
64
+ -------
65
+ str
66
+ Parameter shortname.
67
+
68
+
69
+ >>> paramid_to_shortname(167)
70
+ '2t'
71
+
72
+ """
73
+ return _search(str(paramid))["shortname"]