hpcflow-new2 0.2.0a189__py3-none-any.whl → 0.2.0a199__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. hpcflow/__pyinstaller/hook-hpcflow.py +9 -6
  2. hpcflow/_version.py +1 -1
  3. hpcflow/app.py +1 -0
  4. hpcflow/data/scripts/bad_script.py +2 -0
  5. hpcflow/data/scripts/do_nothing.py +2 -0
  6. hpcflow/data/scripts/env_specifier_test/input_file_generator_pass_env_spec.py +4 -0
  7. hpcflow/data/scripts/env_specifier_test/main_script_test_pass_env_spec.py +8 -0
  8. hpcflow/data/scripts/env_specifier_test/output_file_parser_pass_env_spec.py +4 -0
  9. hpcflow/data/scripts/env_specifier_test/v1/input_file_generator_basic.py +4 -0
  10. hpcflow/data/scripts/env_specifier_test/v1/main_script_test_direct_in_direct_out.py +7 -0
  11. hpcflow/data/scripts/env_specifier_test/v1/output_file_parser_basic.py +4 -0
  12. hpcflow/data/scripts/env_specifier_test/v2/main_script_test_direct_in_direct_out.py +7 -0
  13. hpcflow/data/scripts/input_file_generator_basic.py +3 -0
  14. hpcflow/data/scripts/input_file_generator_basic_FAIL.py +3 -0
  15. hpcflow/data/scripts/input_file_generator_test_stdout_stderr.py +8 -0
  16. hpcflow/data/scripts/main_script_test_direct_in.py +3 -0
  17. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2.py +6 -0
  18. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed.py +6 -0
  19. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed_group.py +7 -0
  20. hpcflow/data/scripts/main_script_test_direct_in_direct_out_3.py +6 -0
  21. hpcflow/data/scripts/main_script_test_direct_in_group_direct_out_3.py +6 -0
  22. hpcflow/data/scripts/main_script_test_direct_in_group_one_fail_direct_out_3.py +6 -0
  23. hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +1 -1
  24. hpcflow/data/scripts/main_script_test_hdf5_in_obj_2.py +12 -0
  25. hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +1 -1
  26. hpcflow/data/scripts/main_script_test_json_out_FAIL.py +3 -0
  27. hpcflow/data/scripts/main_script_test_shell_env_vars.py +12 -0
  28. hpcflow/data/scripts/main_script_test_std_out_std_err.py +6 -0
  29. hpcflow/data/scripts/output_file_parser_basic.py +3 -0
  30. hpcflow/data/scripts/output_file_parser_basic_FAIL.py +7 -0
  31. hpcflow/data/scripts/output_file_parser_test_stdout_stderr.py +8 -0
  32. hpcflow/data/scripts/script_exit_test.py +5 -0
  33. hpcflow/data/template_components/environments.yaml +1 -1
  34. hpcflow/sdk/__init__.py +26 -15
  35. hpcflow/sdk/app.py +2192 -768
  36. hpcflow/sdk/cli.py +506 -296
  37. hpcflow/sdk/cli_common.py +105 -7
  38. hpcflow/sdk/config/__init__.py +1 -1
  39. hpcflow/sdk/config/callbacks.py +115 -43
  40. hpcflow/sdk/config/cli.py +126 -103
  41. hpcflow/sdk/config/config.py +674 -318
  42. hpcflow/sdk/config/config_file.py +131 -95
  43. hpcflow/sdk/config/errors.py +125 -84
  44. hpcflow/sdk/config/types.py +148 -0
  45. hpcflow/sdk/core/__init__.py +25 -1
  46. hpcflow/sdk/core/actions.py +1771 -1059
  47. hpcflow/sdk/core/app_aware.py +24 -0
  48. hpcflow/sdk/core/cache.py +139 -79
  49. hpcflow/sdk/core/command_files.py +263 -287
  50. hpcflow/sdk/core/commands.py +145 -112
  51. hpcflow/sdk/core/element.py +828 -535
  52. hpcflow/sdk/core/enums.py +192 -0
  53. hpcflow/sdk/core/environment.py +74 -93
  54. hpcflow/sdk/core/errors.py +455 -52
  55. hpcflow/sdk/core/execute.py +207 -0
  56. hpcflow/sdk/core/json_like.py +540 -272
  57. hpcflow/sdk/core/loop.py +751 -347
  58. hpcflow/sdk/core/loop_cache.py +164 -47
  59. hpcflow/sdk/core/object_list.py +370 -207
  60. hpcflow/sdk/core/parameters.py +1100 -627
  61. hpcflow/sdk/core/rule.py +59 -41
  62. hpcflow/sdk/core/run_dir_files.py +21 -37
  63. hpcflow/sdk/core/skip_reason.py +7 -0
  64. hpcflow/sdk/core/task.py +1649 -1339
  65. hpcflow/sdk/core/task_schema.py +308 -196
  66. hpcflow/sdk/core/test_utils.py +191 -114
  67. hpcflow/sdk/core/types.py +440 -0
  68. hpcflow/sdk/core/utils.py +485 -309
  69. hpcflow/sdk/core/validation.py +82 -9
  70. hpcflow/sdk/core/workflow.py +2544 -1178
  71. hpcflow/sdk/core/zarr_io.py +98 -137
  72. hpcflow/sdk/data/workflow_spec_schema.yaml +2 -0
  73. hpcflow/sdk/demo/cli.py +53 -33
  74. hpcflow/sdk/helper/cli.py +18 -15
  75. hpcflow/sdk/helper/helper.py +75 -63
  76. hpcflow/sdk/helper/watcher.py +61 -28
  77. hpcflow/sdk/log.py +122 -71
  78. hpcflow/sdk/persistence/__init__.py +8 -31
  79. hpcflow/sdk/persistence/base.py +1360 -606
  80. hpcflow/sdk/persistence/defaults.py +6 -0
  81. hpcflow/sdk/persistence/discovery.py +38 -0
  82. hpcflow/sdk/persistence/json.py +568 -188
  83. hpcflow/sdk/persistence/pending.py +382 -179
  84. hpcflow/sdk/persistence/store_resource.py +39 -23
  85. hpcflow/sdk/persistence/types.py +318 -0
  86. hpcflow/sdk/persistence/utils.py +14 -11
  87. hpcflow/sdk/persistence/zarr.py +1337 -433
  88. hpcflow/sdk/runtime.py +44 -41
  89. hpcflow/sdk/submission/{jobscript_info.py → enums.py} +39 -12
  90. hpcflow/sdk/submission/jobscript.py +1651 -692
  91. hpcflow/sdk/submission/schedulers/__init__.py +167 -39
  92. hpcflow/sdk/submission/schedulers/direct.py +121 -81
  93. hpcflow/sdk/submission/schedulers/sge.py +170 -129
  94. hpcflow/sdk/submission/schedulers/slurm.py +291 -268
  95. hpcflow/sdk/submission/schedulers/utils.py +12 -2
  96. hpcflow/sdk/submission/shells/__init__.py +14 -15
  97. hpcflow/sdk/submission/shells/base.py +150 -29
  98. hpcflow/sdk/submission/shells/bash.py +283 -173
  99. hpcflow/sdk/submission/shells/os_version.py +31 -30
  100. hpcflow/sdk/submission/shells/powershell.py +228 -170
  101. hpcflow/sdk/submission/submission.py +1014 -335
  102. hpcflow/sdk/submission/types.py +140 -0
  103. hpcflow/sdk/typing.py +182 -12
  104. hpcflow/sdk/utils/arrays.py +71 -0
  105. hpcflow/sdk/utils/deferred_file.py +55 -0
  106. hpcflow/sdk/utils/hashing.py +16 -0
  107. hpcflow/sdk/utils/patches.py +12 -0
  108. hpcflow/sdk/utils/strings.py +33 -0
  109. hpcflow/tests/api/test_api.py +32 -0
  110. hpcflow/tests/conftest.py +27 -6
  111. hpcflow/tests/data/multi_path_sequences.yaml +29 -0
  112. hpcflow/tests/data/workflow_test_run_abort.yaml +34 -35
  113. hpcflow/tests/schedulers/sge/test_sge_submission.py +36 -0
  114. hpcflow/tests/schedulers/slurm/test_slurm_submission.py +5 -2
  115. hpcflow/tests/scripts/test_input_file_generators.py +282 -0
  116. hpcflow/tests/scripts/test_main_scripts.py +866 -85
  117. hpcflow/tests/scripts/test_non_snippet_script.py +46 -0
  118. hpcflow/tests/scripts/test_ouput_file_parsers.py +353 -0
  119. hpcflow/tests/shells/wsl/test_wsl_submission.py +12 -4
  120. hpcflow/tests/unit/test_action.py +262 -75
  121. hpcflow/tests/unit/test_action_rule.py +9 -4
  122. hpcflow/tests/unit/test_app.py +33 -6
  123. hpcflow/tests/unit/test_cache.py +46 -0
  124. hpcflow/tests/unit/test_cli.py +134 -1
  125. hpcflow/tests/unit/test_command.py +71 -54
  126. hpcflow/tests/unit/test_config.py +142 -16
  127. hpcflow/tests/unit/test_config_file.py +21 -18
  128. hpcflow/tests/unit/test_element.py +58 -62
  129. hpcflow/tests/unit/test_element_iteration.py +50 -1
  130. hpcflow/tests/unit/test_element_set.py +29 -19
  131. hpcflow/tests/unit/test_group.py +4 -2
  132. hpcflow/tests/unit/test_input_source.py +116 -93
  133. hpcflow/tests/unit/test_input_value.py +29 -24
  134. hpcflow/tests/unit/test_jobscript_unit.py +757 -0
  135. hpcflow/tests/unit/test_json_like.py +44 -35
  136. hpcflow/tests/unit/test_loop.py +1396 -84
  137. hpcflow/tests/unit/test_meta_task.py +325 -0
  138. hpcflow/tests/unit/test_multi_path_sequences.py +229 -0
  139. hpcflow/tests/unit/test_object_list.py +17 -12
  140. hpcflow/tests/unit/test_parameter.py +29 -7
  141. hpcflow/tests/unit/test_persistence.py +237 -42
  142. hpcflow/tests/unit/test_resources.py +20 -18
  143. hpcflow/tests/unit/test_run.py +117 -6
  144. hpcflow/tests/unit/test_run_directories.py +29 -0
  145. hpcflow/tests/unit/test_runtime.py +2 -1
  146. hpcflow/tests/unit/test_schema_input.py +23 -15
  147. hpcflow/tests/unit/test_shell.py +23 -2
  148. hpcflow/tests/unit/test_slurm.py +8 -7
  149. hpcflow/tests/unit/test_submission.py +38 -89
  150. hpcflow/tests/unit/test_task.py +352 -247
  151. hpcflow/tests/unit/test_task_schema.py +33 -20
  152. hpcflow/tests/unit/test_utils.py +9 -11
  153. hpcflow/tests/unit/test_value_sequence.py +15 -12
  154. hpcflow/tests/unit/test_workflow.py +114 -83
  155. hpcflow/tests/unit/test_workflow_template.py +0 -1
  156. hpcflow/tests/unit/utils/test_arrays.py +40 -0
  157. hpcflow/tests/unit/utils/test_deferred_file_writer.py +34 -0
  158. hpcflow/tests/unit/utils/test_hashing.py +65 -0
  159. hpcflow/tests/unit/utils/test_patches.py +5 -0
  160. hpcflow/tests/unit/utils/test_redirect_std.py +50 -0
  161. hpcflow/tests/workflows/__init__.py +0 -0
  162. hpcflow/tests/workflows/test_directory_structure.py +31 -0
  163. hpcflow/tests/workflows/test_jobscript.py +334 -1
  164. hpcflow/tests/workflows/test_run_status.py +198 -0
  165. hpcflow/tests/workflows/test_skip_downstream.py +696 -0
  166. hpcflow/tests/workflows/test_submission.py +140 -0
  167. hpcflow/tests/workflows/test_workflows.py +160 -15
  168. hpcflow/tests/workflows/test_zip.py +18 -0
  169. hpcflow/viz_demo.ipynb +6587 -3
  170. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/METADATA +8 -4
  171. hpcflow_new2-0.2.0a199.dist-info/RECORD +221 -0
  172. hpcflow/sdk/core/parallel.py +0 -21
  173. hpcflow_new2-0.2.0a189.dist-info/RECORD +0 -158
  174. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/LICENSE +0 -0
  175. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/WHEEL +0 -0
  176. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/entry_points.txt +0 -0
hpcflow/sdk/core/utils.py CHANGED
@@ -2,12 +2,17 @@
2
2
  Miscellaneous utilities.
3
3
  """
4
4
 
5
+ from __future__ import annotations
6
+ from collections import Counter
7
+ from asyncio import events
8
+ import contextvars
9
+ import contextlib
5
10
  import copy
6
11
  import enum
7
- from functools import wraps
8
- import contextlib
12
+ import functools
9
13
  import hashlib
10
14
  from itertools import accumulate, islice
15
+ from importlib import resources
11
16
  import json
12
17
  import keyword
13
18
  import os
@@ -17,10 +22,11 @@ import re
17
22
  import socket
18
23
  import string
19
24
  import subprocess
20
- from datetime import datetime, timezone
25
+ from datetime import datetime, timedelta, timezone
21
26
  import sys
22
- from typing import Dict, Optional, Tuple, Type, Union, List
23
- import fsspec
27
+ import traceback
28
+ from typing import Literal, cast, overload, TypeVar, TYPE_CHECKING
29
+ import fsspec # type: ignore
24
30
  import numpy as np
25
31
 
26
32
  from ruamel.yaml import YAML
@@ -28,27 +34,30 @@ from watchdog.utils.dirsnapshot import DirectorySnapshot
28
34
 
29
35
  from hpcflow.sdk.core.errors import (
30
36
  ContainerKeyError,
31
- FromSpecMissingObjectError,
32
37
  InvalidIdentifier,
33
38
  MissingVariableSubstitutionError,
34
39
  )
35
40
  from hpcflow.sdk.log import TimeIt
36
- from hpcflow.sdk.typing import PathLike
37
-
38
-
39
- def load_config(func):
40
- """API function decorator to ensure the configuration has been loaded, and load if not."""
41
-
42
- @wraps(func)
43
- def wrapper(self, *args, **kwargs):
44
- if not self.is_config_loaded:
45
- self.load_config()
46
- return func(self, *args, **kwargs)
47
-
48
- return wrapper
49
-
50
-
51
- def make_workflow_id():
41
+ from hpcflow.sdk.utils.deferred_file import DeferredFileWriter
42
+
43
+ if TYPE_CHECKING:
44
+ from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence
45
+ from contextlib import AbstractContextManager
46
+ from types import ModuleType
47
+ from typing import Any, IO, Iterator
48
+ from typing_extensions import TypeAlias
49
+ from numpy.typing import NDArray
50
+ from ..typing import PathLike
51
+
52
+ T = TypeVar("T")
53
+ T2 = TypeVar("T2")
54
+ T3 = TypeVar("T3")
55
+ TList: TypeAlias = "T | list[TList]"
56
+ TD = TypeVar("TD", bound="Mapping[str, Any]")
57
+ E = TypeVar("E", bound=enum.Enum)
58
+
59
+
60
+ def make_workflow_id() -> str:
52
61
  """
53
62
  Generate a random ID for a workflow.
54
63
  """
@@ -57,14 +66,14 @@ def make_workflow_id():
57
66
  return "".join(random.choices(chars, k=length))
58
67
 
59
68
 
60
- def get_time_stamp():
69
+ def get_time_stamp() -> str:
61
70
  """
62
71
  Get the current time in standard string form.
63
72
  """
64
73
  return datetime.now(timezone.utc).astimezone().strftime("%Y.%m.%d_%H:%M:%S_%z")
65
74
 
66
75
 
67
- def get_duplicate_items(lst):
76
+ def get_duplicate_items(lst: Iterable[T]) -> list[T]:
68
77
  """Get a list of all items in an iterable that appear more than once, assuming items
69
78
  are hashable.
70
79
 
@@ -77,11 +86,10 @@ def get_duplicate_items(lst):
77
86
  []
78
87
 
79
88
  >>> get_duplicate_items([1, 2, 3, 3, 3, 2])
80
- [2, 3, 2]
89
+ [2, 3]
81
90
 
82
91
  """
83
- seen = []
84
- return list(set(x for x in lst if x in seen or seen.append(x)))
92
+ return [x for x, y in Counter(lst).items() if y > 1]
85
93
 
86
94
 
87
95
  def check_valid_py_identifier(name: str) -> str:
@@ -105,29 +113,42 @@ def check_valid_py_identifier(name: str) -> str:
105
113
  - `Loop.name`
106
114
 
107
115
  """
108
- exc = InvalidIdentifier(f"Invalid string for identifier: {name!r}")
109
116
  try:
110
117
  trial_name = name[1:].replace("_", "") # "internal" underscores are allowed
111
118
  except TypeError:
112
- raise exc
119
+ raise InvalidIdentifier(name) from None
120
+ except KeyError as e:
121
+ raise KeyError(f"unexpected name type {name}") from e
113
122
  if (
114
123
  not name
115
124
  or not (name[0].isalpha() and ((trial_name[1:] or "a").isalnum()))
116
125
  or keyword.iskeyword(name)
117
126
  ):
118
- raise exc
127
+ raise InvalidIdentifier(name)
119
128
 
120
129
  return name
121
130
 
122
131
 
123
- def group_by_dict_key_values(lst, *keys):
132
+ @overload
133
+ def group_by_dict_key_values( # type: ignore[overload-overlap]
134
+ lst: list[dict[T, T2]], key: T
135
+ ) -> list[list[dict[T, T2]]]:
136
+ ...
137
+
138
+
139
+ @overload
140
+ def group_by_dict_key_values(lst: list[TD], key: str) -> list[list[TD]]:
141
+ ...
142
+
143
+
144
+ def group_by_dict_key_values(lst: list, key):
124
145
  """Group a list of dicts according to specified equivalent key-values.
125
146
 
126
147
  Parameters
127
148
  ----------
128
149
  lst : list of dict
129
150
  The list of dicts to group together.
130
- keys : tuple
151
+ key : key value
131
152
  Dicts that have identical values for all of these keys will be grouped together
132
153
  into a sub-list.
133
154
 
@@ -146,10 +167,10 @@ def group_by_dict_key_values(lst, *keys):
146
167
  for lst_item in lst[1:]:
147
168
  for group_idx, group in enumerate(grouped):
148
169
  try:
149
- is_vals_equal = all(lst_item[k] == group[0][k] for k in keys)
170
+ is_vals_equal = lst_item[key] == group[0][key]
150
171
 
151
172
  except KeyError:
152
- # dicts that do not have all `keys` will be in their own group:
173
+ # dicts that do not have the `key` will be in their own group:
153
174
  is_vals_equal = False
154
175
 
155
176
  if is_vals_equal:
@@ -162,7 +183,7 @@ def group_by_dict_key_values(lst, *keys):
162
183
  return grouped
163
184
 
164
185
 
165
- def swap_nested_dict_keys(dct, inner_key):
186
+ def swap_nested_dict_keys(dct: dict[T, dict[T2, T3]], inner_key: T2):
166
187
  """Return a copy where top-level keys have been swapped with a second-level inner key.
167
188
 
168
189
  Examples:
@@ -181,16 +202,35 @@ def swap_nested_dict_keys(dct, inner_key):
181
202
  }
182
203
 
183
204
  """
184
- out = {}
205
+ out: dict[T3, dict[T, dict[T2, T3]]] = {}
185
206
  for k, v in copy.deepcopy(dct or {}).items():
186
- inner_val = v.pop(inner_key)
187
- if inner_val not in out:
188
- out[inner_val] = {}
189
- out[inner_val][k] = v
207
+ out.setdefault(v.pop(inner_key), {})[k] = v
190
208
  return out
191
209
 
192
210
 
193
- def get_in_container(cont, path, cast_indices=False, allow_getattr=False):
211
+ def _ensure_int(path_comp: Any, cur_data: Any, cast_indices: bool) -> int:
212
+ """
213
+ Helper for get_in_container() and set_in_container()
214
+ """
215
+ if isinstance(path_comp, int):
216
+ return path_comp
217
+ if not cast_indices:
218
+ raise TypeError(
219
+ f"Path component {path_comp!r} must be an integer index "
220
+ f"since data is a sequence: {cur_data!r}."
221
+ )
222
+ try:
223
+ return int(path_comp)
224
+ except (TypeError, ValueError) as e:
225
+ raise TypeError(
226
+ f"Path component {path_comp!r} must be an integer index "
227
+ f"since data is a sequence: {cur_data!r}."
228
+ ) from e
229
+
230
+
231
+ def get_in_container(
232
+ cont, path: Sequence, cast_indices: bool = False, allow_getattr: bool = False
233
+ ):
194
234
  """
195
235
  Follow a path (sequence of indices of appropriate type) into a container to obtain
196
236
  a "leaf" value. Containers can be lists, tuples, dicts,
@@ -203,24 +243,12 @@ def get_in_container(cont, path, cast_indices=False, allow_getattr=False):
203
243
  )
204
244
  for idx, path_comp in enumerate(path):
205
245
  if isinstance(cur_data, (list, tuple)):
206
- if not isinstance(path_comp, int):
207
- msg = (
208
- f"Path component {path_comp!r} must be an integer index "
209
- f"since data is a sequence: {cur_data!r}."
210
- )
211
- if cast_indices:
212
- try:
213
- path_comp = int(path_comp)
214
- except TypeError:
215
- raise TypeError(msg)
216
- else:
217
- raise TypeError(msg)
218
- cur_data = cur_data[path_comp]
219
- elif isinstance(cur_data, dict):
246
+ cur_data = cur_data[_ensure_int(path_comp, cur_data, cast_indices)]
247
+ elif isinstance(cur_data, dict) or hasattr(cur_data, "__getitem__"):
220
248
  try:
221
249
  cur_data = cur_data[path_comp]
222
250
  except KeyError:
223
- raise ContainerKeyError(path=path[: idx + 1])
251
+ raise ContainerKeyError(path=cast("list[str]", path[: idx + 1]))
224
252
  elif allow_getattr:
225
253
  try:
226
254
  cur_data = getattr(cur_data, path_comp)
@@ -235,7 +263,9 @@ def get_in_container(cont, path, cast_indices=False, allow_getattr=False):
235
263
  return cur_data
236
264
 
237
265
 
238
- def set_in_container(cont, path, value, ensure_path=False, cast_indices=False):
266
+ def set_in_container(
267
+ cont, path: Sequence, value, ensure_path=False, cast_indices=False
268
+ ) -> None:
239
269
  """
240
270
  Follow a path (sequence of indices of appropriate type) into a container to update
241
271
  a "leaf" value. Containers can be lists, tuples or dicts.
@@ -258,22 +288,11 @@ def set_in_container(cont, path, value, ensure_path=False, cast_indices=False):
258
288
  sub_data = get_in_container(cont, path[:-1], cast_indices=cast_indices)
259
289
  path_comp = path[-1]
260
290
  if isinstance(sub_data, (list, tuple)):
261
- if not isinstance(path_comp, int):
262
- msg = (
263
- f"Path component {path_comp!r} must be an integer index "
264
- f"since data is a sequence: {sub_data!r}."
265
- )
266
- if cast_indices:
267
- try:
268
- path_comp = int(path_comp)
269
- except ValueError:
270
- raise ValueError(msg)
271
- else:
272
- raise ValueError(msg)
291
+ path_comp = _ensure_int(path_comp, sub_data, cast_indices)
273
292
  sub_data[path_comp] = value
274
293
 
275
294
 
276
- def get_relative_path(path1, path2):
295
+ def get_relative_path(path1: Sequence[T], path2: Sequence[T]) -> Sequence[T]:
277
296
  """Get relative path components between two paths.
278
297
 
279
298
  Parameters
@@ -308,79 +327,39 @@ def get_relative_path(path1, path2):
308
327
  """
309
328
 
310
329
  len_path2 = len(path2)
311
- msg = f"{path1!r} is not in the subpath of {path2!r}."
312
-
313
- if len(path1) < len_path2:
314
- raise ValueError(msg)
315
-
316
- for i, j in zip(path1[:len_path2], path2):
317
- if i != j:
318
- raise ValueError(msg)
330
+ if len(path1) < len_path2 or any(i != j for i, j in zip(path1[:len_path2], path2)):
331
+ raise ValueError(f"{path1!r} is not in the subpath of {path2!r}.")
319
332
 
320
333
  return path1[len_path2:]
321
334
 
322
335
 
323
- def search_dir_files_by_regex(pattern, group=0, directory=".") -> List[str]:
336
+ def search_dir_files_by_regex(
337
+ pattern: str | re.Pattern[str], directory: str | os.PathLike = "."
338
+ ) -> list[str]:
324
339
  """Search recursively for files in a directory by a regex pattern and return matching
325
340
  file paths, relative to the given directory."""
326
- vals = []
327
- for i in Path(directory).rglob("*"):
328
- match = re.search(pattern, i.name)
329
- if match:
330
- match_groups = match.groups()
331
- if match_groups:
332
- match = match_groups[group]
333
- vals.append(str(i.relative_to(directory)))
334
- return vals
341
+ dir_ = Path(directory)
342
+ return [
343
+ str(entry.relative_to(dir_))
344
+ for entry in dir_.rglob("*")
345
+ if re.search(pattern, entry.name)
346
+ ]
335
347
 
336
348
 
337
- class classproperty(object):
338
- """
339
- Simple class property decorator.
340
- """
341
-
342
- def __init__(self, f):
343
- self.f = f
344
-
345
- def __get__(self, obj, owner):
346
- return self.f(owner)
347
-
348
-
349
- class PrettyPrinter(object):
349
+ class PrettyPrinter:
350
350
  """
351
351
  A class that produces a nice readable version of itself with ``str()``.
352
352
  Intended to be subclassed.
353
353
  """
354
354
 
355
- def __str__(self):
355
+ def __str__(self) -> str:
356
356
  lines = [self.__class__.__name__ + ":"]
357
357
  for key, val in vars(self).items():
358
- lines += f"{key}: {val}".split("\n")
358
+ lines.extend(f"{key}: {val}".split("\n"))
359
359
  return "\n ".join(lines)
360
360
 
361
361
 
362
- class Singleton(type):
363
- """
364
- Metaclass that enforces that only one instance can exist of the classes to which it
365
- is applied.
366
- """
367
-
368
- _instances = {}
369
-
370
- def __call__(cls, *args, **kwargs):
371
- if cls not in cls._instances:
372
- cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
373
- elif args or kwargs:
374
- # if existing instance, make the point that new arguments don't do anything!
375
- raise ValueError(
376
- f"{cls.__name__!r} is a singleton class and cannot be instantiated with new "
377
- f"arguments. The positional arguments {args!r} and keyword-arguments "
378
- f"{kwargs!r} have been ignored."
379
- )
380
- return cls._instances[cls]
381
-
382
-
383
- def capitalise_first_letter(chars):
362
+ def capitalise_first_letter(chars: str) -> str:
384
363
  """
385
364
  Convert the first character of a string to upper case (if that makes sense).
386
365
  The rest of the string is unchanged.
@@ -388,30 +367,11 @@ def capitalise_first_letter(chars):
388
367
  return chars[0].upper() + chars[1:]
389
368
 
390
369
 
391
- def check_in_object_list(spec_name, spec_pos=1, obj_list_pos=2):
392
- """Decorator factory for the various `from_spec` class methods that have attributes
393
- that should be replaced by an object from an object list."""
394
-
395
- def decorator(func):
396
- @wraps(func)
397
- def wrap(*args, **kwargs):
398
- spec = args[spec_pos]
399
- obj_list = args[obj_list_pos]
400
- if spec[spec_name] not in obj_list:
401
- cls_name = args[0].__name__
402
- raise FromSpecMissingObjectError(
403
- f"A {spec_name!r} object required to instantiate the {cls_name!r} "
404
- f"object is missing."
405
- )
406
- return func(*args, **kwargs)
407
-
408
- return wrap
409
-
410
- return decorator
370
+ _STRING_VARS_RE = re.compile(r"\<\<var:(.*?)(?:\[(.*)\])?\>\>")
411
371
 
412
372
 
413
373
  @TimeIt.decorator
414
- def substitute_string_vars(string, variables: Dict[str, str] = None):
374
+ def substitute_string_vars(string: str, variables: dict[str, str]):
415
375
  """
416
376
  Scan ``string`` and substitute sequences like ``<<var:ABC>>`` with the value
417
377
  looked up in the supplied dictionary (with ``ABC`` as the key).
@@ -424,15 +384,14 @@ def substitute_string_vars(string, variables: Dict[str, str] = None):
424
384
  >>> substitute_string_vars("abc <<var:def>> ghi", {"def": "123"})
425
385
  "abc 123 def"
426
386
  """
427
- variables = variables or {}
428
387
 
429
- def var_repl(match_obj):
430
- kwargs = {}
431
- var_name, kwargs_str = match_obj.groups()
388
+ def var_repl(match_obj: re.Match[str]) -> str:
389
+ kwargs: dict[str, str] = {}
390
+ var_name: str = match_obj[1]
391
+ kwargs_str: str | None = match_obj[2]
432
392
  if kwargs_str:
433
- kwargs_lst = kwargs_str.split(",")
434
- for i in kwargs_lst:
435
- k, v = i.strip().split("=")
393
+ for i in kwargs_str.split(","):
394
+ k, v = i.split("=")
436
395
  kwargs[k.strip()] = v.strip()
437
396
  try:
438
397
  out = str(variables[var_name])
@@ -444,65 +403,69 @@ def substitute_string_vars(string, variables: Dict[str, str] = None):
444
403
  f"variable {var_name!r}."
445
404
  )
446
405
  else:
447
- raise MissingVariableSubstitutionError(
448
- f"The variable {var_name!r} referenced in the string does not match "
449
- f"any of the provided variables: {list(variables)!r}."
450
- )
406
+ raise MissingVariableSubstitutionError(var_name, variables)
451
407
  return out
452
408
 
453
- new_str = re.sub(
454
- pattern=r"\<\<var:(.*?)(?:\[(.*)\])?\>\>",
409
+ return _STRING_VARS_RE.sub(
455
410
  repl=var_repl,
456
411
  string=string,
457
412
  )
458
- return new_str
459
413
 
460
414
 
461
415
  @TimeIt.decorator
462
- def read_YAML_str(yaml_str, typ="safe", variables: Dict[str, str] = None):
416
+ def read_YAML_str(
417
+ yaml_str: str, typ="safe", variables: dict[str, str] | None = None
418
+ ) -> Any:
463
419
  """Load a YAML string. This will produce basic objects."""
464
- if variables is not False and "<<var:" in yaml_str:
420
+ if variables is not None and "<<var:" in yaml_str:
465
421
  yaml_str = substitute_string_vars(yaml_str, variables=variables)
466
422
  yaml = YAML(typ=typ)
467
423
  return yaml.load(yaml_str)
468
424
 
469
425
 
470
426
  @TimeIt.decorator
471
- def read_YAML_file(path: PathLike, typ="safe", variables: Dict[str, str] = None):
427
+ def read_YAML_file(
428
+ path: PathLike, typ="safe", variables: dict[str, str] | None = None
429
+ ) -> Any:
472
430
  """Load a YAML file. This will produce basic objects."""
473
431
  with fsspec.open(path, "rt") as f:
474
- yaml_str = f.read()
432
+ yaml_str: str = f.read()
475
433
  return read_YAML_str(yaml_str, typ=typ, variables=variables)
476
434
 
477
435
 
478
- def write_YAML_file(obj, path: PathLike, typ="safe"):
436
+ def write_YAML_file(obj, path: str | Path, typ: str = "safe") -> None:
479
437
  """Write a basic object to a YAML file."""
480
438
  yaml = YAML(typ=typ)
481
439
  with Path(path).open("wt") as fp:
482
440
  yaml.dump(obj, fp)
483
441
 
484
442
 
485
- def read_JSON_string(json_str: str, variables: Dict[str, str] = None):
443
+ def read_JSON_string(json_str: str, variables: dict[str, str] | None = None) -> Any:
486
444
  """Load a JSON string. This will produce basic objects."""
487
- if variables is not False and "<<var:" in json_str:
445
+ if variables is not None and "<<var:" in json_str:
488
446
  json_str = substitute_string_vars(json_str, variables=variables)
489
447
  return json.loads(json_str)
490
448
 
491
449
 
492
- def read_JSON_file(path, variables: Dict[str, str] = None):
450
+ def read_JSON_file(path, variables: dict[str, str] | None = None) -> Any:
493
451
  """Load a JSON file. This will produce basic objects."""
494
452
  with fsspec.open(path, "rt") as f:
495
- json_str = f.read()
453
+ json_str: str = f.read()
496
454
  return read_JSON_string(json_str, variables=variables)
497
455
 
498
456
 
499
- def write_JSON_file(obj, path: PathLike):
457
+ def write_JSON_file(obj, path: str | Path) -> None:
500
458
  """Write a basic object to a JSON file."""
501
459
  with Path(path).open("wt") as fp:
502
460
  json.dump(obj, fp)
503
461
 
504
462
 
505
- def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
463
+ def get_item_repeat_index(
464
+ lst: Sequence[T],
465
+ *,
466
+ distinguish_singular: bool = False,
467
+ item_callable: Callable[[T], Hashable] | None = None,
468
+ ):
506
469
  """Get the repeat index for each item in a list.
507
470
 
508
471
  Parameters
@@ -510,10 +473,10 @@ def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
510
473
  lst : list
511
474
  Must contain hashable items, or hashable objects that are returned via `callable`
512
475
  called on each item.
513
- distinguish_singular : bool, optional
476
+ distinguish_singular : bool
514
477
  If True, items that are not repeated will have a repeat index of 0, and items that
515
478
  are repeated will have repeat indices starting from 1.
516
- item_callable : callable, optional
479
+ item_callable : callable
517
480
  If specified, comparisons are made on the output of this callable on each item.
518
481
 
519
482
  Returns
@@ -523,16 +486,16 @@ def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
523
486
 
524
487
  """
525
488
 
526
- idx = {}
527
- for i_idx, item in enumerate(lst):
528
- if item_callable:
529
- item = item_callable(item)
530
- if item not in idx:
531
- idx[item] = []
532
- idx[item] += [i_idx]
489
+ idx: dict[Any, list[int]] = {}
490
+ if item_callable:
491
+ for i_idx, item in enumerate(lst):
492
+ idx.setdefault(item_callable(item), []).append(i_idx)
493
+ else:
494
+ for i_idx, item in enumerate(lst):
495
+ idx.setdefault(item, []).append(i_idx)
533
496
 
534
- rep_idx = [None] * len(lst)
535
- for k, v in idx.items():
497
+ rep_idx = [0] * len(lst)
498
+ for v in idx.values():
536
499
  start = len(v) > 1 if distinguish_singular else 0
537
500
  for i_idx, i in enumerate(v, start):
538
501
  rep_idx[i] = i_idx
@@ -540,7 +503,7 @@ def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
540
503
  return rep_idx
541
504
 
542
505
 
543
- def get_process_stamp():
506
+ def get_process_stamp() -> str:
544
507
  """
545
508
  Return a globally unique string identifying this process.
546
509
 
@@ -555,15 +518,17 @@ def get_process_stamp():
555
518
  )
556
519
 
557
520
 
558
- def remove_ansi_escape_sequences(string):
521
+ _ANSI_ESCAPE_RE = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
522
+
523
+
524
+ def remove_ansi_escape_sequences(string: str) -> str:
559
525
  """
560
526
  Strip ANSI terminal escape codes from a string.
561
527
  """
562
- ansi_escape = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
563
- return ansi_escape.sub("", string)
528
+ return _ANSI_ESCAPE_RE.sub("", string)
564
529
 
565
530
 
566
- def get_md5_hash(obj):
531
+ def get_md5_hash(obj) -> str:
567
532
  """
568
533
  Compute the MD5 hash of an object.
569
534
  This is the hash of the JSON of the object (with sorted keys) as a hex string.
@@ -572,7 +537,9 @@ def get_md5_hash(obj):
572
537
  return hashlib.md5(json_str.encode("utf-8")).hexdigest()
573
538
 
574
539
 
575
- def get_nested_indices(idx, size, nest_levels, raise_on_rollover=False):
540
+ def get_nested_indices(
541
+ idx: int, size: int, nest_levels: int, raise_on_rollover: bool = False
542
+ ) -> list[int]:
576
543
  """Generate the set of nested indices of length `n` that correspond to a global
577
544
  `idx`.
578
545
 
@@ -618,34 +585,34 @@ def get_nested_indices(idx, size, nest_levels, raise_on_rollover=False):
618
585
  return [(idx // (size ** (nest_levels - (i + 1)))) % size for i in range(nest_levels)]
619
586
 
620
587
 
621
- def ensure_in(item, lst) -> int:
588
+ def ensure_in(item: T, lst: list[T]) -> int:
622
589
  """Get the index of an item in a list and append the item if it is not in the
623
590
  list."""
624
591
  # TODO: add tests
625
592
  try:
626
- idx = lst.index(item)
593
+ return lst.index(item)
627
594
  except ValueError:
628
595
  lst.append(item)
629
- idx = len(lst) - 1
630
- return idx
596
+ return len(lst) - 1
631
597
 
632
598
 
633
- def list_to_dict(lst, exclude=None):
599
+ def list_to_dict(
600
+ lst: Sequence[Mapping[T, T2]], exclude: Iterable[T] | None = None
601
+ ) -> dict[T, list[T2]]:
634
602
  """
635
603
  Convert a list of dicts to a dict of lists.
636
604
  """
637
- # TODD: test
638
- exclude = exclude or []
639
- dct = {k: [] for k in lst[0].keys() if k not in exclude}
640
- for i in lst:
641
- for k, v in i.items():
642
- if k not in exclude:
605
+ # TODO: test
606
+ exc = frozenset(exclude or ())
607
+ dct: dict[T, list[T2]] = {k: [] for k in lst[0] if k not in exc}
608
+ for d in lst:
609
+ for k, v in d.items():
610
+ if k not in exc:
643
611
  dct[k].append(v)
644
-
645
612
  return dct
646
613
 
647
614
 
648
- def bisect_slice(selection: slice, len_A: int):
615
+ def bisect_slice(selection: slice, len_A: int) -> tuple[slice, slice]:
649
616
  """Given two sequences (the first of which of known length), get the two slices that
650
617
  are equivalent to a given slice if the two sequences were combined."""
651
618
 
@@ -660,32 +627,33 @@ def bisect_slice(selection: slice, len_A: int):
660
627
  B_stop = B_start
661
628
  else:
662
629
  B_stop = selection.stop - len_A
663
- B_idx = (B_start, B_stop, selection.step)
664
- A_slice = slice(*A_idx)
665
- B_slice = slice(*B_idx)
666
630
 
667
- return A_slice, B_slice
631
+ return slice(*A_idx), slice(B_start, B_stop, selection.step)
668
632
 
669
633
 
670
- def replace_items(lst, start, end, repl):
634
+ def replace_items(lst: list[T], start: int, end: int, repl: list[T]) -> list[T]:
671
635
  """Replaced a range of items in a list with items in another list."""
672
- if end <= start:
636
+ # Convert to actual indices for our safety checks; handles end-relative addressing
637
+ real_start, real_end, _ = slice(start, end).indices(len(lst))
638
+ if real_end <= real_start:
673
639
  raise ValueError(
674
640
  f"`end` ({end}) must be greater than or equal to `start` ({start})."
675
641
  )
676
- if start >= len(lst):
642
+ if real_start >= len(lst):
677
643
  raise ValueError(f"`start` ({start}) must be less than length ({len(lst)}).")
678
- if end > len(lst):
644
+ if real_end > len(lst):
679
645
  raise ValueError(
680
646
  f"`end` ({end}) must be less than or equal to length ({len(lst)})."
681
647
  )
682
648
 
683
- lst_a = lst[:start]
684
- lst_b = lst[end:]
685
- return lst_a + repl + lst_b
649
+ lst = list(lst)
650
+ lst[start:end] = repl
651
+ return lst
686
652
 
687
653
 
688
- def flatten(lst):
654
+ def flatten(
655
+ lst: list[int] | list[list[int]] | list[list[list[int]]],
656
+ ) -> tuple[list[int], tuple[list[int], ...]]:
689
657
  """Flatten an arbitrarily (but of uniform depth) nested list and return shape
690
658
  information to enable un-flattening.
691
659
 
@@ -700,116 +668,173 @@ def flatten(lst):
700
668
 
701
669
  """
702
670
 
703
- def _flatten(lst, _depth=0):
704
- out = []
705
- for i in lst:
706
- if isinstance(i, list):
707
- out += _flatten(i, _depth=_depth + 1)
708
- all_lens[_depth].append(len(i))
671
+ def _flatten(
672
+ lst: list[int] | list[list[int]] | list[list[list[int]]], depth=0
673
+ ) -> list[int]:
674
+ out: list[int] = []
675
+ for item in lst:
676
+ if isinstance(item, list):
677
+ out.extend(_flatten(item, depth + 1))
678
+ all_lens[depth].append(len(item))
709
679
  else:
710
- out.append(i)
680
+ out.append(item)
711
681
  return out
712
682
 
713
- def _get_max_depth(lst):
714
- lst = lst[:]
683
+ def _get_max_depth(lst: list[int] | list[list[int]] | list[list[list[int]]]) -> int:
684
+ val: Any = lst
715
685
  max_depth = 0
716
- while isinstance(lst, list):
686
+ while isinstance(val, list):
717
687
  max_depth += 1
718
688
  try:
719
- lst = lst[0]
689
+ val = val[0]
720
690
  except IndexError:
721
691
  # empty list, assume this is max depth
722
692
  break
723
693
  return max_depth
724
694
 
725
695
  max_depth = _get_max_depth(lst) - 1
726
- all_lens = tuple([] for _ in range(max_depth))
696
+ all_lens: tuple[list[int], ...] = tuple([] for _ in range(max_depth))
727
697
 
728
698
  return _flatten(lst), all_lens
729
699
 
730
700
 
731
- def reshape(lst, lens):
701
+ def reshape(lst: Sequence[T], lens: Sequence[Sequence[int]]) -> list[TList[T]]:
732
702
  """
733
703
  Reverse the destructuring of the :py:func:`flatten` function.
734
704
  """
735
705
 
736
- def _reshape(lst, lens):
737
- lens_acc = [0] + list(accumulate(lens))
738
- lst_rs = [lst[lens_acc[idx] : lens_acc[idx + 1]] for idx in range(len(lens))]
739
- return lst_rs
706
+ def _reshape(lst: list[T2], lens: Sequence[int]) -> list[list[T2]]:
707
+ lens_acc = [0, *accumulate(lens)]
708
+ return [lst[lens_acc[idx] : lens_acc[idx + 1]] for idx in range(len(lens))]
740
709
 
710
+ result: list[TList[T]] = list(lst)
741
711
  for lens_i in lens[::-1]:
742
- lst = _reshape(lst, lens_i)
712
+ result = cast("list[TList[T]]", _reshape(result, lens_i))
743
713
 
744
- return lst
714
+ return result
715
+
716
+
717
+ @overload
718
+ def remap(
719
+ lst: list[int], mapping_func: Callable[[Sequence[int]], Sequence[T]]
720
+ ) -> list[T]:
721
+ ...
722
+
723
+
724
+ @overload
725
+ def remap(
726
+ lst: list[list[int]], mapping_func: Callable[[Sequence[int]], Sequence[T]]
727
+ ) -> list[list[T]]:
728
+ ...
729
+
730
+
731
+ @overload
732
+ def remap(
733
+ lst: list[list[list[int]]], mapping_func: Callable[[Sequence[int]], Sequence[T]]
734
+ ) -> list[list[list[T]]]:
735
+ ...
736
+
737
+
738
+ def remap(lst, mapping_func):
739
+ """
740
+ Apply a mapping to a structure of lists with ints (typically indices) as leaves to
741
+ get a structure of lists with some objects as leaves.
742
+
743
+ Parameters
744
+ ----------
745
+ lst: list[int] | list[list[int]] | list[list[list[int]]]
746
+ The structure to remap.
747
+ mapping_func: Callable[[Sequence[int]], Sequence[T]]
748
+ The mapping function from sequences of ints to sequences of objects.
749
+
750
+ Returns
751
+ -------
752
+ list[T] | list[list[T]] | list[list[list[T]]]
753
+ Nested list structure in same form as input, with leaves remapped.
754
+ """
755
+ x, y = flatten(lst)
756
+ return reshape(mapping_func(x), y)
757
+
758
+
759
+ _FSSPEC_URL_RE = re.compile(r"(?:[a-z0-9]+:{1,2})+\/\/")
745
760
 
746
761
 
747
762
  def is_fsspec_url(url: str) -> bool:
748
763
  """
749
764
  Test if a URL appears to be one that can be understood by fsspec.
750
765
  """
751
- return bool(re.match(r"(?:[a-z0-9]+:{1,2})+\/\/", url))
766
+ return bool(_FSSPEC_URL_RE.match(url))
752
767
 
753
768
 
754
769
  class JSONLikeDirSnapShot(DirectorySnapshot):
755
770
  """
756
771
  Overridden DirectorySnapshot from watchdog to allow saving and loading from JSON.
757
- """
758
772
 
759
- def __init__(self, root_path=None, data=None):
760
- """Create an empty snapshot or load from JSON-like data.
773
+ Parameters
774
+ ----------
775
+ root_path: str
776
+ Where to take the snapshot based at.
777
+ data: dict[str, list]
778
+ Serialised snapshot to reload from.
779
+ See :py:meth:`to_json_like`.
780
+ """
761
781
 
762
- Parameters
763
- ----------
764
- root_path: str
765
- Where to take the snapshot based at.
766
- data: dict
767
- Serialised snapshot to reload from.
768
- See :py:meth:`to_json_like`.
782
+ def __init__(
783
+ self,
784
+ root_path: str | None = None,
785
+ data: dict[str, list] | None = None,
786
+ use_strings: bool = False,
787
+ ):
788
+ """
789
+ Create an empty snapshot or load from JSON-like data.
769
790
  """
770
791
 
771
792
  #: Where to take the snapshot based at.
772
793
  self.root_path = root_path
773
- self._stat_info = {}
774
- self._inode_to_path = {}
794
+ self._stat_info: dict[str, os.stat_result] = {}
795
+ self._inode_to_path: dict[tuple, str] = {}
775
796
 
776
797
  if data:
777
- for k in list((data or {}).keys()):
798
+ assert root_path
799
+ for name, item in data.items():
778
800
  # add root path
779
- full_k = str(PurePath(root_path) / PurePath(k))
780
- stat_dat, inode_key = data[k][:-2], data[k][-2:]
781
- self._stat_info[full_k] = os.stat_result(stat_dat)
782
- self._inode_to_path[tuple(inode_key)] = full_k
801
+ full_name = str(PurePath(root_path) / PurePath(name))
802
+ item = [int(i) for i in item] if use_strings else item
803
+ stat_dat, inode_key = item[:-2], item[-2:]
804
+ self._stat_info[full_name] = os.stat_result(stat_dat)
805
+ self._inode_to_path[tuple(inode_key)] = full_name
783
806
 
784
- def take(self, *args, **kwargs):
807
+ def take(self, *args, **kwargs) -> None:
785
808
  """Take the snapshot."""
786
809
  super().__init__(*args, **kwargs)
787
810
 
788
- def to_json_like(self):
811
+ def to_json_like(self, use_strings: bool = False) -> dict[str, Any]:
789
812
  """Export to a dict that is JSON-compatible and can be later reloaded.
790
813
 
791
814
  The last two integers in `data` for each path are the keys in
792
815
  `self._inode_to_path`.
793
816
 
794
817
  """
795
-
796
818
  # first key is the root path:
797
- root_path = next(iter(self._stat_info.keys()))
819
+ root_path = next(iter(self._stat_info))
798
820
 
799
821
  # store efficiently:
800
822
  inode_invert = {v: k for k, v in self._inode_to_path.items()}
801
- data = {}
802
- for k, v in self._stat_info.items():
803
- k_rel = str(PurePath(k).relative_to(root_path))
804
- data[k_rel] = list(v) + list(inode_invert[k])
823
+ data: dict[str, list] = {
824
+ str(PurePath(k).relative_to(root_path)): [
825
+ str(i) if use_strings else i for i in [*v, *inode_invert[k]]
826
+ ]
827
+ for k, v in self._stat_info.items()
828
+ }
805
829
 
806
830
  return {
807
831
  "root_path": root_path,
808
832
  "data": data,
833
+ "use_strings": use_strings,
809
834
  }
810
835
 
811
836
 
812
- def open_file(filename):
837
+ def open_file(filename: str | Path):
813
838
  """Open a file or directory using the default system application."""
814
839
  if sys.platform == "win32":
815
840
  os.startfile(filename)
@@ -818,59 +843,76 @@ def open_file(filename):
818
843
  subprocess.call([opener, filename])
819
844
 
820
845
 
821
- def get_enum_by_name_or_val(enum_cls: Type, key: Union[str, None]) -> enum.Enum:
846
+ @overload
847
+ def get_enum_by_name_or_val(enum_cls: type[E], key: None) -> None:
848
+ ...
849
+
850
+
851
+ @overload
852
+ def get_enum_by_name_or_val(enum_cls: type[E], key: str | int | float | E) -> E:
853
+ ...
854
+
855
+
856
+ def get_enum_by_name_or_val(
857
+ enum_cls: type[E], key: str | int | float | E | None
858
+ ) -> E | None:
822
859
  """Retrieve an enum by name or value, assuming uppercase names and integer values."""
823
- err = f"Unknown enum key or value {key!r} for class {enum_cls!r}"
824
860
  if key is None or isinstance(key, enum_cls):
825
861
  return key
826
862
  elif isinstance(key, (int, float)):
827
863
  return enum_cls(int(key)) # retrieve by value
828
864
  elif isinstance(key, str):
829
865
  try:
830
- return getattr(enum_cls, key.upper()) # retrieve by name
866
+ return cast("E", getattr(enum_cls, key.upper())) # retrieve by name
831
867
  except AttributeError:
832
- raise ValueError(err)
833
- else:
834
- raise ValueError(err)
868
+ pass
869
+ raise ValueError(f"Unknown enum key or value {key!r} for class {enum_cls!r}")
870
+
871
+
872
+ _PARAM_SPLIT_RE = re.compile(r"((?:\w|\.)+)(?:\[(\w+)\])?")
835
873
 
836
874
 
837
- def split_param_label(param_path: str) -> Tuple[Union[str, None]]:
875
+ def split_param_label(param_path: str) -> tuple[str, str] | tuple[None, None]:
838
876
  """Split a parameter path into the path and the label, if present."""
839
- pattern = r"((?:\w|\.)+)(?:\[(\w+)\])?"
840
- match = re.match(pattern, param_path)
841
- return match.group(1), match.group(2)
877
+ if match := _PARAM_SPLIT_RE.match(param_path):
878
+ return match[1], match[2]
879
+ else:
880
+ return None, None
842
881
 
843
882
 
844
- def process_string_nodes(data, str_processor):
883
+ def process_string_nodes(data: T, str_processor: Callable[[str], str]) -> T:
845
884
  """Walk through a nested data structure and process string nodes using a provided
846
885
  callable."""
847
886
 
848
887
  if isinstance(data, dict):
849
- for k, v in data.items():
850
- data[k] = process_string_nodes(v, str_processor)
888
+ return cast(
889
+ "T", {k: process_string_nodes(v, str_processor) for k, v in data.items()}
890
+ )
851
891
 
852
- elif isinstance(data, (list, tuple, set)):
853
- _data = [process_string_nodes(i, str_processor) for i in data]
892
+ elif isinstance(data, (list, tuple, set, frozenset)):
893
+ _data = (process_string_nodes(i, str_processor) for i in data)
854
894
  if isinstance(data, tuple):
855
- data = tuple(_data)
895
+ return cast("T", tuple(_data))
856
896
  elif isinstance(data, set):
857
- data = set(_data)
897
+ return cast("T", set(_data))
898
+ elif isinstance(data, frozenset):
899
+ return cast("T", frozenset(_data))
858
900
  else:
859
- data = _data
901
+ return cast("T", list(_data))
860
902
 
861
903
  elif isinstance(data, str):
862
- data = str_processor(data)
904
+ return cast("T", str_processor(data))
863
905
 
864
906
  return data
865
907
 
866
908
 
867
909
  def linspace_rect(
868
- start: List[float],
869
- stop: List[float],
870
- num: List[float],
871
- include: Optional[List[str]] = None,
910
+ start: Sequence[float],
911
+ stop: Sequence[float],
912
+ num: Sequence[int],
913
+ include: Sequence[str] | None = None,
872
914
  **kwargs,
873
- ):
915
+ ) -> NDArray:
874
916
  """Generate a linear space around a rectangle.
875
917
 
876
918
  Parameters
@@ -892,19 +934,18 @@ def linspace_rect(
892
934
 
893
935
  """
894
936
 
895
- if num[0] == 1 or num[1] == 1:
937
+ if num[0] <= 1 or num[1] <= 1:
896
938
  raise ValueError("Both values in `num` must be greater than 1.")
897
939
 
898
- if not include:
899
- include = ["top", "right", "bottom", "left"]
940
+ inc = set(include) if include else {"top", "right", "bottom", "left"}
900
941
 
901
942
  c0_range = np.linspace(start=start[0], stop=stop[0], num=num[0], **kwargs)
902
943
  c1_range_all = np.linspace(start=start[1], stop=stop[1], num=num[1], **kwargs)
903
944
 
904
945
  c1_range = c1_range_all
905
- if "bottom" in include:
946
+ if "bottom" in inc:
906
947
  c1_range = c1_range[1:]
907
- if "top" in include:
948
+ if "top" in inc:
908
949
  c1_range = c1_range[:-1]
909
950
 
910
951
  c0_range_c1_start = np.vstack([c0_range, np.repeat(start[1], num[0])])
@@ -914,20 +955,21 @@ def linspace_rect(
914
955
  c1_range_c0_stop = np.vstack([np.repeat(c0_range[-1], len(c1_range)), c1_range])
915
956
 
916
957
  stacked = []
917
- if "top" in include:
958
+ if "top" in inc:
918
959
  stacked.append(c0_range_c1_stop)
919
- if "right" in include:
960
+ if "right" in inc:
920
961
  stacked.append(c1_range_c0_stop)
921
- if "bottom" in include:
962
+ if "bottom" in inc:
922
963
  stacked.append(c0_range_c1_start)
923
- if "left" in include:
964
+ if "left" in inc:
924
965
  stacked.append(c1_range_c0_start)
925
966
 
926
- rect = np.hstack(stacked)
927
- return rect
967
+ return np.hstack(stacked)
928
968
 
929
969
 
930
- def dict_values_process_flat(d, callable):
970
+ def dict_values_process_flat(
971
+ d: Mapping[T, T2 | list[T2]], callable: Callable[[list[T2]], list[T3]]
972
+ ) -> Mapping[T, T3 | list[T3]]:
931
973
  """
932
974
  Return a copy of a dict, where the values are processed by a callable that is to
933
975
  be called only once, and where the values may be single items or lists of items.
@@ -939,32 +981,34 @@ def dict_values_process_flat(d, callable):
939
981
  {'a': 1, 'b': [2, 3], 'c': 6}
940
982
 
941
983
  """
942
- flat = [] # values of `d`, flattened
943
- is_multi = [] # whether a list, and the number of items to process
984
+ flat: list[T2] = [] # values of `d`, flattened
985
+ is_multi: list[
986
+ tuple[bool, int]
987
+ ] = [] # whether a list, and the number of items to process
944
988
  for i in d.values():
945
- try:
946
- flat.extend(i)
989
+ if isinstance(i, list):
990
+ flat.extend(cast("list[T2]", i))
947
991
  is_multi.append((True, len(i)))
948
- except TypeError:
949
- flat.append(i)
992
+ else:
993
+ flat.append(cast("T2", i))
950
994
  is_multi.append((False, 1))
951
995
 
952
996
  processed = callable(flat)
953
997
 
954
- out = {}
955
- for idx_i, (m, k) in enumerate(zip(is_multi, d.keys())):
956
-
998
+ out: dict[T, T3 | list[T3]] = {}
999
+ for idx_i, (m, k) in enumerate(zip(is_multi, d)):
957
1000
  start_idx = sum(i[1] for i in is_multi[:idx_i])
958
1001
  end_idx = start_idx + m[1]
959
- proc_idx_k = processed[slice(start_idx, end_idx)]
1002
+ proc_idx_k = processed[start_idx:end_idx]
960
1003
  if not m[0]:
961
- proc_idx_k = proc_idx_k[0]
962
- out[k] = proc_idx_k
1004
+ out[k] = proc_idx_k[0]
1005
+ else:
1006
+ out[k] = proc_idx_k
963
1007
 
964
1008
  return out
965
1009
 
966
1010
 
967
- def nth_key(dct, n):
1011
+ def nth_key(dct: Iterable[T], n: int) -> T:
968
1012
  """
969
1013
  Given a dict in some order, get the n'th key of that dict.
970
1014
  """
@@ -973,8 +1017,140 @@ def nth_key(dct, n):
973
1017
  return next(it)
974
1018
 
975
1019
 
976
- def nth_value(dct, n):
1020
+ def nth_value(dct: dict[Any, T], n: int) -> T:
977
1021
  """
978
1022
  Given a dict in some order, get the n'th value of that dict.
979
1023
  """
980
1024
  return dct[nth_key(dct, n)]
1025
+
1026
+
1027
+ def normalise_timestamp(timestamp: datetime) -> datetime:
1028
+ """
1029
+ Force a timestamp to have UTC as its timezone,
1030
+ then convert to use the local timezone.
1031
+ """
1032
+ return timestamp.replace(tzinfo=timezone.utc).astimezone()
1033
+
1034
+
1035
+ def parse_timestamp(timestamp: str | datetime, ts_fmt: str) -> datetime:
1036
+ """
1037
+ Standard timestamp parsing.
1038
+ Ensures that timestamps are internally all UTC.
1039
+ """
1040
+ return normalise_timestamp(
1041
+ timestamp
1042
+ if isinstance(timestamp, datetime)
1043
+ else datetime.strptime(timestamp, ts_fmt)
1044
+ )
1045
+
1046
+
1047
+ def current_timestamp() -> datetime:
1048
+ """
1049
+ Get a UTC timestamp for the current time
1050
+ """
1051
+ return datetime.now(timezone.utc)
1052
+
1053
+
1054
+ def timedelta_format(td: timedelta) -> str:
1055
+ """
1056
+ Convert time delta to string in standard form.
1057
+ """
1058
+ days, seconds = td.days, td.seconds
1059
+ hours = seconds // (60 * 60)
1060
+ seconds -= hours * (60 * 60)
1061
+ minutes = seconds // 60
1062
+ seconds -= minutes * 60
1063
+ return f"{days}-{hours:02}:{minutes:02}:{seconds:02}"
1064
+
1065
+
1066
+ _TD_RE = re.compile(r"(\d+)-(\d+):(\d+):(\d+)")
1067
+
1068
+
1069
+ def timedelta_parse(td_str: str) -> timedelta:
1070
+ """
1071
+ Parse a string in standard form as a time delta.
1072
+ """
1073
+ if not (m := _TD_RE.fullmatch(td_str)):
1074
+ raise ValueError("not a supported timedelta form")
1075
+ days, hours, mins, secs = map(int, m.groups())
1076
+ return timedelta(days=days, hours=hours, minutes=mins, seconds=secs)
1077
+
1078
+
1079
+ def open_text_resource(package: ModuleType | str, resource: str) -> IO[str]:
1080
+ """
1081
+ Open a file in a package.
1082
+ """
1083
+ try:
1084
+ return resources.files(package).joinpath(resource).open("r")
1085
+ except AttributeError:
1086
+ # < python 3.9; `resource.open_text` deprecated since 3.11
1087
+ return resources.open_text(package, resource)
1088
+
1089
+
1090
+ def get_file_context(
1091
+ package: ModuleType | str, src: str | None = None
1092
+ ) -> AbstractContextManager[Path]:
1093
+ """
1094
+ Find a file or directory in a package.
1095
+ """
1096
+ try:
1097
+ files = resources.files(package)
1098
+ return resources.as_file(files.joinpath(src) if src else files)
1099
+ # raises ModuleNotFoundError
1100
+ except AttributeError:
1101
+ # < python 3.9
1102
+ return resources.path(package, src or "")
1103
+
1104
+
1105
+ @contextlib.contextmanager
1106
+ def redirect_std_to_file(
1107
+ file,
1108
+ mode: Literal["w", "a"] = "a",
1109
+ ignore: Callable[[BaseException], Literal[True] | int] | None = None,
1110
+ ) -> Iterator[None]:
1111
+ """Temporarily redirect both stdout and stderr to a file, and if an exception is
1112
+ raised, catch it, print the traceback to that file, and exit.
1113
+
1114
+ File creation is deferred until an actual write is required.
1115
+
1116
+ Parameters
1117
+ ----------
1118
+ ignore
1119
+ Callable to test if a given exception should be ignored. If an exception is
1120
+ not ignored, its traceback will be printed to `file` and the program will
1121
+ exit with exit code 1. The callable should accept one parameter, the
1122
+ exception, and should return True if that exception should be ignored, or
1123
+ an integer representing the exit code to exit the program with if that
1124
+ exception should not be ignored. By default, no exceptions are ignored.
1125
+
1126
+ """
1127
+ ignore = ignore or (lambda _: 1)
1128
+ with DeferredFileWriter(file, mode=mode) as fp:
1129
+ with contextlib.redirect_stdout(fp):
1130
+ with contextlib.redirect_stderr(fp):
1131
+ try:
1132
+ yield
1133
+ except BaseException as exc:
1134
+ ignore_ret = ignore(exc)
1135
+ if ignore_ret is not True:
1136
+ traceback.print_exc()
1137
+ sys.exit(ignore_ret)
1138
+
1139
+
1140
+ async def to_thread(func, /, *args, **kwargs):
1141
+ """Copied from https://github.com/python/cpython/blob/4b4227b907a262446b9d276c274feda2590a4e6e/Lib/asyncio/threads.py
1142
+ to support Python 3.8, which does not have `asyncio.to_thread`.
1143
+
1144
+ Asynchronously run function *func* in a separate thread.
1145
+
1146
+ Any *args and **kwargs supplied for this function are directly passed
1147
+ to *func*. Also, the current :class:`contextvars.Context` is propagated,
1148
+ allowing context variables from the main thread to be accessed in the
1149
+ separate thread.
1150
+
1151
+ Return a coroutine that can be awaited to get the eventual result of *func*.
1152
+ """
1153
+ loop = events.get_running_loop()
1154
+ ctx = contextvars.copy_context()
1155
+ func_call = functools.partial(ctx.run, func, *args, **kwargs)
1156
+ return await loop.run_in_executor(None, func_call)