hpcflow-new2 0.2.0a188__py3-none-any.whl → 0.2.0a190__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hpcflow/__pyinstaller/hook-hpcflow.py +8 -6
- hpcflow/_version.py +1 -1
- hpcflow/app.py +1 -0
- hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +1 -1
- hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +1 -1
- hpcflow/sdk/__init__.py +21 -15
- hpcflow/sdk/app.py +2133 -770
- hpcflow/sdk/cli.py +281 -250
- hpcflow/sdk/cli_common.py +6 -2
- hpcflow/sdk/config/__init__.py +1 -1
- hpcflow/sdk/config/callbacks.py +77 -42
- hpcflow/sdk/config/cli.py +126 -103
- hpcflow/sdk/config/config.py +578 -311
- hpcflow/sdk/config/config_file.py +131 -95
- hpcflow/sdk/config/errors.py +112 -85
- hpcflow/sdk/config/types.py +145 -0
- hpcflow/sdk/core/actions.py +1054 -994
- hpcflow/sdk/core/app_aware.py +24 -0
- hpcflow/sdk/core/cache.py +81 -63
- hpcflow/sdk/core/command_files.py +275 -185
- hpcflow/sdk/core/commands.py +111 -107
- hpcflow/sdk/core/element.py +724 -503
- hpcflow/sdk/core/enums.py +192 -0
- hpcflow/sdk/core/environment.py +74 -93
- hpcflow/sdk/core/errors.py +398 -51
- hpcflow/sdk/core/json_like.py +540 -272
- hpcflow/sdk/core/loop.py +380 -334
- hpcflow/sdk/core/loop_cache.py +160 -43
- hpcflow/sdk/core/object_list.py +370 -207
- hpcflow/sdk/core/parameters.py +728 -600
- hpcflow/sdk/core/rule.py +59 -41
- hpcflow/sdk/core/run_dir_files.py +33 -22
- hpcflow/sdk/core/task.py +1546 -1325
- hpcflow/sdk/core/task_schema.py +240 -196
- hpcflow/sdk/core/test_utils.py +126 -88
- hpcflow/sdk/core/types.py +387 -0
- hpcflow/sdk/core/utils.py +410 -305
- hpcflow/sdk/core/validation.py +82 -9
- hpcflow/sdk/core/workflow.py +1192 -1028
- hpcflow/sdk/core/zarr_io.py +98 -137
- hpcflow/sdk/demo/cli.py +46 -33
- hpcflow/sdk/helper/cli.py +18 -16
- hpcflow/sdk/helper/helper.py +75 -63
- hpcflow/sdk/helper/watcher.py +61 -28
- hpcflow/sdk/log.py +83 -59
- hpcflow/sdk/persistence/__init__.py +8 -31
- hpcflow/sdk/persistence/base.py +988 -586
- hpcflow/sdk/persistence/defaults.py +6 -0
- hpcflow/sdk/persistence/discovery.py +38 -0
- hpcflow/sdk/persistence/json.py +408 -153
- hpcflow/sdk/persistence/pending.py +158 -123
- hpcflow/sdk/persistence/store_resource.py +37 -22
- hpcflow/sdk/persistence/types.py +307 -0
- hpcflow/sdk/persistence/utils.py +14 -11
- hpcflow/sdk/persistence/zarr.py +477 -420
- hpcflow/sdk/runtime.py +44 -41
- hpcflow/sdk/submission/{jobscript_info.py → enums.py} +39 -12
- hpcflow/sdk/submission/jobscript.py +444 -404
- hpcflow/sdk/submission/schedulers/__init__.py +133 -40
- hpcflow/sdk/submission/schedulers/direct.py +97 -71
- hpcflow/sdk/submission/schedulers/sge.py +132 -126
- hpcflow/sdk/submission/schedulers/slurm.py +263 -268
- hpcflow/sdk/submission/schedulers/utils.py +7 -2
- hpcflow/sdk/submission/shells/__init__.py +14 -15
- hpcflow/sdk/submission/shells/base.py +102 -29
- hpcflow/sdk/submission/shells/bash.py +72 -55
- hpcflow/sdk/submission/shells/os_version.py +31 -30
- hpcflow/sdk/submission/shells/powershell.py +37 -29
- hpcflow/sdk/submission/submission.py +203 -257
- hpcflow/sdk/submission/types.py +143 -0
- hpcflow/sdk/typing.py +163 -12
- hpcflow/tests/conftest.py +8 -6
- hpcflow/tests/schedulers/slurm/test_slurm_submission.py +5 -2
- hpcflow/tests/scripts/test_main_scripts.py +60 -30
- hpcflow/tests/shells/wsl/test_wsl_submission.py +6 -4
- hpcflow/tests/unit/test_action.py +86 -75
- hpcflow/tests/unit/test_action_rule.py +9 -4
- hpcflow/tests/unit/test_app.py +13 -6
- hpcflow/tests/unit/test_cli.py +1 -1
- hpcflow/tests/unit/test_command.py +71 -54
- hpcflow/tests/unit/test_config.py +20 -15
- hpcflow/tests/unit/test_config_file.py +21 -18
- hpcflow/tests/unit/test_element.py +58 -62
- hpcflow/tests/unit/test_element_iteration.py +3 -1
- hpcflow/tests/unit/test_element_set.py +29 -19
- hpcflow/tests/unit/test_group.py +4 -2
- hpcflow/tests/unit/test_input_source.py +116 -93
- hpcflow/tests/unit/test_input_value.py +29 -24
- hpcflow/tests/unit/test_json_like.py +44 -35
- hpcflow/tests/unit/test_loop.py +65 -58
- hpcflow/tests/unit/test_object_list.py +17 -12
- hpcflow/tests/unit/test_parameter.py +16 -7
- hpcflow/tests/unit/test_persistence.py +48 -35
- hpcflow/tests/unit/test_resources.py +20 -18
- hpcflow/tests/unit/test_run.py +8 -3
- hpcflow/tests/unit/test_runtime.py +2 -1
- hpcflow/tests/unit/test_schema_input.py +23 -15
- hpcflow/tests/unit/test_shell.py +3 -2
- hpcflow/tests/unit/test_slurm.py +8 -7
- hpcflow/tests/unit/test_submission.py +39 -19
- hpcflow/tests/unit/test_task.py +352 -247
- hpcflow/tests/unit/test_task_schema.py +33 -20
- hpcflow/tests/unit/test_utils.py +9 -11
- hpcflow/tests/unit/test_value_sequence.py +15 -12
- hpcflow/tests/unit/test_workflow.py +114 -83
- hpcflow/tests/unit/test_workflow_template.py +0 -1
- hpcflow/tests/workflows/test_jobscript.py +2 -1
- hpcflow/tests/workflows/test_workflows.py +18 -13
- {hpcflow_new2-0.2.0a188.dist-info → hpcflow_new2-0.2.0a190.dist-info}/METADATA +2 -1
- hpcflow_new2-0.2.0a190.dist-info/RECORD +165 -0
- hpcflow/sdk/core/parallel.py +0 -21
- hpcflow_new2-0.2.0a188.dist-info/RECORD +0 -158
- {hpcflow_new2-0.2.0a188.dist-info → hpcflow_new2-0.2.0a190.dist-info}/LICENSE +0 -0
- {hpcflow_new2-0.2.0a188.dist-info → hpcflow_new2-0.2.0a190.dist-info}/WHEEL +0 -0
- {hpcflow_new2-0.2.0a188.dist-info → hpcflow_new2-0.2.0a190.dist-info}/entry_points.txt +0 -0
hpcflow/sdk/core/utils.py
CHANGED
@@ -2,12 +2,13 @@
|
|
2
2
|
Miscellaneous utilities.
|
3
3
|
"""
|
4
4
|
|
5
|
+
from __future__ import annotations
|
6
|
+
from collections import Counter
|
5
7
|
import copy
|
6
8
|
import enum
|
7
|
-
from functools import wraps
|
8
|
-
import contextlib
|
9
9
|
import hashlib
|
10
10
|
from itertools import accumulate, islice
|
11
|
+
from importlib import resources
|
11
12
|
import json
|
12
13
|
import keyword
|
13
14
|
import os
|
@@ -17,10 +18,10 @@ import re
|
|
17
18
|
import socket
|
18
19
|
import string
|
19
20
|
import subprocess
|
20
|
-
from datetime import datetime, timezone
|
21
|
+
from datetime import datetime, timedelta, timezone
|
21
22
|
import sys
|
22
|
-
from typing import
|
23
|
-
import fsspec
|
23
|
+
from typing import cast, overload, TypeVar, TYPE_CHECKING
|
24
|
+
import fsspec # type: ignore
|
24
25
|
import numpy as np
|
25
26
|
|
26
27
|
from ruamel.yaml import YAML
|
@@ -28,27 +29,29 @@ from watchdog.utils.dirsnapshot import DirectorySnapshot
|
|
28
29
|
|
29
30
|
from hpcflow.sdk.core.errors import (
|
30
31
|
ContainerKeyError,
|
31
|
-
FromSpecMissingObjectError,
|
32
32
|
InvalidIdentifier,
|
33
33
|
MissingVariableSubstitutionError,
|
34
34
|
)
|
35
35
|
from hpcflow.sdk.log import TimeIt
|
36
|
-
from hpcflow.sdk.typing import PathLike
|
37
36
|
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence
|
39
|
+
from contextlib import AbstractContextManager
|
40
|
+
from types import ModuleType
|
41
|
+
from typing import Any, IO
|
42
|
+
from typing_extensions import TypeAlias
|
43
|
+
from numpy.typing import NDArray
|
44
|
+
from ..typing import PathLike
|
38
45
|
|
39
|
-
|
40
|
-
|
46
|
+
T = TypeVar("T")
|
47
|
+
T2 = TypeVar("T2")
|
48
|
+
T3 = TypeVar("T3")
|
49
|
+
TList: TypeAlias = "T | list[TList]"
|
50
|
+
TD = TypeVar("TD", bound="Mapping[str, Any]")
|
51
|
+
E = TypeVar("E", bound=enum.Enum)
|
41
52
|
|
42
|
-
@wraps(func)
|
43
|
-
def wrapper(self, *args, **kwargs):
|
44
|
-
if not self.is_config_loaded:
|
45
|
-
self.load_config()
|
46
|
-
return func(self, *args, **kwargs)
|
47
53
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
def make_workflow_id():
|
54
|
+
def make_workflow_id() -> str:
|
52
55
|
"""
|
53
56
|
Generate a random ID for a workflow.
|
54
57
|
"""
|
@@ -57,14 +60,14 @@ def make_workflow_id():
|
|
57
60
|
return "".join(random.choices(chars, k=length))
|
58
61
|
|
59
62
|
|
60
|
-
def get_time_stamp():
|
63
|
+
def get_time_stamp() -> str:
|
61
64
|
"""
|
62
65
|
Get the current time in standard string form.
|
63
66
|
"""
|
64
67
|
return datetime.now(timezone.utc).astimezone().strftime("%Y.%m.%d_%H:%M:%S_%z")
|
65
68
|
|
66
69
|
|
67
|
-
def get_duplicate_items(lst):
|
70
|
+
def get_duplicate_items(lst: Iterable[T]) -> list[T]:
|
68
71
|
"""Get a list of all items in an iterable that appear more than once, assuming items
|
69
72
|
are hashable.
|
70
73
|
|
@@ -77,11 +80,10 @@ def get_duplicate_items(lst):
|
|
77
80
|
[]
|
78
81
|
|
79
82
|
>>> get_duplicate_items([1, 2, 3, 3, 3, 2])
|
80
|
-
[2, 3
|
83
|
+
[2, 3]
|
81
84
|
|
82
85
|
"""
|
83
|
-
|
84
|
-
return list(set(x for x in lst if x in seen or seen.append(x)))
|
86
|
+
return [x for x, y in Counter(lst).items() if y > 1]
|
85
87
|
|
86
88
|
|
87
89
|
def check_valid_py_identifier(name: str) -> str:
|
@@ -105,29 +107,40 @@ def check_valid_py_identifier(name: str) -> str:
|
|
105
107
|
- `Loop.name`
|
106
108
|
|
107
109
|
"""
|
108
|
-
exc = InvalidIdentifier(f"Invalid string for identifier: {name!r}")
|
109
110
|
try:
|
110
111
|
trial_name = name[1:].replace("_", "") # "internal" underscores are allowed
|
111
112
|
except TypeError:
|
112
|
-
raise
|
113
|
+
raise InvalidIdentifier(name) from None
|
114
|
+
except KeyError as e:
|
115
|
+
raise KeyError(f"unexpected name type {name}") from e
|
113
116
|
if (
|
114
117
|
not name
|
115
118
|
or not (name[0].isalpha() and ((trial_name[1:] or "a").isalnum()))
|
116
119
|
or keyword.iskeyword(name)
|
117
120
|
):
|
118
|
-
raise
|
121
|
+
raise InvalidIdentifier(name)
|
119
122
|
|
120
123
|
return name
|
121
124
|
|
122
125
|
|
123
|
-
|
126
|
+
@overload
|
127
|
+
def group_by_dict_key_values(lst: list[dict[T, T2]], key: T) -> list[list[dict[T, T2]]]:
|
128
|
+
...
|
129
|
+
|
130
|
+
|
131
|
+
@overload
|
132
|
+
def group_by_dict_key_values(lst: list[TD], key: str) -> list[list[TD]]:
|
133
|
+
...
|
134
|
+
|
135
|
+
|
136
|
+
def group_by_dict_key_values(lst: list, key):
|
124
137
|
"""Group a list of dicts according to specified equivalent key-values.
|
125
138
|
|
126
139
|
Parameters
|
127
140
|
----------
|
128
141
|
lst : list of dict
|
129
142
|
The list of dicts to group together.
|
130
|
-
|
143
|
+
key : key value
|
131
144
|
Dicts that have identical values for all of these keys will be grouped together
|
132
145
|
into a sub-list.
|
133
146
|
|
@@ -146,10 +159,10 @@ def group_by_dict_key_values(lst, *keys):
|
|
146
159
|
for lst_item in lst[1:]:
|
147
160
|
for group_idx, group in enumerate(grouped):
|
148
161
|
try:
|
149
|
-
is_vals_equal =
|
162
|
+
is_vals_equal = lst_item[key] == group[0][key]
|
150
163
|
|
151
164
|
except KeyError:
|
152
|
-
# dicts that do not have
|
165
|
+
# dicts that do not have the `key` will be in their own group:
|
153
166
|
is_vals_equal = False
|
154
167
|
|
155
168
|
if is_vals_equal:
|
@@ -162,7 +175,7 @@ def group_by_dict_key_values(lst, *keys):
|
|
162
175
|
return grouped
|
163
176
|
|
164
177
|
|
165
|
-
def swap_nested_dict_keys(dct, inner_key):
|
178
|
+
def swap_nested_dict_keys(dct: dict[T, dict[T2, T3]], inner_key: T2):
|
166
179
|
"""Return a copy where top-level keys have been swapped with a second-level inner key.
|
167
180
|
|
168
181
|
Examples:
|
@@ -181,16 +194,35 @@ def swap_nested_dict_keys(dct, inner_key):
|
|
181
194
|
}
|
182
195
|
|
183
196
|
"""
|
184
|
-
out = {}
|
197
|
+
out: dict[T3, dict[T, dict[T2, T3]]] = {}
|
185
198
|
for k, v in copy.deepcopy(dct or {}).items():
|
186
|
-
|
187
|
-
if inner_val not in out:
|
188
|
-
out[inner_val] = {}
|
189
|
-
out[inner_val][k] = v
|
199
|
+
out.setdefault(v.pop(inner_key), {})[k] = v
|
190
200
|
return out
|
191
201
|
|
192
202
|
|
193
|
-
def
|
203
|
+
def _ensure_int(path_comp: Any, cur_data: Any, cast_indices: bool) -> int:
|
204
|
+
"""
|
205
|
+
Helper for get_in_container() and set_in_container()
|
206
|
+
"""
|
207
|
+
if isinstance(path_comp, int):
|
208
|
+
return path_comp
|
209
|
+
if not cast_indices:
|
210
|
+
raise TypeError(
|
211
|
+
f"Path component {path_comp!r} must be an integer index "
|
212
|
+
f"since data is a sequence: {cur_data!r}."
|
213
|
+
)
|
214
|
+
try:
|
215
|
+
return int(path_comp)
|
216
|
+
except (TypeError, ValueError) as e:
|
217
|
+
raise TypeError(
|
218
|
+
f"Path component {path_comp!r} must be an integer index "
|
219
|
+
f"since data is a sequence: {cur_data!r}."
|
220
|
+
) from e
|
221
|
+
|
222
|
+
|
223
|
+
def get_in_container(
|
224
|
+
cont, path: Sequence, cast_indices: bool = False, allow_getattr: bool = False
|
225
|
+
):
|
194
226
|
"""
|
195
227
|
Follow a path (sequence of indices of appropriate type) into a container to obtain
|
196
228
|
a "leaf" value. Containers can be lists, tuples, dicts,
|
@@ -203,24 +235,12 @@ def get_in_container(cont, path, cast_indices=False, allow_getattr=False):
|
|
203
235
|
)
|
204
236
|
for idx, path_comp in enumerate(path):
|
205
237
|
if isinstance(cur_data, (list, tuple)):
|
206
|
-
|
207
|
-
|
208
|
-
f"Path component {path_comp!r} must be an integer index "
|
209
|
-
f"since data is a sequence: {cur_data!r}."
|
210
|
-
)
|
211
|
-
if cast_indices:
|
212
|
-
try:
|
213
|
-
path_comp = int(path_comp)
|
214
|
-
except TypeError:
|
215
|
-
raise TypeError(msg)
|
216
|
-
else:
|
217
|
-
raise TypeError(msg)
|
218
|
-
cur_data = cur_data[path_comp]
|
219
|
-
elif isinstance(cur_data, dict):
|
238
|
+
cur_data = cur_data[_ensure_int(path_comp, cur_data, cast_indices)]
|
239
|
+
elif isinstance(cur_data, dict) or hasattr(cur_data, "__getitem__"):
|
220
240
|
try:
|
221
241
|
cur_data = cur_data[path_comp]
|
222
242
|
except KeyError:
|
223
|
-
raise ContainerKeyError(path=path[: idx + 1])
|
243
|
+
raise ContainerKeyError(path=cast("list[str]", path[: idx + 1]))
|
224
244
|
elif allow_getattr:
|
225
245
|
try:
|
226
246
|
cur_data = getattr(cur_data, path_comp)
|
@@ -235,7 +255,9 @@ def get_in_container(cont, path, cast_indices=False, allow_getattr=False):
|
|
235
255
|
return cur_data
|
236
256
|
|
237
257
|
|
238
|
-
def set_in_container(
|
258
|
+
def set_in_container(
|
259
|
+
cont, path: Sequence, value, ensure_path=False, cast_indices=False
|
260
|
+
) -> None:
|
239
261
|
"""
|
240
262
|
Follow a path (sequence of indices of appropriate type) into a container to update
|
241
263
|
a "leaf" value. Containers can be lists, tuples or dicts.
|
@@ -258,22 +280,11 @@ def set_in_container(cont, path, value, ensure_path=False, cast_indices=False):
|
|
258
280
|
sub_data = get_in_container(cont, path[:-1], cast_indices=cast_indices)
|
259
281
|
path_comp = path[-1]
|
260
282
|
if isinstance(sub_data, (list, tuple)):
|
261
|
-
|
262
|
-
msg = (
|
263
|
-
f"Path component {path_comp!r} must be an integer index "
|
264
|
-
f"since data is a sequence: {sub_data!r}."
|
265
|
-
)
|
266
|
-
if cast_indices:
|
267
|
-
try:
|
268
|
-
path_comp = int(path_comp)
|
269
|
-
except ValueError:
|
270
|
-
raise ValueError(msg)
|
271
|
-
else:
|
272
|
-
raise ValueError(msg)
|
283
|
+
path_comp = _ensure_int(path_comp, sub_data, cast_indices)
|
273
284
|
sub_data[path_comp] = value
|
274
285
|
|
275
286
|
|
276
|
-
def get_relative_path(path1, path2):
|
287
|
+
def get_relative_path(path1: Sequence[T], path2: Sequence[T]) -> Sequence[T]:
|
277
288
|
"""Get relative path components between two paths.
|
278
289
|
|
279
290
|
Parameters
|
@@ -308,79 +319,39 @@ def get_relative_path(path1, path2):
|
|
308
319
|
"""
|
309
320
|
|
310
321
|
len_path2 = len(path2)
|
311
|
-
|
312
|
-
|
313
|
-
if len(path1) < len_path2:
|
314
|
-
raise ValueError(msg)
|
315
|
-
|
316
|
-
for i, j in zip(path1[:len_path2], path2):
|
317
|
-
if i != j:
|
318
|
-
raise ValueError(msg)
|
322
|
+
if len(path1) < len_path2 or any(i != j for i, j in zip(path1[:len_path2], path2)):
|
323
|
+
raise ValueError(f"{path1!r} is not in the subpath of {path2!r}.")
|
319
324
|
|
320
325
|
return path1[len_path2:]
|
321
326
|
|
322
327
|
|
323
|
-
def search_dir_files_by_regex(
|
328
|
+
def search_dir_files_by_regex(
|
329
|
+
pattern: str | re.Pattern[str], directory: str = "."
|
330
|
+
) -> list[str]:
|
324
331
|
"""Search recursively for files in a directory by a regex pattern and return matching
|
325
332
|
file paths, relative to the given directory."""
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
match = match_groups[group]
|
333
|
-
vals.append(str(i.relative_to(directory)))
|
334
|
-
return vals
|
335
|
-
|
336
|
-
|
337
|
-
class classproperty(object):
|
338
|
-
"""
|
339
|
-
Simple class property decorator.
|
340
|
-
"""
|
341
|
-
|
342
|
-
def __init__(self, f):
|
343
|
-
self.f = f
|
333
|
+
dir_ = Path(directory)
|
334
|
+
return [
|
335
|
+
str(entry.relative_to(dir_))
|
336
|
+
for entry in dir_.rglob("*")
|
337
|
+
if re.search(pattern, entry.name)
|
338
|
+
]
|
344
339
|
|
345
|
-
def __get__(self, obj, owner):
|
346
|
-
return self.f(owner)
|
347
340
|
|
348
|
-
|
349
|
-
class PrettyPrinter(object):
|
341
|
+
class PrettyPrinter:
|
350
342
|
"""
|
351
343
|
A class that produces a nice readable version of itself with ``str()``.
|
352
344
|
Intended to be subclassed.
|
353
345
|
"""
|
354
346
|
|
355
|
-
def __str__(self):
|
347
|
+
def __str__(self) -> str:
|
356
348
|
lines = [self.__class__.__name__ + ":"]
|
357
349
|
for key, val in vars(self).items():
|
358
|
-
lines
|
350
|
+
lines.extend(f"{key}: {val}".split("\n"))
|
359
351
|
return "\n ".join(lines)
|
360
352
|
|
361
353
|
|
362
|
-
|
363
|
-
"""
|
364
|
-
Metaclass that enforces that only one instance can exist of the classes to which it
|
365
|
-
is applied.
|
366
|
-
"""
|
367
|
-
|
368
|
-
_instances = {}
|
369
|
-
|
370
|
-
def __call__(cls, *args, **kwargs):
|
371
|
-
if cls not in cls._instances:
|
372
|
-
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
373
|
-
elif args or kwargs:
|
374
|
-
# if existing instance, make the point that new arguments don't do anything!
|
375
|
-
raise ValueError(
|
376
|
-
f"{cls.__name__!r} is a singleton class and cannot be instantiated with new "
|
377
|
-
f"arguments. The positional arguments {args!r} and keyword-arguments "
|
378
|
-
f"{kwargs!r} have been ignored."
|
379
|
-
)
|
380
|
-
return cls._instances[cls]
|
381
|
-
|
382
|
-
|
383
|
-
def capitalise_first_letter(chars):
|
354
|
+
def capitalise_first_letter(chars: str) -> str:
|
384
355
|
"""
|
385
356
|
Convert the first character of a string to upper case (if that makes sense).
|
386
357
|
The rest of the string is unchanged.
|
@@ -388,30 +359,11 @@ def capitalise_first_letter(chars):
|
|
388
359
|
return chars[0].upper() + chars[1:]
|
389
360
|
|
390
361
|
|
391
|
-
|
392
|
-
"""Decorator factory for the various `from_spec` class methods that have attributes
|
393
|
-
that should be replaced by an object from an object list."""
|
394
|
-
|
395
|
-
def decorator(func):
|
396
|
-
@wraps(func)
|
397
|
-
def wrap(*args, **kwargs):
|
398
|
-
spec = args[spec_pos]
|
399
|
-
obj_list = args[obj_list_pos]
|
400
|
-
if spec[spec_name] not in obj_list:
|
401
|
-
cls_name = args[0].__name__
|
402
|
-
raise FromSpecMissingObjectError(
|
403
|
-
f"A {spec_name!r} object required to instantiate the {cls_name!r} "
|
404
|
-
f"object is missing."
|
405
|
-
)
|
406
|
-
return func(*args, **kwargs)
|
407
|
-
|
408
|
-
return wrap
|
409
|
-
|
410
|
-
return decorator
|
362
|
+
_STRING_VARS_RE = re.compile(r"\<\<var:(.*?)(?:\[(.*)\])?\>\>")
|
411
363
|
|
412
364
|
|
413
365
|
@TimeIt.decorator
|
414
|
-
def substitute_string_vars(string, variables:
|
366
|
+
def substitute_string_vars(string: str, variables: dict[str, str]):
|
415
367
|
"""
|
416
368
|
Scan ``string`` and substitute sequences like ``<<var:ABC>>`` with the value
|
417
369
|
looked up in the supplied dictionary (with ``ABC`` as the key).
|
@@ -424,15 +376,14 @@ def substitute_string_vars(string, variables: Dict[str, str] = None):
|
|
424
376
|
>>> substitute_string_vars("abc <<var:def>> ghi", {"def": "123"})
|
425
377
|
"abc 123 def"
|
426
378
|
"""
|
427
|
-
variables = variables or {}
|
428
379
|
|
429
|
-
def var_repl(match_obj):
|
430
|
-
kwargs = {}
|
431
|
-
var_name
|
380
|
+
def var_repl(match_obj: re.Match[str]) -> str:
|
381
|
+
kwargs: dict[str, str] = {}
|
382
|
+
var_name: str = match_obj[1]
|
383
|
+
kwargs_str: str | None = match_obj[2]
|
432
384
|
if kwargs_str:
|
433
|
-
|
434
|
-
|
435
|
-
k, v = i.strip().split("=")
|
385
|
+
for i in kwargs_str.split(","):
|
386
|
+
k, v = i.split("=")
|
436
387
|
kwargs[k.strip()] = v.strip()
|
437
388
|
try:
|
438
389
|
out = str(variables[var_name])
|
@@ -444,65 +395,69 @@ def substitute_string_vars(string, variables: Dict[str, str] = None):
|
|
444
395
|
f"variable {var_name!r}."
|
445
396
|
)
|
446
397
|
else:
|
447
|
-
raise MissingVariableSubstitutionError(
|
448
|
-
f"The variable {var_name!r} referenced in the string does not match "
|
449
|
-
f"any of the provided variables: {list(variables)!r}."
|
450
|
-
)
|
398
|
+
raise MissingVariableSubstitutionError(var_name, variables)
|
451
399
|
return out
|
452
400
|
|
453
|
-
|
454
|
-
pattern=r"\<\<var:(.*?)(?:\[(.*)\])?\>\>",
|
401
|
+
return _STRING_VARS_RE.sub(
|
455
402
|
repl=var_repl,
|
456
403
|
string=string,
|
457
404
|
)
|
458
|
-
return new_str
|
459
405
|
|
460
406
|
|
461
407
|
@TimeIt.decorator
|
462
|
-
def read_YAML_str(
|
408
|
+
def read_YAML_str(
|
409
|
+
yaml_str: str, typ="safe", variables: dict[str, str] | None = None
|
410
|
+
) -> Any:
|
463
411
|
"""Load a YAML string. This will produce basic objects."""
|
464
|
-
if variables is not
|
412
|
+
if variables is not None and "<<var:" in yaml_str:
|
465
413
|
yaml_str = substitute_string_vars(yaml_str, variables=variables)
|
466
414
|
yaml = YAML(typ=typ)
|
467
415
|
return yaml.load(yaml_str)
|
468
416
|
|
469
417
|
|
470
418
|
@TimeIt.decorator
|
471
|
-
def read_YAML_file(
|
419
|
+
def read_YAML_file(
|
420
|
+
path: PathLike, typ="safe", variables: dict[str, str] | None = None
|
421
|
+
) -> Any:
|
472
422
|
"""Load a YAML file. This will produce basic objects."""
|
473
423
|
with fsspec.open(path, "rt") as f:
|
474
|
-
yaml_str = f.read()
|
424
|
+
yaml_str: str = f.read()
|
475
425
|
return read_YAML_str(yaml_str, typ=typ, variables=variables)
|
476
426
|
|
477
427
|
|
478
|
-
def write_YAML_file(obj, path:
|
428
|
+
def write_YAML_file(obj, path: str | Path, typ: str = "safe") -> None:
|
479
429
|
"""Write a basic object to a YAML file."""
|
480
430
|
yaml = YAML(typ=typ)
|
481
431
|
with Path(path).open("wt") as fp:
|
482
432
|
yaml.dump(obj, fp)
|
483
433
|
|
484
434
|
|
485
|
-
def read_JSON_string(json_str: str, variables:
|
435
|
+
def read_JSON_string(json_str: str, variables: dict[str, str] | None = None) -> Any:
|
486
436
|
"""Load a JSON string. This will produce basic objects."""
|
487
|
-
if variables is not
|
437
|
+
if variables is not None and "<<var:" in json_str:
|
488
438
|
json_str = substitute_string_vars(json_str, variables=variables)
|
489
439
|
return json.loads(json_str)
|
490
440
|
|
491
441
|
|
492
|
-
def read_JSON_file(path, variables:
|
442
|
+
def read_JSON_file(path, variables: dict[str, str] | None = None) -> Any:
|
493
443
|
"""Load a JSON file. This will produce basic objects."""
|
494
444
|
with fsspec.open(path, "rt") as f:
|
495
|
-
json_str = f.read()
|
445
|
+
json_str: str = f.read()
|
496
446
|
return read_JSON_string(json_str, variables=variables)
|
497
447
|
|
498
448
|
|
499
|
-
def write_JSON_file(obj, path:
|
449
|
+
def write_JSON_file(obj, path: str | Path) -> None:
|
500
450
|
"""Write a basic object to a JSON file."""
|
501
451
|
with Path(path).open("wt") as fp:
|
502
452
|
json.dump(obj, fp)
|
503
453
|
|
504
454
|
|
505
|
-
def get_item_repeat_index(
|
455
|
+
def get_item_repeat_index(
|
456
|
+
lst: Sequence[T],
|
457
|
+
*,
|
458
|
+
distinguish_singular: bool = False,
|
459
|
+
item_callable: Callable[[T], Hashable] | None = None,
|
460
|
+
):
|
506
461
|
"""Get the repeat index for each item in a list.
|
507
462
|
|
508
463
|
Parameters
|
@@ -510,10 +465,10 @@ def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
|
|
510
465
|
lst : list
|
511
466
|
Must contain hashable items, or hashable objects that are returned via `callable`
|
512
467
|
called on each item.
|
513
|
-
distinguish_singular : bool
|
468
|
+
distinguish_singular : bool
|
514
469
|
If True, items that are not repeated will have a repeat index of 0, and items that
|
515
470
|
are repeated will have repeat indices starting from 1.
|
516
|
-
item_callable : callable
|
471
|
+
item_callable : callable
|
517
472
|
If specified, comparisons are made on the output of this callable on each item.
|
518
473
|
|
519
474
|
Returns
|
@@ -523,16 +478,16 @@ def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
|
|
523
478
|
|
524
479
|
"""
|
525
480
|
|
526
|
-
idx = {}
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
481
|
+
idx: dict[Any, list[int]] = {}
|
482
|
+
if item_callable:
|
483
|
+
for i_idx, item in enumerate(lst):
|
484
|
+
idx.setdefault(item_callable(item), []).append(i_idx)
|
485
|
+
else:
|
486
|
+
for i_idx, item in enumerate(lst):
|
487
|
+
idx.setdefault(item, []).append(i_idx)
|
533
488
|
|
534
|
-
rep_idx = [
|
535
|
-
for
|
489
|
+
rep_idx = [0] * len(lst)
|
490
|
+
for v in idx.values():
|
536
491
|
start = len(v) > 1 if distinguish_singular else 0
|
537
492
|
for i_idx, i in enumerate(v, start):
|
538
493
|
rep_idx[i] = i_idx
|
@@ -540,7 +495,7 @@ def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
|
|
540
495
|
return rep_idx
|
541
496
|
|
542
497
|
|
543
|
-
def get_process_stamp():
|
498
|
+
def get_process_stamp() -> str:
|
544
499
|
"""
|
545
500
|
Return a globally unique string identifying this process.
|
546
501
|
|
@@ -555,15 +510,17 @@ def get_process_stamp():
|
|
555
510
|
)
|
556
511
|
|
557
512
|
|
558
|
-
|
513
|
+
_ANSI_ESCAPE_RE = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
|
514
|
+
|
515
|
+
|
516
|
+
def remove_ansi_escape_sequences(string: str) -> str:
|
559
517
|
"""
|
560
518
|
Strip ANSI terminal escape codes from a string.
|
561
519
|
"""
|
562
|
-
|
563
|
-
return ansi_escape.sub("", string)
|
520
|
+
return _ANSI_ESCAPE_RE.sub("", string)
|
564
521
|
|
565
522
|
|
566
|
-
def get_md5_hash(obj):
|
523
|
+
def get_md5_hash(obj) -> str:
|
567
524
|
"""
|
568
525
|
Compute the MD5 hash of an object.
|
569
526
|
This is the hash of the JSON of the object (with sorted keys) as a hex string.
|
@@ -572,7 +529,9 @@ def get_md5_hash(obj):
|
|
572
529
|
return hashlib.md5(json_str.encode("utf-8")).hexdigest()
|
573
530
|
|
574
531
|
|
575
|
-
def get_nested_indices(
|
532
|
+
def get_nested_indices(
|
533
|
+
idx: int, size: int, nest_levels: int, raise_on_rollover: bool = False
|
534
|
+
) -> list[int]:
|
576
535
|
"""Generate the set of nested indices of length `n` that correspond to a global
|
577
536
|
`idx`.
|
578
537
|
|
@@ -618,34 +577,34 @@ def get_nested_indices(idx, size, nest_levels, raise_on_rollover=False):
|
|
618
577
|
return [(idx // (size ** (nest_levels - (i + 1)))) % size for i in range(nest_levels)]
|
619
578
|
|
620
579
|
|
621
|
-
def ensure_in(item, lst) -> int:
|
580
|
+
def ensure_in(item: T, lst: list[T]) -> int:
|
622
581
|
"""Get the index of an item in a list and append the item if it is not in the
|
623
582
|
list."""
|
624
583
|
# TODO: add tests
|
625
584
|
try:
|
626
|
-
|
585
|
+
return lst.index(item)
|
627
586
|
except ValueError:
|
628
587
|
lst.append(item)
|
629
|
-
|
630
|
-
return idx
|
588
|
+
return len(lst) - 1
|
631
589
|
|
632
590
|
|
633
|
-
def list_to_dict(
|
591
|
+
def list_to_dict(
|
592
|
+
lst: Sequence[Mapping[T, T2]], exclude: Iterable[T] | None = None
|
593
|
+
) -> dict[T, list[T2]]:
|
634
594
|
"""
|
635
595
|
Convert a list of dicts to a dict of lists.
|
636
596
|
"""
|
637
|
-
#
|
638
|
-
|
639
|
-
dct = {k: [] for k in lst[0]
|
640
|
-
for
|
641
|
-
for k, v in
|
642
|
-
if k not in
|
597
|
+
# TODO: test
|
598
|
+
exc = frozenset(exclude or ())
|
599
|
+
dct: dict[T, list[T2]] = {k: [] for k in lst[0] if k not in exc}
|
600
|
+
for d in lst:
|
601
|
+
for k, v in d.items():
|
602
|
+
if k not in exc:
|
643
603
|
dct[k].append(v)
|
644
|
-
|
645
604
|
return dct
|
646
605
|
|
647
606
|
|
648
|
-
def bisect_slice(selection: slice, len_A: int):
|
607
|
+
def bisect_slice(selection: slice, len_A: int) -> tuple[slice, slice]:
|
649
608
|
"""Given two sequences (the first of which of known length), get the two slices that
|
650
609
|
are equivalent to a given slice if the two sequences were combined."""
|
651
610
|
|
@@ -660,32 +619,33 @@ def bisect_slice(selection: slice, len_A: int):
|
|
660
619
|
B_stop = B_start
|
661
620
|
else:
|
662
621
|
B_stop = selection.stop - len_A
|
663
|
-
B_idx = (B_start, B_stop, selection.step)
|
664
|
-
A_slice = slice(*A_idx)
|
665
|
-
B_slice = slice(*B_idx)
|
666
622
|
|
667
|
-
return
|
623
|
+
return slice(*A_idx), slice(B_start, B_stop, selection.step)
|
668
624
|
|
669
625
|
|
670
|
-
def replace_items(lst, start, end, repl):
|
626
|
+
def replace_items(lst: list[T], start: int, end: int, repl: list[T]) -> list[T]:
|
671
627
|
"""Replaced a range of items in a list with items in another list."""
|
672
|
-
|
628
|
+
# Convert to actual indices for our safety checks; handles end-relative addressing
|
629
|
+
real_start, real_end, _ = slice(start, end).indices(len(lst))
|
630
|
+
if real_end <= real_start:
|
673
631
|
raise ValueError(
|
674
632
|
f"`end` ({end}) must be greater than or equal to `start` ({start})."
|
675
633
|
)
|
676
|
-
if
|
634
|
+
if real_start >= len(lst):
|
677
635
|
raise ValueError(f"`start` ({start}) must be less than length ({len(lst)}).")
|
678
|
-
if
|
636
|
+
if real_end > len(lst):
|
679
637
|
raise ValueError(
|
680
638
|
f"`end` ({end}) must be less than or equal to length ({len(lst)})."
|
681
639
|
)
|
682
640
|
|
683
|
-
|
684
|
-
|
685
|
-
return
|
641
|
+
lst = list(lst)
|
642
|
+
lst[start:end] = repl
|
643
|
+
return lst
|
686
644
|
|
687
645
|
|
688
|
-
def flatten(
|
646
|
+
def flatten(
|
647
|
+
lst: list[int] | list[list[int]] | list[list[list[int]]],
|
648
|
+
) -> tuple[list[int], tuple[list[int], ...]]:
|
689
649
|
"""Flatten an arbitrarily (but of uniform depth) nested list and return shape
|
690
650
|
information to enable un-flattening.
|
691
651
|
|
@@ -700,108 +660,156 @@ def flatten(lst):
|
|
700
660
|
|
701
661
|
"""
|
702
662
|
|
703
|
-
def _flatten(
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
663
|
+
def _flatten(
|
664
|
+
lst: list[int] | list[list[int]] | list[list[list[int]]], depth=0
|
665
|
+
) -> list[int]:
|
666
|
+
out: list[int] = []
|
667
|
+
for item in lst:
|
668
|
+
if isinstance(item, list):
|
669
|
+
out.extend(_flatten(item, depth + 1))
|
670
|
+
all_lens[depth].append(len(item))
|
709
671
|
else:
|
710
|
-
out.append(
|
672
|
+
out.append(item)
|
711
673
|
return out
|
712
674
|
|
713
|
-
def _get_max_depth(lst):
|
714
|
-
|
675
|
+
def _get_max_depth(lst: list[int] | list[list[int]] | list[list[list[int]]]) -> int:
|
676
|
+
val: Any = lst
|
715
677
|
max_depth = 0
|
716
|
-
while isinstance(
|
678
|
+
while isinstance(val, list):
|
717
679
|
max_depth += 1
|
718
680
|
try:
|
719
|
-
|
681
|
+
val = val[0]
|
720
682
|
except IndexError:
|
721
683
|
# empty list, assume this is max depth
|
722
684
|
break
|
723
685
|
return max_depth
|
724
686
|
|
725
687
|
max_depth = _get_max_depth(lst) - 1
|
726
|
-
all_lens = tuple([] for _ in range(max_depth))
|
688
|
+
all_lens: tuple[list[int], ...] = tuple([] for _ in range(max_depth))
|
727
689
|
|
728
690
|
return _flatten(lst), all_lens
|
729
691
|
|
730
692
|
|
731
|
-
def reshape(lst, lens):
|
693
|
+
def reshape(lst: Sequence[T], lens: Sequence[Sequence[int]]) -> list[TList[T]]:
|
732
694
|
"""
|
733
695
|
Reverse the destructuring of the :py:func:`flatten` function.
|
734
696
|
"""
|
735
697
|
|
736
|
-
def _reshape(lst, lens):
|
737
|
-
lens_acc = [0
|
738
|
-
|
739
|
-
return lst_rs
|
698
|
+
def _reshape(lst: list[T2], lens: Sequence[int]) -> list[list[T2]]:
|
699
|
+
lens_acc = [0, *accumulate(lens)]
|
700
|
+
return [lst[lens_acc[idx] : lens_acc[idx + 1]] for idx in range(len(lens))]
|
740
701
|
|
702
|
+
result: list[TList[T]] = list(lst)
|
741
703
|
for lens_i in lens[::-1]:
|
742
|
-
|
704
|
+
result = cast("list[TList[T]]", _reshape(result, lens_i))
|
743
705
|
|
744
|
-
return
|
706
|
+
return result
|
707
|
+
|
708
|
+
|
709
|
+
@overload
|
710
|
+
def remap(
|
711
|
+
lst: list[int], mapping_func: Callable[[Sequence[int]], Sequence[T]]
|
712
|
+
) -> list[T]:
|
713
|
+
...
|
714
|
+
|
715
|
+
|
716
|
+
@overload
|
717
|
+
def remap(
|
718
|
+
lst: list[list[int]], mapping_func: Callable[[Sequence[int]], Sequence[T]]
|
719
|
+
) -> list[list[T]]:
|
720
|
+
...
|
721
|
+
|
722
|
+
|
723
|
+
@overload
|
724
|
+
def remap(
|
725
|
+
lst: list[list[list[int]]], mapping_func: Callable[[Sequence[int]], Sequence[T]]
|
726
|
+
) -> list[list[list[T]]]:
|
727
|
+
...
|
728
|
+
|
729
|
+
|
730
|
+
def remap(lst, mapping_func):
|
731
|
+
"""
|
732
|
+
Apply a mapping to a structure of lists with ints (typically indices) as leaves to
|
733
|
+
get a structure of lists with some objects as leaves.
|
734
|
+
|
735
|
+
Parameters
|
736
|
+
----------
|
737
|
+
lst: list[int] | list[list[int]] | list[list[list[int]]]
|
738
|
+
The structure to remap.
|
739
|
+
mapping_func: Callable[[Sequence[int]], Sequence[T]]
|
740
|
+
The mapping function from sequences of ints to sequences of objects.
|
741
|
+
|
742
|
+
Returns
|
743
|
+
-------
|
744
|
+
list[T] | list[list[T]] | list[list[list[T]]]
|
745
|
+
Nested list structure in same form as input, with leaves remapped.
|
746
|
+
"""
|
747
|
+
x, y = flatten(lst)
|
748
|
+
return reshape(mapping_func(x), y)
|
749
|
+
|
750
|
+
|
751
|
+
_FSSPEC_URL_RE = re.compile(r"(?:[a-z0-9]+:{1,2})+\/\/")
|
745
752
|
|
746
753
|
|
747
754
|
def is_fsspec_url(url: str) -> bool:
|
748
755
|
"""
|
749
756
|
Test if a URL appears to be one that can be understood by fsspec.
|
750
757
|
"""
|
751
|
-
return bool(
|
758
|
+
return bool(_FSSPEC_URL_RE.match(url))
|
752
759
|
|
753
760
|
|
754
761
|
class JSONLikeDirSnapShot(DirectorySnapshot):
|
755
762
|
"""
|
756
763
|
Overridden DirectorySnapshot from watchdog to allow saving and loading from JSON.
|
757
|
-
"""
|
758
764
|
|
759
|
-
|
760
|
-
|
765
|
+
Parameters
|
766
|
+
----------
|
767
|
+
root_path: str
|
768
|
+
Where to take the snapshot based at.
|
769
|
+
data: dict[str, list]
|
770
|
+
Serialised snapshot to reload from.
|
771
|
+
See :py:meth:`to_json_like`.
|
772
|
+
"""
|
761
773
|
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
Where to take the snapshot based at.
|
766
|
-
data: dict
|
767
|
-
Serialised snapshot to reload from.
|
768
|
-
See :py:meth:`to_json_like`.
|
774
|
+
def __init__(self, root_path: str | None = None, data: dict[str, list] | None = None):
|
775
|
+
"""
|
776
|
+
Create an empty snapshot or load from JSON-like data.
|
769
777
|
"""
|
770
778
|
|
771
779
|
#: Where to take the snapshot based at.
|
772
780
|
self.root_path = root_path
|
773
|
-
self._stat_info = {}
|
774
|
-
self._inode_to_path = {}
|
781
|
+
self._stat_info: dict[str, os.stat_result] = {}
|
782
|
+
self._inode_to_path: dict[tuple, str] = {}
|
775
783
|
|
776
784
|
if data:
|
777
|
-
|
785
|
+
assert root_path
|
786
|
+
for name, item in data.items():
|
778
787
|
# add root path
|
779
|
-
|
780
|
-
stat_dat, inode_key =
|
781
|
-
self._stat_info[
|
782
|
-
self._inode_to_path[tuple(inode_key)] =
|
788
|
+
full_name = str(PurePath(root_path) / PurePath(name))
|
789
|
+
stat_dat, inode_key = item[:-2], item[-2:]
|
790
|
+
self._stat_info[full_name] = os.stat_result(stat_dat)
|
791
|
+
self._inode_to_path[tuple(inode_key)] = full_name
|
783
792
|
|
784
|
-
def take(self, *args, **kwargs):
|
793
|
+
def take(self, *args, **kwargs) -> None:
|
785
794
|
"""Take the snapshot."""
|
786
795
|
super().__init__(*args, **kwargs)
|
787
796
|
|
788
|
-
def to_json_like(self):
|
797
|
+
def to_json_like(self) -> dict[str, Any]:
|
789
798
|
"""Export to a dict that is JSON-compatible and can be later reloaded.
|
790
799
|
|
791
800
|
The last two integers in `data` for each path are the keys in
|
792
801
|
`self._inode_to_path`.
|
793
802
|
|
794
803
|
"""
|
795
|
-
|
796
804
|
# first key is the root path:
|
797
|
-
root_path = next(iter(self._stat_info
|
805
|
+
root_path = next(iter(self._stat_info))
|
798
806
|
|
799
807
|
# store efficiently:
|
800
808
|
inode_invert = {v: k for k, v in self._inode_to_path.items()}
|
801
|
-
data = {
|
802
|
-
|
803
|
-
|
804
|
-
|
809
|
+
data: dict[str, list] = {
|
810
|
+
str(PurePath(k).relative_to(root_path)): [*v, *inode_invert[k]]
|
811
|
+
for k, v in self._stat_info.items()
|
812
|
+
}
|
805
813
|
|
806
814
|
return {
|
807
815
|
"root_path": root_path,
|
@@ -809,7 +817,7 @@ class JSONLikeDirSnapShot(DirectorySnapshot):
|
|
809
817
|
}
|
810
818
|
|
811
819
|
|
812
|
-
def open_file(filename):
|
820
|
+
def open_file(filename: str | Path):
|
813
821
|
"""Open a file or directory using the default system application."""
|
814
822
|
if sys.platform == "win32":
|
815
823
|
os.startfile(filename)
|
@@ -818,59 +826,76 @@ def open_file(filename):
|
|
818
826
|
subprocess.call([opener, filename])
|
819
827
|
|
820
828
|
|
821
|
-
|
829
|
+
@overload
|
830
|
+
def get_enum_by_name_or_val(enum_cls: type[E], key: None) -> None:
|
831
|
+
...
|
832
|
+
|
833
|
+
|
834
|
+
@overload
|
835
|
+
def get_enum_by_name_or_val(enum_cls: type[E], key: str | int | float | E) -> E:
|
836
|
+
...
|
837
|
+
|
838
|
+
|
839
|
+
def get_enum_by_name_or_val(
|
840
|
+
enum_cls: type[E], key: str | int | float | E | None
|
841
|
+
) -> E | None:
|
822
842
|
"""Retrieve an enum by name or value, assuming uppercase names and integer values."""
|
823
|
-
err = f"Unknown enum key or value {key!r} for class {enum_cls!r}"
|
824
843
|
if key is None or isinstance(key, enum_cls):
|
825
844
|
return key
|
826
845
|
elif isinstance(key, (int, float)):
|
827
846
|
return enum_cls(int(key)) # retrieve by value
|
828
847
|
elif isinstance(key, str):
|
829
848
|
try:
|
830
|
-
return getattr(enum_cls, key.upper()) # retrieve by name
|
849
|
+
return cast("E", getattr(enum_cls, key.upper())) # retrieve by name
|
831
850
|
except AttributeError:
|
832
|
-
|
833
|
-
|
834
|
-
|
851
|
+
pass
|
852
|
+
raise ValueError(f"Unknown enum key or value {key!r} for class {enum_cls!r}")
|
853
|
+
|
835
854
|
|
855
|
+
_PARAM_SPLIT_RE = re.compile(r"((?:\w|\.)+)(?:\[(\w+)\])?")
|
836
856
|
|
837
|
-
|
857
|
+
|
858
|
+
def split_param_label(param_path: str) -> tuple[str, str] | tuple[None, None]:
|
838
859
|
"""Split a parameter path into the path and the label, if present."""
|
839
|
-
|
840
|
-
|
841
|
-
|
860
|
+
if match := _PARAM_SPLIT_RE.match(param_path):
|
861
|
+
return match[1], match[2]
|
862
|
+
else:
|
863
|
+
return None, None
|
842
864
|
|
843
865
|
|
844
|
-
def process_string_nodes(data, str_processor):
|
866
|
+
def process_string_nodes(data: T, str_processor: Callable[[str], str]) -> T:
|
845
867
|
"""Walk through a nested data structure and process string nodes using a provided
|
846
868
|
callable."""
|
847
869
|
|
848
870
|
if isinstance(data, dict):
|
849
|
-
|
850
|
-
|
871
|
+
return cast(
|
872
|
+
"T", {k: process_string_nodes(v, str_processor) for k, v in data.items()}
|
873
|
+
)
|
851
874
|
|
852
|
-
elif isinstance(data, (list, tuple, set)):
|
853
|
-
_data =
|
875
|
+
elif isinstance(data, (list, tuple, set, frozenset)):
|
876
|
+
_data = (process_string_nodes(i, str_processor) for i in data)
|
854
877
|
if isinstance(data, tuple):
|
855
|
-
|
878
|
+
return cast("T", tuple(_data))
|
856
879
|
elif isinstance(data, set):
|
857
|
-
|
880
|
+
return cast("T", set(_data))
|
881
|
+
elif isinstance(data, frozenset):
|
882
|
+
return cast("T", frozenset(_data))
|
858
883
|
else:
|
859
|
-
|
884
|
+
return cast("T", list(_data))
|
860
885
|
|
861
886
|
elif isinstance(data, str):
|
862
|
-
|
887
|
+
return cast("T", str_processor(data))
|
863
888
|
|
864
889
|
return data
|
865
890
|
|
866
891
|
|
867
892
|
def linspace_rect(
|
868
|
-
start:
|
869
|
-
stop:
|
870
|
-
num:
|
871
|
-
include:
|
893
|
+
start: Sequence[float],
|
894
|
+
stop: Sequence[float],
|
895
|
+
num: Sequence[int],
|
896
|
+
include: Sequence[str] | None = None,
|
872
897
|
**kwargs,
|
873
|
-
):
|
898
|
+
) -> NDArray:
|
874
899
|
"""Generate a linear space around a rectangle.
|
875
900
|
|
876
901
|
Parameters
|
@@ -892,19 +917,18 @@ def linspace_rect(
|
|
892
917
|
|
893
918
|
"""
|
894
919
|
|
895
|
-
if num[0]
|
920
|
+
if num[0] <= 1 or num[1] <= 1:
|
896
921
|
raise ValueError("Both values in `num` must be greater than 1.")
|
897
922
|
|
898
|
-
if
|
899
|
-
include = ["top", "right", "bottom", "left"]
|
923
|
+
inc = set(include) if include else {"top", "right", "bottom", "left"}
|
900
924
|
|
901
925
|
c0_range = np.linspace(start=start[0], stop=stop[0], num=num[0], **kwargs)
|
902
926
|
c1_range_all = np.linspace(start=start[1], stop=stop[1], num=num[1], **kwargs)
|
903
927
|
|
904
928
|
c1_range = c1_range_all
|
905
|
-
if "bottom" in
|
929
|
+
if "bottom" in inc:
|
906
930
|
c1_range = c1_range[1:]
|
907
|
-
if "top" in
|
931
|
+
if "top" in inc:
|
908
932
|
c1_range = c1_range[:-1]
|
909
933
|
|
910
934
|
c0_range_c1_start = np.vstack([c0_range, np.repeat(start[1], num[0])])
|
@@ -914,20 +938,21 @@ def linspace_rect(
|
|
914
938
|
c1_range_c0_stop = np.vstack([np.repeat(c0_range[-1], len(c1_range)), c1_range])
|
915
939
|
|
916
940
|
stacked = []
|
917
|
-
if "top" in
|
941
|
+
if "top" in inc:
|
918
942
|
stacked.append(c0_range_c1_stop)
|
919
|
-
if "right" in
|
943
|
+
if "right" in inc:
|
920
944
|
stacked.append(c1_range_c0_stop)
|
921
|
-
if "bottom" in
|
945
|
+
if "bottom" in inc:
|
922
946
|
stacked.append(c0_range_c1_start)
|
923
|
-
if "left" in
|
947
|
+
if "left" in inc:
|
924
948
|
stacked.append(c1_range_c0_start)
|
925
949
|
|
926
|
-
|
927
|
-
return rect
|
950
|
+
return np.hstack(stacked)
|
928
951
|
|
929
952
|
|
930
|
-
def dict_values_process_flat(
|
953
|
+
def dict_values_process_flat(
|
954
|
+
d: Mapping[T, T2 | list[T2]], callable: Callable[[list[T2]], list[T3]]
|
955
|
+
) -> Mapping[T, T3 | list[T3]]:
|
931
956
|
"""
|
932
957
|
Return a copy of a dict, where the values are processed by a callable that is to
|
933
958
|
be called only once, and where the values may be single items or lists of items.
|
@@ -939,32 +964,34 @@ def dict_values_process_flat(d, callable):
|
|
939
964
|
{'a': 1, 'b': [2, 3], 'c': 6}
|
940
965
|
|
941
966
|
"""
|
942
|
-
flat = [] # values of `d`, flattened
|
943
|
-
is_multi
|
967
|
+
flat: list[T2] = [] # values of `d`, flattened
|
968
|
+
is_multi: list[
|
969
|
+
tuple[bool, int]
|
970
|
+
] = [] # whether a list, and the number of items to process
|
944
971
|
for i in d.values():
|
945
|
-
|
946
|
-
flat.extend(i)
|
972
|
+
if isinstance(i, list):
|
973
|
+
flat.extend(cast("list[T2]", i))
|
947
974
|
is_multi.append((True, len(i)))
|
948
|
-
|
949
|
-
flat.append(i)
|
975
|
+
else:
|
976
|
+
flat.append(cast("T2", i))
|
950
977
|
is_multi.append((False, 1))
|
951
978
|
|
952
979
|
processed = callable(flat)
|
953
980
|
|
954
|
-
out = {}
|
955
|
-
for idx_i, (m, k) in enumerate(zip(is_multi, d
|
956
|
-
|
981
|
+
out: dict[T, T3 | list[T3]] = {}
|
982
|
+
for idx_i, (m, k) in enumerate(zip(is_multi, d)):
|
957
983
|
start_idx = sum(i[1] for i in is_multi[:idx_i])
|
958
984
|
end_idx = start_idx + m[1]
|
959
|
-
proc_idx_k = processed[
|
985
|
+
proc_idx_k = processed[start_idx:end_idx]
|
960
986
|
if not m[0]:
|
961
|
-
|
962
|
-
|
987
|
+
out[k] = proc_idx_k[0]
|
988
|
+
else:
|
989
|
+
out[k] = proc_idx_k
|
963
990
|
|
964
991
|
return out
|
965
992
|
|
966
993
|
|
967
|
-
def nth_key(dct, n):
|
994
|
+
def nth_key(dct: Iterable[T], n: int) -> T:
|
968
995
|
"""
|
969
996
|
Given a dict in some order, get the n'th key of that dict.
|
970
997
|
"""
|
@@ -973,8 +1000,86 @@ def nth_key(dct, n):
|
|
973
1000
|
return next(it)
|
974
1001
|
|
975
1002
|
|
976
|
-
def nth_value(dct, n):
|
1003
|
+
def nth_value(dct: dict[Any, T], n: int) -> T:
|
977
1004
|
"""
|
978
1005
|
Given a dict in some order, get the n'th value of that dict.
|
979
1006
|
"""
|
980
1007
|
return dct[nth_key(dct, n)]
|
1008
|
+
|
1009
|
+
|
1010
|
+
def normalise_timestamp(timestamp: datetime) -> datetime:
|
1011
|
+
"""
|
1012
|
+
Force a timestamp to have UTC as its timezone,
|
1013
|
+
then convert to use the local timezone.
|
1014
|
+
"""
|
1015
|
+
return timestamp.replace(tzinfo=timezone.utc).astimezone()
|
1016
|
+
|
1017
|
+
|
1018
|
+
def parse_timestamp(timestamp: str | datetime, ts_fmt: str) -> datetime:
|
1019
|
+
"""
|
1020
|
+
Standard timestamp parsing.
|
1021
|
+
Ensures that timestamps are internally all UTC.
|
1022
|
+
"""
|
1023
|
+
return normalise_timestamp(
|
1024
|
+
timestamp
|
1025
|
+
if isinstance(timestamp, datetime)
|
1026
|
+
else datetime.strptime(timestamp, ts_fmt)
|
1027
|
+
)
|
1028
|
+
|
1029
|
+
|
1030
|
+
def current_timestamp() -> datetime:
|
1031
|
+
"""
|
1032
|
+
Get a UTC timestamp for the current time
|
1033
|
+
"""
|
1034
|
+
return datetime.now(timezone.utc)
|
1035
|
+
|
1036
|
+
|
1037
|
+
def timedelta_format(td: timedelta) -> str:
|
1038
|
+
"""
|
1039
|
+
Convert time delta to string in standard form.
|
1040
|
+
"""
|
1041
|
+
days, seconds = td.days, td.seconds
|
1042
|
+
hours = seconds // (60 * 60)
|
1043
|
+
seconds -= hours * (60 * 60)
|
1044
|
+
minutes = seconds // 60
|
1045
|
+
seconds -= minutes * 60
|
1046
|
+
return f"{days}-{hours:02}:{minutes:02}:{seconds:02}"
|
1047
|
+
|
1048
|
+
|
1049
|
+
_TD_RE = re.compile(r"(\d+)-(\d+):(\d+):(\d+)")
|
1050
|
+
|
1051
|
+
|
1052
|
+
def timedelta_parse(td_str: str) -> timedelta:
|
1053
|
+
"""
|
1054
|
+
Parse a string in standard form as a time delta.
|
1055
|
+
"""
|
1056
|
+
if not (m := _TD_RE.fullmatch(td_str)):
|
1057
|
+
raise ValueError("not a supported timedelta form")
|
1058
|
+
days, hours, mins, secs = map(int, m.groups())
|
1059
|
+
return timedelta(days=days, hours=hours, minutes=mins, seconds=secs)
|
1060
|
+
|
1061
|
+
|
1062
|
+
def open_text_resource(package: ModuleType | str, resource: str) -> IO[str]:
|
1063
|
+
"""
|
1064
|
+
Open a file in a package.
|
1065
|
+
"""
|
1066
|
+
try:
|
1067
|
+
return resources.files(package).joinpath(resource).open("r")
|
1068
|
+
except AttributeError:
|
1069
|
+
# < python 3.9; `resource.open_text` deprecated since 3.11
|
1070
|
+
return resources.open_text(package, resource)
|
1071
|
+
|
1072
|
+
|
1073
|
+
def get_file_context(
|
1074
|
+
package: ModuleType | str, src: str | None = None
|
1075
|
+
) -> AbstractContextManager[Path]:
|
1076
|
+
"""
|
1077
|
+
Find a file or directory in a package.
|
1078
|
+
"""
|
1079
|
+
try:
|
1080
|
+
files = resources.files(package)
|
1081
|
+
return resources.as_file(files.joinpath(src) if src else files)
|
1082
|
+
# raises ModuleNotFoundError
|
1083
|
+
except AttributeError:
|
1084
|
+
# < python 3.9
|
1085
|
+
return resources.path(package, src or "")
|