anemoi-utils 0.4.12__py3-none-any.whl → 0.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of anemoi-utils might be problematic. Click here for more details.

anemoi/utils/logs.py CHANGED
@@ -19,17 +19,45 @@ thread_local = threading.local()
19
19
  LOGGER = logging.getLogger(__name__)
20
20
 
21
21
 
22
- def set_logging_name(name):
22
+ def set_logging_name(name: str) -> None:
23
+ """Set the logging name for the current thread.
24
+
25
+ Parameters
26
+ ----------
27
+ name : str
28
+ The name to set for logging.
29
+ """
23
30
  thread_local.logging_name = name
24
31
 
25
32
 
26
33
  class ThreadCustomFormatter(logging.Formatter):
27
- def format(self, record):
34
+ """Custom logging formatter that includes thread-specific logging names."""
35
+
36
+ def format(self, record: logging.LogRecord) -> str:
37
+ """Format the log record to include the thread-specific logging name.
38
+
39
+ Parameters
40
+ ----------
41
+ record : logging.LogRecord
42
+ The log record to format.
43
+
44
+ Returns
45
+ -------
46
+ str
47
+ The formatted log record.
48
+ """
28
49
  record.logging_name = thread_local.logging_name
29
50
  return super().format(record)
30
51
 
31
52
 
32
- def enable_logging_name(name="main"):
53
+ def enable_logging_name(name: str = "main") -> None:
54
+ """Enable logging with a thread-specific logging name.
55
+
56
+ Parameters
57
+ ----------
58
+ name : str, optional
59
+ The default logging name to set, by default "main".
60
+ """
33
61
  thread_local.logging_name = name
34
62
 
35
63
  formatter = ThreadCustomFormatter("%(asctime)s - %(logging_name)s - %(levelname)s - %(message)s")
@@ -10,13 +10,16 @@
10
10
 
11
11
  """Utilities for working with Mars requests.
12
12
 
13
- Has some konwledge of how certain streams are organised in Mars.
14
-
13
+ Has some knowledge of how certain streams are organised in Mars.
15
14
  """
16
15
 
17
16
  import datetime
18
17
  import logging
19
18
  import os
19
+ from typing import Any
20
+ from typing import Dict
21
+ from typing import Optional
22
+ from typing import Tuple
20
23
 
21
24
  import yaml
22
25
 
@@ -30,16 +33,18 @@ DEFAULT_MARS_LABELLING = {
30
33
  }
31
34
 
32
35
 
33
- def _expand_mars_labelling(request):
36
+ def _expand_mars_labelling(request: Dict[str, Any]) -> Dict[str, Any]:
34
37
  """Expand the request with the default Mars labelling.
35
38
 
36
- The default Mars labelling is:
37
-
38
- {'class': 'od',
39
- 'type': 'an',
40
- 'stream': 'oper',
41
- 'expver': '0001'}
39
+ Parameters
40
+ ----------
41
+ request : dict
42
+ The original Mars request.
42
43
 
44
+ Returns
45
+ -------
46
+ dict
47
+ The Mars request expanded with default labelling.
43
48
  """
44
49
  result = DEFAULT_MARS_LABELLING.copy()
45
50
  result.update(request)
@@ -49,7 +54,19 @@ def _expand_mars_labelling(request):
49
54
  STREAMS = None
50
55
 
51
56
 
52
- def _lookup_mars_stream(request):
57
+ def _lookup_mars_stream(request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
58
+ """Look up the Mars stream information for a given request.
59
+
60
+ Parameters
61
+ ----------
62
+ request : dict
63
+ The Mars request.
64
+
65
+ Returns
66
+ -------
67
+ dict or None
68
+ The stream information if a match is found, otherwise None.
69
+ """
53
70
  global STREAMS
54
71
 
55
72
  if STREAMS is None:
@@ -64,8 +81,25 @@ def _lookup_mars_stream(request):
64
81
  return s["info"]
65
82
 
66
83
 
67
- def recenter(date, center, members):
68
-
84
+ def recenter(
85
+ date: datetime.datetime, center: Dict[str, Any], members: Dict[str, Any]
86
+ ) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
87
+ """Recenter the given date with the specified center and members.
88
+
89
+ Parameters
90
+ ----------
91
+ date : datetime.datetime
92
+ The date to recenter.
93
+ center : dict
94
+ The center request information.
95
+ members : dict
96
+ The members request information.
97
+
98
+ Returns
99
+ -------
100
+ tuple
101
+ A tuple containing the recentered center and members information.
102
+ """
69
103
  center = _lookup_mars_stream(center)
70
104
  members = _lookup_mars_stream(members)
71
105
 
@@ -6,9 +6,23 @@
6
6
  # nor does it submit to any jurisdiction.
7
7
 
8
8
  import sys
9
+ from typing import Any
10
+ from typing import Dict
11
+ from typing import TextIO
9
12
 
10
13
 
11
- def print_request(verb, request, file=sys.stdout):
14
+ def print_request(verb: str, request: Dict[str, Any], file: TextIO = sys.stdout) -> None:
15
+ """Prints a formatted request.
16
+
17
+ Parameters
18
+ ----------
19
+ verb : str
20
+ A mars verb
21
+ request : Dict[str, Any]
22
+ The request parameters.
23
+ file : TextIO, optional
24
+ The file to which the request is printed, by default sys.stdout.
25
+ """
12
26
  r = [verb]
13
27
  for k, v in request.items():
14
28
  if not isinstance(v, (list, tuple, set)):
@@ -14,7 +14,6 @@
14
14
  - The versions of the modules which are currently loaded
15
15
  - The git information for the modules which are currently loaded from a git repository
16
16
  - ...
17
-
18
17
  """
19
18
 
20
19
  import datetime
@@ -25,11 +24,29 @@ import subprocess
25
24
  import sys
26
25
  import sysconfig
27
26
  from functools import cache
27
+ from typing import Any
28
+ from typing import Dict
29
+ from typing import List
30
+ from typing import Optional
31
+ from typing import Tuple
32
+ from typing import Union
28
33
 
29
34
  LOG = logging.getLogger(__name__)
30
35
 
31
36
 
32
- def lookup_git_repo(path):
37
+ def lookup_git_repo(path: str) -> Optional[Any]:
38
+ """Lookup the git repository for a given path.
39
+
40
+ Parameters
41
+ ----------
42
+ path : str
43
+ The path to lookup.
44
+
45
+ Returns
46
+ -------
47
+ Repo, optional
48
+ The git repository if found, otherwise None.
49
+ """
33
50
  from git import InvalidGitRepositoryError
34
51
  from git import Repo
35
52
 
@@ -42,7 +59,21 @@ def lookup_git_repo(path):
42
59
  return None
43
60
 
44
61
 
45
- def _check_for_git(paths, full):
62
+ def _check_for_git(paths: List[Tuple[str, str]], full: bool) -> Dict[str, Any]:
63
+ """Check for git information for the given paths.
64
+
65
+ Parameters
66
+ ----------
67
+ paths : list of tuple
68
+ The list of paths to check.
69
+ full : bool
70
+ Whether to collect full information.
71
+
72
+ Returns
73
+ -------
74
+ dict
75
+ The git information for the given paths.
76
+ """
46
77
  versions = {}
47
78
  for name, path in paths:
48
79
  repo = lookup_git_repo(path)
@@ -77,7 +108,28 @@ def _check_for_git(paths, full):
77
108
  return versions
78
109
 
79
110
 
80
- def version(versions, name, module, roots, namespaces, paths, full):
111
+ def version(
112
+ versions: Dict[str, Any], name: str, module: Any, roots: Dict[str, str], namespaces: set, paths: set, full: bool
113
+ ) -> None:
114
+ """Collect version information for a module.
115
+
116
+ Parameters
117
+ ----------
118
+ versions : dict
119
+ The dictionary to store the version information.
120
+ name : str
121
+ The name of the module.
122
+ module : Any
123
+ The module to collect information for.
124
+ roots : dict
125
+ The dictionary of root paths.
126
+ namespaces : set
127
+ The set of namespaces.
128
+ paths : set
129
+ The set of paths.
130
+ full : bool
131
+ Whether to collect full information.
132
+ """
81
133
  path = None
82
134
 
83
135
  if hasattr(module, "__file__"):
@@ -119,7 +171,19 @@ def version(versions, name, module, roots, namespaces, paths, full):
119
171
  versions[name] = str(module)
120
172
 
121
173
 
122
- def _module_versions(full):
174
+ def _module_versions(full: bool) -> Tuple[Dict[str, Any], set]:
175
+ """Collect version information for all loaded modules.
176
+
177
+ Parameters
178
+ ----------
179
+ full : bool
180
+ Whether to collect full information.
181
+
182
+ Returns
183
+ -------
184
+ tuple of dict and set
185
+ The version information and the set of paths.
186
+ """
123
187
  # https://docs.python.org/3/library/sysconfig.html
124
188
 
125
189
  roots = {}
@@ -149,7 +213,14 @@ def _module_versions(full):
149
213
 
150
214
 
151
215
  @cache
152
- def package_distributions() -> dict[str, list[str]]:
216
+ def package_distributions() -> Dict[str, List[str]]:
217
+ """Get the package distributions.
218
+
219
+ Returns
220
+ -------
221
+ dict
222
+ The package distributions.
223
+ """
153
224
  # Takes a significant amount of time to run
154
225
  # so cache the result
155
226
  from importlib import metadata
@@ -161,7 +232,20 @@ def package_distributions() -> dict[str, list[str]]:
161
232
  return metadata.packages_distributions()
162
233
 
163
234
 
164
- def import_name_to_distribution_name(packages: list):
235
+ def import_name_to_distribution_name(packages: List[str]) -> Dict[str, str]:
236
+ """Convert import names to distribution names.
237
+
238
+ Parameters
239
+ ----------
240
+ packages : list of str
241
+ The list of import names.
242
+
243
+ Returns
244
+ -------
245
+ dict
246
+ The dictionary mapping import names to distribution names.
247
+ """
248
+
165
249
  distribution_names = {}
166
250
  package_distribution_names = package_distributions()
167
251
 
@@ -179,13 +263,37 @@ def import_name_to_distribution_name(packages: list):
179
263
  return distribution_names
180
264
 
181
265
 
182
- def module_versions(full):
266
+ def module_versions(full: bool) -> Tuple[Dict[str, Any], Dict[str, Any]]:
267
+ """Collect version information for all loaded modules and their git information.
268
+
269
+ Parameters
270
+ ----------
271
+ full : bool
272
+ Whether to collect full information.
273
+
274
+ Returns
275
+ -------
276
+ tuple of dict and dict
277
+ The version information and the git information.
278
+ """
183
279
  versions, paths = _module_versions(full)
184
280
  git_versions = _check_for_git(paths, full)
185
281
  return versions, git_versions
186
282
 
187
283
 
188
- def _name(obj):
284
+ def _name(obj: Any) -> str:
285
+ """Get the name of an object.
286
+
287
+ Parameters
288
+ ----------
289
+ obj : Any
290
+ The object to get the name of.
291
+
292
+ Returns
293
+ -------
294
+ str
295
+ The name of the object.
296
+ """
189
297
  if hasattr(obj, "__name__"):
190
298
  if hasattr(obj, "__module__"):
191
299
  return f"{obj.__module__}.{obj.__name__}"
@@ -195,8 +303,19 @@ def _name(obj):
195
303
  return str(obj)
196
304
 
197
305
 
198
- def _paths(path_or_object):
306
+ def _paths(path_or_object: Union[None, str, List[str], Tuple[str], Any]) -> List[Tuple[str, str]]:
307
+ """Get the paths for a given path or object.
199
308
 
309
+ Parameters
310
+ ----------
311
+ path_or_object : str, list, tuple, or object
312
+ The path or object to get the paths for.
313
+
314
+ Returns
315
+ -------
316
+ list of tuple
317
+ The list of paths.
318
+ """
200
319
  if path_or_object is None:
201
320
  _, paths = _module_versions(full=False)
202
321
  return paths
@@ -235,7 +354,7 @@ def _paths(path_or_object):
235
354
  return paths
236
355
 
237
356
 
238
- def git_check(*args) -> dict:
357
+ def git_check(*args: Any) -> Dict[str, Any]:
239
358
  """Return the git information for the given arguments.
240
359
 
241
360
  Arguments can be:
@@ -255,18 +374,18 @@ def git_check(*args) -> dict:
255
374
  dict
256
375
  An object with the git information for the given arguments.
257
376
 
258
- >>> {
259
- "anemoi.utils": {
260
- "sha1": "c999d83ae283bcbb99f68d92c42d24315922129f",
261
- "remotes": [
262
- "git@github.com:ecmwf/anemoi-utils.git"
263
- ],
264
- "modified_files": [
265
- "anemoi/utils/checkpoints.py"
266
- ],
267
- "untracked_files": []
377
+ >>> {
378
+ "anemoi.utils": {
379
+ "sha1": "c999d83ae283bcbb99f68d92c42d24315922129f",
380
+ "remotes": [
381
+ "git@github.com:ecmwf/anemoi-utils.git"
382
+ ],
383
+ "modified_files": [
384
+ "anemoi/utils/checkpoints.py"
385
+ ],
386
+ "untracked_files": []
387
+ }
268
388
  }
269
- }
270
389
  """
271
390
  paths = _paths(args if len(args) > 0 else None)
272
391
 
@@ -278,7 +397,14 @@ def git_check(*args) -> dict:
278
397
  return result
279
398
 
280
399
 
281
- def platform_info():
400
+ def platform_info() -> Dict[str, Any]:
401
+ """Get the platform information.
402
+
403
+ Returns
404
+ -------
405
+ dict
406
+ The platform information.
407
+ """
282
408
  import platform
283
409
 
284
410
  r = {}
@@ -300,7 +426,14 @@ def platform_info():
300
426
  return r
301
427
 
302
428
 
303
- def gpu_info():
429
+ def gpu_info() -> Union[str, List[Dict[str, Any]]]:
430
+ """Get the GPU information.
431
+
432
+ Returns
433
+ -------
434
+ str or list of dict
435
+ The GPU information or an error message.
436
+ """
304
437
  import nvsmi
305
438
 
306
439
  if not nvsmi.is_nvidia_smi_on_path():
@@ -312,7 +445,19 @@ def gpu_info():
312
445
  return e.output.decode("utf-8").strip()
313
446
 
314
447
 
315
- def path_md5(path):
448
+ def path_md5(path: str) -> str:
449
+ """Calculate the MD5 hash of a file.
450
+
451
+ Parameters
452
+ ----------
453
+ path : str
454
+ The path to the file.
455
+
456
+ Returns
457
+ -------
458
+ str
459
+ The MD5 hash of the file.
460
+ """
316
461
  import hashlib
317
462
 
318
463
  hash = hashlib.md5()
@@ -322,7 +467,19 @@ def path_md5(path):
322
467
  return hash.hexdigest()
323
468
 
324
469
 
325
- def assets_info(paths):
470
+ def assets_info(paths: List[str]) -> Dict[str, Any]:
471
+ """Get information about the given assets.
472
+
473
+ Parameters
474
+ ----------
475
+ paths : list of str
476
+ The list of paths to the assets.
477
+
478
+ Returns
479
+ -------
480
+ dict
481
+ The information about the assets.
482
+ """
326
483
  result = {}
327
484
 
328
485
  for path in paths:
@@ -351,8 +508,8 @@ def assets_info(paths):
351
508
  return result
352
509
 
353
510
 
354
- def gather_provenance_info(assets=[], full=False) -> dict:
355
- """Gather information about the current environment
511
+ def gather_provenance_info(assets: List[str] = [], full: bool = False) -> Dict[str, Any]:
512
+ """Gather information about the current environment.
356
513
 
357
514
  Parameters
358
515
  ----------
anemoi/utils/registry.py CHANGED
@@ -12,6 +12,11 @@ import importlib
12
12
  import logging
13
13
  import os
14
14
  import sys
15
+ from typing import Any
16
+ from typing import Callable
17
+ from typing import Dict
18
+ from typing import Optional
19
+ from typing import Union
15
20
 
16
21
  import entrypoints
17
22
 
@@ -19,13 +24,33 @@ LOG = logging.getLogger(__name__)
19
24
 
20
25
 
21
26
  class Wrapper:
22
- """A wrapper for the registry"""
27
+ """A wrapper for the registry.
23
28
 
24
- def __init__(self, name, registry):
29
+ Parameters
30
+ ----------
31
+ name : str
32
+ The name of the wrapper.
33
+ registry : Registry
34
+ The registry to wrap.
35
+ """
36
+
37
+ def __init__(self, name: str, registry: "Registry"):
25
38
  self.name = name
26
39
  self.registry = registry
27
40
 
28
- def __call__(self, factory):
41
+ def __call__(self, factory: Callable) -> Callable:
42
+ """Register a factory with the registry.
43
+
44
+ Parameters
45
+ ----------
46
+ factory : Callable
47
+ The factory to register.
48
+
49
+ Returns
50
+ -------
51
+ Callable
52
+ The registered factory.
53
+ """
29
54
  self.registry.register(self.name, factory)
30
55
  return factory
31
56
 
@@ -34,10 +59,17 @@ _BY_KIND = {}
34
59
 
35
60
 
36
61
  class Registry:
37
- """A registry of factories"""
62
+ """A registry of factories.
38
63
 
39
- def __init__(self, package, key="_type"):
64
+ Parameters
65
+ ----------
66
+ package : str
67
+ The package name.
68
+ key : str, optional
69
+ The key to use for the registry, by default "_type".
70
+ """
40
71
 
72
+ def __init__(self, package: str, key: str = "_type"):
41
73
  self.package = package
42
74
  self.registered = {}
43
75
  self.kind = package.split(".")[-1]
@@ -45,11 +77,36 @@ class Registry:
45
77
  _BY_KIND[self.kind] = self
46
78
 
47
79
  @classmethod
48
- def lookup_kind(cls, kind: str):
80
+ def lookup_kind(cls, kind: str) -> Optional["Registry"]:
81
+ """Lookup a registry by kind.
82
+
83
+ Parameters
84
+ ----------
85
+ kind : str
86
+ The kind of the registry.
87
+
88
+ Returns
89
+ -------
90
+ Registry, optional
91
+ The registry if found, otherwise None.
92
+ """
49
93
  return _BY_KIND.get(kind)
50
94
 
51
- def register(self, name: str, factory: callable = None):
52
-
95
+ def register(self, name: str, factory: Optional[Callable] = None) -> Optional[Wrapper]:
96
+ """Register a factory with the registry.
97
+
98
+ Parameters
99
+ ----------
100
+ name : str
101
+ The name of the factory.
102
+ factory : Callable, optional
103
+ The factory to register, by default None.
104
+
105
+ Returns
106
+ -------
107
+ Wrapper, optional
108
+ A wrapper if the factory is None, otherwise None.
109
+ """
53
110
  if factory is None:
54
111
  return Wrapper(name, self)
55
112
 
@@ -58,15 +115,35 @@ class Registry:
58
115
  # def registered(self, name: str):
59
116
  # return name in self.registered
60
117
 
61
- def _load(self, file):
118
+ def _load(self, file: str) -> None:
119
+ """Load a module from a file.
120
+
121
+ Parameters
122
+ ----------
123
+ file : str
124
+ The file to load.
125
+ """
62
126
  name, _ = os.path.splitext(file)
63
127
  try:
64
128
  importlib.import_module(f".{name}", package=self.package)
65
129
  except Exception:
66
130
  LOG.warning(f"Error loading filter '{self.package}.{name}'", exc_info=True)
67
131
 
68
- def lookup(self, name: str, *, return_none=False) -> callable:
69
-
132
+ def lookup(self, name: str, *, return_none: bool = False) -> Optional[Callable]:
133
+ """Lookup a factory by name.
134
+
135
+ Parameters
136
+ ----------
137
+ name : str
138
+ The name of the factory.
139
+ return_none : bool, optional
140
+ Whether to return None if the factory is not found, by default False.
141
+
142
+ Returns
143
+ -------
144
+ Callable, optional
145
+ The factory if found, otherwise None.
146
+ """
70
147
  # print('✅✅✅✅✅✅✅✅✅✅✅✅✅', name, self.registered)
71
148
  if name in self.registered:
72
149
  return self.registered[name]
@@ -110,14 +187,46 @@ class Registry:
110
187
 
111
188
  return self.registered[name]
112
189
 
113
- def create(self, name: str, *args, **kwargs):
190
+ def create(self, name: str, *args: Any, **kwargs: Any) -> Any:
191
+ """Create an instance using a factory.
192
+
193
+ Parameters
194
+ ----------
195
+ name : str
196
+ The name of the factory.
197
+ *args : Any
198
+ Positional arguments for the factory.
199
+ **kwargs : Any
200
+ Keyword arguments for the factory.
201
+
202
+ Returns
203
+ -------
204
+ Any
205
+ The created instance.
206
+ """
114
207
  factory = self.lookup(name)
115
208
  return factory(*args, **kwargs)
116
209
 
117
210
  # def __call__(self, name: str, *args, **kwargs):
118
211
  # return self.create(name, *args, **kwargs)
119
212
 
120
- def from_config(self, config, *args, **kwargs):
213
+ def from_config(self, config: Union[str, Dict[str, Any]], *args: Any, **kwargs: Any) -> Any:
214
+ """Create an instance from a configuration.
215
+
216
+ Parameters
217
+ ----------
218
+ config : str or dict
219
+ The configuration.
220
+ *args : Any
221
+ Positional arguments for the factory.
222
+ **kwargs : Any
223
+ Keyword arguments for the factory.
224
+
225
+ Returns
226
+ -------
227
+ Any
228
+ The created instance.
229
+ """
121
230
  if isinstance(config, str):
122
231
  config = {config: {}}
123
232