anemoi-utils 0.4.11__py3-none-any.whl → 0.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

@@ -12,10 +12,12 @@ from __future__ import annotations
12
12
  import functools
13
13
  from typing import Any
14
14
  from typing import Callable
15
+ from typing import Optional
16
+ from typing import Union
15
17
 
16
18
 
17
19
  def aliases(
18
- aliases: dict[str, str | list[str]] | None = None, **kwargs: str | list[str]
20
+ aliases: Optional[dict[str, Union[str, list[str]]]] = None, **kwargs: Any
19
21
  ) -> Callable[[Callable], Callable]:
20
22
  """Alias keyword arguments in a function call.
21
23
 
@@ -23,10 +25,10 @@ def aliases(
23
25
 
24
26
  Parameters
25
27
  ----------
26
- aliases : dict[str, str | list[str]] | None, optional
28
+ aliases : dict[str, Union[str, list[str]]], optional
27
29
  Key, value pair of aliases, with keys being the true name, and value being a str or list of aliases,
28
30
  by default None
29
- **kwargs : str | list[str]
31
+ **kwargs : Any
30
32
  Kwargs form of aliases
31
33
 
32
34
  Returns
@@ -49,7 +51,6 @@ def aliases(
49
51
  func(a=1, c=2) # (1, 2)
50
52
  func(b=1, d=2) # (1, 2)
51
53
  ```
52
-
53
54
  """
54
55
 
55
56
  if aliases is None:
@@ -60,7 +61,7 @@ def aliases(
60
61
 
61
62
  def decorator(func: Callable) -> Callable:
62
63
  @functools.wraps(func)
63
- def wrapper(*args, **kwargs) -> Any:
64
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
64
65
  keys = kwargs.keys()
65
66
  for k in set(keys).intersection(set(aliases.keys())):
66
67
  if aliases[k] in keys:
anemoi/utils/config.py CHANGED
@@ -14,6 +14,9 @@ import json
14
14
  import logging
15
15
  import os
16
16
  import threading
17
+ from typing import Any
18
+ from typing import Optional
19
+ from typing import Union
17
20
 
18
21
  import yaml
19
22
 
@@ -43,10 +46,18 @@ class DotDict(dict):
43
46
  The DotDict class has the same constructor as the dict class.
44
47
 
45
48
  >>> d = DotDict(a=1, b=2)
46
-
47
49
  """
48
50
 
49
51
  def __init__(self, *args, **kwargs):
52
+ """Initialize a DotDict instance.
53
+
54
+ Parameters
55
+ ----------
56
+ *args : tuple
57
+ Positional arguments for the dict constructor.
58
+ **kwargs : dict
59
+ Keyword arguments for the dict constructor.
60
+ """
50
61
  super().__init__(*args, **kwargs)
51
62
 
52
63
  for k, v in self.items():
@@ -60,7 +71,19 @@ class DotDict(dict):
60
71
  self[k] = [DotDict(i) if isinstance(i, dict) or is_omegaconf_dict(i) else i for i in v]
61
72
 
62
73
  @classmethod
63
- def from_file(cls, path: str):
74
+ def from_file(cls, path: str) -> DotDict:
75
+ """Create a DotDict from a file.
76
+
77
+ Parameters
78
+ ----------
79
+ path : str
80
+ The path to the file.
81
+
82
+ Returns
83
+ -------
84
+ DotDict
85
+ The created DotDict.
86
+ """
64
87
  _, ext = os.path.splitext(path)
65
88
  if ext == ".yaml" or ext == ".yml":
66
89
  return cls.from_yaml_file(path)
@@ -72,41 +95,117 @@ class DotDict(dict):
72
95
  raise ValueError(f"Unknown file extension {ext}")
73
96
 
74
97
  @classmethod
75
- def from_yaml_file(cls, path: str):
98
+ def from_yaml_file(cls, path: str) -> DotDict:
99
+ """Create a DotDict from a YAML file.
100
+
101
+ Parameters
102
+ ----------
103
+ path : str
104
+ The path to the YAML file.
105
+
106
+ Returns
107
+ -------
108
+ DotDict
109
+ The created DotDict.
110
+ """
76
111
  with open(path, "r") as file:
77
112
  data = yaml.safe_load(file)
78
113
 
79
114
  return cls(data)
80
115
 
81
116
  @classmethod
82
- def from_json_file(cls, path: str):
117
+ def from_json_file(cls, path: str) -> DotDict:
118
+ """Create a DotDict from a JSON file.
119
+
120
+ Parameters
121
+ ----------
122
+ path : str
123
+ The path to the JSON file.
124
+
125
+ Returns
126
+ -------
127
+ DotDict
128
+ The created DotDict.
129
+ """
83
130
  with open(path, "r") as file:
84
131
  data = json.load(file)
85
132
 
86
133
  return cls(data)
87
134
 
88
135
  @classmethod
89
- def from_toml_file(cls, path: str):
136
+ def from_toml_file(cls, path: str) -> DotDict:
137
+ """Create a DotDict from a TOML file.
138
+
139
+ Parameters
140
+ ----------
141
+ path : str
142
+ The path to the TOML file.
143
+
144
+ Returns
145
+ -------
146
+ DotDict
147
+ The created DotDict.
148
+ """
90
149
  with open(path, "r") as file:
91
150
  data = tomllib.load(file)
92
151
  return cls(data)
93
152
 
94
- def __getattr__(self, attr):
153
+ def __getattr__(self, attr: str) -> Any:
154
+ """Get an attribute.
155
+
156
+ Parameters
157
+ ----------
158
+ attr : str
159
+ The attribute name.
160
+
161
+ Returns
162
+ -------
163
+ Any
164
+ The attribute value.
165
+ """
95
166
  try:
96
167
  return self[attr]
97
168
  except KeyError:
98
169
  raise AttributeError(attr)
99
170
 
100
- def __setattr__(self, attr, value):
171
+ def __setattr__(self, attr: str, value: Any) -> None:
172
+ """Set an attribute.
173
+
174
+ Parameters
175
+ ----------
176
+ attr : str
177
+ The attribute name.
178
+ value : Any
179
+ The attribute value.
180
+ """
101
181
  if isinstance(value, dict):
102
182
  value = DotDict(value)
103
183
  self[attr] = value
104
184
 
105
185
  def __repr__(self) -> str:
186
+ """Return a string representation of the DotDict.
187
+
188
+ Returns
189
+ -------
190
+ str
191
+ The string representation.
192
+ """
106
193
  return f"DotDict({super().__repr__()})"
107
194
 
108
195
 
109
- def is_omegaconf_dict(value) -> bool:
196
+ def is_omegaconf_dict(value: Any) -> bool:
197
+ """Check if a value is an OmegaConf DictConfig.
198
+
199
+ Parameters
200
+ ----------
201
+ value : Any
202
+ The value to check.
203
+
204
+ Returns
205
+ -------
206
+ bool
207
+ True if the value is a DictConfig, False otherwise.
208
+ """
110
209
  try:
111
210
  from omegaconf import DictConfig
112
211
 
@@ -115,7 +214,19 @@ def is_omegaconf_dict(value) -> bool:
115
214
  return False
116
215
 
117
216
 
118
- def is_omegaconf_list(value) -> bool:
217
+ def is_omegaconf_list(value: Any) -> bool:
218
+ """Check if a value is an OmegaConf ListConfig.
219
+
220
+ Parameters
221
+ ----------
222
+ value : Any
223
+ The value to check.
224
+
225
+ Returns
226
+ -------
227
+ bool
228
+ True if the value is a ListConfig, False otherwise.
229
+ """
119
230
  try:
120
231
  from omegaconf import ListConfig
121
232
 
@@ -130,7 +241,23 @@ CONFIG_LOCK = threading.RLock()
130
241
  QUIET = False
131
242
 
132
243
 
133
- def _find(config, what, result=None):
244
+ def _find(config: Union[dict, list], what: str, result: list = None) -> list:
245
+ """Find all occurrences of a key in a nested dictionary or list.
246
+
247
+ Parameters
248
+ ----------
249
+ config : dict or list
250
+ The configuration to search.
251
+ what : str
252
+ The key to search for.
253
+ result : list, optional
254
+ The list to store results, by default None.
255
+
256
+ Returns
257
+ -------
258
+ list
259
+ The list of found values.
260
+ """
134
261
  if result is None:
135
262
  result = []
136
263
 
@@ -149,7 +276,16 @@ def _find(config, what, result=None):
149
276
  return result
150
277
 
151
278
 
152
- def _merge_dicts(a, b):
279
+ def _merge_dicts(a: dict, b: dict) -> None:
280
+ """Merge two dictionaries recursively.
281
+
282
+ Parameters
283
+ ----------
284
+ a : dict
285
+ The first dictionary.
286
+ b : dict
287
+ The second dictionary.
288
+ """
153
289
  for k, v in b.items():
154
290
  if k in a and isinstance(a[k], dict) and isinstance(v, dict):
155
291
  _merge_dicts(a[k], v)
@@ -157,7 +293,16 @@ def _merge_dicts(a, b):
157
293
  a[k] = v
158
294
 
159
295
 
160
- def _set_defaults(a, b):
296
+ def _set_defaults(a: dict, b: dict) -> None:
297
+ """Set default values in a dictionary.
298
+
299
+ Parameters
300
+ ----------
301
+ a : dict
302
+ The dictionary to set defaults in.
303
+ b : dict
304
+ The dictionary with default values.
305
+ """
161
306
  for k, v in b.items():
162
307
  if k in a and isinstance(a[k], dict) and isinstance(v, dict):
163
308
  _set_defaults(a[k], v)
@@ -165,7 +310,19 @@ def _set_defaults(a, b):
165
310
  a.setdefault(k, v)
166
311
 
167
312
 
168
- def config_path(name="settings.toml"):
313
+ def config_path(name: str = "settings.toml") -> str:
314
+ """Get the path to a configuration file.
315
+
316
+ Parameters
317
+ ----------
318
+ name : str, optional
319
+ The name of the configuration file, by default "settings.toml".
320
+
321
+ Returns
322
+ -------
323
+ str
324
+ The path to the configuration file.
325
+ """
169
326
  global QUIET
170
327
 
171
328
  if name.startswith("/") or name.startswith("."):
@@ -197,7 +354,7 @@ def config_path(name="settings.toml"):
197
354
  return full
198
355
 
199
356
 
200
- def load_any_dict_format(path) -> dict:
357
+ def load_any_dict_format(path: str) -> dict:
201
358
  """Load a configuration file in any supported format: JSON, YAML and TOML.
202
359
 
203
360
  Parameters
@@ -247,8 +404,27 @@ def load_any_dict_format(path) -> dict:
247
404
  return open(path).read()
248
405
 
249
406
 
250
- def _load_config(name="settings.toml", secrets=None, defaults=None):
407
+ def _load_config(
408
+ name: str = "settings.toml",
409
+ secrets: Optional[Union[str, list[str]]] = None,
410
+ defaults: Optional[Union[str, dict]] = None,
411
+ ) -> DotDict:
412
+ """Load a configuration file.
251
413
 
414
+ Parameters
415
+ ----------
416
+ name : str, optional
417
+ The name of the configuration file, by default "settings.toml".
418
+ secrets : str or list, optional
419
+ The name of the secrets file, by default None.
420
+ defaults : str or dict, optional
421
+ The name of the defaults file, by default None.
422
+
423
+ Returns
424
+ -------
425
+ DotDict
426
+ The loaded configuration.
427
+ """
252
428
  key = json.dumps((name, secrets, defaults), sort_keys=True, default=str)
253
429
  if key in CONFIG:
254
430
  return CONFIG[key]
@@ -287,7 +463,16 @@ def _load_config(name="settings.toml", secrets=None, defaults=None):
287
463
  return CONFIG[key]
288
464
 
289
465
 
290
- def _save_config(name, data) -> None:
466
+ def _save_config(name: str, data: Any) -> None:
467
+ """Save a configuration file.
468
+
469
+ Parameters
470
+ ----------
471
+ name : str
472
+ The name of the configuration file.
473
+ data : Any
474
+ The data to save.
475
+ """
291
476
  CONFIG.pop(name, None)
292
477
 
293
478
  conf = config_path(name)
@@ -309,7 +494,7 @@ def _save_config(name, data) -> None:
309
494
  f.write(data)
310
495
 
311
496
 
312
- def save_config(name, data) -> None:
497
+ def save_config(name: str, data: Any) -> None:
313
498
  """Save a configuration file.
314
499
 
315
500
  Parameters
@@ -319,13 +504,16 @@ def save_config(name, data) -> None:
319
504
 
320
505
  data : Any
321
506
  The data to save.
322
-
323
507
  """
324
508
  with CONFIG_LOCK:
325
509
  _save_config(name, data)
326
510
 
327
511
 
328
- def load_config(name="settings.toml", secrets=None, defaults=None) -> DotDict | str:
512
+ def load_config(
513
+ name: str = "settings.toml",
514
+ secrets: Optional[Union[str, list[str]]] = None,
515
+ defaults: Optional[Union[str, dict]] = None,
516
+ ) -> DotDict | str:
329
517
  """Read a configuration file.
330
518
 
331
519
  Parameters
@@ -347,8 +535,21 @@ def load_config(name="settings.toml", secrets=None, defaults=None) -> DotDict |
347
535
  return _load_config(name, secrets, defaults)
348
536
 
349
537
 
350
- def load_raw_config(name, default=None) -> DotDict | str:
538
+ def load_raw_config(name: str, default: Any = None) -> Union[DotDict, str]:
539
+ """Load a raw configuration file.
540
+
541
+ Parameters
542
+ ----------
543
+ name : str
544
+ The name of the configuration file.
545
+ default : Any, optional
546
+ The default value if the file does not exist, by default None.
351
547
 
548
+ Returns
549
+ -------
550
+ DotDict or str
551
+ The loaded configuration or the default value.
552
+ """
352
553
  path = config_path(name)
353
554
  if os.path.exists(path):
354
555
  return load_any_dict_format(path)
@@ -356,7 +557,7 @@ def load_raw_config(name, default=None) -> DotDict | str:
356
557
  return default
357
558
 
358
559
 
359
- def check_config_mode(name="settings.toml", secrets_name=None, secrets=None) -> None:
560
+ def check_config_mode(name: str = "settings.toml", secrets_name: str = None, secrets: list[str] = None) -> None:
360
561
  """Check that a configuration file is secure.
361
562
 
362
563
  Parameters
@@ -393,7 +594,25 @@ def check_config_mode(name="settings.toml", secrets_name=None, secrets=None) ->
393
594
  CHECKED[name] = True
394
595
 
395
596
 
396
- def find(metadata, what, result=None, *, select: callable = None):
597
+ def find(metadata: Union[dict, list], what: str, result: list = None, *, select: callable = None) -> list:
598
+ """Find all occurrences of a key in a nested dictionary or list with an optional selector.
599
+
600
+ Parameters
601
+ ----------
602
+ metadata : dict or list
603
+ The metadata to search.
604
+ what : str
605
+ The key to search for.
606
+ result : list, optional
607
+ The list to store results, by default None.
608
+ select : callable, optional
609
+ A function to filter the results, by default None.
610
+
611
+ Returns
612
+ -------
613
+ list
614
+ The list of found values.
615
+ """
397
616
  if result is None:
398
617
  result = []
399
618
 
@@ -413,7 +632,19 @@ def find(metadata, what, result=None, *, select: callable = None):
413
632
  return result
414
633
 
415
634
 
416
- def merge_configs(*configs):
635
+ def merge_configs(*configs: dict) -> dict:
636
+ """Merge multiple configuration dictionaries.
637
+
638
+ Parameters
639
+ ----------
640
+ *configs : dict
641
+ The configuration dictionaries to merge.
642
+
643
+ Returns
644
+ -------
645
+ dict
646
+ The merged configuration dictionary.
647
+ """
417
648
  result = {}
418
649
  for config in configs:
419
650
  _merge_dicts(result, config)