anemoi-utils 0.4.28__py3-none-any.whl → 0.4.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

anemoi/utils/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.4.28'
21
- __version_tuple__ = version_tuple = (0, 4, 28)
20
+ __version__ = version = '0.4.30'
21
+ __version_tuple__ = version_tuple = (0, 4, 30)
anemoi/utils/caching.py CHANGED
@@ -12,10 +12,9 @@ import hashlib
12
12
  import json
13
13
  import os
14
14
  import time
15
+ from collections.abc import Callable
15
16
  from threading import Lock
16
17
  from typing import Any
17
- from typing import Callable
18
- from typing import Optional
19
18
 
20
19
  import numpy as np
21
20
 
@@ -61,7 +60,7 @@ class Cacher:
61
60
  Private class, do not use directly.
62
61
  """
63
62
 
64
- def __init__(self, collection: str, expires: Optional[int]):
63
+ def __init__(self, collection: str, expires: int | None):
65
64
  """Initialize the Cacher.
66
65
 
67
66
  Parameters
@@ -181,7 +180,7 @@ class JsonCacher(Cacher):
181
180
  dict
182
181
  The loaded data
183
182
  """
184
- with open(path, "r") as f:
183
+ with open(path) as f:
185
184
  return json.load(f)
186
185
 
187
186
 
@@ -226,7 +225,7 @@ class NpzCacher(Cacher):
226
225
 
227
226
 
228
227
  # This function is the main entry point for the caching mechanism for the other anemoi packages
229
- def cached(collection: str = "default", expires: Optional[int] = None, encoding: str = "json") -> Callable:
228
+ def cached(collection: str = "default", expires: int | None = None, encoding: str = "json") -> Callable:
230
229
  """Decorator to cache the result of a function.
231
230
 
232
231
  Default is to use a json file to store the cache, but you can also use npz files
@@ -17,8 +17,8 @@ import logging
17
17
  import os
18
18
  import time
19
19
  import zipfile
20
+ from collections.abc import Callable
20
21
  from tempfile import TemporaryDirectory
21
- from typing import Callable
22
22
 
23
23
  import tqdm
24
24
 
anemoi/utils/cli.py CHANGED
@@ -14,8 +14,7 @@ import logging
14
14
  import os
15
15
  import sys
16
16
  import traceback
17
- from typing import Callable
18
- from typing import Optional
17
+ from collections.abc import Callable
19
18
 
20
19
  try:
21
20
  import argcomplete
@@ -187,7 +186,7 @@ def register_commands(here: str, package: str, select: Callable, fail: Callable
187
186
 
188
187
 
189
188
  def cli_main(
190
- version: str, description: str, commands: dict[str, Command], test_arguments: Optional[list[str]] = None
189
+ version: str, description: str, commands: dict[str, Command], test_arguments: list[str] | None = None
191
190
  ) -> None:
192
191
  """Main entry point for the CLI.
193
192
 
@@ -17,7 +17,6 @@ from argparse import ArgumentParser
17
17
  from argparse import Namespace
18
18
  from tempfile import TemporaryDirectory
19
19
  from typing import Any
20
- from typing import Dict
21
20
 
22
21
  import yaml
23
22
 
@@ -213,7 +212,7 @@ class Metadata(Command):
213
212
  from anemoi.utils.checkpoints import load_metadata
214
213
  from anemoi.utils.checkpoints import replace_metadata
215
214
 
216
- kwargs: Dict[str, Any] = {}
215
+ kwargs: dict[str, Any] = {}
217
216
 
218
217
  if args.json:
219
218
  ext = "json"
@@ -10,15 +10,11 @@
10
10
  from __future__ import annotations
11
11
 
12
12
  import functools
13
+ from collections.abc import Callable
13
14
  from typing import Any
14
- from typing import Callable
15
- from typing import Optional
16
- from typing import Union
17
15
 
18
16
 
19
- def aliases(
20
- aliases: Optional[dict[str, Union[str, list[str]]]] = None, **kwargs: Any
21
- ) -> Callable[[Callable], Callable]:
17
+ def aliases(aliases: dict[str, str | list[str]] | None = None, **kwargs: Any) -> Callable[[Callable], Callable]:
22
18
  """Alias keyword arguments in a function call.
23
19
 
24
20
  Allows for dynamically renaming keyword arguments in a function call.
anemoi/utils/config.py CHANGED
@@ -16,8 +16,6 @@ import logging
16
16
  import os
17
17
  import threading
18
18
  from typing import Any
19
- from typing import Optional
20
- from typing import Union
21
19
 
22
20
  import yaml
23
21
 
@@ -62,14 +60,20 @@ class DotDict(dict):
62
60
  super().__init__(*args, **kwargs)
63
61
 
64
62
  for k, v in self.items():
65
- if isinstance(v, dict) or is_omegaconf_dict(v):
66
- self[k] = DotDict(v)
63
+ self[k] = self.convert_to_nested_dot_dict(v)
67
64
 
68
- if isinstance(v, list) or is_omegaconf_list(v):
69
- self[k] = [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in v]
65
+ @staticmethod
66
+ def convert_to_nested_dot_dict(value):
67
+ if isinstance(value, dict) or is_omegaconf_dict(value):
68
+ return DotDict(value)
70
69
 
71
- if isinstance(v, tuple):
72
- self[k] = [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in v]
70
+ if isinstance(value, list) or is_omegaconf_list(value):
71
+ return [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in value]
72
+
73
+ if isinstance(value, tuple):
74
+ return [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in value]
75
+
76
+ return value
73
77
 
74
78
  @classmethod
75
79
  def from_file(cls, path: str) -> DotDict:
@@ -109,7 +113,7 @@ class DotDict(dict):
109
113
  DotDict
110
114
  The created DotDict.
111
115
  """
112
- with open(path, "r") as file:
116
+ with open(path) as file:
113
117
  data = yaml.safe_load(file)
114
118
 
115
119
  return cls(data)
@@ -128,7 +132,7 @@ class DotDict(dict):
128
132
  DotDict
129
133
  The created DotDict.
130
134
  """
131
- with open(path, "r") as file:
135
+ with open(path) as file:
132
136
  data = json.load(file)
133
137
 
134
138
  return cls(data)
@@ -147,7 +151,7 @@ class DotDict(dict):
147
151
  DotDict
148
152
  The created DotDict.
149
153
  """
150
- with open(path, "r") as file:
154
+ with open(path) as file:
151
155
  data = tomllib.load(file)
152
156
  return cls(data)
153
157
 
@@ -179,9 +183,31 @@ class DotDict(dict):
179
183
  value : Any
180
184
  The attribute value.
181
185
  """
182
- if isinstance(value, dict):
183
- value = DotDict(value)
184
- self[attr] = value
186
+
187
+ self.warn_on_mutation(attr)
188
+ value = self.convert_to_nested_dot_dict(value)
189
+ super().__setitem__(attr, value)
190
+
191
+ def __setitem__(self, key: str, value: Any) -> None:
192
+ """Set an item in the dictionary.
193
+
194
+ Parameters
195
+ ----------
196
+ key : str
197
+ The key to set.
198
+ value : Any
199
+ The value to set.
200
+ """
201
+ self.warn_on_mutation(key)
202
+ value = self.convert_to_nested_dot_dict(value)
203
+ super().__setitem__(key, value)
204
+
205
+ @staticmethod
206
+ def warn_on_mutation(key):
207
+ LOG.warning(
208
+ f"Config key '{key}' was modified after instantiation. "
209
+ "This is bad practice — configs are intended to be immutable. "
210
+ )
185
211
 
186
212
  def __repr__(self) -> str:
187
213
  """Return a string representation of the DotDict.
@@ -243,7 +269,7 @@ QUIET = False
243
269
  CONFIG_PATCH = None
244
270
 
245
271
 
246
- def _find(config: Union[dict, list], what: str, result: list = None) -> list:
272
+ def _find(config: dict | list, what: str, result: list = None) -> list:
247
273
  """Find all occurrences of a key in a nested dictionary or list.
248
274
 
249
275
  Parameters
@@ -408,8 +434,8 @@ def load_any_dict_format(path: str) -> dict:
408
434
 
409
435
  def _load_config(
410
436
  name: str = "settings.toml",
411
- secrets: Optional[Union[str, list[str]]] = None,
412
- defaults: Optional[Union[str, dict]] = None,
437
+ secrets: str | list[str] | None = None,
438
+ defaults: str | dict | None = None,
413
439
  ) -> DotDict:
414
440
  """Load a configuration file.
415
441
 
@@ -531,8 +557,8 @@ def save_config(name: str, data: Any) -> None:
531
557
 
532
558
  def load_config(
533
559
  name: str = "settings.toml",
534
- secrets: Optional[Union[str, list[str]]] = None,
535
- defaults: Optional[Union[str, dict]] = None,
560
+ secrets: str | list[str] | None = None,
561
+ defaults: str | dict | None = None,
536
562
  ) -> DotDict | str:
537
563
  """Read a configuration file.
538
564
 
@@ -558,7 +584,7 @@ def load_config(
558
584
  return config
559
585
 
560
586
 
561
- def load_raw_config(name: str, default: Any = None) -> Union[DotDict, str]:
587
+ def load_raw_config(name: str, default: Any = None) -> DotDict | str:
562
588
  """Load a raw configuration file.
563
589
 
564
590
  Parameters
@@ -617,7 +643,7 @@ def check_config_mode(name: str = "settings.toml", secrets_name: str = None, sec
617
643
  CHECKED[name] = True
618
644
 
619
645
 
620
- def find(metadata: Union[dict, list], what: str, result: list = None, *, select: callable = None) -> list:
646
+ def find(metadata: dict | list, what: str, result: list = None, *, select: callable = None) -> list:
621
647
  """Find all occurrences of a key in a nested dictionary or list with an optional selector.
622
648
 
623
649
  Parameters
anemoi/utils/dates.py CHANGED
@@ -12,16 +12,11 @@ import calendar
12
12
  import datetime
13
13
  import re
14
14
  from typing import Any
15
- from typing import List
16
- from typing import Optional
17
- from typing import Set
18
- from typing import Tuple
19
- from typing import Union
20
15
 
21
16
  import aniso8601
22
17
 
23
18
 
24
- def normalise_frequency(frequency: Union[int, str]) -> int:
19
+ def normalise_frequency(frequency: int | str) -> int:
25
20
  """Normalise frequency to hours.
26
21
 
27
22
  Parameters
@@ -61,7 +56,7 @@ def _no_time_zone(date: datetime.datetime) -> datetime.datetime:
61
56
 
62
57
 
63
58
  # this function is use in anemoi-datasets
64
- def as_datetime(date: Union[datetime.date, datetime.datetime, str], keep_time_zone: bool = False) -> datetime.datetime:
59
+ def as_datetime(date: datetime.date | datetime.datetime | str, keep_time_zone: bool = False) -> datetime.datetime:
65
60
  """Convert a date to a datetime object, removing any time zone information.
66
61
 
67
62
  Parameters
@@ -91,9 +86,7 @@ def as_datetime(date: Union[datetime.date, datetime.datetime, str], keep_time_zo
91
86
  raise ValueError(f"Invalid date type: {type(date)}")
92
87
 
93
88
 
94
- def _as_datetime_list(
95
- date: Union[datetime.date, datetime.datetime, str], default_increment: datetime.timedelta
96
- ) -> iter:
89
+ def _as_datetime_list(date: datetime.date | datetime.datetime | str, default_increment: datetime.timedelta) -> iter:
97
90
  """Convert a date to a list of datetime objects.
98
91
 
99
92
  Parameters
@@ -137,7 +130,7 @@ def _as_datetime_list(
137
130
 
138
131
 
139
132
  def as_datetime_list(
140
- date: Union[datetime.date, datetime.datetime, str], default_increment: int = 1
133
+ date: datetime.date | datetime.datetime | str, default_increment: int = 1
141
134
  ) -> list[datetime.datetime]:
142
135
  """Convert a date to a list of datetime objects.
143
136
 
@@ -157,7 +150,7 @@ def as_datetime_list(
157
150
  return list(_as_datetime_list(date, default_increment))
158
151
 
159
152
 
160
- def as_timedelta(frequency: Union[int, str, datetime.timedelta]) -> datetime.timedelta:
153
+ def as_timedelta(frequency: int | str | datetime.timedelta) -> datetime.timedelta:
161
154
  """Convert anything to a timedelta object.
162
155
 
163
156
  Parameters
@@ -199,6 +192,15 @@ def as_timedelta(frequency: Union[int, str, datetime.timedelta]) -> datetime.tim
199
192
  except ValueError:
200
193
  pass
201
194
 
195
+ if frequency.startswith(" ") or frequency.startswith(" "):
196
+ frequency = frequency.strip()
197
+
198
+ if frequency.startswith("-"):
199
+ return -as_timedelta(frequency[1:])
200
+
201
+ if frequency.startswith("+"):
202
+ return as_timedelta(frequency[1:])
203
+
202
204
  if re.match(r"^\d+[hdms]$", frequency, re.IGNORECASE):
203
205
  unit = frequency[-1].lower()
204
206
  v = int(frequency[:-1])
@@ -228,7 +230,7 @@ def as_timedelta(frequency: Union[int, str, datetime.timedelta]) -> datetime.tim
228
230
  raise ValueError(f"Cannot convert frequency {frequency} to timedelta")
229
231
 
230
232
 
231
- def frequency_to_timedelta(frequency: Union[int, str, datetime.timedelta]) -> datetime.timedelta:
233
+ def frequency_to_timedelta(frequency: int | str | datetime.timedelta) -> datetime.timedelta:
232
234
  """Convert a frequency to a timedelta object.
233
235
 
234
236
  Parameters
@@ -261,6 +263,8 @@ def frequency_to_string(frequency: datetime.timedelta) -> str:
261
263
  frequency = frequency_to_timedelta(frequency)
262
264
 
263
265
  total_seconds = frequency.total_seconds()
266
+ if total_seconds < 0:
267
+ return f"-{frequency_to_string(-frequency)}"
264
268
  assert int(total_seconds) == total_seconds, total_seconds
265
269
  total_seconds = int(total_seconds)
266
270
 
@@ -291,7 +295,7 @@ def frequency_to_string(frequency: datetime.timedelta) -> str:
291
295
  return str(frequency)
292
296
 
293
297
 
294
- def frequency_to_seconds(frequency: Union[int, str, datetime.timedelta]) -> int:
298
+ def frequency_to_seconds(frequency: int | str | datetime.timedelta) -> int:
295
299
  """Convert a frequency to seconds.
296
300
 
297
301
  Parameters
@@ -336,7 +340,7 @@ MONTH = {
336
340
  }
337
341
 
338
342
 
339
- def _make_day(day: Optional[Tuple[int, List[int]]]) -> Set[int]:
343
+ def _make_day(day: tuple[int, list[int]] | None) -> set[int]:
340
344
  """Create a set of days.
341
345
 
342
346
  Parameters
@@ -356,7 +360,7 @@ def _make_day(day: Optional[Tuple[int, List[int]]]) -> Set[int]:
356
360
  return {int(d) for d in day}
357
361
 
358
362
 
359
- def _make_week(week: Optional[Tuple[str, List[str]]]) -> Set[int]:
363
+ def _make_week(week: tuple[str, list[str]] | None) -> set[int]:
360
364
  """Create a set of weekdays.
361
365
 
362
366
  Parameters
@@ -376,7 +380,7 @@ def _make_week(week: Optional[Tuple[str, List[str]]]) -> Set[int]:
376
380
  return {DOW[w.lower()] for w in week}
377
381
 
378
382
 
379
- def _make_months(months: Optional[Union[int, str, List[Union[int, str]]]]) -> Set[int]:
383
+ def _make_months(months: int | str | list[int | str] | None) -> set[int]:
380
384
  """Create a set of months.
381
385
 
382
386
  Parameters
@@ -403,13 +407,13 @@ class DateTimes:
403
407
 
404
408
  def __init__(
405
409
  self,
406
- start: Union[datetime.date, datetime.datetime, str],
407
- end: Union[datetime.date, datetime.datetime, str],
410
+ start: datetime.date | datetime.datetime | str,
411
+ end: datetime.date | datetime.datetime | str,
408
412
  increment: int = 24,
409
413
  *,
410
- day_of_month: Optional[Tuple[int, List[int]]] = None,
411
- day_of_week: Optional[Tuple[str, List[str]]] = None,
412
- calendar_months: Optional[Union[int, str, List[Union[int, str]]]] = None,
414
+ day_of_month: tuple[int, list[int]] | None = None,
415
+ day_of_week: tuple[str, list[str]] | None = None,
416
+ calendar_months: int | str | list[int | str] | None = None,
413
417
  ):
414
418
  """Initialize the DateTimes iterator.
415
419
 
@@ -560,7 +564,7 @@ class ConcatDateTimes:
560
564
  class EnumDateTimes:
561
565
  """EnumDateTimes is an iterator that generates datetime objects from a list of dates."""
562
566
 
563
- def __init__(self, dates: list[Union[datetime.date, datetime.datetime, str]]):
567
+ def __init__(self, dates: list[datetime.date | datetime.datetime | str]):
564
568
  """Initialize the EnumDateTimes iterator.
565
569
 
566
570
  Parameters
@@ -582,7 +586,7 @@ class EnumDateTimes:
582
586
  yield as_datetime(date)
583
587
 
584
588
 
585
- def datetimes_factory(*args: Any, **kwargs: Any) -> Union[DateTimes, ConcatDateTimes, EnumDateTimes]:
589
+ def datetimes_factory(*args: Any, **kwargs: Any) -> DateTimes | ConcatDateTimes | EnumDateTimes:
586
590
  """Create a DateTimes, ConcatDateTimes, or EnumDateTimes object.
587
591
 
588
592
  Parameters
anemoi/utils/grib.py CHANGED
@@ -15,8 +15,6 @@ See https://codes.ecmwf.int/grib/param-db/ for more information.
15
15
 
16
16
  import logging
17
17
  import re
18
- from typing import Dict
19
- from typing import Union
20
18
 
21
19
  import requests
22
20
 
@@ -26,7 +24,7 @@ LOG = logging.getLogger(__name__)
26
24
 
27
25
 
28
26
  @cached(collection="grib", expires=30 * 24 * 60 * 60)
29
- def _units() -> Dict[str, str]:
27
+ def _units() -> dict[str, str]:
30
28
  """Fetch and cache GRIB parameter units.
31
29
 
32
30
  Returns
@@ -41,7 +39,7 @@ def _units() -> Dict[str, str]:
41
39
 
42
40
 
43
41
  @cached(collection="grib", expires=30 * 24 * 60 * 60)
44
- def _search_param(name: str) -> Dict[str, Union[str, int]]:
42
+ def _search_param(name: str) -> dict[str, str | int]:
45
43
  """Search for a GRIB parameter by name.
46
44
 
47
45
  Parameters
@@ -116,7 +114,7 @@ def paramid_to_shortname(paramid: int) -> str:
116
114
  return _search_param(str(paramid))["shortname"]
117
115
 
118
116
 
119
- def units(param: Union[int, str]) -> str:
117
+ def units(param: int | str) -> str:
120
118
  """Return the units of a GRIB parameter given its name or id.
121
119
 
122
120
  Parameters
@@ -137,7 +135,7 @@ def units(param: Union[int, str]) -> str:
137
135
  return _units()[unit_id]
138
136
 
139
137
 
140
- def must_be_positive(param: Union[int, str]) -> bool:
138
+ def must_be_positive(param: int | str) -> bool:
141
139
  """Check if a parameter must be positive.
142
140
 
143
141
  Parameters
anemoi/utils/grids.py CHANGED
@@ -13,9 +13,6 @@
13
13
  import logging
14
14
  import os
15
15
  from io import BytesIO
16
- from typing import List
17
- from typing import Tuple
18
- from typing import Union
19
16
 
20
17
  import deprecation
21
18
  import numpy as np
@@ -145,7 +142,7 @@ def nearest_grid_points(
145
142
 
146
143
 
147
144
  @cached(collection="grids", encoding="npz")
148
- def _grids(name: Union[str, List[float], Tuple[float, ...]]) -> bytes:
145
+ def _grids(name: str | list[float] | tuple[float, ...]) -> bytes:
149
146
  """Get grid data by name.
150
147
 
151
148
  Parameters
@@ -196,7 +193,7 @@ def _grids(name: Union[str, List[float], Tuple[float, ...]]) -> bytes:
196
193
  current_version=__version__,
197
194
  details="Use anemoi.transform.grids.named.lookup instead.",
198
195
  )
199
- def grids(name: Union[str, List[float], Tuple[float, ...]]) -> dict:
196
+ def grids(name: str | list[float] | tuple[float, ...]) -> dict:
200
197
  """Load grid data by name.
201
198
 
202
199
  Parameters
anemoi/utils/hindcasts.py CHANGED
@@ -9,9 +9,7 @@
9
9
 
10
10
 
11
11
  import datetime
12
- from typing import Iterator
13
- from typing import List
14
- from typing import Tuple
12
+ from collections.abc import Iterator
15
13
 
16
14
 
17
15
  class HindcastDatesTimes:
@@ -25,7 +23,7 @@ class HindcastDatesTimes:
25
23
  Number of years to go back from each reference date.
26
24
  """
27
25
 
28
- def __init__(self, reference_dates: List[datetime.datetime], years: int = 20):
26
+ def __init__(self, reference_dates: list[datetime.datetime], years: int = 20):
29
27
  """Initialize the HindcastDatesTimes iterator.
30
28
 
31
29
  Parameters
@@ -41,7 +39,7 @@ class HindcastDatesTimes:
41
39
  assert years > 0, f"years must be greater than 0, got {years}"
42
40
  self.years = years
43
41
 
44
- def __iter__(self) -> Iterator[Tuple[datetime.datetime, datetime.datetime]]:
42
+ def __iter__(self) -> Iterator[tuple[datetime.datetime, datetime.datetime]]:
45
43
  """Generate tuples of past dates and their corresponding reference dates.
46
44
 
47
45
  Yields