hpcflow-new2 0.2.0a189__py3-none-any.whl → 0.2.0a199__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hpcflow/__pyinstaller/hook-hpcflow.py +9 -6
- hpcflow/_version.py +1 -1
- hpcflow/app.py +1 -0
- hpcflow/data/scripts/bad_script.py +2 -0
- hpcflow/data/scripts/do_nothing.py +2 -0
- hpcflow/data/scripts/env_specifier_test/input_file_generator_pass_env_spec.py +4 -0
- hpcflow/data/scripts/env_specifier_test/main_script_test_pass_env_spec.py +8 -0
- hpcflow/data/scripts/env_specifier_test/output_file_parser_pass_env_spec.py +4 -0
- hpcflow/data/scripts/env_specifier_test/v1/input_file_generator_basic.py +4 -0
- hpcflow/data/scripts/env_specifier_test/v1/main_script_test_direct_in_direct_out.py +7 -0
- hpcflow/data/scripts/env_specifier_test/v1/output_file_parser_basic.py +4 -0
- hpcflow/data/scripts/env_specifier_test/v2/main_script_test_direct_in_direct_out.py +7 -0
- hpcflow/data/scripts/input_file_generator_basic.py +3 -0
- hpcflow/data/scripts/input_file_generator_basic_FAIL.py +3 -0
- hpcflow/data/scripts/input_file_generator_test_stdout_stderr.py +8 -0
- hpcflow/data/scripts/main_script_test_direct_in.py +3 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_2.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed_group.py +7 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_3.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_group_direct_out_3.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_group_one_fail_direct_out_3.py +6 -0
- hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +1 -1
- hpcflow/data/scripts/main_script_test_hdf5_in_obj_2.py +12 -0
- hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +1 -1
- hpcflow/data/scripts/main_script_test_json_out_FAIL.py +3 -0
- hpcflow/data/scripts/main_script_test_shell_env_vars.py +12 -0
- hpcflow/data/scripts/main_script_test_std_out_std_err.py +6 -0
- hpcflow/data/scripts/output_file_parser_basic.py +3 -0
- hpcflow/data/scripts/output_file_parser_basic_FAIL.py +7 -0
- hpcflow/data/scripts/output_file_parser_test_stdout_stderr.py +8 -0
- hpcflow/data/scripts/script_exit_test.py +5 -0
- hpcflow/data/template_components/environments.yaml +1 -1
- hpcflow/sdk/__init__.py +26 -15
- hpcflow/sdk/app.py +2192 -768
- hpcflow/sdk/cli.py +506 -296
- hpcflow/sdk/cli_common.py +105 -7
- hpcflow/sdk/config/__init__.py +1 -1
- hpcflow/sdk/config/callbacks.py +115 -43
- hpcflow/sdk/config/cli.py +126 -103
- hpcflow/sdk/config/config.py +674 -318
- hpcflow/sdk/config/config_file.py +131 -95
- hpcflow/sdk/config/errors.py +125 -84
- hpcflow/sdk/config/types.py +148 -0
- hpcflow/sdk/core/__init__.py +25 -1
- hpcflow/sdk/core/actions.py +1771 -1059
- hpcflow/sdk/core/app_aware.py +24 -0
- hpcflow/sdk/core/cache.py +139 -79
- hpcflow/sdk/core/command_files.py +263 -287
- hpcflow/sdk/core/commands.py +145 -112
- hpcflow/sdk/core/element.py +828 -535
- hpcflow/sdk/core/enums.py +192 -0
- hpcflow/sdk/core/environment.py +74 -93
- hpcflow/sdk/core/errors.py +455 -52
- hpcflow/sdk/core/execute.py +207 -0
- hpcflow/sdk/core/json_like.py +540 -272
- hpcflow/sdk/core/loop.py +751 -347
- hpcflow/sdk/core/loop_cache.py +164 -47
- hpcflow/sdk/core/object_list.py +370 -207
- hpcflow/sdk/core/parameters.py +1100 -627
- hpcflow/sdk/core/rule.py +59 -41
- hpcflow/sdk/core/run_dir_files.py +21 -37
- hpcflow/sdk/core/skip_reason.py +7 -0
- hpcflow/sdk/core/task.py +1649 -1339
- hpcflow/sdk/core/task_schema.py +308 -196
- hpcflow/sdk/core/test_utils.py +191 -114
- hpcflow/sdk/core/types.py +440 -0
- hpcflow/sdk/core/utils.py +485 -309
- hpcflow/sdk/core/validation.py +82 -9
- hpcflow/sdk/core/workflow.py +2544 -1178
- hpcflow/sdk/core/zarr_io.py +98 -137
- hpcflow/sdk/data/workflow_spec_schema.yaml +2 -0
- hpcflow/sdk/demo/cli.py +53 -33
- hpcflow/sdk/helper/cli.py +18 -15
- hpcflow/sdk/helper/helper.py +75 -63
- hpcflow/sdk/helper/watcher.py +61 -28
- hpcflow/sdk/log.py +122 -71
- hpcflow/sdk/persistence/__init__.py +8 -31
- hpcflow/sdk/persistence/base.py +1360 -606
- hpcflow/sdk/persistence/defaults.py +6 -0
- hpcflow/sdk/persistence/discovery.py +38 -0
- hpcflow/sdk/persistence/json.py +568 -188
- hpcflow/sdk/persistence/pending.py +382 -179
- hpcflow/sdk/persistence/store_resource.py +39 -23
- hpcflow/sdk/persistence/types.py +318 -0
- hpcflow/sdk/persistence/utils.py +14 -11
- hpcflow/sdk/persistence/zarr.py +1337 -433
- hpcflow/sdk/runtime.py +44 -41
- hpcflow/sdk/submission/{jobscript_info.py → enums.py} +39 -12
- hpcflow/sdk/submission/jobscript.py +1651 -692
- hpcflow/sdk/submission/schedulers/__init__.py +167 -39
- hpcflow/sdk/submission/schedulers/direct.py +121 -81
- hpcflow/sdk/submission/schedulers/sge.py +170 -129
- hpcflow/sdk/submission/schedulers/slurm.py +291 -268
- hpcflow/sdk/submission/schedulers/utils.py +12 -2
- hpcflow/sdk/submission/shells/__init__.py +14 -15
- hpcflow/sdk/submission/shells/base.py +150 -29
- hpcflow/sdk/submission/shells/bash.py +283 -173
- hpcflow/sdk/submission/shells/os_version.py +31 -30
- hpcflow/sdk/submission/shells/powershell.py +228 -170
- hpcflow/sdk/submission/submission.py +1014 -335
- hpcflow/sdk/submission/types.py +140 -0
- hpcflow/sdk/typing.py +182 -12
- hpcflow/sdk/utils/arrays.py +71 -0
- hpcflow/sdk/utils/deferred_file.py +55 -0
- hpcflow/sdk/utils/hashing.py +16 -0
- hpcflow/sdk/utils/patches.py +12 -0
- hpcflow/sdk/utils/strings.py +33 -0
- hpcflow/tests/api/test_api.py +32 -0
- hpcflow/tests/conftest.py +27 -6
- hpcflow/tests/data/multi_path_sequences.yaml +29 -0
- hpcflow/tests/data/workflow_test_run_abort.yaml +34 -35
- hpcflow/tests/schedulers/sge/test_sge_submission.py +36 -0
- hpcflow/tests/schedulers/slurm/test_slurm_submission.py +5 -2
- hpcflow/tests/scripts/test_input_file_generators.py +282 -0
- hpcflow/tests/scripts/test_main_scripts.py +866 -85
- hpcflow/tests/scripts/test_non_snippet_script.py +46 -0
- hpcflow/tests/scripts/test_ouput_file_parsers.py +353 -0
- hpcflow/tests/shells/wsl/test_wsl_submission.py +12 -4
- hpcflow/tests/unit/test_action.py +262 -75
- hpcflow/tests/unit/test_action_rule.py +9 -4
- hpcflow/tests/unit/test_app.py +33 -6
- hpcflow/tests/unit/test_cache.py +46 -0
- hpcflow/tests/unit/test_cli.py +134 -1
- hpcflow/tests/unit/test_command.py +71 -54
- hpcflow/tests/unit/test_config.py +142 -16
- hpcflow/tests/unit/test_config_file.py +21 -18
- hpcflow/tests/unit/test_element.py +58 -62
- hpcflow/tests/unit/test_element_iteration.py +50 -1
- hpcflow/tests/unit/test_element_set.py +29 -19
- hpcflow/tests/unit/test_group.py +4 -2
- hpcflow/tests/unit/test_input_source.py +116 -93
- hpcflow/tests/unit/test_input_value.py +29 -24
- hpcflow/tests/unit/test_jobscript_unit.py +757 -0
- hpcflow/tests/unit/test_json_like.py +44 -35
- hpcflow/tests/unit/test_loop.py +1396 -84
- hpcflow/tests/unit/test_meta_task.py +325 -0
- hpcflow/tests/unit/test_multi_path_sequences.py +229 -0
- hpcflow/tests/unit/test_object_list.py +17 -12
- hpcflow/tests/unit/test_parameter.py +29 -7
- hpcflow/tests/unit/test_persistence.py +237 -42
- hpcflow/tests/unit/test_resources.py +20 -18
- hpcflow/tests/unit/test_run.py +117 -6
- hpcflow/tests/unit/test_run_directories.py +29 -0
- hpcflow/tests/unit/test_runtime.py +2 -1
- hpcflow/tests/unit/test_schema_input.py +23 -15
- hpcflow/tests/unit/test_shell.py +23 -2
- hpcflow/tests/unit/test_slurm.py +8 -7
- hpcflow/tests/unit/test_submission.py +38 -89
- hpcflow/tests/unit/test_task.py +352 -247
- hpcflow/tests/unit/test_task_schema.py +33 -20
- hpcflow/tests/unit/test_utils.py +9 -11
- hpcflow/tests/unit/test_value_sequence.py +15 -12
- hpcflow/tests/unit/test_workflow.py +114 -83
- hpcflow/tests/unit/test_workflow_template.py +0 -1
- hpcflow/tests/unit/utils/test_arrays.py +40 -0
- hpcflow/tests/unit/utils/test_deferred_file_writer.py +34 -0
- hpcflow/tests/unit/utils/test_hashing.py +65 -0
- hpcflow/tests/unit/utils/test_patches.py +5 -0
- hpcflow/tests/unit/utils/test_redirect_std.py +50 -0
- hpcflow/tests/workflows/__init__.py +0 -0
- hpcflow/tests/workflows/test_directory_structure.py +31 -0
- hpcflow/tests/workflows/test_jobscript.py +334 -1
- hpcflow/tests/workflows/test_run_status.py +198 -0
- hpcflow/tests/workflows/test_skip_downstream.py +696 -0
- hpcflow/tests/workflows/test_submission.py +140 -0
- hpcflow/tests/workflows/test_workflows.py +160 -15
- hpcflow/tests/workflows/test_zip.py +18 -0
- hpcflow/viz_demo.ipynb +6587 -3
- {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/METADATA +8 -4
- hpcflow_new2-0.2.0a199.dist-info/RECORD +221 -0
- hpcflow/sdk/core/parallel.py +0 -21
- hpcflow_new2-0.2.0a189.dist-info/RECORD +0 -158
- {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/LICENSE +0 -0
- {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/WHEEL +0 -0
- {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/entry_points.txt +0 -0
hpcflow/sdk/core/utils.py
CHANGED
@@ -2,12 +2,17 @@
|
|
2
2
|
Miscellaneous utilities.
|
3
3
|
"""
|
4
4
|
|
5
|
+
from __future__ import annotations
|
6
|
+
from collections import Counter
|
7
|
+
from asyncio import events
|
8
|
+
import contextvars
|
9
|
+
import contextlib
|
5
10
|
import copy
|
6
11
|
import enum
|
7
|
-
|
8
|
-
import contextlib
|
12
|
+
import functools
|
9
13
|
import hashlib
|
10
14
|
from itertools import accumulate, islice
|
15
|
+
from importlib import resources
|
11
16
|
import json
|
12
17
|
import keyword
|
13
18
|
import os
|
@@ -17,10 +22,11 @@ import re
|
|
17
22
|
import socket
|
18
23
|
import string
|
19
24
|
import subprocess
|
20
|
-
from datetime import datetime, timezone
|
25
|
+
from datetime import datetime, timedelta, timezone
|
21
26
|
import sys
|
22
|
-
|
23
|
-
import
|
27
|
+
import traceback
|
28
|
+
from typing import Literal, cast, overload, TypeVar, TYPE_CHECKING
|
29
|
+
import fsspec # type: ignore
|
24
30
|
import numpy as np
|
25
31
|
|
26
32
|
from ruamel.yaml import YAML
|
@@ -28,27 +34,30 @@ from watchdog.utils.dirsnapshot import DirectorySnapshot
|
|
28
34
|
|
29
35
|
from hpcflow.sdk.core.errors import (
|
30
36
|
ContainerKeyError,
|
31
|
-
FromSpecMissingObjectError,
|
32
37
|
InvalidIdentifier,
|
33
38
|
MissingVariableSubstitutionError,
|
34
39
|
)
|
35
40
|
from hpcflow.sdk.log import TimeIt
|
36
|
-
from hpcflow.sdk.
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
41
|
+
from hpcflow.sdk.utils.deferred_file import DeferredFileWriter
|
42
|
+
|
43
|
+
if TYPE_CHECKING:
|
44
|
+
from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence
|
45
|
+
from contextlib import AbstractContextManager
|
46
|
+
from types import ModuleType
|
47
|
+
from typing import Any, IO, Iterator
|
48
|
+
from typing_extensions import TypeAlias
|
49
|
+
from numpy.typing import NDArray
|
50
|
+
from ..typing import PathLike
|
51
|
+
|
52
|
+
T = TypeVar("T")
|
53
|
+
T2 = TypeVar("T2")
|
54
|
+
T3 = TypeVar("T3")
|
55
|
+
TList: TypeAlias = "T | list[TList]"
|
56
|
+
TD = TypeVar("TD", bound="Mapping[str, Any]")
|
57
|
+
E = TypeVar("E", bound=enum.Enum)
|
58
|
+
|
59
|
+
|
60
|
+
def make_workflow_id() -> str:
|
52
61
|
"""
|
53
62
|
Generate a random ID for a workflow.
|
54
63
|
"""
|
@@ -57,14 +66,14 @@ def make_workflow_id():
|
|
57
66
|
return "".join(random.choices(chars, k=length))
|
58
67
|
|
59
68
|
|
60
|
-
def get_time_stamp():
|
69
|
+
def get_time_stamp() -> str:
|
61
70
|
"""
|
62
71
|
Get the current time in standard string form.
|
63
72
|
"""
|
64
73
|
return datetime.now(timezone.utc).astimezone().strftime("%Y.%m.%d_%H:%M:%S_%z")
|
65
74
|
|
66
75
|
|
67
|
-
def get_duplicate_items(lst):
|
76
|
+
def get_duplicate_items(lst: Iterable[T]) -> list[T]:
|
68
77
|
"""Get a list of all items in an iterable that appear more than once, assuming items
|
69
78
|
are hashable.
|
70
79
|
|
@@ -77,11 +86,10 @@ def get_duplicate_items(lst):
|
|
77
86
|
[]
|
78
87
|
|
79
88
|
>>> get_duplicate_items([1, 2, 3, 3, 3, 2])
|
80
|
-
[2, 3
|
89
|
+
[2, 3]
|
81
90
|
|
82
91
|
"""
|
83
|
-
|
84
|
-
return list(set(x for x in lst if x in seen or seen.append(x)))
|
92
|
+
return [x for x, y in Counter(lst).items() if y > 1]
|
85
93
|
|
86
94
|
|
87
95
|
def check_valid_py_identifier(name: str) -> str:
|
@@ -105,29 +113,42 @@ def check_valid_py_identifier(name: str) -> str:
|
|
105
113
|
- `Loop.name`
|
106
114
|
|
107
115
|
"""
|
108
|
-
exc = InvalidIdentifier(f"Invalid string for identifier: {name!r}")
|
109
116
|
try:
|
110
117
|
trial_name = name[1:].replace("_", "") # "internal" underscores are allowed
|
111
118
|
except TypeError:
|
112
|
-
raise
|
119
|
+
raise InvalidIdentifier(name) from None
|
120
|
+
except KeyError as e:
|
121
|
+
raise KeyError(f"unexpected name type {name}") from e
|
113
122
|
if (
|
114
123
|
not name
|
115
124
|
or not (name[0].isalpha() and ((trial_name[1:] or "a").isalnum()))
|
116
125
|
or keyword.iskeyword(name)
|
117
126
|
):
|
118
|
-
raise
|
127
|
+
raise InvalidIdentifier(name)
|
119
128
|
|
120
129
|
return name
|
121
130
|
|
122
131
|
|
123
|
-
|
132
|
+
@overload
|
133
|
+
def group_by_dict_key_values( # type: ignore[overload-overlap]
|
134
|
+
lst: list[dict[T, T2]], key: T
|
135
|
+
) -> list[list[dict[T, T2]]]:
|
136
|
+
...
|
137
|
+
|
138
|
+
|
139
|
+
@overload
|
140
|
+
def group_by_dict_key_values(lst: list[TD], key: str) -> list[list[TD]]:
|
141
|
+
...
|
142
|
+
|
143
|
+
|
144
|
+
def group_by_dict_key_values(lst: list, key):
|
124
145
|
"""Group a list of dicts according to specified equivalent key-values.
|
125
146
|
|
126
147
|
Parameters
|
127
148
|
----------
|
128
149
|
lst : list of dict
|
129
150
|
The list of dicts to group together.
|
130
|
-
|
151
|
+
key : key value
|
131
152
|
Dicts that have identical values for all of these keys will be grouped together
|
132
153
|
into a sub-list.
|
133
154
|
|
@@ -146,10 +167,10 @@ def group_by_dict_key_values(lst, *keys):
|
|
146
167
|
for lst_item in lst[1:]:
|
147
168
|
for group_idx, group in enumerate(grouped):
|
148
169
|
try:
|
149
|
-
is_vals_equal =
|
170
|
+
is_vals_equal = lst_item[key] == group[0][key]
|
150
171
|
|
151
172
|
except KeyError:
|
152
|
-
# dicts that do not have
|
173
|
+
# dicts that do not have the `key` will be in their own group:
|
153
174
|
is_vals_equal = False
|
154
175
|
|
155
176
|
if is_vals_equal:
|
@@ -162,7 +183,7 @@ def group_by_dict_key_values(lst, *keys):
|
|
162
183
|
return grouped
|
163
184
|
|
164
185
|
|
165
|
-
def swap_nested_dict_keys(dct, inner_key):
|
186
|
+
def swap_nested_dict_keys(dct: dict[T, dict[T2, T3]], inner_key: T2):
|
166
187
|
"""Return a copy where top-level keys have been swapped with a second-level inner key.
|
167
188
|
|
168
189
|
Examples:
|
@@ -181,16 +202,35 @@ def swap_nested_dict_keys(dct, inner_key):
|
|
181
202
|
}
|
182
203
|
|
183
204
|
"""
|
184
|
-
out = {}
|
205
|
+
out: dict[T3, dict[T, dict[T2, T3]]] = {}
|
185
206
|
for k, v in copy.deepcopy(dct or {}).items():
|
186
|
-
|
187
|
-
if inner_val not in out:
|
188
|
-
out[inner_val] = {}
|
189
|
-
out[inner_val][k] = v
|
207
|
+
out.setdefault(v.pop(inner_key), {})[k] = v
|
190
208
|
return out
|
191
209
|
|
192
210
|
|
193
|
-
def
|
211
|
+
def _ensure_int(path_comp: Any, cur_data: Any, cast_indices: bool) -> int:
|
212
|
+
"""
|
213
|
+
Helper for get_in_container() and set_in_container()
|
214
|
+
"""
|
215
|
+
if isinstance(path_comp, int):
|
216
|
+
return path_comp
|
217
|
+
if not cast_indices:
|
218
|
+
raise TypeError(
|
219
|
+
f"Path component {path_comp!r} must be an integer index "
|
220
|
+
f"since data is a sequence: {cur_data!r}."
|
221
|
+
)
|
222
|
+
try:
|
223
|
+
return int(path_comp)
|
224
|
+
except (TypeError, ValueError) as e:
|
225
|
+
raise TypeError(
|
226
|
+
f"Path component {path_comp!r} must be an integer index "
|
227
|
+
f"since data is a sequence: {cur_data!r}."
|
228
|
+
) from e
|
229
|
+
|
230
|
+
|
231
|
+
def get_in_container(
|
232
|
+
cont, path: Sequence, cast_indices: bool = False, allow_getattr: bool = False
|
233
|
+
):
|
194
234
|
"""
|
195
235
|
Follow a path (sequence of indices of appropriate type) into a container to obtain
|
196
236
|
a "leaf" value. Containers can be lists, tuples, dicts,
|
@@ -203,24 +243,12 @@ def get_in_container(cont, path, cast_indices=False, allow_getattr=False):
|
|
203
243
|
)
|
204
244
|
for idx, path_comp in enumerate(path):
|
205
245
|
if isinstance(cur_data, (list, tuple)):
|
206
|
-
|
207
|
-
|
208
|
-
f"Path component {path_comp!r} must be an integer index "
|
209
|
-
f"since data is a sequence: {cur_data!r}."
|
210
|
-
)
|
211
|
-
if cast_indices:
|
212
|
-
try:
|
213
|
-
path_comp = int(path_comp)
|
214
|
-
except TypeError:
|
215
|
-
raise TypeError(msg)
|
216
|
-
else:
|
217
|
-
raise TypeError(msg)
|
218
|
-
cur_data = cur_data[path_comp]
|
219
|
-
elif isinstance(cur_data, dict):
|
246
|
+
cur_data = cur_data[_ensure_int(path_comp, cur_data, cast_indices)]
|
247
|
+
elif isinstance(cur_data, dict) or hasattr(cur_data, "__getitem__"):
|
220
248
|
try:
|
221
249
|
cur_data = cur_data[path_comp]
|
222
250
|
except KeyError:
|
223
|
-
raise ContainerKeyError(path=path[: idx + 1])
|
251
|
+
raise ContainerKeyError(path=cast("list[str]", path[: idx + 1]))
|
224
252
|
elif allow_getattr:
|
225
253
|
try:
|
226
254
|
cur_data = getattr(cur_data, path_comp)
|
@@ -235,7 +263,9 @@ def get_in_container(cont, path, cast_indices=False, allow_getattr=False):
|
|
235
263
|
return cur_data
|
236
264
|
|
237
265
|
|
238
|
-
def set_in_container(
|
266
|
+
def set_in_container(
|
267
|
+
cont, path: Sequence, value, ensure_path=False, cast_indices=False
|
268
|
+
) -> None:
|
239
269
|
"""
|
240
270
|
Follow a path (sequence of indices of appropriate type) into a container to update
|
241
271
|
a "leaf" value. Containers can be lists, tuples or dicts.
|
@@ -258,22 +288,11 @@ def set_in_container(cont, path, value, ensure_path=False, cast_indices=False):
|
|
258
288
|
sub_data = get_in_container(cont, path[:-1], cast_indices=cast_indices)
|
259
289
|
path_comp = path[-1]
|
260
290
|
if isinstance(sub_data, (list, tuple)):
|
261
|
-
|
262
|
-
msg = (
|
263
|
-
f"Path component {path_comp!r} must be an integer index "
|
264
|
-
f"since data is a sequence: {sub_data!r}."
|
265
|
-
)
|
266
|
-
if cast_indices:
|
267
|
-
try:
|
268
|
-
path_comp = int(path_comp)
|
269
|
-
except ValueError:
|
270
|
-
raise ValueError(msg)
|
271
|
-
else:
|
272
|
-
raise ValueError(msg)
|
291
|
+
path_comp = _ensure_int(path_comp, sub_data, cast_indices)
|
273
292
|
sub_data[path_comp] = value
|
274
293
|
|
275
294
|
|
276
|
-
def get_relative_path(path1, path2):
|
295
|
+
def get_relative_path(path1: Sequence[T], path2: Sequence[T]) -> Sequence[T]:
|
277
296
|
"""Get relative path components between two paths.
|
278
297
|
|
279
298
|
Parameters
|
@@ -308,79 +327,39 @@ def get_relative_path(path1, path2):
|
|
308
327
|
"""
|
309
328
|
|
310
329
|
len_path2 = len(path2)
|
311
|
-
|
312
|
-
|
313
|
-
if len(path1) < len_path2:
|
314
|
-
raise ValueError(msg)
|
315
|
-
|
316
|
-
for i, j in zip(path1[:len_path2], path2):
|
317
|
-
if i != j:
|
318
|
-
raise ValueError(msg)
|
330
|
+
if len(path1) < len_path2 or any(i != j for i, j in zip(path1[:len_path2], path2)):
|
331
|
+
raise ValueError(f"{path1!r} is not in the subpath of {path2!r}.")
|
319
332
|
|
320
333
|
return path1[len_path2:]
|
321
334
|
|
322
335
|
|
323
|
-
def search_dir_files_by_regex(
|
336
|
+
def search_dir_files_by_regex(
|
337
|
+
pattern: str | re.Pattern[str], directory: str | os.PathLike = "."
|
338
|
+
) -> list[str]:
|
324
339
|
"""Search recursively for files in a directory by a regex pattern and return matching
|
325
340
|
file paths, relative to the given directory."""
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
match = match_groups[group]
|
333
|
-
vals.append(str(i.relative_to(directory)))
|
334
|
-
return vals
|
341
|
+
dir_ = Path(directory)
|
342
|
+
return [
|
343
|
+
str(entry.relative_to(dir_))
|
344
|
+
for entry in dir_.rglob("*")
|
345
|
+
if re.search(pattern, entry.name)
|
346
|
+
]
|
335
347
|
|
336
348
|
|
337
|
-
class
|
338
|
-
"""
|
339
|
-
Simple class property decorator.
|
340
|
-
"""
|
341
|
-
|
342
|
-
def __init__(self, f):
|
343
|
-
self.f = f
|
344
|
-
|
345
|
-
def __get__(self, obj, owner):
|
346
|
-
return self.f(owner)
|
347
|
-
|
348
|
-
|
349
|
-
class PrettyPrinter(object):
|
349
|
+
class PrettyPrinter:
|
350
350
|
"""
|
351
351
|
A class that produces a nice readable version of itself with ``str()``.
|
352
352
|
Intended to be subclassed.
|
353
353
|
"""
|
354
354
|
|
355
|
-
def __str__(self):
|
355
|
+
def __str__(self) -> str:
|
356
356
|
lines = [self.__class__.__name__ + ":"]
|
357
357
|
for key, val in vars(self).items():
|
358
|
-
lines
|
358
|
+
lines.extend(f"{key}: {val}".split("\n"))
|
359
359
|
return "\n ".join(lines)
|
360
360
|
|
361
361
|
|
362
|
-
|
363
|
-
"""
|
364
|
-
Metaclass that enforces that only one instance can exist of the classes to which it
|
365
|
-
is applied.
|
366
|
-
"""
|
367
|
-
|
368
|
-
_instances = {}
|
369
|
-
|
370
|
-
def __call__(cls, *args, **kwargs):
|
371
|
-
if cls not in cls._instances:
|
372
|
-
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
373
|
-
elif args or kwargs:
|
374
|
-
# if existing instance, make the point that new arguments don't do anything!
|
375
|
-
raise ValueError(
|
376
|
-
f"{cls.__name__!r} is a singleton class and cannot be instantiated with new "
|
377
|
-
f"arguments. The positional arguments {args!r} and keyword-arguments "
|
378
|
-
f"{kwargs!r} have been ignored."
|
379
|
-
)
|
380
|
-
return cls._instances[cls]
|
381
|
-
|
382
|
-
|
383
|
-
def capitalise_first_letter(chars):
|
362
|
+
def capitalise_first_letter(chars: str) -> str:
|
384
363
|
"""
|
385
364
|
Convert the first character of a string to upper case (if that makes sense).
|
386
365
|
The rest of the string is unchanged.
|
@@ -388,30 +367,11 @@ def capitalise_first_letter(chars):
|
|
388
367
|
return chars[0].upper() + chars[1:]
|
389
368
|
|
390
369
|
|
391
|
-
|
392
|
-
"""Decorator factory for the various `from_spec` class methods that have attributes
|
393
|
-
that should be replaced by an object from an object list."""
|
394
|
-
|
395
|
-
def decorator(func):
|
396
|
-
@wraps(func)
|
397
|
-
def wrap(*args, **kwargs):
|
398
|
-
spec = args[spec_pos]
|
399
|
-
obj_list = args[obj_list_pos]
|
400
|
-
if spec[spec_name] not in obj_list:
|
401
|
-
cls_name = args[0].__name__
|
402
|
-
raise FromSpecMissingObjectError(
|
403
|
-
f"A {spec_name!r} object required to instantiate the {cls_name!r} "
|
404
|
-
f"object is missing."
|
405
|
-
)
|
406
|
-
return func(*args, **kwargs)
|
407
|
-
|
408
|
-
return wrap
|
409
|
-
|
410
|
-
return decorator
|
370
|
+
_STRING_VARS_RE = re.compile(r"\<\<var:(.*?)(?:\[(.*)\])?\>\>")
|
411
371
|
|
412
372
|
|
413
373
|
@TimeIt.decorator
|
414
|
-
def substitute_string_vars(string, variables:
|
374
|
+
def substitute_string_vars(string: str, variables: dict[str, str]):
|
415
375
|
"""
|
416
376
|
Scan ``string`` and substitute sequences like ``<<var:ABC>>`` with the value
|
417
377
|
looked up in the supplied dictionary (with ``ABC`` as the key).
|
@@ -424,15 +384,14 @@ def substitute_string_vars(string, variables: Dict[str, str] = None):
|
|
424
384
|
>>> substitute_string_vars("abc <<var:def>> ghi", {"def": "123"})
|
425
385
|
"abc 123 def"
|
426
386
|
"""
|
427
|
-
variables = variables or {}
|
428
387
|
|
429
|
-
def var_repl(match_obj):
|
430
|
-
kwargs = {}
|
431
|
-
var_name
|
388
|
+
def var_repl(match_obj: re.Match[str]) -> str:
|
389
|
+
kwargs: dict[str, str] = {}
|
390
|
+
var_name: str = match_obj[1]
|
391
|
+
kwargs_str: str | None = match_obj[2]
|
432
392
|
if kwargs_str:
|
433
|
-
|
434
|
-
|
435
|
-
k, v = i.strip().split("=")
|
393
|
+
for i in kwargs_str.split(","):
|
394
|
+
k, v = i.split("=")
|
436
395
|
kwargs[k.strip()] = v.strip()
|
437
396
|
try:
|
438
397
|
out = str(variables[var_name])
|
@@ -444,65 +403,69 @@ def substitute_string_vars(string, variables: Dict[str, str] = None):
|
|
444
403
|
f"variable {var_name!r}."
|
445
404
|
)
|
446
405
|
else:
|
447
|
-
raise MissingVariableSubstitutionError(
|
448
|
-
f"The variable {var_name!r} referenced in the string does not match "
|
449
|
-
f"any of the provided variables: {list(variables)!r}."
|
450
|
-
)
|
406
|
+
raise MissingVariableSubstitutionError(var_name, variables)
|
451
407
|
return out
|
452
408
|
|
453
|
-
|
454
|
-
pattern=r"\<\<var:(.*?)(?:\[(.*)\])?\>\>",
|
409
|
+
return _STRING_VARS_RE.sub(
|
455
410
|
repl=var_repl,
|
456
411
|
string=string,
|
457
412
|
)
|
458
|
-
return new_str
|
459
413
|
|
460
414
|
|
461
415
|
@TimeIt.decorator
|
462
|
-
def read_YAML_str(
|
416
|
+
def read_YAML_str(
|
417
|
+
yaml_str: str, typ="safe", variables: dict[str, str] | None = None
|
418
|
+
) -> Any:
|
463
419
|
"""Load a YAML string. This will produce basic objects."""
|
464
|
-
if variables is not
|
420
|
+
if variables is not None and "<<var:" in yaml_str:
|
465
421
|
yaml_str = substitute_string_vars(yaml_str, variables=variables)
|
466
422
|
yaml = YAML(typ=typ)
|
467
423
|
return yaml.load(yaml_str)
|
468
424
|
|
469
425
|
|
470
426
|
@TimeIt.decorator
|
471
|
-
def read_YAML_file(
|
427
|
+
def read_YAML_file(
|
428
|
+
path: PathLike, typ="safe", variables: dict[str, str] | None = None
|
429
|
+
) -> Any:
|
472
430
|
"""Load a YAML file. This will produce basic objects."""
|
473
431
|
with fsspec.open(path, "rt") as f:
|
474
|
-
yaml_str = f.read()
|
432
|
+
yaml_str: str = f.read()
|
475
433
|
return read_YAML_str(yaml_str, typ=typ, variables=variables)
|
476
434
|
|
477
435
|
|
478
|
-
def write_YAML_file(obj, path:
|
436
|
+
def write_YAML_file(obj, path: str | Path, typ: str = "safe") -> None:
|
479
437
|
"""Write a basic object to a YAML file."""
|
480
438
|
yaml = YAML(typ=typ)
|
481
439
|
with Path(path).open("wt") as fp:
|
482
440
|
yaml.dump(obj, fp)
|
483
441
|
|
484
442
|
|
485
|
-
def read_JSON_string(json_str: str, variables:
|
443
|
+
def read_JSON_string(json_str: str, variables: dict[str, str] | None = None) -> Any:
|
486
444
|
"""Load a JSON string. This will produce basic objects."""
|
487
|
-
if variables is not
|
445
|
+
if variables is not None and "<<var:" in json_str:
|
488
446
|
json_str = substitute_string_vars(json_str, variables=variables)
|
489
447
|
return json.loads(json_str)
|
490
448
|
|
491
449
|
|
492
|
-
def read_JSON_file(path, variables:
|
450
|
+
def read_JSON_file(path, variables: dict[str, str] | None = None) -> Any:
|
493
451
|
"""Load a JSON file. This will produce basic objects."""
|
494
452
|
with fsspec.open(path, "rt") as f:
|
495
|
-
json_str = f.read()
|
453
|
+
json_str: str = f.read()
|
496
454
|
return read_JSON_string(json_str, variables=variables)
|
497
455
|
|
498
456
|
|
499
|
-
def write_JSON_file(obj, path:
|
457
|
+
def write_JSON_file(obj, path: str | Path) -> None:
|
500
458
|
"""Write a basic object to a JSON file."""
|
501
459
|
with Path(path).open("wt") as fp:
|
502
460
|
json.dump(obj, fp)
|
503
461
|
|
504
462
|
|
505
|
-
def get_item_repeat_index(
|
463
|
+
def get_item_repeat_index(
|
464
|
+
lst: Sequence[T],
|
465
|
+
*,
|
466
|
+
distinguish_singular: bool = False,
|
467
|
+
item_callable: Callable[[T], Hashable] | None = None,
|
468
|
+
):
|
506
469
|
"""Get the repeat index for each item in a list.
|
507
470
|
|
508
471
|
Parameters
|
@@ -510,10 +473,10 @@ def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
|
|
510
473
|
lst : list
|
511
474
|
Must contain hashable items, or hashable objects that are returned via `callable`
|
512
475
|
called on each item.
|
513
|
-
distinguish_singular : bool
|
476
|
+
distinguish_singular : bool
|
514
477
|
If True, items that are not repeated will have a repeat index of 0, and items that
|
515
478
|
are repeated will have repeat indices starting from 1.
|
516
|
-
item_callable : callable
|
479
|
+
item_callable : callable
|
517
480
|
If specified, comparisons are made on the output of this callable on each item.
|
518
481
|
|
519
482
|
Returns
|
@@ -523,16 +486,16 @@ def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
|
|
523
486
|
|
524
487
|
"""
|
525
488
|
|
526
|
-
idx = {}
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
489
|
+
idx: dict[Any, list[int]] = {}
|
490
|
+
if item_callable:
|
491
|
+
for i_idx, item in enumerate(lst):
|
492
|
+
idx.setdefault(item_callable(item), []).append(i_idx)
|
493
|
+
else:
|
494
|
+
for i_idx, item in enumerate(lst):
|
495
|
+
idx.setdefault(item, []).append(i_idx)
|
533
496
|
|
534
|
-
rep_idx = [
|
535
|
-
for
|
497
|
+
rep_idx = [0] * len(lst)
|
498
|
+
for v in idx.values():
|
536
499
|
start = len(v) > 1 if distinguish_singular else 0
|
537
500
|
for i_idx, i in enumerate(v, start):
|
538
501
|
rep_idx[i] = i_idx
|
@@ -540,7 +503,7 @@ def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
|
|
540
503
|
return rep_idx
|
541
504
|
|
542
505
|
|
543
|
-
def get_process_stamp():
|
506
|
+
def get_process_stamp() -> str:
|
544
507
|
"""
|
545
508
|
Return a globally unique string identifying this process.
|
546
509
|
|
@@ -555,15 +518,17 @@ def get_process_stamp():
|
|
555
518
|
)
|
556
519
|
|
557
520
|
|
558
|
-
|
521
|
+
_ANSI_ESCAPE_RE = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
|
522
|
+
|
523
|
+
|
524
|
+
def remove_ansi_escape_sequences(string: str) -> str:
|
559
525
|
"""
|
560
526
|
Strip ANSI terminal escape codes from a string.
|
561
527
|
"""
|
562
|
-
|
563
|
-
return ansi_escape.sub("", string)
|
528
|
+
return _ANSI_ESCAPE_RE.sub("", string)
|
564
529
|
|
565
530
|
|
566
|
-
def get_md5_hash(obj):
|
531
|
+
def get_md5_hash(obj) -> str:
|
567
532
|
"""
|
568
533
|
Compute the MD5 hash of an object.
|
569
534
|
This is the hash of the JSON of the object (with sorted keys) as a hex string.
|
@@ -572,7 +537,9 @@ def get_md5_hash(obj):
|
|
572
537
|
return hashlib.md5(json_str.encode("utf-8")).hexdigest()
|
573
538
|
|
574
539
|
|
575
|
-
def get_nested_indices(
|
540
|
+
def get_nested_indices(
|
541
|
+
idx: int, size: int, nest_levels: int, raise_on_rollover: bool = False
|
542
|
+
) -> list[int]:
|
576
543
|
"""Generate the set of nested indices of length `n` that correspond to a global
|
577
544
|
`idx`.
|
578
545
|
|
@@ -618,34 +585,34 @@ def get_nested_indices(idx, size, nest_levels, raise_on_rollover=False):
|
|
618
585
|
return [(idx // (size ** (nest_levels - (i + 1)))) % size for i in range(nest_levels)]
|
619
586
|
|
620
587
|
|
621
|
-
def ensure_in(item, lst) -> int:
|
588
|
+
def ensure_in(item: T, lst: list[T]) -> int:
|
622
589
|
"""Get the index of an item in a list and append the item if it is not in the
|
623
590
|
list."""
|
624
591
|
# TODO: add tests
|
625
592
|
try:
|
626
|
-
|
593
|
+
return lst.index(item)
|
627
594
|
except ValueError:
|
628
595
|
lst.append(item)
|
629
|
-
|
630
|
-
return idx
|
596
|
+
return len(lst) - 1
|
631
597
|
|
632
598
|
|
633
|
-
def list_to_dict(
|
599
|
+
def list_to_dict(
|
600
|
+
lst: Sequence[Mapping[T, T2]], exclude: Iterable[T] | None = None
|
601
|
+
) -> dict[T, list[T2]]:
|
634
602
|
"""
|
635
603
|
Convert a list of dicts to a dict of lists.
|
636
604
|
"""
|
637
|
-
#
|
638
|
-
|
639
|
-
dct = {k: [] for k in lst[0]
|
640
|
-
for
|
641
|
-
for k, v in
|
642
|
-
if k not in
|
605
|
+
# TODO: test
|
606
|
+
exc = frozenset(exclude or ())
|
607
|
+
dct: dict[T, list[T2]] = {k: [] for k in lst[0] if k not in exc}
|
608
|
+
for d in lst:
|
609
|
+
for k, v in d.items():
|
610
|
+
if k not in exc:
|
643
611
|
dct[k].append(v)
|
644
|
-
|
645
612
|
return dct
|
646
613
|
|
647
614
|
|
648
|
-
def bisect_slice(selection: slice, len_A: int):
|
615
|
+
def bisect_slice(selection: slice, len_A: int) -> tuple[slice, slice]:
|
649
616
|
"""Given two sequences (the first of which of known length), get the two slices that
|
650
617
|
are equivalent to a given slice if the two sequences were combined."""
|
651
618
|
|
@@ -660,32 +627,33 @@ def bisect_slice(selection: slice, len_A: int):
|
|
660
627
|
B_stop = B_start
|
661
628
|
else:
|
662
629
|
B_stop = selection.stop - len_A
|
663
|
-
B_idx = (B_start, B_stop, selection.step)
|
664
|
-
A_slice = slice(*A_idx)
|
665
|
-
B_slice = slice(*B_idx)
|
666
630
|
|
667
|
-
return
|
631
|
+
return slice(*A_idx), slice(B_start, B_stop, selection.step)
|
668
632
|
|
669
633
|
|
670
|
-
def replace_items(lst, start, end, repl):
|
634
|
+
def replace_items(lst: list[T], start: int, end: int, repl: list[T]) -> list[T]:
|
671
635
|
"""Replaced a range of items in a list with items in another list."""
|
672
|
-
|
636
|
+
# Convert to actual indices for our safety checks; handles end-relative addressing
|
637
|
+
real_start, real_end, _ = slice(start, end).indices(len(lst))
|
638
|
+
if real_end <= real_start:
|
673
639
|
raise ValueError(
|
674
640
|
f"`end` ({end}) must be greater than or equal to `start` ({start})."
|
675
641
|
)
|
676
|
-
if
|
642
|
+
if real_start >= len(lst):
|
677
643
|
raise ValueError(f"`start` ({start}) must be less than length ({len(lst)}).")
|
678
|
-
if
|
644
|
+
if real_end > len(lst):
|
679
645
|
raise ValueError(
|
680
646
|
f"`end` ({end}) must be less than or equal to length ({len(lst)})."
|
681
647
|
)
|
682
648
|
|
683
|
-
|
684
|
-
|
685
|
-
return
|
649
|
+
lst = list(lst)
|
650
|
+
lst[start:end] = repl
|
651
|
+
return lst
|
686
652
|
|
687
653
|
|
688
|
-
def flatten(
|
654
|
+
def flatten(
|
655
|
+
lst: list[int] | list[list[int]] | list[list[list[int]]],
|
656
|
+
) -> tuple[list[int], tuple[list[int], ...]]:
|
689
657
|
"""Flatten an arbitrarily (but of uniform depth) nested list and return shape
|
690
658
|
information to enable un-flattening.
|
691
659
|
|
@@ -700,116 +668,173 @@ def flatten(lst):
|
|
700
668
|
|
701
669
|
"""
|
702
670
|
|
703
|
-
def _flatten(
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
671
|
+
def _flatten(
|
672
|
+
lst: list[int] | list[list[int]] | list[list[list[int]]], depth=0
|
673
|
+
) -> list[int]:
|
674
|
+
out: list[int] = []
|
675
|
+
for item in lst:
|
676
|
+
if isinstance(item, list):
|
677
|
+
out.extend(_flatten(item, depth + 1))
|
678
|
+
all_lens[depth].append(len(item))
|
709
679
|
else:
|
710
|
-
out.append(
|
680
|
+
out.append(item)
|
711
681
|
return out
|
712
682
|
|
713
|
-
def _get_max_depth(lst):
|
714
|
-
|
683
|
+
def _get_max_depth(lst: list[int] | list[list[int]] | list[list[list[int]]]) -> int:
|
684
|
+
val: Any = lst
|
715
685
|
max_depth = 0
|
716
|
-
while isinstance(
|
686
|
+
while isinstance(val, list):
|
717
687
|
max_depth += 1
|
718
688
|
try:
|
719
|
-
|
689
|
+
val = val[0]
|
720
690
|
except IndexError:
|
721
691
|
# empty list, assume this is max depth
|
722
692
|
break
|
723
693
|
return max_depth
|
724
694
|
|
725
695
|
max_depth = _get_max_depth(lst) - 1
|
726
|
-
all_lens = tuple([] for _ in range(max_depth))
|
696
|
+
all_lens: tuple[list[int], ...] = tuple([] for _ in range(max_depth))
|
727
697
|
|
728
698
|
return _flatten(lst), all_lens
|
729
699
|
|
730
700
|
|
731
|
-
def reshape(lst, lens):
|
701
|
+
def reshape(lst: Sequence[T], lens: Sequence[Sequence[int]]) -> list[TList[T]]:
|
732
702
|
"""
|
733
703
|
Reverse the destructuring of the :py:func:`flatten` function.
|
734
704
|
"""
|
735
705
|
|
736
|
-
def _reshape(lst, lens):
|
737
|
-
lens_acc = [0
|
738
|
-
|
739
|
-
return lst_rs
|
706
|
+
def _reshape(lst: list[T2], lens: Sequence[int]) -> list[list[T2]]:
|
707
|
+
lens_acc = [0, *accumulate(lens)]
|
708
|
+
return [lst[lens_acc[idx] : lens_acc[idx + 1]] for idx in range(len(lens))]
|
740
709
|
|
710
|
+
result: list[TList[T]] = list(lst)
|
741
711
|
for lens_i in lens[::-1]:
|
742
|
-
|
712
|
+
result = cast("list[TList[T]]", _reshape(result, lens_i))
|
743
713
|
|
744
|
-
return
|
714
|
+
return result
|
715
|
+
|
716
|
+
|
717
|
+
@overload
|
718
|
+
def remap(
|
719
|
+
lst: list[int], mapping_func: Callable[[Sequence[int]], Sequence[T]]
|
720
|
+
) -> list[T]:
|
721
|
+
...
|
722
|
+
|
723
|
+
|
724
|
+
@overload
|
725
|
+
def remap(
|
726
|
+
lst: list[list[int]], mapping_func: Callable[[Sequence[int]], Sequence[T]]
|
727
|
+
) -> list[list[T]]:
|
728
|
+
...
|
729
|
+
|
730
|
+
|
731
|
+
@overload
|
732
|
+
def remap(
|
733
|
+
lst: list[list[list[int]]], mapping_func: Callable[[Sequence[int]], Sequence[T]]
|
734
|
+
) -> list[list[list[T]]]:
|
735
|
+
...
|
736
|
+
|
737
|
+
|
738
|
+
def remap(lst, mapping_func):
|
739
|
+
"""
|
740
|
+
Apply a mapping to a structure of lists with ints (typically indices) as leaves to
|
741
|
+
get a structure of lists with some objects as leaves.
|
742
|
+
|
743
|
+
Parameters
|
744
|
+
----------
|
745
|
+
lst: list[int] | list[list[int]] | list[list[list[int]]]
|
746
|
+
The structure to remap.
|
747
|
+
mapping_func: Callable[[Sequence[int]], Sequence[T]]
|
748
|
+
The mapping function from sequences of ints to sequences of objects.
|
749
|
+
|
750
|
+
Returns
|
751
|
+
-------
|
752
|
+
list[T] | list[list[T]] | list[list[list[T]]]
|
753
|
+
Nested list structure in same form as input, with leaves remapped.
|
754
|
+
"""
|
755
|
+
x, y = flatten(lst)
|
756
|
+
return reshape(mapping_func(x), y)
|
757
|
+
|
758
|
+
|
759
|
+
_FSSPEC_URL_RE = re.compile(r"(?:[a-z0-9]+:{1,2})+\/\/")
|
745
760
|
|
746
761
|
|
747
762
|
def is_fsspec_url(url: str) -> bool:
|
748
763
|
"""
|
749
764
|
Test if a URL appears to be one that can be understood by fsspec.
|
750
765
|
"""
|
751
|
-
return bool(
|
766
|
+
return bool(_FSSPEC_URL_RE.match(url))
|
752
767
|
|
753
768
|
|
754
769
|
class JSONLikeDirSnapShot(DirectorySnapshot):
|
755
770
|
"""
|
756
771
|
Overridden DirectorySnapshot from watchdog to allow saving and loading from JSON.
|
757
|
-
"""
|
758
772
|
|
759
|
-
|
760
|
-
|
773
|
+
Parameters
|
774
|
+
----------
|
775
|
+
root_path: str
|
776
|
+
Where to take the snapshot based at.
|
777
|
+
data: dict[str, list]
|
778
|
+
Serialised snapshot to reload from.
|
779
|
+
See :py:meth:`to_json_like`.
|
780
|
+
"""
|
761
781
|
|
762
|
-
|
763
|
-
|
764
|
-
root_path: str
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
782
|
+
def __init__(
|
783
|
+
self,
|
784
|
+
root_path: str | None = None,
|
785
|
+
data: dict[str, list] | None = None,
|
786
|
+
use_strings: bool = False,
|
787
|
+
):
|
788
|
+
"""
|
789
|
+
Create an empty snapshot or load from JSON-like data.
|
769
790
|
"""
|
770
791
|
|
771
792
|
#: Where to take the snapshot based at.
|
772
793
|
self.root_path = root_path
|
773
|
-
self._stat_info = {}
|
774
|
-
self._inode_to_path = {}
|
794
|
+
self._stat_info: dict[str, os.stat_result] = {}
|
795
|
+
self._inode_to_path: dict[tuple, str] = {}
|
775
796
|
|
776
797
|
if data:
|
777
|
-
|
798
|
+
assert root_path
|
799
|
+
for name, item in data.items():
|
778
800
|
# add root path
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
self.
|
801
|
+
full_name = str(PurePath(root_path) / PurePath(name))
|
802
|
+
item = [int(i) for i in item] if use_strings else item
|
803
|
+
stat_dat, inode_key = item[:-2], item[-2:]
|
804
|
+
self._stat_info[full_name] = os.stat_result(stat_dat)
|
805
|
+
self._inode_to_path[tuple(inode_key)] = full_name
|
783
806
|
|
784
|
-
def take(self, *args, **kwargs):
|
807
|
+
def take(self, *args, **kwargs) -> None:
|
785
808
|
"""Take the snapshot."""
|
786
809
|
super().__init__(*args, **kwargs)
|
787
810
|
|
788
|
-
def to_json_like(self):
|
811
|
+
def to_json_like(self, use_strings: bool = False) -> dict[str, Any]:
|
789
812
|
"""Export to a dict that is JSON-compatible and can be later reloaded.
|
790
813
|
|
791
814
|
The last two integers in `data` for each path are the keys in
|
792
815
|
`self._inode_to_path`.
|
793
816
|
|
794
817
|
"""
|
795
|
-
|
796
818
|
# first key is the root path:
|
797
|
-
root_path = next(iter(self._stat_info
|
819
|
+
root_path = next(iter(self._stat_info))
|
798
820
|
|
799
821
|
# store efficiently:
|
800
822
|
inode_invert = {v: k for k, v in self._inode_to_path.items()}
|
801
|
-
data = {
|
802
|
-
|
803
|
-
|
804
|
-
|
823
|
+
data: dict[str, list] = {
|
824
|
+
str(PurePath(k).relative_to(root_path)): [
|
825
|
+
str(i) if use_strings else i for i in [*v, *inode_invert[k]]
|
826
|
+
]
|
827
|
+
for k, v in self._stat_info.items()
|
828
|
+
}
|
805
829
|
|
806
830
|
return {
|
807
831
|
"root_path": root_path,
|
808
832
|
"data": data,
|
833
|
+
"use_strings": use_strings,
|
809
834
|
}
|
810
835
|
|
811
836
|
|
812
|
-
def open_file(filename):
|
837
|
+
def open_file(filename: str | Path):
|
813
838
|
"""Open a file or directory using the default system application."""
|
814
839
|
if sys.platform == "win32":
|
815
840
|
os.startfile(filename)
|
@@ -818,59 +843,76 @@ def open_file(filename):
|
|
818
843
|
subprocess.call([opener, filename])
|
819
844
|
|
820
845
|
|
821
|
-
|
846
|
+
@overload
|
847
|
+
def get_enum_by_name_or_val(enum_cls: type[E], key: None) -> None:
|
848
|
+
...
|
849
|
+
|
850
|
+
|
851
|
+
@overload
|
852
|
+
def get_enum_by_name_or_val(enum_cls: type[E], key: str | int | float | E) -> E:
|
853
|
+
...
|
854
|
+
|
855
|
+
|
856
|
+
def get_enum_by_name_or_val(
|
857
|
+
enum_cls: type[E], key: str | int | float | E | None
|
858
|
+
) -> E | None:
|
822
859
|
"""Retrieve an enum by name or value, assuming uppercase names and integer values."""
|
823
|
-
err = f"Unknown enum key or value {key!r} for class {enum_cls!r}"
|
824
860
|
if key is None or isinstance(key, enum_cls):
|
825
861
|
return key
|
826
862
|
elif isinstance(key, (int, float)):
|
827
863
|
return enum_cls(int(key)) # retrieve by value
|
828
864
|
elif isinstance(key, str):
|
829
865
|
try:
|
830
|
-
return getattr(enum_cls, key.upper()) # retrieve by name
|
866
|
+
return cast("E", getattr(enum_cls, key.upper())) # retrieve by name
|
831
867
|
except AttributeError:
|
832
|
-
|
833
|
-
|
834
|
-
|
868
|
+
pass
|
869
|
+
raise ValueError(f"Unknown enum key or value {key!r} for class {enum_cls!r}")
|
870
|
+
|
871
|
+
|
872
|
+
_PARAM_SPLIT_RE = re.compile(r"((?:\w|\.)+)(?:\[(\w+)\])?")
|
835
873
|
|
836
874
|
|
837
|
-
def split_param_label(param_path: str) ->
|
875
|
+
def split_param_label(param_path: str) -> tuple[str, str] | tuple[None, None]:
|
838
876
|
"""Split a parameter path into the path and the label, if present."""
|
839
|
-
|
840
|
-
|
841
|
-
|
877
|
+
if match := _PARAM_SPLIT_RE.match(param_path):
|
878
|
+
return match[1], match[2]
|
879
|
+
else:
|
880
|
+
return None, None
|
842
881
|
|
843
882
|
|
844
|
-
def process_string_nodes(data, str_processor):
|
883
|
+
def process_string_nodes(data: T, str_processor: Callable[[str], str]) -> T:
|
845
884
|
"""Walk through a nested data structure and process string nodes using a provided
|
846
885
|
callable."""
|
847
886
|
|
848
887
|
if isinstance(data, dict):
|
849
|
-
|
850
|
-
|
888
|
+
return cast(
|
889
|
+
"T", {k: process_string_nodes(v, str_processor) for k, v in data.items()}
|
890
|
+
)
|
851
891
|
|
852
|
-
elif isinstance(data, (list, tuple, set)):
|
853
|
-
_data =
|
892
|
+
elif isinstance(data, (list, tuple, set, frozenset)):
|
893
|
+
_data = (process_string_nodes(i, str_processor) for i in data)
|
854
894
|
if isinstance(data, tuple):
|
855
|
-
|
895
|
+
return cast("T", tuple(_data))
|
856
896
|
elif isinstance(data, set):
|
857
|
-
|
897
|
+
return cast("T", set(_data))
|
898
|
+
elif isinstance(data, frozenset):
|
899
|
+
return cast("T", frozenset(_data))
|
858
900
|
else:
|
859
|
-
|
901
|
+
return cast("T", list(_data))
|
860
902
|
|
861
903
|
elif isinstance(data, str):
|
862
|
-
|
904
|
+
return cast("T", str_processor(data))
|
863
905
|
|
864
906
|
return data
|
865
907
|
|
866
908
|
|
867
909
|
def linspace_rect(
|
868
|
-
start:
|
869
|
-
stop:
|
870
|
-
num:
|
871
|
-
include:
|
910
|
+
start: Sequence[float],
|
911
|
+
stop: Sequence[float],
|
912
|
+
num: Sequence[int],
|
913
|
+
include: Sequence[str] | None = None,
|
872
914
|
**kwargs,
|
873
|
-
):
|
915
|
+
) -> NDArray:
|
874
916
|
"""Generate a linear space around a rectangle.
|
875
917
|
|
876
918
|
Parameters
|
@@ -892,19 +934,18 @@ def linspace_rect(
|
|
892
934
|
|
893
935
|
"""
|
894
936
|
|
895
|
-
if num[0]
|
937
|
+
if num[0] <= 1 or num[1] <= 1:
|
896
938
|
raise ValueError("Both values in `num` must be greater than 1.")
|
897
939
|
|
898
|
-
if
|
899
|
-
include = ["top", "right", "bottom", "left"]
|
940
|
+
inc = set(include) if include else {"top", "right", "bottom", "left"}
|
900
941
|
|
901
942
|
c0_range = np.linspace(start=start[0], stop=stop[0], num=num[0], **kwargs)
|
902
943
|
c1_range_all = np.linspace(start=start[1], stop=stop[1], num=num[1], **kwargs)
|
903
944
|
|
904
945
|
c1_range = c1_range_all
|
905
|
-
if "bottom" in
|
946
|
+
if "bottom" in inc:
|
906
947
|
c1_range = c1_range[1:]
|
907
|
-
if "top" in
|
948
|
+
if "top" in inc:
|
908
949
|
c1_range = c1_range[:-1]
|
909
950
|
|
910
951
|
c0_range_c1_start = np.vstack([c0_range, np.repeat(start[1], num[0])])
|
@@ -914,20 +955,21 @@ def linspace_rect(
|
|
914
955
|
c1_range_c0_stop = np.vstack([np.repeat(c0_range[-1], len(c1_range)), c1_range])
|
915
956
|
|
916
957
|
stacked = []
|
917
|
-
if "top" in
|
958
|
+
if "top" in inc:
|
918
959
|
stacked.append(c0_range_c1_stop)
|
919
|
-
if "right" in
|
960
|
+
if "right" in inc:
|
920
961
|
stacked.append(c1_range_c0_stop)
|
921
|
-
if "bottom" in
|
962
|
+
if "bottom" in inc:
|
922
963
|
stacked.append(c0_range_c1_start)
|
923
|
-
if "left" in
|
964
|
+
if "left" in inc:
|
924
965
|
stacked.append(c1_range_c0_start)
|
925
966
|
|
926
|
-
|
927
|
-
return rect
|
967
|
+
return np.hstack(stacked)
|
928
968
|
|
929
969
|
|
930
|
-
def dict_values_process_flat(
|
970
|
+
def dict_values_process_flat(
|
971
|
+
d: Mapping[T, T2 | list[T2]], callable: Callable[[list[T2]], list[T3]]
|
972
|
+
) -> Mapping[T, T3 | list[T3]]:
|
931
973
|
"""
|
932
974
|
Return a copy of a dict, where the values are processed by a callable that is to
|
933
975
|
be called only once, and where the values may be single items or lists of items.
|
@@ -939,32 +981,34 @@ def dict_values_process_flat(d, callable):
|
|
939
981
|
{'a': 1, 'b': [2, 3], 'c': 6}
|
940
982
|
|
941
983
|
"""
|
942
|
-
flat = [] # values of `d`, flattened
|
943
|
-
is_multi
|
984
|
+
flat: list[T2] = [] # values of `d`, flattened
|
985
|
+
is_multi: list[
|
986
|
+
tuple[bool, int]
|
987
|
+
] = [] # whether a list, and the number of items to process
|
944
988
|
for i in d.values():
|
945
|
-
|
946
|
-
flat.extend(i)
|
989
|
+
if isinstance(i, list):
|
990
|
+
flat.extend(cast("list[T2]", i))
|
947
991
|
is_multi.append((True, len(i)))
|
948
|
-
|
949
|
-
flat.append(i)
|
992
|
+
else:
|
993
|
+
flat.append(cast("T2", i))
|
950
994
|
is_multi.append((False, 1))
|
951
995
|
|
952
996
|
processed = callable(flat)
|
953
997
|
|
954
|
-
out = {}
|
955
|
-
for idx_i, (m, k) in enumerate(zip(is_multi, d
|
956
|
-
|
998
|
+
out: dict[T, T3 | list[T3]] = {}
|
999
|
+
for idx_i, (m, k) in enumerate(zip(is_multi, d)):
|
957
1000
|
start_idx = sum(i[1] for i in is_multi[:idx_i])
|
958
1001
|
end_idx = start_idx + m[1]
|
959
|
-
proc_idx_k = processed[
|
1002
|
+
proc_idx_k = processed[start_idx:end_idx]
|
960
1003
|
if not m[0]:
|
961
|
-
|
962
|
-
|
1004
|
+
out[k] = proc_idx_k[0]
|
1005
|
+
else:
|
1006
|
+
out[k] = proc_idx_k
|
963
1007
|
|
964
1008
|
return out
|
965
1009
|
|
966
1010
|
|
967
|
-
def nth_key(dct, n):
|
1011
|
+
def nth_key(dct: Iterable[T], n: int) -> T:
|
968
1012
|
"""
|
969
1013
|
Given a dict in some order, get the n'th key of that dict.
|
970
1014
|
"""
|
@@ -973,8 +1017,140 @@ def nth_key(dct, n):
|
|
973
1017
|
return next(it)
|
974
1018
|
|
975
1019
|
|
976
|
-
def nth_value(dct, n):
|
1020
|
+
def nth_value(dct: dict[Any, T], n: int) -> T:
|
977
1021
|
"""
|
978
1022
|
Given a dict in some order, get the n'th value of that dict.
|
979
1023
|
"""
|
980
1024
|
return dct[nth_key(dct, n)]
|
1025
|
+
|
1026
|
+
|
1027
|
+
def normalise_timestamp(timestamp: datetime) -> datetime:
|
1028
|
+
"""
|
1029
|
+
Force a timestamp to have UTC as its timezone,
|
1030
|
+
then convert to use the local timezone.
|
1031
|
+
"""
|
1032
|
+
return timestamp.replace(tzinfo=timezone.utc).astimezone()
|
1033
|
+
|
1034
|
+
|
1035
|
+
def parse_timestamp(timestamp: str | datetime, ts_fmt: str) -> datetime:
|
1036
|
+
"""
|
1037
|
+
Standard timestamp parsing.
|
1038
|
+
Ensures that timestamps are internally all UTC.
|
1039
|
+
"""
|
1040
|
+
return normalise_timestamp(
|
1041
|
+
timestamp
|
1042
|
+
if isinstance(timestamp, datetime)
|
1043
|
+
else datetime.strptime(timestamp, ts_fmt)
|
1044
|
+
)
|
1045
|
+
|
1046
|
+
|
1047
|
+
def current_timestamp() -> datetime:
|
1048
|
+
"""
|
1049
|
+
Get a UTC timestamp for the current time
|
1050
|
+
"""
|
1051
|
+
return datetime.now(timezone.utc)
|
1052
|
+
|
1053
|
+
|
1054
|
+
def timedelta_format(td: timedelta) -> str:
|
1055
|
+
"""
|
1056
|
+
Convert time delta to string in standard form.
|
1057
|
+
"""
|
1058
|
+
days, seconds = td.days, td.seconds
|
1059
|
+
hours = seconds // (60 * 60)
|
1060
|
+
seconds -= hours * (60 * 60)
|
1061
|
+
minutes = seconds // 60
|
1062
|
+
seconds -= minutes * 60
|
1063
|
+
return f"{days}-{hours:02}:{minutes:02}:{seconds:02}"
|
1064
|
+
|
1065
|
+
|
1066
|
+
_TD_RE = re.compile(r"(\d+)-(\d+):(\d+):(\d+)")
|
1067
|
+
|
1068
|
+
|
1069
|
+
def timedelta_parse(td_str: str) -> timedelta:
|
1070
|
+
"""
|
1071
|
+
Parse a string in standard form as a time delta.
|
1072
|
+
"""
|
1073
|
+
if not (m := _TD_RE.fullmatch(td_str)):
|
1074
|
+
raise ValueError("not a supported timedelta form")
|
1075
|
+
days, hours, mins, secs = map(int, m.groups())
|
1076
|
+
return timedelta(days=days, hours=hours, minutes=mins, seconds=secs)
|
1077
|
+
|
1078
|
+
|
1079
|
+
def open_text_resource(package: ModuleType | str, resource: str) -> IO[str]:
|
1080
|
+
"""
|
1081
|
+
Open a file in a package.
|
1082
|
+
"""
|
1083
|
+
try:
|
1084
|
+
return resources.files(package).joinpath(resource).open("r")
|
1085
|
+
except AttributeError:
|
1086
|
+
# < python 3.9; `resource.open_text` deprecated since 3.11
|
1087
|
+
return resources.open_text(package, resource)
|
1088
|
+
|
1089
|
+
|
1090
|
+
def get_file_context(
|
1091
|
+
package: ModuleType | str, src: str | None = None
|
1092
|
+
) -> AbstractContextManager[Path]:
|
1093
|
+
"""
|
1094
|
+
Find a file or directory in a package.
|
1095
|
+
"""
|
1096
|
+
try:
|
1097
|
+
files = resources.files(package)
|
1098
|
+
return resources.as_file(files.joinpath(src) if src else files)
|
1099
|
+
# raises ModuleNotFoundError
|
1100
|
+
except AttributeError:
|
1101
|
+
# < python 3.9
|
1102
|
+
return resources.path(package, src or "")
|
1103
|
+
|
1104
|
+
|
1105
|
+
@contextlib.contextmanager
|
1106
|
+
def redirect_std_to_file(
|
1107
|
+
file,
|
1108
|
+
mode: Literal["w", "a"] = "a",
|
1109
|
+
ignore: Callable[[BaseException], Literal[True] | int] | None = None,
|
1110
|
+
) -> Iterator[None]:
|
1111
|
+
"""Temporarily redirect both stdout and stderr to a file, and if an exception is
|
1112
|
+
raised, catch it, print the traceback to that file, and exit.
|
1113
|
+
|
1114
|
+
File creation is deferred until an actual write is required.
|
1115
|
+
|
1116
|
+
Parameters
|
1117
|
+
----------
|
1118
|
+
ignore
|
1119
|
+
Callable to test if a given exception should be ignored. If an exception is
|
1120
|
+
not ignored, its traceback will be printed to `file` and the program will
|
1121
|
+
exit with exit code 1. The callable should accept one parameter, the
|
1122
|
+
exception, and should return True if that exception should be ignored, or
|
1123
|
+
an integer representing the exit code to exit the program with if that
|
1124
|
+
exception should not be ignored. By default, no exceptions are ignored.
|
1125
|
+
|
1126
|
+
"""
|
1127
|
+
ignore = ignore or (lambda _: 1)
|
1128
|
+
with DeferredFileWriter(file, mode=mode) as fp:
|
1129
|
+
with contextlib.redirect_stdout(fp):
|
1130
|
+
with contextlib.redirect_stderr(fp):
|
1131
|
+
try:
|
1132
|
+
yield
|
1133
|
+
except BaseException as exc:
|
1134
|
+
ignore_ret = ignore(exc)
|
1135
|
+
if ignore_ret is not True:
|
1136
|
+
traceback.print_exc()
|
1137
|
+
sys.exit(ignore_ret)
|
1138
|
+
|
1139
|
+
|
1140
|
+
async def to_thread(func, /, *args, **kwargs):
|
1141
|
+
"""Copied from https://github.com/python/cpython/blob/4b4227b907a262446b9d276c274feda2590a4e6e/Lib/asyncio/threads.py
|
1142
|
+
to support Python 3.8, which does not have `asyncio.to_thread`.
|
1143
|
+
|
1144
|
+
Asynchronously run function *func* in a separate thread.
|
1145
|
+
|
1146
|
+
Any *args and **kwargs supplied for this function are directly passed
|
1147
|
+
to *func*. Also, the current :class:`contextvars.Context` is propagated,
|
1148
|
+
allowing context variables from the main thread to be accessed in the
|
1149
|
+
separate thread.
|
1150
|
+
|
1151
|
+
Return a coroutine that can be awaited to get the eventual result of *func*.
|
1152
|
+
"""
|
1153
|
+
loop = events.get_running_loop()
|
1154
|
+
ctx = contextvars.copy_context()
|
1155
|
+
func_call = functools.partial(ctx.run, func, *args, **kwargs)
|
1156
|
+
return await loop.run_in_executor(None, func_call)
|