anemoi-utils 0.4.11__py3-none-any.whl → 0.4.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of anemoi-utils might be problematic. Click here for more details.
- anemoi/utils/__init__.py +1 -0
- anemoi/utils/__main__.py +12 -2
- anemoi/utils/_version.py +9 -4
- anemoi/utils/caching.py +138 -13
- anemoi/utils/checkpoints.py +81 -13
- anemoi/utils/cli.py +83 -7
- anemoi/utils/commands/__init__.py +4 -0
- anemoi/utils/commands/config.py +19 -2
- anemoi/utils/commands/requests.py +24 -4
- anemoi/utils/compatibility.py +6 -5
- anemoi/utils/config.py +254 -23
- anemoi/utils/dates.py +216 -55
- anemoi/utils/devtools.py +68 -7
- anemoi/utils/grib.py +30 -9
- anemoi/utils/grids.py +85 -8
- anemoi/utils/hindcasts.py +25 -8
- anemoi/utils/humanize.py +357 -52
- anemoi/utils/logs.py +31 -3
- anemoi/utils/mars/__init__.py +46 -12
- anemoi/utils/mars/requests.py +15 -1
- anemoi/utils/provenance.py +185 -28
- anemoi/utils/registry.py +122 -13
- anemoi/utils/remote/__init__.py +386 -38
- anemoi/utils/remote/s3.py +252 -29
- anemoi/utils/remote/ssh.py +140 -8
- anemoi/utils/s3.py +77 -4
- anemoi/utils/sanitise.py +52 -7
- anemoi/utils/text.py +218 -54
- anemoi/utils/timer.py +91 -15
- {anemoi_utils-0.4.11.dist-info → anemoi_utils-0.4.13.dist-info}/LICENSE +1 -1
- {anemoi_utils-0.4.11.dist-info → anemoi_utils-0.4.13.dist-info}/METADATA +7 -4
- anemoi_utils-0.4.13.dist-info/RECORD +37 -0
- {anemoi_utils-0.4.11.dist-info → anemoi_utils-0.4.13.dist-info}/WHEEL +1 -1
- anemoi_utils-0.4.11.dist-info/RECORD +0 -37
- {anemoi_utils-0.4.11.dist-info → anemoi_utils-0.4.13.dist-info}/entry_points.txt +0 -0
- {anemoi_utils-0.4.11.dist-info → anemoi_utils-0.4.13.dist-info}/top_level.txt +0 -0
anemoi/utils/logs.py
CHANGED
|
@@ -19,17 +19,45 @@ thread_local = threading.local()
|
|
|
19
19
|
LOGGER = logging.getLogger(__name__)
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
def set_logging_name(name):
|
|
22
|
+
def set_logging_name(name: str) -> None:
|
|
23
|
+
"""Set the logging name for the current thread.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
name : str
|
|
28
|
+
The name to set for logging.
|
|
29
|
+
"""
|
|
23
30
|
thread_local.logging_name = name
|
|
24
31
|
|
|
25
32
|
|
|
26
33
|
class ThreadCustomFormatter(logging.Formatter):
|
|
27
|
-
|
|
34
|
+
"""Custom logging formatter that includes thread-specific logging names."""
|
|
35
|
+
|
|
36
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
37
|
+
"""Format the log record to include the thread-specific logging name.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
record : logging.LogRecord
|
|
42
|
+
The log record to format.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
str
|
|
47
|
+
The formatted log record.
|
|
48
|
+
"""
|
|
28
49
|
record.logging_name = thread_local.logging_name
|
|
29
50
|
return super().format(record)
|
|
30
51
|
|
|
31
52
|
|
|
32
|
-
def enable_logging_name(name="main"):
|
|
53
|
+
def enable_logging_name(name: str = "main") -> None:
|
|
54
|
+
"""Enable logging with a thread-specific logging name.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
name : str, optional
|
|
59
|
+
The default logging name to set, by default "main".
|
|
60
|
+
"""
|
|
33
61
|
thread_local.logging_name = name
|
|
34
62
|
|
|
35
63
|
formatter = ThreadCustomFormatter("%(asctime)s - %(logging_name)s - %(levelname)s - %(message)s")
|
anemoi/utils/mars/__init__.py
CHANGED
|
@@ -10,13 +10,16 @@
|
|
|
10
10
|
|
|
11
11
|
"""Utilities for working with Mars requests.
|
|
12
12
|
|
|
13
|
-
Has some
|
|
14
|
-
|
|
13
|
+
Has some knowledge of how certain streams are organised in Mars.
|
|
15
14
|
"""
|
|
16
15
|
|
|
17
16
|
import datetime
|
|
18
17
|
import logging
|
|
19
18
|
import os
|
|
19
|
+
from typing import Any
|
|
20
|
+
from typing import Dict
|
|
21
|
+
from typing import Optional
|
|
22
|
+
from typing import Tuple
|
|
20
23
|
|
|
21
24
|
import yaml
|
|
22
25
|
|
|
@@ -30,16 +33,18 @@ DEFAULT_MARS_LABELLING = {
|
|
|
30
33
|
}
|
|
31
34
|
|
|
32
35
|
|
|
33
|
-
def _expand_mars_labelling(request):
|
|
36
|
+
def _expand_mars_labelling(request: Dict[str, Any]) -> Dict[str, Any]:
|
|
34
37
|
"""Expand the request with the default Mars labelling.
|
|
35
38
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
'stream': 'oper',
|
|
41
|
-
'expver': '0001'}
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
request : dict
|
|
42
|
+
The original Mars request.
|
|
42
43
|
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
dict
|
|
47
|
+
The Mars request expanded with default labelling.
|
|
43
48
|
"""
|
|
44
49
|
result = DEFAULT_MARS_LABELLING.copy()
|
|
45
50
|
result.update(request)
|
|
@@ -49,7 +54,19 @@ def _expand_mars_labelling(request):
|
|
|
49
54
|
STREAMS = None
|
|
50
55
|
|
|
51
56
|
|
|
52
|
-
def _lookup_mars_stream(request):
|
|
57
|
+
def _lookup_mars_stream(request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
58
|
+
"""Look up the Mars stream information for a given request.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
request : dict
|
|
63
|
+
The Mars request.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
dict or None
|
|
68
|
+
The stream information if a match is found, otherwise None.
|
|
69
|
+
"""
|
|
53
70
|
global STREAMS
|
|
54
71
|
|
|
55
72
|
if STREAMS is None:
|
|
@@ -64,8 +81,25 @@ def _lookup_mars_stream(request):
|
|
|
64
81
|
return s["info"]
|
|
65
82
|
|
|
66
83
|
|
|
67
|
-
def recenter(
|
|
68
|
-
|
|
84
|
+
def recenter(
|
|
85
|
+
date: datetime.datetime, center: Dict[str, Any], members: Dict[str, Any]
|
|
86
|
+
) -> Tuple[Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
|
87
|
+
"""Recenter the given date with the specified center and members.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
date : datetime.datetime
|
|
92
|
+
The date to recenter.
|
|
93
|
+
center : dict
|
|
94
|
+
The center request information.
|
|
95
|
+
members : dict
|
|
96
|
+
The members request information.
|
|
97
|
+
|
|
98
|
+
Returns
|
|
99
|
+
-------
|
|
100
|
+
tuple
|
|
101
|
+
A tuple containing the recentered center and members information.
|
|
102
|
+
"""
|
|
69
103
|
center = _lookup_mars_stream(center)
|
|
70
104
|
members = _lookup_mars_stream(members)
|
|
71
105
|
|
anemoi/utils/mars/requests.py
CHANGED
|
@@ -6,9 +6,23 @@
|
|
|
6
6
|
# nor does it submit to any jurisdiction.
|
|
7
7
|
|
|
8
8
|
import sys
|
|
9
|
+
from typing import Any
|
|
10
|
+
from typing import Dict
|
|
11
|
+
from typing import TextIO
|
|
9
12
|
|
|
10
13
|
|
|
11
|
-
def print_request(verb, request, file=sys.stdout):
|
|
14
|
+
def print_request(verb: str, request: Dict[str, Any], file: TextIO = sys.stdout) -> None:
|
|
15
|
+
"""Prints a formatted request.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
verb : str
|
|
20
|
+
A mars verb
|
|
21
|
+
request : Dict[str, Any]
|
|
22
|
+
The request parameters.
|
|
23
|
+
file : TextIO, optional
|
|
24
|
+
The file to which the request is printed, by default sys.stdout.
|
|
25
|
+
"""
|
|
12
26
|
r = [verb]
|
|
13
27
|
for k, v in request.items():
|
|
14
28
|
if not isinstance(v, (list, tuple, set)):
|
anemoi/utils/provenance.py
CHANGED
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
- The versions of the modules which are currently loaded
|
|
15
15
|
- The git information for the modules which are currently loaded from a git repository
|
|
16
16
|
- ...
|
|
17
|
-
|
|
18
17
|
"""
|
|
19
18
|
|
|
20
19
|
import datetime
|
|
@@ -25,11 +24,29 @@ import subprocess
|
|
|
25
24
|
import sys
|
|
26
25
|
import sysconfig
|
|
27
26
|
from functools import cache
|
|
27
|
+
from typing import Any
|
|
28
|
+
from typing import Dict
|
|
29
|
+
from typing import List
|
|
30
|
+
from typing import Optional
|
|
31
|
+
from typing import Tuple
|
|
32
|
+
from typing import Union
|
|
28
33
|
|
|
29
34
|
LOG = logging.getLogger(__name__)
|
|
30
35
|
|
|
31
36
|
|
|
32
|
-
def lookup_git_repo(path):
|
|
37
|
+
def lookup_git_repo(path: str) -> Optional[Any]:
|
|
38
|
+
"""Lookup the git repository for a given path.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
path : str
|
|
43
|
+
The path to lookup.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
Repo, optional
|
|
48
|
+
The git repository if found, otherwise None.
|
|
49
|
+
"""
|
|
33
50
|
from git import InvalidGitRepositoryError
|
|
34
51
|
from git import Repo
|
|
35
52
|
|
|
@@ -42,7 +59,21 @@ def lookup_git_repo(path):
|
|
|
42
59
|
return None
|
|
43
60
|
|
|
44
61
|
|
|
45
|
-
def _check_for_git(paths, full):
|
|
62
|
+
def _check_for_git(paths: List[Tuple[str, str]], full: bool) -> Dict[str, Any]:
|
|
63
|
+
"""Check for git information for the given paths.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
paths : list of tuple
|
|
68
|
+
The list of paths to check.
|
|
69
|
+
full : bool
|
|
70
|
+
Whether to collect full information.
|
|
71
|
+
|
|
72
|
+
Returns
|
|
73
|
+
-------
|
|
74
|
+
dict
|
|
75
|
+
The git information for the given paths.
|
|
76
|
+
"""
|
|
46
77
|
versions = {}
|
|
47
78
|
for name, path in paths:
|
|
48
79
|
repo = lookup_git_repo(path)
|
|
@@ -77,7 +108,28 @@ def _check_for_git(paths, full):
|
|
|
77
108
|
return versions
|
|
78
109
|
|
|
79
110
|
|
|
80
|
-
def version(
|
|
111
|
+
def version(
|
|
112
|
+
versions: Dict[str, Any], name: str, module: Any, roots: Dict[str, str], namespaces: set, paths: set, full: bool
|
|
113
|
+
) -> None:
|
|
114
|
+
"""Collect version information for a module.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
versions : dict
|
|
119
|
+
The dictionary to store the version information.
|
|
120
|
+
name : str
|
|
121
|
+
The name of the module.
|
|
122
|
+
module : Any
|
|
123
|
+
The module to collect information for.
|
|
124
|
+
roots : dict
|
|
125
|
+
The dictionary of root paths.
|
|
126
|
+
namespaces : set
|
|
127
|
+
The set of namespaces.
|
|
128
|
+
paths : set
|
|
129
|
+
The set of paths.
|
|
130
|
+
full : bool
|
|
131
|
+
Whether to collect full information.
|
|
132
|
+
"""
|
|
81
133
|
path = None
|
|
82
134
|
|
|
83
135
|
if hasattr(module, "__file__"):
|
|
@@ -119,7 +171,19 @@ def version(versions, name, module, roots, namespaces, paths, full):
|
|
|
119
171
|
versions[name] = str(module)
|
|
120
172
|
|
|
121
173
|
|
|
122
|
-
def _module_versions(full):
|
|
174
|
+
def _module_versions(full: bool) -> Tuple[Dict[str, Any], set]:
|
|
175
|
+
"""Collect version information for all loaded modules.
|
|
176
|
+
|
|
177
|
+
Parameters
|
|
178
|
+
----------
|
|
179
|
+
full : bool
|
|
180
|
+
Whether to collect full information.
|
|
181
|
+
|
|
182
|
+
Returns
|
|
183
|
+
-------
|
|
184
|
+
tuple of dict and set
|
|
185
|
+
The version information and the set of paths.
|
|
186
|
+
"""
|
|
123
187
|
# https://docs.python.org/3/library/sysconfig.html
|
|
124
188
|
|
|
125
189
|
roots = {}
|
|
@@ -149,7 +213,14 @@ def _module_versions(full):
|
|
|
149
213
|
|
|
150
214
|
|
|
151
215
|
@cache
|
|
152
|
-
def package_distributions() ->
|
|
216
|
+
def package_distributions() -> Dict[str, List[str]]:
|
|
217
|
+
"""Get the package distributions.
|
|
218
|
+
|
|
219
|
+
Returns
|
|
220
|
+
-------
|
|
221
|
+
dict
|
|
222
|
+
The package distributions.
|
|
223
|
+
"""
|
|
153
224
|
# Takes a significant amount of time to run
|
|
154
225
|
# so cache the result
|
|
155
226
|
from importlib import metadata
|
|
@@ -161,7 +232,20 @@ def package_distributions() -> dict[str, list[str]]:
|
|
|
161
232
|
return metadata.packages_distributions()
|
|
162
233
|
|
|
163
234
|
|
|
164
|
-
def import_name_to_distribution_name(packages:
|
|
235
|
+
def import_name_to_distribution_name(packages: List[str]) -> Dict[str, str]:
|
|
236
|
+
"""Convert import names to distribution names.
|
|
237
|
+
|
|
238
|
+
Parameters
|
|
239
|
+
----------
|
|
240
|
+
packages : list of str
|
|
241
|
+
The list of import names.
|
|
242
|
+
|
|
243
|
+
Returns
|
|
244
|
+
-------
|
|
245
|
+
dict
|
|
246
|
+
The dictionary mapping import names to distribution names.
|
|
247
|
+
"""
|
|
248
|
+
|
|
165
249
|
distribution_names = {}
|
|
166
250
|
package_distribution_names = package_distributions()
|
|
167
251
|
|
|
@@ -179,13 +263,37 @@ def import_name_to_distribution_name(packages: list):
|
|
|
179
263
|
return distribution_names
|
|
180
264
|
|
|
181
265
|
|
|
182
|
-
def module_versions(full):
|
|
266
|
+
def module_versions(full: bool) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
267
|
+
"""Collect version information for all loaded modules and their git information.
|
|
268
|
+
|
|
269
|
+
Parameters
|
|
270
|
+
----------
|
|
271
|
+
full : bool
|
|
272
|
+
Whether to collect full information.
|
|
273
|
+
|
|
274
|
+
Returns
|
|
275
|
+
-------
|
|
276
|
+
tuple of dict and dict
|
|
277
|
+
The version information and the git information.
|
|
278
|
+
"""
|
|
183
279
|
versions, paths = _module_versions(full)
|
|
184
280
|
git_versions = _check_for_git(paths, full)
|
|
185
281
|
return versions, git_versions
|
|
186
282
|
|
|
187
283
|
|
|
188
|
-
def _name(obj):
|
|
284
|
+
def _name(obj: Any) -> str:
|
|
285
|
+
"""Get the name of an object.
|
|
286
|
+
|
|
287
|
+
Parameters
|
|
288
|
+
----------
|
|
289
|
+
obj : Any
|
|
290
|
+
The object to get the name of.
|
|
291
|
+
|
|
292
|
+
Returns
|
|
293
|
+
-------
|
|
294
|
+
str
|
|
295
|
+
The name of the object.
|
|
296
|
+
"""
|
|
189
297
|
if hasattr(obj, "__name__"):
|
|
190
298
|
if hasattr(obj, "__module__"):
|
|
191
299
|
return f"{obj.__module__}.{obj.__name__}"
|
|
@@ -195,8 +303,19 @@ def _name(obj):
|
|
|
195
303
|
return str(obj)
|
|
196
304
|
|
|
197
305
|
|
|
198
|
-
def _paths(path_or_object):
|
|
306
|
+
def _paths(path_or_object: Union[None, str, List[str], Tuple[str], Any]) -> List[Tuple[str, str]]:
|
|
307
|
+
"""Get the paths for a given path or object.
|
|
199
308
|
|
|
309
|
+
Parameters
|
|
310
|
+
----------
|
|
311
|
+
path_or_object : str, list, tuple, or object
|
|
312
|
+
The path or object to get the paths for.
|
|
313
|
+
|
|
314
|
+
Returns
|
|
315
|
+
-------
|
|
316
|
+
list of tuple
|
|
317
|
+
The list of paths.
|
|
318
|
+
"""
|
|
200
319
|
if path_or_object is None:
|
|
201
320
|
_, paths = _module_versions(full=False)
|
|
202
321
|
return paths
|
|
@@ -235,7 +354,7 @@ def _paths(path_or_object):
|
|
|
235
354
|
return paths
|
|
236
355
|
|
|
237
356
|
|
|
238
|
-
def git_check(*args) ->
|
|
357
|
+
def git_check(*args: Any) -> Dict[str, Any]:
|
|
239
358
|
"""Return the git information for the given arguments.
|
|
240
359
|
|
|
241
360
|
Arguments can be:
|
|
@@ -255,18 +374,18 @@ def git_check(*args) -> dict:
|
|
|
255
374
|
dict
|
|
256
375
|
An object with the git information for the given arguments.
|
|
257
376
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
377
|
+
>>> {
|
|
378
|
+
"anemoi.utils": {
|
|
379
|
+
"sha1": "c999d83ae283bcbb99f68d92c42d24315922129f",
|
|
380
|
+
"remotes": [
|
|
381
|
+
"git@github.com:ecmwf/anemoi-utils.git"
|
|
382
|
+
],
|
|
383
|
+
"modified_files": [
|
|
384
|
+
"anemoi/utils/checkpoints.py"
|
|
385
|
+
],
|
|
386
|
+
"untracked_files": []
|
|
387
|
+
}
|
|
268
388
|
}
|
|
269
|
-
}
|
|
270
389
|
"""
|
|
271
390
|
paths = _paths(args if len(args) > 0 else None)
|
|
272
391
|
|
|
@@ -278,7 +397,14 @@ def git_check(*args) -> dict:
|
|
|
278
397
|
return result
|
|
279
398
|
|
|
280
399
|
|
|
281
|
-
def platform_info():
|
|
400
|
+
def platform_info() -> Dict[str, Any]:
|
|
401
|
+
"""Get the platform information.
|
|
402
|
+
|
|
403
|
+
Returns
|
|
404
|
+
-------
|
|
405
|
+
dict
|
|
406
|
+
The platform information.
|
|
407
|
+
"""
|
|
282
408
|
import platform
|
|
283
409
|
|
|
284
410
|
r = {}
|
|
@@ -300,7 +426,14 @@ def platform_info():
|
|
|
300
426
|
return r
|
|
301
427
|
|
|
302
428
|
|
|
303
|
-
def gpu_info():
|
|
429
|
+
def gpu_info() -> Union[str, List[Dict[str, Any]]]:
|
|
430
|
+
"""Get the GPU information.
|
|
431
|
+
|
|
432
|
+
Returns
|
|
433
|
+
-------
|
|
434
|
+
str or list of dict
|
|
435
|
+
The GPU information or an error message.
|
|
436
|
+
"""
|
|
304
437
|
import nvsmi
|
|
305
438
|
|
|
306
439
|
if not nvsmi.is_nvidia_smi_on_path():
|
|
@@ -312,7 +445,19 @@ def gpu_info():
|
|
|
312
445
|
return e.output.decode("utf-8").strip()
|
|
313
446
|
|
|
314
447
|
|
|
315
|
-
def path_md5(path):
|
|
448
|
+
def path_md5(path: str) -> str:
|
|
449
|
+
"""Calculate the MD5 hash of a file.
|
|
450
|
+
|
|
451
|
+
Parameters
|
|
452
|
+
----------
|
|
453
|
+
path : str
|
|
454
|
+
The path to the file.
|
|
455
|
+
|
|
456
|
+
Returns
|
|
457
|
+
-------
|
|
458
|
+
str
|
|
459
|
+
The MD5 hash of the file.
|
|
460
|
+
"""
|
|
316
461
|
import hashlib
|
|
317
462
|
|
|
318
463
|
hash = hashlib.md5()
|
|
@@ -322,7 +467,19 @@ def path_md5(path):
|
|
|
322
467
|
return hash.hexdigest()
|
|
323
468
|
|
|
324
469
|
|
|
325
|
-
def assets_info(paths):
|
|
470
|
+
def assets_info(paths: List[str]) -> Dict[str, Any]:
|
|
471
|
+
"""Get information about the given assets.
|
|
472
|
+
|
|
473
|
+
Parameters
|
|
474
|
+
----------
|
|
475
|
+
paths : list of str
|
|
476
|
+
The list of paths to the assets.
|
|
477
|
+
|
|
478
|
+
Returns
|
|
479
|
+
-------
|
|
480
|
+
dict
|
|
481
|
+
The information about the assets.
|
|
482
|
+
"""
|
|
326
483
|
result = {}
|
|
327
484
|
|
|
328
485
|
for path in paths:
|
|
@@ -351,8 +508,8 @@ def assets_info(paths):
|
|
|
351
508
|
return result
|
|
352
509
|
|
|
353
510
|
|
|
354
|
-
def gather_provenance_info(assets=[], full=False) ->
|
|
355
|
-
"""Gather information about the current environment
|
|
511
|
+
def gather_provenance_info(assets: List[str] = [], full: bool = False) -> Dict[str, Any]:
|
|
512
|
+
"""Gather information about the current environment.
|
|
356
513
|
|
|
357
514
|
Parameters
|
|
358
515
|
----------
|
anemoi/utils/registry.py
CHANGED
|
@@ -12,6 +12,11 @@ import importlib
|
|
|
12
12
|
import logging
|
|
13
13
|
import os
|
|
14
14
|
import sys
|
|
15
|
+
from typing import Any
|
|
16
|
+
from typing import Callable
|
|
17
|
+
from typing import Dict
|
|
18
|
+
from typing import Optional
|
|
19
|
+
from typing import Union
|
|
15
20
|
|
|
16
21
|
import entrypoints
|
|
17
22
|
|
|
@@ -19,13 +24,33 @@ LOG = logging.getLogger(__name__)
|
|
|
19
24
|
|
|
20
25
|
|
|
21
26
|
class Wrapper:
|
|
22
|
-
"""A wrapper for the registry
|
|
27
|
+
"""A wrapper for the registry.
|
|
23
28
|
|
|
24
|
-
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
name : str
|
|
32
|
+
The name of the wrapper.
|
|
33
|
+
registry : Registry
|
|
34
|
+
The registry to wrap.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, name: str, registry: "Registry"):
|
|
25
38
|
self.name = name
|
|
26
39
|
self.registry = registry
|
|
27
40
|
|
|
28
|
-
def __call__(self, factory):
|
|
41
|
+
def __call__(self, factory: Callable) -> Callable:
|
|
42
|
+
"""Register a factory with the registry.
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
----------
|
|
46
|
+
factory : Callable
|
|
47
|
+
The factory to register.
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
Callable
|
|
52
|
+
The registered factory.
|
|
53
|
+
"""
|
|
29
54
|
self.registry.register(self.name, factory)
|
|
30
55
|
return factory
|
|
31
56
|
|
|
@@ -34,10 +59,17 @@ _BY_KIND = {}
|
|
|
34
59
|
|
|
35
60
|
|
|
36
61
|
class Registry:
|
|
37
|
-
"""A registry of factories
|
|
62
|
+
"""A registry of factories.
|
|
38
63
|
|
|
39
|
-
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
package : str
|
|
67
|
+
The package name.
|
|
68
|
+
key : str, optional
|
|
69
|
+
The key to use for the registry, by default "_type".
|
|
70
|
+
"""
|
|
40
71
|
|
|
72
|
+
def __init__(self, package: str, key: str = "_type"):
|
|
41
73
|
self.package = package
|
|
42
74
|
self.registered = {}
|
|
43
75
|
self.kind = package.split(".")[-1]
|
|
@@ -45,11 +77,36 @@ class Registry:
|
|
|
45
77
|
_BY_KIND[self.kind] = self
|
|
46
78
|
|
|
47
79
|
@classmethod
|
|
48
|
-
def lookup_kind(cls, kind: str):
|
|
80
|
+
def lookup_kind(cls, kind: str) -> Optional["Registry"]:
|
|
81
|
+
"""Lookup a registry by kind.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
kind : str
|
|
86
|
+
The kind of the registry.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
Registry, optional
|
|
91
|
+
The registry if found, otherwise None.
|
|
92
|
+
"""
|
|
49
93
|
return _BY_KIND.get(kind)
|
|
50
94
|
|
|
51
|
-
def register(self, name: str, factory:
|
|
52
|
-
|
|
95
|
+
def register(self, name: str, factory: Optional[Callable] = None) -> Optional[Wrapper]:
|
|
96
|
+
"""Register a factory with the registry.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
name : str
|
|
101
|
+
The name of the factory.
|
|
102
|
+
factory : Callable, optional
|
|
103
|
+
The factory to register, by default None.
|
|
104
|
+
|
|
105
|
+
Returns
|
|
106
|
+
-------
|
|
107
|
+
Wrapper, optional
|
|
108
|
+
A wrapper if the factory is None, otherwise None.
|
|
109
|
+
"""
|
|
53
110
|
if factory is None:
|
|
54
111
|
return Wrapper(name, self)
|
|
55
112
|
|
|
@@ -58,15 +115,35 @@ class Registry:
|
|
|
58
115
|
# def registered(self, name: str):
|
|
59
116
|
# return name in self.registered
|
|
60
117
|
|
|
61
|
-
def _load(self, file):
|
|
118
|
+
def _load(self, file: str) -> None:
|
|
119
|
+
"""Load a module from a file.
|
|
120
|
+
|
|
121
|
+
Parameters
|
|
122
|
+
----------
|
|
123
|
+
file : str
|
|
124
|
+
The file to load.
|
|
125
|
+
"""
|
|
62
126
|
name, _ = os.path.splitext(file)
|
|
63
127
|
try:
|
|
64
128
|
importlib.import_module(f".{name}", package=self.package)
|
|
65
129
|
except Exception:
|
|
66
130
|
LOG.warning(f"Error loading filter '{self.package}.{name}'", exc_info=True)
|
|
67
131
|
|
|
68
|
-
def lookup(self, name: str, *, return_none=False) ->
|
|
69
|
-
|
|
132
|
+
def lookup(self, name: str, *, return_none: bool = False) -> Optional[Callable]:
|
|
133
|
+
"""Lookup a factory by name.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
name : str
|
|
138
|
+
The name of the factory.
|
|
139
|
+
return_none : bool, optional
|
|
140
|
+
Whether to return None if the factory is not found, by default False.
|
|
141
|
+
|
|
142
|
+
Returns
|
|
143
|
+
-------
|
|
144
|
+
Callable, optional
|
|
145
|
+
The factory if found, otherwise None.
|
|
146
|
+
"""
|
|
70
147
|
# print('✅✅✅✅✅✅✅✅✅✅✅✅✅', name, self.registered)
|
|
71
148
|
if name in self.registered:
|
|
72
149
|
return self.registered[name]
|
|
@@ -110,14 +187,46 @@ class Registry:
|
|
|
110
187
|
|
|
111
188
|
return self.registered[name]
|
|
112
189
|
|
|
113
|
-
def create(self, name: str, *args, **kwargs):
|
|
190
|
+
def create(self, name: str, *args: Any, **kwargs: Any) -> Any:
|
|
191
|
+
"""Create an instance using a factory.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
name : str
|
|
196
|
+
The name of the factory.
|
|
197
|
+
*args : Any
|
|
198
|
+
Positional arguments for the factory.
|
|
199
|
+
**kwargs : Any
|
|
200
|
+
Keyword arguments for the factory.
|
|
201
|
+
|
|
202
|
+
Returns
|
|
203
|
+
-------
|
|
204
|
+
Any
|
|
205
|
+
The created instance.
|
|
206
|
+
"""
|
|
114
207
|
factory = self.lookup(name)
|
|
115
208
|
return factory(*args, **kwargs)
|
|
116
209
|
|
|
117
210
|
# def __call__(self, name: str, *args, **kwargs):
|
|
118
211
|
# return self.create(name, *args, **kwargs)
|
|
119
212
|
|
|
120
|
-
def from_config(self, config, *args, **kwargs):
|
|
213
|
+
def from_config(self, config: Union[str, Dict[str, Any]], *args: Any, **kwargs: Any) -> Any:
|
|
214
|
+
"""Create an instance from a configuration.
|
|
215
|
+
|
|
216
|
+
Parameters
|
|
217
|
+
----------
|
|
218
|
+
config : str or dict
|
|
219
|
+
The configuration.
|
|
220
|
+
*args : Any
|
|
221
|
+
Positional arguments for the factory.
|
|
222
|
+
**kwargs : Any
|
|
223
|
+
Keyword arguments for the factory.
|
|
224
|
+
|
|
225
|
+
Returns
|
|
226
|
+
-------
|
|
227
|
+
Any
|
|
228
|
+
The created instance.
|
|
229
|
+
"""
|
|
121
230
|
if isinstance(config, str):
|
|
122
231
|
config = {config: {}}
|
|
123
232
|
|