anemoi-utils 0.4.29__py3-none-any.whl → 0.4.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

anemoi/utils/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.4.29'
21
- __version_tuple__ = version_tuple = (0, 4, 29)
20
+ __version__ = version = '0.4.30'
21
+ __version_tuple__ = version_tuple = (0, 4, 30)
anemoi/utils/caching.py CHANGED
@@ -12,10 +12,9 @@ import hashlib
12
12
  import json
13
13
  import os
14
14
  import time
15
+ from collections.abc import Callable
15
16
  from threading import Lock
16
17
  from typing import Any
17
- from typing import Callable
18
- from typing import Optional
19
18
 
20
19
  import numpy as np
21
20
 
@@ -61,7 +60,7 @@ class Cacher:
61
60
  Private class, do not use directly.
62
61
  """
63
62
 
64
- def __init__(self, collection: str, expires: Optional[int]):
63
+ def __init__(self, collection: str, expires: int | None):
65
64
  """Initialize the Cacher.
66
65
 
67
66
  Parameters
@@ -181,7 +180,7 @@ class JsonCacher(Cacher):
181
180
  dict
182
181
  The loaded data
183
182
  """
184
- with open(path, "r") as f:
183
+ with open(path) as f:
185
184
  return json.load(f)
186
185
 
187
186
 
@@ -226,7 +225,7 @@ class NpzCacher(Cacher):
226
225
 
227
226
 
228
227
  # This function is the main entry point for the caching mechanism for the other anemoi packages
229
- def cached(collection: str = "default", expires: Optional[int] = None, encoding: str = "json") -> Callable:
228
+ def cached(collection: str = "default", expires: int | None = None, encoding: str = "json") -> Callable:
230
229
  """Decorator to cache the result of a function.
231
230
 
232
231
  Default is to use a json file to store the cache, but you can also use npz files
@@ -17,8 +17,8 @@ import logging
17
17
  import os
18
18
  import time
19
19
  import zipfile
20
+ from collections.abc import Callable
20
21
  from tempfile import TemporaryDirectory
21
- from typing import Callable
22
22
 
23
23
  import tqdm
24
24
 
anemoi/utils/cli.py CHANGED
@@ -14,8 +14,7 @@ import logging
14
14
  import os
15
15
  import sys
16
16
  import traceback
17
- from typing import Callable
18
- from typing import Optional
17
+ from collections.abc import Callable
19
18
 
20
19
  try:
21
20
  import argcomplete
@@ -187,7 +186,7 @@ def register_commands(here: str, package: str, select: Callable, fail: Callable
187
186
 
188
187
 
189
188
  def cli_main(
190
- version: str, description: str, commands: dict[str, Command], test_arguments: Optional[list[str]] = None
189
+ version: str, description: str, commands: dict[str, Command], test_arguments: list[str] | None = None
191
190
  ) -> None:
192
191
  """Main entry point for the CLI.
193
192
 
@@ -17,7 +17,6 @@ from argparse import ArgumentParser
17
17
  from argparse import Namespace
18
18
  from tempfile import TemporaryDirectory
19
19
  from typing import Any
20
- from typing import Dict
21
20
 
22
21
  import yaml
23
22
 
@@ -213,7 +212,7 @@ class Metadata(Command):
213
212
  from anemoi.utils.checkpoints import load_metadata
214
213
  from anemoi.utils.checkpoints import replace_metadata
215
214
 
216
- kwargs: Dict[str, Any] = {}
215
+ kwargs: dict[str, Any] = {}
217
216
 
218
217
  if args.json:
219
218
  ext = "json"
@@ -10,15 +10,11 @@
10
10
  from __future__ import annotations
11
11
 
12
12
  import functools
13
+ from collections.abc import Callable
13
14
  from typing import Any
14
- from typing import Callable
15
- from typing import Optional
16
- from typing import Union
17
15
 
18
16
 
19
- def aliases(
20
- aliases: Optional[dict[str, Union[str, list[str]]]] = None, **kwargs: Any
21
- ) -> Callable[[Callable], Callable]:
17
+ def aliases(aliases: dict[str, str | list[str]] | None = None, **kwargs: Any) -> Callable[[Callable], Callable]:
22
18
  """Alias keyword arguments in a function call.
23
19
 
24
20
  Allows for dynamically renaming keyword arguments in a function call.
anemoi/utils/config.py CHANGED
@@ -16,8 +16,6 @@ import logging
16
16
  import os
17
17
  import threading
18
18
  from typing import Any
19
- from typing import Optional
20
- from typing import Union
21
19
 
22
20
  import yaml
23
21
 
@@ -62,14 +60,20 @@ class DotDict(dict):
62
60
  super().__init__(*args, **kwargs)
63
61
 
64
62
  for k, v in self.items():
65
- if isinstance(v, dict) or is_omegaconf_dict(v):
66
- self[k] = DotDict(v)
63
+ self[k] = self.convert_to_nested_dot_dict(v)
67
64
 
68
- if isinstance(v, list) or is_omegaconf_list(v):
69
- self[k] = [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in v]
65
+ @staticmethod
66
+ def convert_to_nested_dot_dict(value):
67
+ if isinstance(value, dict) or is_omegaconf_dict(value):
68
+ return DotDict(value)
70
69
 
71
- if isinstance(v, tuple):
72
- self[k] = [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in v]
70
+ if isinstance(value, list) or is_omegaconf_list(value):
71
+ return [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in value]
72
+
73
+ if isinstance(value, tuple):
74
+ return [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in value]
75
+
76
+ return value
73
77
 
74
78
  @classmethod
75
79
  def from_file(cls, path: str) -> DotDict:
@@ -109,7 +113,7 @@ class DotDict(dict):
109
113
  DotDict
110
114
  The created DotDict.
111
115
  """
112
- with open(path, "r") as file:
116
+ with open(path) as file:
113
117
  data = yaml.safe_load(file)
114
118
 
115
119
  return cls(data)
@@ -128,7 +132,7 @@ class DotDict(dict):
128
132
  DotDict
129
133
  The created DotDict.
130
134
  """
131
- with open(path, "r") as file:
135
+ with open(path) as file:
132
136
  data = json.load(file)
133
137
 
134
138
  return cls(data)
@@ -147,7 +151,7 @@ class DotDict(dict):
147
151
  DotDict
148
152
  The created DotDict.
149
153
  """
150
- with open(path, "r") as file:
154
+ with open(path) as file:
151
155
  data = tomllib.load(file)
152
156
  return cls(data)
153
157
 
@@ -179,9 +183,31 @@ class DotDict(dict):
179
183
  value : Any
180
184
  The attribute value.
181
185
  """
182
- if isinstance(value, dict):
183
- value = DotDict(value)
184
- self[attr] = value
186
+
187
+ self.warn_on_mutation(attr)
188
+ value = self.convert_to_nested_dot_dict(value)
189
+ super().__setitem__(attr, value)
190
+
191
+ def __setitem__(self, key: str, value: Any) -> None:
192
+ """Set an item in the dictionary.
193
+
194
+ Parameters
195
+ ----------
196
+ key : str
197
+ The key to set.
198
+ value : Any
199
+ The value to set.
200
+ """
201
+ self.warn_on_mutation(key)
202
+ value = self.convert_to_nested_dot_dict(value)
203
+ super().__setitem__(key, value)
204
+
205
+ @staticmethod
206
+ def warn_on_mutation(key):
207
+ LOG.warning(
208
+ f"Config key '{key}' was modified after instantiation. "
209
+ "This is bad practice — configs are intended to be immutable. "
210
+ )
185
211
 
186
212
  def __repr__(self) -> str:
187
213
  """Return a string representation of the DotDict.
@@ -243,7 +269,7 @@ QUIET = False
243
269
  CONFIG_PATCH = None
244
270
 
245
271
 
246
- def _find(config: Union[dict, list], what: str, result: list = None) -> list:
272
+ def _find(config: dict | list, what: str, result: list = None) -> list:
247
273
  """Find all occurrences of a key in a nested dictionary or list.
248
274
 
249
275
  Parameters
@@ -408,8 +434,8 @@ def load_any_dict_format(path: str) -> dict:
408
434
 
409
435
  def _load_config(
410
436
  name: str = "settings.toml",
411
- secrets: Optional[Union[str, list[str]]] = None,
412
- defaults: Optional[Union[str, dict]] = None,
437
+ secrets: str | list[str] | None = None,
438
+ defaults: str | dict | None = None,
413
439
  ) -> DotDict:
414
440
  """Load a configuration file.
415
441
 
@@ -531,8 +557,8 @@ def save_config(name: str, data: Any) -> None:
531
557
 
532
558
  def load_config(
533
559
  name: str = "settings.toml",
534
- secrets: Optional[Union[str, list[str]]] = None,
535
- defaults: Optional[Union[str, dict]] = None,
560
+ secrets: str | list[str] | None = None,
561
+ defaults: str | dict | None = None,
536
562
  ) -> DotDict | str:
537
563
  """Read a configuration file.
538
564
 
@@ -558,7 +584,7 @@ def load_config(
558
584
  return config
559
585
 
560
586
 
561
- def load_raw_config(name: str, default: Any = None) -> Union[DotDict, str]:
587
+ def load_raw_config(name: str, default: Any = None) -> DotDict | str:
562
588
  """Load a raw configuration file.
563
589
 
564
590
  Parameters
@@ -617,7 +643,7 @@ def check_config_mode(name: str = "settings.toml", secrets_name: str = None, sec
617
643
  CHECKED[name] = True
618
644
 
619
645
 
620
- def find(metadata: Union[dict, list], what: str, result: list = None, *, select: callable = None) -> list:
646
+ def find(metadata: dict | list, what: str, result: list = None, *, select: callable = None) -> list:
621
647
  """Find all occurrences of a key in a nested dictionary or list with an optional selector.
622
648
 
623
649
  Parameters
anemoi/utils/dates.py CHANGED
@@ -12,16 +12,11 @@ import calendar
12
12
  import datetime
13
13
  import re
14
14
  from typing import Any
15
- from typing import List
16
- from typing import Optional
17
- from typing import Set
18
- from typing import Tuple
19
- from typing import Union
20
15
 
21
16
  import aniso8601
22
17
 
23
18
 
24
- def normalise_frequency(frequency: Union[int, str]) -> int:
19
+ def normalise_frequency(frequency: int | str) -> int:
25
20
  """Normalise frequency to hours.
26
21
 
27
22
  Parameters
@@ -61,7 +56,7 @@ def _no_time_zone(date: datetime.datetime) -> datetime.datetime:
61
56
 
62
57
 
63
58
  # this function is use in anemoi-datasets
64
- def as_datetime(date: Union[datetime.date, datetime.datetime, str], keep_time_zone: bool = False) -> datetime.datetime:
59
+ def as_datetime(date: datetime.date | datetime.datetime | str, keep_time_zone: bool = False) -> datetime.datetime:
65
60
  """Convert a date to a datetime object, removing any time zone information.
66
61
 
67
62
  Parameters
@@ -91,9 +86,7 @@ def as_datetime(date: Union[datetime.date, datetime.datetime, str], keep_time_zo
91
86
  raise ValueError(f"Invalid date type: {type(date)}")
92
87
 
93
88
 
94
- def _as_datetime_list(
95
- date: Union[datetime.date, datetime.datetime, str], default_increment: datetime.timedelta
96
- ) -> iter:
89
+ def _as_datetime_list(date: datetime.date | datetime.datetime | str, default_increment: datetime.timedelta) -> iter:
97
90
  """Convert a date to a list of datetime objects.
98
91
 
99
92
  Parameters
@@ -137,7 +130,7 @@ def _as_datetime_list(
137
130
 
138
131
 
139
132
  def as_datetime_list(
140
- date: Union[datetime.date, datetime.datetime, str], default_increment: int = 1
133
+ date: datetime.date | datetime.datetime | str, default_increment: int = 1
141
134
  ) -> list[datetime.datetime]:
142
135
  """Convert a date to a list of datetime objects.
143
136
 
@@ -157,7 +150,7 @@ def as_datetime_list(
157
150
  return list(_as_datetime_list(date, default_increment))
158
151
 
159
152
 
160
- def as_timedelta(frequency: Union[int, str, datetime.timedelta]) -> datetime.timedelta:
153
+ def as_timedelta(frequency: int | str | datetime.timedelta) -> datetime.timedelta:
161
154
  """Convert anything to a timedelta object.
162
155
 
163
156
  Parameters
@@ -237,7 +230,7 @@ def as_timedelta(frequency: Union[int, str, datetime.timedelta]) -> datetime.tim
237
230
  raise ValueError(f"Cannot convert frequency {frequency} to timedelta")
238
231
 
239
232
 
240
- def frequency_to_timedelta(frequency: Union[int, str, datetime.timedelta]) -> datetime.timedelta:
233
+ def frequency_to_timedelta(frequency: int | str | datetime.timedelta) -> datetime.timedelta:
241
234
  """Convert a frequency to a timedelta object.
242
235
 
243
236
  Parameters
@@ -302,7 +295,7 @@ def frequency_to_string(frequency: datetime.timedelta) -> str:
302
295
  return str(frequency)
303
296
 
304
297
 
305
- def frequency_to_seconds(frequency: Union[int, str, datetime.timedelta]) -> int:
298
+ def frequency_to_seconds(frequency: int | str | datetime.timedelta) -> int:
306
299
  """Convert a frequency to seconds.
307
300
 
308
301
  Parameters
@@ -347,7 +340,7 @@ MONTH = {
347
340
  }
348
341
 
349
342
 
350
- def _make_day(day: Optional[Tuple[int, List[int]]]) -> Set[int]:
343
+ def _make_day(day: tuple[int, list[int]] | None) -> set[int]:
351
344
  """Create a set of days.
352
345
 
353
346
  Parameters
@@ -367,7 +360,7 @@ def _make_day(day: Optional[Tuple[int, List[int]]]) -> Set[int]:
367
360
  return {int(d) for d in day}
368
361
 
369
362
 
370
- def _make_week(week: Optional[Tuple[str, List[str]]]) -> Set[int]:
363
+ def _make_week(week: tuple[str, list[str]] | None) -> set[int]:
371
364
  """Create a set of weekdays.
372
365
 
373
366
  Parameters
@@ -387,7 +380,7 @@ def _make_week(week: Optional[Tuple[str, List[str]]]) -> Set[int]:
387
380
  return {DOW[w.lower()] for w in week}
388
381
 
389
382
 
390
- def _make_months(months: Optional[Union[int, str, List[Union[int, str]]]]) -> Set[int]:
383
+ def _make_months(months: int | str | list[int | str] | None) -> set[int]:
391
384
  """Create a set of months.
392
385
 
393
386
  Parameters
@@ -414,13 +407,13 @@ class DateTimes:
414
407
 
415
408
  def __init__(
416
409
  self,
417
- start: Union[datetime.date, datetime.datetime, str],
418
- end: Union[datetime.date, datetime.datetime, str],
410
+ start: datetime.date | datetime.datetime | str,
411
+ end: datetime.date | datetime.datetime | str,
419
412
  increment: int = 24,
420
413
  *,
421
- day_of_month: Optional[Tuple[int, List[int]]] = None,
422
- day_of_week: Optional[Tuple[str, List[str]]] = None,
423
- calendar_months: Optional[Union[int, str, List[Union[int, str]]]] = None,
414
+ day_of_month: tuple[int, list[int]] | None = None,
415
+ day_of_week: tuple[str, list[str]] | None = None,
416
+ calendar_months: int | str | list[int | str] | None = None,
424
417
  ):
425
418
  """Initialize the DateTimes iterator.
426
419
 
@@ -571,7 +564,7 @@ class ConcatDateTimes:
571
564
  class EnumDateTimes:
572
565
  """EnumDateTimes is an iterator that generates datetime objects from a list of dates."""
573
566
 
574
- def __init__(self, dates: list[Union[datetime.date, datetime.datetime, str]]):
567
+ def __init__(self, dates: list[datetime.date | datetime.datetime | str]):
575
568
  """Initialize the EnumDateTimes iterator.
576
569
 
577
570
  Parameters
@@ -593,7 +586,7 @@ class EnumDateTimes:
593
586
  yield as_datetime(date)
594
587
 
595
588
 
596
- def datetimes_factory(*args: Any, **kwargs: Any) -> Union[DateTimes, ConcatDateTimes, EnumDateTimes]:
589
+ def datetimes_factory(*args: Any, **kwargs: Any) -> DateTimes | ConcatDateTimes | EnumDateTimes:
597
590
  """Create a DateTimes, ConcatDateTimes, or EnumDateTimes object.
598
591
 
599
592
  Parameters
anemoi/utils/grib.py CHANGED
@@ -15,8 +15,6 @@ See https://codes.ecmwf.int/grib/param-db/ for more information.
15
15
 
16
16
  import logging
17
17
  import re
18
- from typing import Dict
19
- from typing import Union
20
18
 
21
19
  import requests
22
20
 
@@ -26,7 +24,7 @@ LOG = logging.getLogger(__name__)
26
24
 
27
25
 
28
26
  @cached(collection="grib", expires=30 * 24 * 60 * 60)
29
- def _units() -> Dict[str, str]:
27
+ def _units() -> dict[str, str]:
30
28
  """Fetch and cache GRIB parameter units.
31
29
 
32
30
  Returns
@@ -41,7 +39,7 @@ def _units() -> Dict[str, str]:
41
39
 
42
40
 
43
41
  @cached(collection="grib", expires=30 * 24 * 60 * 60)
44
- def _search_param(name: str) -> Dict[str, Union[str, int]]:
42
+ def _search_param(name: str) -> dict[str, str | int]:
45
43
  """Search for a GRIB parameter by name.
46
44
 
47
45
  Parameters
@@ -116,7 +114,7 @@ def paramid_to_shortname(paramid: int) -> str:
116
114
  return _search_param(str(paramid))["shortname"]
117
115
 
118
116
 
119
- def units(param: Union[int, str]) -> str:
117
+ def units(param: int | str) -> str:
120
118
  """Return the units of a GRIB parameter given its name or id.
121
119
 
122
120
  Parameters
@@ -137,7 +135,7 @@ def units(param: Union[int, str]) -> str:
137
135
  return _units()[unit_id]
138
136
 
139
137
 
140
- def must_be_positive(param: Union[int, str]) -> bool:
138
+ def must_be_positive(param: int | str) -> bool:
141
139
  """Check if a parameter must be positive.
142
140
 
143
141
  Parameters
anemoi/utils/grids.py CHANGED
@@ -13,9 +13,6 @@
13
13
  import logging
14
14
  import os
15
15
  from io import BytesIO
16
- from typing import List
17
- from typing import Tuple
18
- from typing import Union
19
16
 
20
17
  import deprecation
21
18
  import numpy as np
@@ -145,7 +142,7 @@ def nearest_grid_points(
145
142
 
146
143
 
147
144
  @cached(collection="grids", encoding="npz")
148
- def _grids(name: Union[str, List[float], Tuple[float, ...]]) -> bytes:
145
+ def _grids(name: str | list[float] | tuple[float, ...]) -> bytes:
149
146
  """Get grid data by name.
150
147
 
151
148
  Parameters
@@ -196,7 +193,7 @@ def _grids(name: Union[str, List[float], Tuple[float, ...]]) -> bytes:
196
193
  current_version=__version__,
197
194
  details="Use anemoi.transform.grids.named.lookup instead.",
198
195
  )
199
- def grids(name: Union[str, List[float], Tuple[float, ...]]) -> dict:
196
+ def grids(name: str | list[float] | tuple[float, ...]) -> dict:
200
197
  """Load grid data by name.
201
198
 
202
199
  Parameters
anemoi/utils/hindcasts.py CHANGED
@@ -9,9 +9,7 @@
9
9
 
10
10
 
11
11
  import datetime
12
- from typing import Iterator
13
- from typing import List
14
- from typing import Tuple
12
+ from collections.abc import Iterator
15
13
 
16
14
 
17
15
  class HindcastDatesTimes:
@@ -25,7 +23,7 @@ class HindcastDatesTimes:
25
23
  Number of years to go back from each reference date.
26
24
  """
27
25
 
28
- def __init__(self, reference_dates: List[datetime.datetime], years: int = 20):
26
+ def __init__(self, reference_dates: list[datetime.datetime], years: int = 20):
29
27
  """Initialize the HindcastDatesTimes iterator.
30
28
 
31
29
  Parameters
@@ -41,7 +39,7 @@ class HindcastDatesTimes:
41
39
  assert years > 0, f"years must be greater than 0, got {years}"
42
40
  self.years = years
43
41
 
44
- def __iter__(self) -> Iterator[Tuple[datetime.datetime, datetime.datetime]]:
42
+ def __iter__(self) -> Iterator[tuple[datetime.datetime, datetime.datetime]]:
45
43
  """Generate tuples of past dates and their corresponding reference dates.
46
44
 
47
45
  Yields