hpcflow-new2 0.2.0a189__py3-none-any.whl → 0.2.0a190__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. hpcflow/__pyinstaller/hook-hpcflow.py +8 -6
  2. hpcflow/_version.py +1 -1
  3. hpcflow/app.py +1 -0
  4. hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +1 -1
  5. hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +1 -1
  6. hpcflow/sdk/__init__.py +21 -15
  7. hpcflow/sdk/app.py +2133 -770
  8. hpcflow/sdk/cli.py +281 -250
  9. hpcflow/sdk/cli_common.py +6 -2
  10. hpcflow/sdk/config/__init__.py +1 -1
  11. hpcflow/sdk/config/callbacks.py +77 -42
  12. hpcflow/sdk/config/cli.py +126 -103
  13. hpcflow/sdk/config/config.py +578 -311
  14. hpcflow/sdk/config/config_file.py +131 -95
  15. hpcflow/sdk/config/errors.py +112 -85
  16. hpcflow/sdk/config/types.py +145 -0
  17. hpcflow/sdk/core/actions.py +1054 -994
  18. hpcflow/sdk/core/app_aware.py +24 -0
  19. hpcflow/sdk/core/cache.py +81 -63
  20. hpcflow/sdk/core/command_files.py +275 -185
  21. hpcflow/sdk/core/commands.py +111 -107
  22. hpcflow/sdk/core/element.py +724 -503
  23. hpcflow/sdk/core/enums.py +192 -0
  24. hpcflow/sdk/core/environment.py +74 -93
  25. hpcflow/sdk/core/errors.py +398 -51
  26. hpcflow/sdk/core/json_like.py +540 -272
  27. hpcflow/sdk/core/loop.py +380 -334
  28. hpcflow/sdk/core/loop_cache.py +160 -43
  29. hpcflow/sdk/core/object_list.py +370 -207
  30. hpcflow/sdk/core/parameters.py +728 -600
  31. hpcflow/sdk/core/rule.py +59 -41
  32. hpcflow/sdk/core/run_dir_files.py +33 -22
  33. hpcflow/sdk/core/task.py +1546 -1325
  34. hpcflow/sdk/core/task_schema.py +240 -196
  35. hpcflow/sdk/core/test_utils.py +126 -88
  36. hpcflow/sdk/core/types.py +387 -0
  37. hpcflow/sdk/core/utils.py +410 -305
  38. hpcflow/sdk/core/validation.py +82 -9
  39. hpcflow/sdk/core/workflow.py +1192 -1028
  40. hpcflow/sdk/core/zarr_io.py +98 -137
  41. hpcflow/sdk/demo/cli.py +46 -33
  42. hpcflow/sdk/helper/cli.py +18 -16
  43. hpcflow/sdk/helper/helper.py +75 -63
  44. hpcflow/sdk/helper/watcher.py +61 -28
  45. hpcflow/sdk/log.py +83 -59
  46. hpcflow/sdk/persistence/__init__.py +8 -31
  47. hpcflow/sdk/persistence/base.py +988 -586
  48. hpcflow/sdk/persistence/defaults.py +6 -0
  49. hpcflow/sdk/persistence/discovery.py +38 -0
  50. hpcflow/sdk/persistence/json.py +408 -153
  51. hpcflow/sdk/persistence/pending.py +158 -123
  52. hpcflow/sdk/persistence/store_resource.py +37 -22
  53. hpcflow/sdk/persistence/types.py +307 -0
  54. hpcflow/sdk/persistence/utils.py +14 -11
  55. hpcflow/sdk/persistence/zarr.py +477 -420
  56. hpcflow/sdk/runtime.py +44 -41
  57. hpcflow/sdk/submission/{jobscript_info.py → enums.py} +39 -12
  58. hpcflow/sdk/submission/jobscript.py +444 -404
  59. hpcflow/sdk/submission/schedulers/__init__.py +133 -40
  60. hpcflow/sdk/submission/schedulers/direct.py +97 -71
  61. hpcflow/sdk/submission/schedulers/sge.py +132 -126
  62. hpcflow/sdk/submission/schedulers/slurm.py +263 -268
  63. hpcflow/sdk/submission/schedulers/utils.py +7 -2
  64. hpcflow/sdk/submission/shells/__init__.py +14 -15
  65. hpcflow/sdk/submission/shells/base.py +102 -29
  66. hpcflow/sdk/submission/shells/bash.py +72 -55
  67. hpcflow/sdk/submission/shells/os_version.py +31 -30
  68. hpcflow/sdk/submission/shells/powershell.py +37 -29
  69. hpcflow/sdk/submission/submission.py +203 -257
  70. hpcflow/sdk/submission/types.py +143 -0
  71. hpcflow/sdk/typing.py +163 -12
  72. hpcflow/tests/conftest.py +8 -6
  73. hpcflow/tests/schedulers/slurm/test_slurm_submission.py +5 -2
  74. hpcflow/tests/scripts/test_main_scripts.py +60 -30
  75. hpcflow/tests/shells/wsl/test_wsl_submission.py +6 -4
  76. hpcflow/tests/unit/test_action.py +86 -75
  77. hpcflow/tests/unit/test_action_rule.py +9 -4
  78. hpcflow/tests/unit/test_app.py +13 -6
  79. hpcflow/tests/unit/test_cli.py +1 -1
  80. hpcflow/tests/unit/test_command.py +71 -54
  81. hpcflow/tests/unit/test_config.py +20 -15
  82. hpcflow/tests/unit/test_config_file.py +21 -18
  83. hpcflow/tests/unit/test_element.py +58 -62
  84. hpcflow/tests/unit/test_element_iteration.py +3 -1
  85. hpcflow/tests/unit/test_element_set.py +29 -19
  86. hpcflow/tests/unit/test_group.py +4 -2
  87. hpcflow/tests/unit/test_input_source.py +116 -93
  88. hpcflow/tests/unit/test_input_value.py +29 -24
  89. hpcflow/tests/unit/test_json_like.py +44 -35
  90. hpcflow/tests/unit/test_loop.py +65 -58
  91. hpcflow/tests/unit/test_object_list.py +17 -12
  92. hpcflow/tests/unit/test_parameter.py +16 -7
  93. hpcflow/tests/unit/test_persistence.py +48 -35
  94. hpcflow/tests/unit/test_resources.py +20 -18
  95. hpcflow/tests/unit/test_run.py +8 -3
  96. hpcflow/tests/unit/test_runtime.py +2 -1
  97. hpcflow/tests/unit/test_schema_input.py +23 -15
  98. hpcflow/tests/unit/test_shell.py +3 -2
  99. hpcflow/tests/unit/test_slurm.py +8 -7
  100. hpcflow/tests/unit/test_submission.py +39 -19
  101. hpcflow/tests/unit/test_task.py +352 -247
  102. hpcflow/tests/unit/test_task_schema.py +33 -20
  103. hpcflow/tests/unit/test_utils.py +9 -11
  104. hpcflow/tests/unit/test_value_sequence.py +15 -12
  105. hpcflow/tests/unit/test_workflow.py +114 -83
  106. hpcflow/tests/unit/test_workflow_template.py +0 -1
  107. hpcflow/tests/workflows/test_jobscript.py +2 -1
  108. hpcflow/tests/workflows/test_workflows.py +18 -13
  109. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/METADATA +2 -1
  110. hpcflow_new2-0.2.0a190.dist-info/RECORD +165 -0
  111. hpcflow/sdk/core/parallel.py +0 -21
  112. hpcflow_new2-0.2.0a189.dist-info/RECORD +0 -158
  113. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/LICENSE +0 -0
  114. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/WHEEL +0 -0
  115. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/entry_points.txt +0 -0
hpcflow/sdk/core/utils.py CHANGED
@@ -2,12 +2,13 @@
2
2
  Miscellaneous utilities.
3
3
  """
4
4
 
5
+ from __future__ import annotations
6
+ from collections import Counter
5
7
  import copy
6
8
  import enum
7
- from functools import wraps
8
- import contextlib
9
9
  import hashlib
10
10
  from itertools import accumulate, islice
11
+ from importlib import resources
11
12
  import json
12
13
  import keyword
13
14
  import os
@@ -17,10 +18,10 @@ import re
17
18
  import socket
18
19
  import string
19
20
  import subprocess
20
- from datetime import datetime, timezone
21
+ from datetime import datetime, timedelta, timezone
21
22
  import sys
22
- from typing import Dict, Optional, Tuple, Type, Union, List
23
- import fsspec
23
+ from typing import cast, overload, TypeVar, TYPE_CHECKING
24
+ import fsspec # type: ignore
24
25
  import numpy as np
25
26
 
26
27
  from ruamel.yaml import YAML
@@ -28,27 +29,29 @@ from watchdog.utils.dirsnapshot import DirectorySnapshot
28
29
 
29
30
  from hpcflow.sdk.core.errors import (
30
31
  ContainerKeyError,
31
- FromSpecMissingObjectError,
32
32
  InvalidIdentifier,
33
33
  MissingVariableSubstitutionError,
34
34
  )
35
35
  from hpcflow.sdk.log import TimeIt
36
- from hpcflow.sdk.typing import PathLike
37
36
 
37
+ if TYPE_CHECKING:
38
+ from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence
39
+ from contextlib import AbstractContextManager
40
+ from types import ModuleType
41
+ from typing import Any, IO
42
+ from typing_extensions import TypeAlias
43
+ from numpy.typing import NDArray
44
+ from ..typing import PathLike
38
45
 
39
- def load_config(func):
40
- """API function decorator to ensure the configuration has been loaded, and load if not."""
46
+ T = TypeVar("T")
47
+ T2 = TypeVar("T2")
48
+ T3 = TypeVar("T3")
49
+ TList: TypeAlias = "T | list[TList]"
50
+ TD = TypeVar("TD", bound="Mapping[str, Any]")
51
+ E = TypeVar("E", bound=enum.Enum)
41
52
 
42
- @wraps(func)
43
- def wrapper(self, *args, **kwargs):
44
- if not self.is_config_loaded:
45
- self.load_config()
46
- return func(self, *args, **kwargs)
47
53
 
48
- return wrapper
49
-
50
-
51
- def make_workflow_id():
54
+ def make_workflow_id() -> str:
52
55
  """
53
56
  Generate a random ID for a workflow.
54
57
  """
@@ -57,14 +60,14 @@ def make_workflow_id():
57
60
  return "".join(random.choices(chars, k=length))
58
61
 
59
62
 
60
- def get_time_stamp():
63
+ def get_time_stamp() -> str:
61
64
  """
62
65
  Get the current time in standard string form.
63
66
  """
64
67
  return datetime.now(timezone.utc).astimezone().strftime("%Y.%m.%d_%H:%M:%S_%z")
65
68
 
66
69
 
67
- def get_duplicate_items(lst):
70
+ def get_duplicate_items(lst: Iterable[T]) -> list[T]:
68
71
  """Get a list of all items in an iterable that appear more than once, assuming items
69
72
  are hashable.
70
73
 
@@ -77,11 +80,10 @@ def get_duplicate_items(lst):
77
80
  []
78
81
 
79
82
  >>> get_duplicate_items([1, 2, 3, 3, 3, 2])
80
- [2, 3, 2]
83
+ [2, 3]
81
84
 
82
85
  """
83
- seen = []
84
- return list(set(x for x in lst if x in seen or seen.append(x)))
86
+ return [x for x, y in Counter(lst).items() if y > 1]
85
87
 
86
88
 
87
89
  def check_valid_py_identifier(name: str) -> str:
@@ -105,29 +107,40 @@ def check_valid_py_identifier(name: str) -> str:
105
107
  - `Loop.name`
106
108
 
107
109
  """
108
- exc = InvalidIdentifier(f"Invalid string for identifier: {name!r}")
109
110
  try:
110
111
  trial_name = name[1:].replace("_", "") # "internal" underscores are allowed
111
112
  except TypeError:
112
- raise exc
113
+ raise InvalidIdentifier(name) from None
114
+ except KeyError as e:
115
+ raise KeyError(f"unexpected name type {name}") from e
113
116
  if (
114
117
  not name
115
118
  or not (name[0].isalpha() and ((trial_name[1:] or "a").isalnum()))
116
119
  or keyword.iskeyword(name)
117
120
  ):
118
- raise exc
121
+ raise InvalidIdentifier(name)
119
122
 
120
123
  return name
121
124
 
122
125
 
123
- def group_by_dict_key_values(lst, *keys):
126
+ @overload
127
+ def group_by_dict_key_values(lst: list[dict[T, T2]], key: T) -> list[list[dict[T, T2]]]:
128
+ ...
129
+
130
+
131
+ @overload
132
+ def group_by_dict_key_values(lst: list[TD], key: str) -> list[list[TD]]:
133
+ ...
134
+
135
+
136
+ def group_by_dict_key_values(lst: list, key):
124
137
  """Group a list of dicts according to specified equivalent key-values.
125
138
 
126
139
  Parameters
127
140
  ----------
128
141
  lst : list of dict
129
142
  The list of dicts to group together.
130
- keys : tuple
143
+ key : key value
131
144
  Dicts that have identical values for all of these keys will be grouped together
132
145
  into a sub-list.
133
146
 
@@ -146,10 +159,10 @@ def group_by_dict_key_values(lst, *keys):
146
159
  for lst_item in lst[1:]:
147
160
  for group_idx, group in enumerate(grouped):
148
161
  try:
149
- is_vals_equal = all(lst_item[k] == group[0][k] for k in keys)
162
+ is_vals_equal = lst_item[key] == group[0][key]
150
163
 
151
164
  except KeyError:
152
- # dicts that do not have all `keys` will be in their own group:
165
+ # dicts that do not have the `key` will be in their own group:
153
166
  is_vals_equal = False
154
167
 
155
168
  if is_vals_equal:
@@ -162,7 +175,7 @@ def group_by_dict_key_values(lst, *keys):
162
175
  return grouped
163
176
 
164
177
 
165
- def swap_nested_dict_keys(dct, inner_key):
178
+ def swap_nested_dict_keys(dct: dict[T, dict[T2, T3]], inner_key: T2):
166
179
  """Return a copy where top-level keys have been swapped with a second-level inner key.
167
180
 
168
181
  Examples:
@@ -181,16 +194,35 @@ def swap_nested_dict_keys(dct, inner_key):
181
194
  }
182
195
 
183
196
  """
184
- out = {}
197
+ out: dict[T3, dict[T, dict[T2, T3]]] = {}
185
198
  for k, v in copy.deepcopy(dct or {}).items():
186
- inner_val = v.pop(inner_key)
187
- if inner_val not in out:
188
- out[inner_val] = {}
189
- out[inner_val][k] = v
199
+ out.setdefault(v.pop(inner_key), {})[k] = v
190
200
  return out
191
201
 
192
202
 
193
- def get_in_container(cont, path, cast_indices=False, allow_getattr=False):
203
+ def _ensure_int(path_comp: Any, cur_data: Any, cast_indices: bool) -> int:
204
+ """
205
+ Helper for get_in_container() and set_in_container()
206
+ """
207
+ if isinstance(path_comp, int):
208
+ return path_comp
209
+ if not cast_indices:
210
+ raise TypeError(
211
+ f"Path component {path_comp!r} must be an integer index "
212
+ f"since data is a sequence: {cur_data!r}."
213
+ )
214
+ try:
215
+ return int(path_comp)
216
+ except (TypeError, ValueError) as e:
217
+ raise TypeError(
218
+ f"Path component {path_comp!r} must be an integer index "
219
+ f"since data is a sequence: {cur_data!r}."
220
+ ) from e
221
+
222
+
223
+ def get_in_container(
224
+ cont, path: Sequence, cast_indices: bool = False, allow_getattr: bool = False
225
+ ):
194
226
  """
195
227
  Follow a path (sequence of indices of appropriate type) into a container to obtain
196
228
  a "leaf" value. Containers can be lists, tuples, dicts,
@@ -203,24 +235,12 @@ def get_in_container(cont, path, cast_indices=False, allow_getattr=False):
203
235
  )
204
236
  for idx, path_comp in enumerate(path):
205
237
  if isinstance(cur_data, (list, tuple)):
206
- if not isinstance(path_comp, int):
207
- msg = (
208
- f"Path component {path_comp!r} must be an integer index "
209
- f"since data is a sequence: {cur_data!r}."
210
- )
211
- if cast_indices:
212
- try:
213
- path_comp = int(path_comp)
214
- except TypeError:
215
- raise TypeError(msg)
216
- else:
217
- raise TypeError(msg)
218
- cur_data = cur_data[path_comp]
219
- elif isinstance(cur_data, dict):
238
+ cur_data = cur_data[_ensure_int(path_comp, cur_data, cast_indices)]
239
+ elif isinstance(cur_data, dict) or hasattr(cur_data, "__getitem__"):
220
240
  try:
221
241
  cur_data = cur_data[path_comp]
222
242
  except KeyError:
223
- raise ContainerKeyError(path=path[: idx + 1])
243
+ raise ContainerKeyError(path=cast("list[str]", path[: idx + 1]))
224
244
  elif allow_getattr:
225
245
  try:
226
246
  cur_data = getattr(cur_data, path_comp)
@@ -235,7 +255,9 @@ def get_in_container(cont, path, cast_indices=False, allow_getattr=False):
235
255
  return cur_data
236
256
 
237
257
 
238
- def set_in_container(cont, path, value, ensure_path=False, cast_indices=False):
258
+ def set_in_container(
259
+ cont, path: Sequence, value, ensure_path=False, cast_indices=False
260
+ ) -> None:
239
261
  """
240
262
  Follow a path (sequence of indices of appropriate type) into a container to update
241
263
  a "leaf" value. Containers can be lists, tuples or dicts.
@@ -258,22 +280,11 @@ def set_in_container(cont, path, value, ensure_path=False, cast_indices=False):
258
280
  sub_data = get_in_container(cont, path[:-1], cast_indices=cast_indices)
259
281
  path_comp = path[-1]
260
282
  if isinstance(sub_data, (list, tuple)):
261
- if not isinstance(path_comp, int):
262
- msg = (
263
- f"Path component {path_comp!r} must be an integer index "
264
- f"since data is a sequence: {sub_data!r}."
265
- )
266
- if cast_indices:
267
- try:
268
- path_comp = int(path_comp)
269
- except ValueError:
270
- raise ValueError(msg)
271
- else:
272
- raise ValueError(msg)
283
+ path_comp = _ensure_int(path_comp, sub_data, cast_indices)
273
284
  sub_data[path_comp] = value
274
285
 
275
286
 
276
- def get_relative_path(path1, path2):
287
+ def get_relative_path(path1: Sequence[T], path2: Sequence[T]) -> Sequence[T]:
277
288
  """Get relative path components between two paths.
278
289
 
279
290
  Parameters
@@ -308,79 +319,39 @@ def get_relative_path(path1, path2):
308
319
  """
309
320
 
310
321
  len_path2 = len(path2)
311
- msg = f"{path1!r} is not in the subpath of {path2!r}."
312
-
313
- if len(path1) < len_path2:
314
- raise ValueError(msg)
315
-
316
- for i, j in zip(path1[:len_path2], path2):
317
- if i != j:
318
- raise ValueError(msg)
322
+ if len(path1) < len_path2 or any(i != j for i, j in zip(path1[:len_path2], path2)):
323
+ raise ValueError(f"{path1!r} is not in the subpath of {path2!r}.")
319
324
 
320
325
  return path1[len_path2:]
321
326
 
322
327
 
323
- def search_dir_files_by_regex(pattern, group=0, directory=".") -> List[str]:
328
+ def search_dir_files_by_regex(
329
+ pattern: str | re.Pattern[str], directory: str = "."
330
+ ) -> list[str]:
324
331
  """Search recursively for files in a directory by a regex pattern and return matching
325
332
  file paths, relative to the given directory."""
326
- vals = []
327
- for i in Path(directory).rglob("*"):
328
- match = re.search(pattern, i.name)
329
- if match:
330
- match_groups = match.groups()
331
- if match_groups:
332
- match = match_groups[group]
333
- vals.append(str(i.relative_to(directory)))
334
- return vals
335
-
336
-
337
- class classproperty(object):
338
- """
339
- Simple class property decorator.
340
- """
341
-
342
- def __init__(self, f):
343
- self.f = f
333
+ dir_ = Path(directory)
334
+ return [
335
+ str(entry.relative_to(dir_))
336
+ for entry in dir_.rglob("*")
337
+ if re.search(pattern, entry.name)
338
+ ]
344
339
 
345
- def __get__(self, obj, owner):
346
- return self.f(owner)
347
340
 
348
-
349
- class PrettyPrinter(object):
341
+ class PrettyPrinter:
350
342
  """
351
343
  A class that produces a nice readable version of itself with ``str()``.
352
344
  Intended to be subclassed.
353
345
  """
354
346
 
355
- def __str__(self):
347
+ def __str__(self) -> str:
356
348
  lines = [self.__class__.__name__ + ":"]
357
349
  for key, val in vars(self).items():
358
- lines += f"{key}: {val}".split("\n")
350
+ lines.extend(f"{key}: {val}".split("\n"))
359
351
  return "\n ".join(lines)
360
352
 
361
353
 
362
- class Singleton(type):
363
- """
364
- Metaclass that enforces that only one instance can exist of the classes to which it
365
- is applied.
366
- """
367
-
368
- _instances = {}
369
-
370
- def __call__(cls, *args, **kwargs):
371
- if cls not in cls._instances:
372
- cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
373
- elif args or kwargs:
374
- # if existing instance, make the point that new arguments don't do anything!
375
- raise ValueError(
376
- f"{cls.__name__!r} is a singleton class and cannot be instantiated with new "
377
- f"arguments. The positional arguments {args!r} and keyword-arguments "
378
- f"{kwargs!r} have been ignored."
379
- )
380
- return cls._instances[cls]
381
-
382
-
383
- def capitalise_first_letter(chars):
354
+ def capitalise_first_letter(chars: str) -> str:
384
355
  """
385
356
  Convert the first character of a string to upper case (if that makes sense).
386
357
  The rest of the string is unchanged.
@@ -388,30 +359,11 @@ def capitalise_first_letter(chars):
388
359
  return chars[0].upper() + chars[1:]
389
360
 
390
361
 
391
- def check_in_object_list(spec_name, spec_pos=1, obj_list_pos=2):
392
- """Decorator factory for the various `from_spec` class methods that have attributes
393
- that should be replaced by an object from an object list."""
394
-
395
- def decorator(func):
396
- @wraps(func)
397
- def wrap(*args, **kwargs):
398
- spec = args[spec_pos]
399
- obj_list = args[obj_list_pos]
400
- if spec[spec_name] not in obj_list:
401
- cls_name = args[0].__name__
402
- raise FromSpecMissingObjectError(
403
- f"A {spec_name!r} object required to instantiate the {cls_name!r} "
404
- f"object is missing."
405
- )
406
- return func(*args, **kwargs)
407
-
408
- return wrap
409
-
410
- return decorator
362
+ _STRING_VARS_RE = re.compile(r"\<\<var:(.*?)(?:\[(.*)\])?\>\>")
411
363
 
412
364
 
413
365
  @TimeIt.decorator
414
- def substitute_string_vars(string, variables: Dict[str, str] = None):
366
+ def substitute_string_vars(string: str, variables: dict[str, str]):
415
367
  """
416
368
  Scan ``string`` and substitute sequences like ``<<var:ABC>>`` with the value
417
369
  looked up in the supplied dictionary (with ``ABC`` as the key).
@@ -424,15 +376,14 @@ def substitute_string_vars(string, variables: Dict[str, str] = None):
424
376
  >>> substitute_string_vars("abc <<var:def>> ghi", {"def": "123"})
425
377
  "abc 123 def"
426
378
  """
427
- variables = variables or {}
428
379
 
429
- def var_repl(match_obj):
430
- kwargs = {}
431
- var_name, kwargs_str = match_obj.groups()
380
+ def var_repl(match_obj: re.Match[str]) -> str:
381
+ kwargs: dict[str, str] = {}
382
+ var_name: str = match_obj[1]
383
+ kwargs_str: str | None = match_obj[2]
432
384
  if kwargs_str:
433
- kwargs_lst = kwargs_str.split(",")
434
- for i in kwargs_lst:
435
- k, v = i.strip().split("=")
385
+ for i in kwargs_str.split(","):
386
+ k, v = i.split("=")
436
387
  kwargs[k.strip()] = v.strip()
437
388
  try:
438
389
  out = str(variables[var_name])
@@ -444,65 +395,69 @@ def substitute_string_vars(string, variables: Dict[str, str] = None):
444
395
  f"variable {var_name!r}."
445
396
  )
446
397
  else:
447
- raise MissingVariableSubstitutionError(
448
- f"The variable {var_name!r} referenced in the string does not match "
449
- f"any of the provided variables: {list(variables)!r}."
450
- )
398
+ raise MissingVariableSubstitutionError(var_name, variables)
451
399
  return out
452
400
 
453
- new_str = re.sub(
454
- pattern=r"\<\<var:(.*?)(?:\[(.*)\])?\>\>",
401
+ return _STRING_VARS_RE.sub(
455
402
  repl=var_repl,
456
403
  string=string,
457
404
  )
458
- return new_str
459
405
 
460
406
 
461
407
  @TimeIt.decorator
462
- def read_YAML_str(yaml_str, typ="safe", variables: Dict[str, str] = None):
408
+ def read_YAML_str(
409
+ yaml_str: str, typ="safe", variables: dict[str, str] | None = None
410
+ ) -> Any:
463
411
  """Load a YAML string. This will produce basic objects."""
464
- if variables is not False and "<<var:" in yaml_str:
412
+ if variables is not None and "<<var:" in yaml_str:
465
413
  yaml_str = substitute_string_vars(yaml_str, variables=variables)
466
414
  yaml = YAML(typ=typ)
467
415
  return yaml.load(yaml_str)
468
416
 
469
417
 
470
418
  @TimeIt.decorator
471
- def read_YAML_file(path: PathLike, typ="safe", variables: Dict[str, str] = None):
419
+ def read_YAML_file(
420
+ path: PathLike, typ="safe", variables: dict[str, str] | None = None
421
+ ) -> Any:
472
422
  """Load a YAML file. This will produce basic objects."""
473
423
  with fsspec.open(path, "rt") as f:
474
- yaml_str = f.read()
424
+ yaml_str: str = f.read()
475
425
  return read_YAML_str(yaml_str, typ=typ, variables=variables)
476
426
 
477
427
 
478
- def write_YAML_file(obj, path: PathLike, typ="safe"):
428
+ def write_YAML_file(obj, path: str | Path, typ: str = "safe") -> None:
479
429
  """Write a basic object to a YAML file."""
480
430
  yaml = YAML(typ=typ)
481
431
  with Path(path).open("wt") as fp:
482
432
  yaml.dump(obj, fp)
483
433
 
484
434
 
485
- def read_JSON_string(json_str: str, variables: Dict[str, str] = None):
435
+ def read_JSON_string(json_str: str, variables: dict[str, str] | None = None) -> Any:
486
436
  """Load a JSON string. This will produce basic objects."""
487
- if variables is not False and "<<var:" in json_str:
437
+ if variables is not None and "<<var:" in json_str:
488
438
  json_str = substitute_string_vars(json_str, variables=variables)
489
439
  return json.loads(json_str)
490
440
 
491
441
 
492
- def read_JSON_file(path, variables: Dict[str, str] = None):
442
+ def read_JSON_file(path, variables: dict[str, str] | None = None) -> Any:
493
443
  """Load a JSON file. This will produce basic objects."""
494
444
  with fsspec.open(path, "rt") as f:
495
- json_str = f.read()
445
+ json_str: str = f.read()
496
446
  return read_JSON_string(json_str, variables=variables)
497
447
 
498
448
 
499
- def write_JSON_file(obj, path: PathLike):
449
+ def write_JSON_file(obj, path: str | Path) -> None:
500
450
  """Write a basic object to a JSON file."""
501
451
  with Path(path).open("wt") as fp:
502
452
  json.dump(obj, fp)
503
453
 
504
454
 
505
- def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
455
+ def get_item_repeat_index(
456
+ lst: Sequence[T],
457
+ *,
458
+ distinguish_singular: bool = False,
459
+ item_callable: Callable[[T], Hashable] | None = None,
460
+ ):
506
461
  """Get the repeat index for each item in a list.
507
462
 
508
463
  Parameters
@@ -510,10 +465,10 @@ def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
510
465
  lst : list
511
466
  Must contain hashable items, or hashable objects that are returned via `callable`
512
467
  called on each item.
513
- distinguish_singular : bool, optional
468
+ distinguish_singular : bool
514
469
  If True, items that are not repeated will have a repeat index of 0, and items that
515
470
  are repeated will have repeat indices starting from 1.
516
- item_callable : callable, optional
471
+ item_callable : callable
517
472
  If specified, comparisons are made on the output of this callable on each item.
518
473
 
519
474
  Returns
@@ -523,16 +478,16 @@ def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
523
478
 
524
479
  """
525
480
 
526
- idx = {}
527
- for i_idx, item in enumerate(lst):
528
- if item_callable:
529
- item = item_callable(item)
530
- if item not in idx:
531
- idx[item] = []
532
- idx[item] += [i_idx]
481
+ idx: dict[Any, list[int]] = {}
482
+ if item_callable:
483
+ for i_idx, item in enumerate(lst):
484
+ idx.setdefault(item_callable(item), []).append(i_idx)
485
+ else:
486
+ for i_idx, item in enumerate(lst):
487
+ idx.setdefault(item, []).append(i_idx)
533
488
 
534
- rep_idx = [None] * len(lst)
535
- for k, v in idx.items():
489
+ rep_idx = [0] * len(lst)
490
+ for v in idx.values():
536
491
  start = len(v) > 1 if distinguish_singular else 0
537
492
  for i_idx, i in enumerate(v, start):
538
493
  rep_idx[i] = i_idx
@@ -540,7 +495,7 @@ def get_item_repeat_index(lst, distinguish_singular=False, item_callable=None):
540
495
  return rep_idx
541
496
 
542
497
 
543
- def get_process_stamp():
498
+ def get_process_stamp() -> str:
544
499
  """
545
500
  Return a globally unique string identifying this process.
546
501
 
@@ -555,15 +510,17 @@ def get_process_stamp():
555
510
  )
556
511
 
557
512
 
558
- def remove_ansi_escape_sequences(string):
513
+ _ANSI_ESCAPE_RE = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
514
+
515
+
516
+ def remove_ansi_escape_sequences(string: str) -> str:
559
517
  """
560
518
  Strip ANSI terminal escape codes from a string.
561
519
  """
562
- ansi_escape = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
563
- return ansi_escape.sub("", string)
520
+ return _ANSI_ESCAPE_RE.sub("", string)
564
521
 
565
522
 
566
- def get_md5_hash(obj):
523
+ def get_md5_hash(obj) -> str:
567
524
  """
568
525
  Compute the MD5 hash of an object.
569
526
  This is the hash of the JSON of the object (with sorted keys) as a hex string.
@@ -572,7 +529,9 @@ def get_md5_hash(obj):
572
529
  return hashlib.md5(json_str.encode("utf-8")).hexdigest()
573
530
 
574
531
 
575
- def get_nested_indices(idx, size, nest_levels, raise_on_rollover=False):
532
+ def get_nested_indices(
533
+ idx: int, size: int, nest_levels: int, raise_on_rollover: bool = False
534
+ ) -> list[int]:
576
535
  """Generate the set of nested indices of length `n` that correspond to a global
577
536
  `idx`.
578
537
 
@@ -618,34 +577,34 @@ def get_nested_indices(idx, size, nest_levels, raise_on_rollover=False):
618
577
  return [(idx // (size ** (nest_levels - (i + 1)))) % size for i in range(nest_levels)]
619
578
 
620
579
 
621
- def ensure_in(item, lst) -> int:
580
+ def ensure_in(item: T, lst: list[T]) -> int:
622
581
  """Get the index of an item in a list and append the item if it is not in the
623
582
  list."""
624
583
  # TODO: add tests
625
584
  try:
626
- idx = lst.index(item)
585
+ return lst.index(item)
627
586
  except ValueError:
628
587
  lst.append(item)
629
- idx = len(lst) - 1
630
- return idx
588
+ return len(lst) - 1
631
589
 
632
590
 
633
- def list_to_dict(lst, exclude=None):
591
+ def list_to_dict(
592
+ lst: Sequence[Mapping[T, T2]], exclude: Iterable[T] | None = None
593
+ ) -> dict[T, list[T2]]:
634
594
  """
635
595
  Convert a list of dicts to a dict of lists.
636
596
  """
637
- # TODD: test
638
- exclude = exclude or []
639
- dct = {k: [] for k in lst[0].keys() if k not in exclude}
640
- for i in lst:
641
- for k, v in i.items():
642
- if k not in exclude:
597
+ # TODO: test
598
+ exc = frozenset(exclude or ())
599
+ dct: dict[T, list[T2]] = {k: [] for k in lst[0] if k not in exc}
600
+ for d in lst:
601
+ for k, v in d.items():
602
+ if k not in exc:
643
603
  dct[k].append(v)
644
-
645
604
  return dct
646
605
 
647
606
 
648
- def bisect_slice(selection: slice, len_A: int):
607
+ def bisect_slice(selection: slice, len_A: int) -> tuple[slice, slice]:
649
608
  """Given two sequences (the first of which of known length), get the two slices that
650
609
  are equivalent to a given slice if the two sequences were combined."""
651
610
 
@@ -660,32 +619,33 @@ def bisect_slice(selection: slice, len_A: int):
660
619
  B_stop = B_start
661
620
  else:
662
621
  B_stop = selection.stop - len_A
663
- B_idx = (B_start, B_stop, selection.step)
664
- A_slice = slice(*A_idx)
665
- B_slice = slice(*B_idx)
666
622
 
667
- return A_slice, B_slice
623
+ return slice(*A_idx), slice(B_start, B_stop, selection.step)
668
624
 
669
625
 
670
- def replace_items(lst, start, end, repl):
626
+ def replace_items(lst: list[T], start: int, end: int, repl: list[T]) -> list[T]:
671
627
  """Replaced a range of items in a list with items in another list."""
672
- if end <= start:
628
+ # Convert to actual indices for our safety checks; handles end-relative addressing
629
+ real_start, real_end, _ = slice(start, end).indices(len(lst))
630
+ if real_end <= real_start:
673
631
  raise ValueError(
674
632
  f"`end` ({end}) must be greater than or equal to `start` ({start})."
675
633
  )
676
- if start >= len(lst):
634
+ if real_start >= len(lst):
677
635
  raise ValueError(f"`start` ({start}) must be less than length ({len(lst)}).")
678
- if end > len(lst):
636
+ if real_end > len(lst):
679
637
  raise ValueError(
680
638
  f"`end` ({end}) must be less than or equal to length ({len(lst)})."
681
639
  )
682
640
 
683
- lst_a = lst[:start]
684
- lst_b = lst[end:]
685
- return lst_a + repl + lst_b
641
+ lst = list(lst)
642
+ lst[start:end] = repl
643
+ return lst
686
644
 
687
645
 
688
- def flatten(lst):
646
+ def flatten(
647
+ lst: list[int] | list[list[int]] | list[list[list[int]]],
648
+ ) -> tuple[list[int], tuple[list[int], ...]]:
689
649
  """Flatten an arbitrarily (but of uniform depth) nested list and return shape
690
650
  information to enable un-flattening.
691
651
 
@@ -700,108 +660,156 @@ def flatten(lst):
700
660
 
701
661
  """
702
662
 
703
- def _flatten(lst, _depth=0):
704
- out = []
705
- for i in lst:
706
- if isinstance(i, list):
707
- out += _flatten(i, _depth=_depth + 1)
708
- all_lens[_depth].append(len(i))
663
+ def _flatten(
664
+ lst: list[int] | list[list[int]] | list[list[list[int]]], depth=0
665
+ ) -> list[int]:
666
+ out: list[int] = []
667
+ for item in lst:
668
+ if isinstance(item, list):
669
+ out.extend(_flatten(item, depth + 1))
670
+ all_lens[depth].append(len(item))
709
671
  else:
710
- out.append(i)
672
+ out.append(item)
711
673
  return out
712
674
 
713
- def _get_max_depth(lst):
714
- lst = lst[:]
675
+ def _get_max_depth(lst: list[int] | list[list[int]] | list[list[list[int]]]) -> int:
676
+ val: Any = lst
715
677
  max_depth = 0
716
- while isinstance(lst, list):
678
+ while isinstance(val, list):
717
679
  max_depth += 1
718
680
  try:
719
- lst = lst[0]
681
+ val = val[0]
720
682
  except IndexError:
721
683
  # empty list, assume this is max depth
722
684
  break
723
685
  return max_depth
724
686
 
725
687
  max_depth = _get_max_depth(lst) - 1
726
- all_lens = tuple([] for _ in range(max_depth))
688
+ all_lens: tuple[list[int], ...] = tuple([] for _ in range(max_depth))
727
689
 
728
690
  return _flatten(lst), all_lens
729
691
 
730
692
 
731
- def reshape(lst, lens):
693
+ def reshape(lst: Sequence[T], lens: Sequence[Sequence[int]]) -> list[TList[T]]:
732
694
  """
733
695
  Reverse the destructuring of the :py:func:`flatten` function.
734
696
  """
735
697
 
736
- def _reshape(lst, lens):
737
- lens_acc = [0] + list(accumulate(lens))
738
- lst_rs = [lst[lens_acc[idx] : lens_acc[idx + 1]] for idx in range(len(lens))]
739
- return lst_rs
698
+ def _reshape(lst: list[T2], lens: Sequence[int]) -> list[list[T2]]:
699
+ lens_acc = [0, *accumulate(lens)]
700
+ return [lst[lens_acc[idx] : lens_acc[idx + 1]] for idx in range(len(lens))]
740
701
 
702
+ result: list[TList[T]] = list(lst)
741
703
  for lens_i in lens[::-1]:
742
- lst = _reshape(lst, lens_i)
704
+ result = cast("list[TList[T]]", _reshape(result, lens_i))
743
705
 
744
- return lst
706
+ return result
707
+
708
+
709
+ @overload
710
+ def remap(
711
+ lst: list[int], mapping_func: Callable[[Sequence[int]], Sequence[T]]
712
+ ) -> list[T]:
713
+ ...
714
+
715
+
716
+ @overload
717
+ def remap(
718
+ lst: list[list[int]], mapping_func: Callable[[Sequence[int]], Sequence[T]]
719
+ ) -> list[list[T]]:
720
+ ...
721
+
722
+
723
+ @overload
724
+ def remap(
725
+ lst: list[list[list[int]]], mapping_func: Callable[[Sequence[int]], Sequence[T]]
726
+ ) -> list[list[list[T]]]:
727
+ ...
728
+
729
+
730
+ def remap(lst, mapping_func):
731
+ """
732
+ Apply a mapping to a structure of lists with ints (typically indices) as leaves to
733
+ get a structure of lists with some objects as leaves.
734
+
735
+ Parameters
736
+ ----------
737
+ lst: list[int] | list[list[int]] | list[list[list[int]]]
738
+ The structure to remap.
739
+ mapping_func: Callable[[Sequence[int]], Sequence[T]]
740
+ The mapping function from sequences of ints to sequences of objects.
741
+
742
+ Returns
743
+ -------
744
+ list[T] | list[list[T]] | list[list[list[T]]]
745
+ Nested list structure in same form as input, with leaves remapped.
746
+ """
747
+ x, y = flatten(lst)
748
+ return reshape(mapping_func(x), y)
749
+
750
+
751
+ _FSSPEC_URL_RE = re.compile(r"(?:[a-z0-9]+:{1,2})+\/\/")
745
752
 
746
753
 
747
754
  def is_fsspec_url(url: str) -> bool:
748
755
  """
749
756
  Test if a URL appears to be one that can be understood by fsspec.
750
757
  """
751
- return bool(re.match(r"(?:[a-z0-9]+:{1,2})+\/\/", url))
758
+ return bool(_FSSPEC_URL_RE.match(url))
752
759
 
753
760
 
754
761
  class JSONLikeDirSnapShot(DirectorySnapshot):
755
762
  """
756
763
  Overridden DirectorySnapshot from watchdog to allow saving and loading from JSON.
757
- """
758
764
 
759
- def __init__(self, root_path=None, data=None):
760
- """Create an empty snapshot or load from JSON-like data.
765
+ Parameters
766
+ ----------
767
+ root_path: str
768
+ Where to take the snapshot based at.
769
+ data: dict[str, list]
770
+ Serialised snapshot to reload from.
771
+ See :py:meth:`to_json_like`.
772
+ """
761
773
 
762
- Parameters
763
- ----------
764
- root_path: str
765
- Where to take the snapshot based at.
766
- data: dict
767
- Serialised snapshot to reload from.
768
- See :py:meth:`to_json_like`.
774
+ def __init__(self, root_path: str | None = None, data: dict[str, list] | None = None):
775
+ """
776
+ Create an empty snapshot or load from JSON-like data.
769
777
  """
770
778
 
771
779
  #: Where to take the snapshot based at.
772
780
  self.root_path = root_path
773
- self._stat_info = {}
774
- self._inode_to_path = {}
781
+ self._stat_info: dict[str, os.stat_result] = {}
782
+ self._inode_to_path: dict[tuple, str] = {}
775
783
 
776
784
  if data:
777
- for k in list((data or {}).keys()):
785
+ assert root_path
786
+ for name, item in data.items():
778
787
  # add root path
779
- full_k = str(PurePath(root_path) / PurePath(k))
780
- stat_dat, inode_key = data[k][:-2], data[k][-2:]
781
- self._stat_info[full_k] = os.stat_result(stat_dat)
782
- self._inode_to_path[tuple(inode_key)] = full_k
788
+ full_name = str(PurePath(root_path) / PurePath(name))
789
+ stat_dat, inode_key = item[:-2], item[-2:]
790
+ self._stat_info[full_name] = os.stat_result(stat_dat)
791
+ self._inode_to_path[tuple(inode_key)] = full_name
783
792
 
784
- def take(self, *args, **kwargs):
793
+ def take(self, *args, **kwargs) -> None:
785
794
  """Take the snapshot."""
786
795
  super().__init__(*args, **kwargs)
787
796
 
788
- def to_json_like(self):
797
+ def to_json_like(self) -> dict[str, Any]:
789
798
  """Export to a dict that is JSON-compatible and can be later reloaded.
790
799
 
791
800
  The last two integers in `data` for each path are the keys in
792
801
  `self._inode_to_path`.
793
802
 
794
803
  """
795
-
796
804
  # first key is the root path:
797
- root_path = next(iter(self._stat_info.keys()))
805
+ root_path = next(iter(self._stat_info))
798
806
 
799
807
  # store efficiently:
800
808
  inode_invert = {v: k for k, v in self._inode_to_path.items()}
801
- data = {}
802
- for k, v in self._stat_info.items():
803
- k_rel = str(PurePath(k).relative_to(root_path))
804
- data[k_rel] = list(v) + list(inode_invert[k])
809
+ data: dict[str, list] = {
810
+ str(PurePath(k).relative_to(root_path)): [*v, *inode_invert[k]]
811
+ for k, v in self._stat_info.items()
812
+ }
805
813
 
806
814
  return {
807
815
  "root_path": root_path,
@@ -809,7 +817,7 @@ class JSONLikeDirSnapShot(DirectorySnapshot):
809
817
  }
810
818
 
811
819
 
812
- def open_file(filename):
820
+ def open_file(filename: str | Path):
813
821
  """Open a file or directory using the default system application."""
814
822
  if sys.platform == "win32":
815
823
  os.startfile(filename)
@@ -818,59 +826,76 @@ def open_file(filename):
818
826
  subprocess.call([opener, filename])
819
827
 
820
828
 
821
- def get_enum_by_name_or_val(enum_cls: Type, key: Union[str, None]) -> enum.Enum:
829
+ @overload
830
+ def get_enum_by_name_or_val(enum_cls: type[E], key: None) -> None:
831
+ ...
832
+
833
+
834
+ @overload
835
+ def get_enum_by_name_or_val(enum_cls: type[E], key: str | int | float | E) -> E:
836
+ ...
837
+
838
+
839
+ def get_enum_by_name_or_val(
840
+ enum_cls: type[E], key: str | int | float | E | None
841
+ ) -> E | None:
822
842
  """Retrieve an enum by name or value, assuming uppercase names and integer values."""
823
- err = f"Unknown enum key or value {key!r} for class {enum_cls!r}"
824
843
  if key is None or isinstance(key, enum_cls):
825
844
  return key
826
845
  elif isinstance(key, (int, float)):
827
846
  return enum_cls(int(key)) # retrieve by value
828
847
  elif isinstance(key, str):
829
848
  try:
830
- return getattr(enum_cls, key.upper()) # retrieve by name
849
+ return cast("E", getattr(enum_cls, key.upper())) # retrieve by name
831
850
  except AttributeError:
832
- raise ValueError(err)
833
- else:
834
- raise ValueError(err)
851
+ pass
852
+ raise ValueError(f"Unknown enum key or value {key!r} for class {enum_cls!r}")
853
+
835
854
 
855
+ _PARAM_SPLIT_RE = re.compile(r"((?:\w|\.)+)(?:\[(\w+)\])?")
836
856
 
837
- def split_param_label(param_path: str) -> Tuple[Union[str, None]]:
857
+
858
+ def split_param_label(param_path: str) -> tuple[str, str] | tuple[None, None]:
838
859
  """Split a parameter path into the path and the label, if present."""
839
- pattern = r"((?:\w|\.)+)(?:\[(\w+)\])?"
840
- match = re.match(pattern, param_path)
841
- return match.group(1), match.group(2)
860
+ if match := _PARAM_SPLIT_RE.match(param_path):
861
+ return match[1], match[2]
862
+ else:
863
+ return None, None
842
864
 
843
865
 
844
- def process_string_nodes(data, str_processor):
866
+ def process_string_nodes(data: T, str_processor: Callable[[str], str]) -> T:
845
867
  """Walk through a nested data structure and process string nodes using a provided
846
868
  callable."""
847
869
 
848
870
  if isinstance(data, dict):
849
- for k, v in data.items():
850
- data[k] = process_string_nodes(v, str_processor)
871
+ return cast(
872
+ "T", {k: process_string_nodes(v, str_processor) for k, v in data.items()}
873
+ )
851
874
 
852
- elif isinstance(data, (list, tuple, set)):
853
- _data = [process_string_nodes(i, str_processor) for i in data]
875
+ elif isinstance(data, (list, tuple, set, frozenset)):
876
+ _data = (process_string_nodes(i, str_processor) for i in data)
854
877
  if isinstance(data, tuple):
855
- data = tuple(_data)
878
+ return cast("T", tuple(_data))
856
879
  elif isinstance(data, set):
857
- data = set(_data)
880
+ return cast("T", set(_data))
881
+ elif isinstance(data, frozenset):
882
+ return cast("T", frozenset(_data))
858
883
  else:
859
- data = _data
884
+ return cast("T", list(_data))
860
885
 
861
886
  elif isinstance(data, str):
862
- data = str_processor(data)
887
+ return cast("T", str_processor(data))
863
888
 
864
889
  return data
865
890
 
866
891
 
867
892
  def linspace_rect(
868
- start: List[float],
869
- stop: List[float],
870
- num: List[float],
871
- include: Optional[List[str]] = None,
893
+ start: Sequence[float],
894
+ stop: Sequence[float],
895
+ num: Sequence[int],
896
+ include: Sequence[str] | None = None,
872
897
  **kwargs,
873
- ):
898
+ ) -> NDArray:
874
899
  """Generate a linear space around a rectangle.
875
900
 
876
901
  Parameters
@@ -892,19 +917,18 @@ def linspace_rect(
892
917
 
893
918
  """
894
919
 
895
- if num[0] == 1 or num[1] == 1:
920
+ if num[0] <= 1 or num[1] <= 1:
896
921
  raise ValueError("Both values in `num` must be greater than 1.")
897
922
 
898
- if not include:
899
- include = ["top", "right", "bottom", "left"]
923
+ inc = set(include) if include else {"top", "right", "bottom", "left"}
900
924
 
901
925
  c0_range = np.linspace(start=start[0], stop=stop[0], num=num[0], **kwargs)
902
926
  c1_range_all = np.linspace(start=start[1], stop=stop[1], num=num[1], **kwargs)
903
927
 
904
928
  c1_range = c1_range_all
905
- if "bottom" in include:
929
+ if "bottom" in inc:
906
930
  c1_range = c1_range[1:]
907
- if "top" in include:
931
+ if "top" in inc:
908
932
  c1_range = c1_range[:-1]
909
933
 
910
934
  c0_range_c1_start = np.vstack([c0_range, np.repeat(start[1], num[0])])
@@ -914,20 +938,21 @@ def linspace_rect(
914
938
  c1_range_c0_stop = np.vstack([np.repeat(c0_range[-1], len(c1_range)), c1_range])
915
939
 
916
940
  stacked = []
917
- if "top" in include:
941
+ if "top" in inc:
918
942
  stacked.append(c0_range_c1_stop)
919
- if "right" in include:
943
+ if "right" in inc:
920
944
  stacked.append(c1_range_c0_stop)
921
- if "bottom" in include:
945
+ if "bottom" in inc:
922
946
  stacked.append(c0_range_c1_start)
923
- if "left" in include:
947
+ if "left" in inc:
924
948
  stacked.append(c1_range_c0_start)
925
949
 
926
- rect = np.hstack(stacked)
927
- return rect
950
+ return np.hstack(stacked)
928
951
 
929
952
 
930
- def dict_values_process_flat(d, callable):
953
+ def dict_values_process_flat(
954
+ d: Mapping[T, T2 | list[T2]], callable: Callable[[list[T2]], list[T3]]
955
+ ) -> Mapping[T, T3 | list[T3]]:
931
956
  """
932
957
  Return a copy of a dict, where the values are processed by a callable that is to
933
958
  be called only once, and where the values may be single items or lists of items.
@@ -939,32 +964,34 @@ def dict_values_process_flat(d, callable):
939
964
  {'a': 1, 'b': [2, 3], 'c': 6}
940
965
 
941
966
  """
942
- flat = [] # values of `d`, flattened
943
- is_multi = [] # whether a list, and the number of items to process
967
+ flat: list[T2] = [] # values of `d`, flattened
968
+ is_multi: list[
969
+ tuple[bool, int]
970
+ ] = [] # whether a list, and the number of items to process
944
971
  for i in d.values():
945
- try:
946
- flat.extend(i)
972
+ if isinstance(i, list):
973
+ flat.extend(cast("list[T2]", i))
947
974
  is_multi.append((True, len(i)))
948
- except TypeError:
949
- flat.append(i)
975
+ else:
976
+ flat.append(cast("T2", i))
950
977
  is_multi.append((False, 1))
951
978
 
952
979
  processed = callable(flat)
953
980
 
954
- out = {}
955
- for idx_i, (m, k) in enumerate(zip(is_multi, d.keys())):
956
-
981
+ out: dict[T, T3 | list[T3]] = {}
982
+ for idx_i, (m, k) in enumerate(zip(is_multi, d)):
957
983
  start_idx = sum(i[1] for i in is_multi[:idx_i])
958
984
  end_idx = start_idx + m[1]
959
- proc_idx_k = processed[slice(start_idx, end_idx)]
985
+ proc_idx_k = processed[start_idx:end_idx]
960
986
  if not m[0]:
961
- proc_idx_k = proc_idx_k[0]
962
- out[k] = proc_idx_k
987
+ out[k] = proc_idx_k[0]
988
+ else:
989
+ out[k] = proc_idx_k
963
990
 
964
991
  return out
965
992
 
966
993
 
967
- def nth_key(dct, n):
994
+ def nth_key(dct: Iterable[T], n: int) -> T:
968
995
  """
969
996
  Given a dict in some order, get the n'th key of that dict.
970
997
  """
@@ -973,8 +1000,86 @@ def nth_key(dct, n):
973
1000
  return next(it)
974
1001
 
975
1002
 
976
- def nth_value(dct, n):
1003
+ def nth_value(dct: dict[Any, T], n: int) -> T:
977
1004
  """
978
1005
  Given a dict in some order, get the n'th value of that dict.
979
1006
  """
980
1007
  return dct[nth_key(dct, n)]
1008
+
1009
+
1010
+ def normalise_timestamp(timestamp: datetime) -> datetime:
1011
+ """
1012
+ Force a timestamp to have UTC as its timezone,
1013
+ then convert to use the local timezone.
1014
+ """
1015
+ return timestamp.replace(tzinfo=timezone.utc).astimezone()
1016
+
1017
+
1018
+ def parse_timestamp(timestamp: str | datetime, ts_fmt: str) -> datetime:
1019
+ """
1020
+ Standard timestamp parsing.
1021
+ Ensures that timestamps are internally all UTC.
1022
+ """
1023
+ return normalise_timestamp(
1024
+ timestamp
1025
+ if isinstance(timestamp, datetime)
1026
+ else datetime.strptime(timestamp, ts_fmt)
1027
+ )
1028
+
1029
+
1030
+ def current_timestamp() -> datetime:
1031
+ """
1032
+ Get a UTC timestamp for the current time
1033
+ """
1034
+ return datetime.now(timezone.utc)
1035
+
1036
+
1037
+ def timedelta_format(td: timedelta) -> str:
1038
+ """
1039
+ Convert time delta to string in standard form.
1040
+ """
1041
+ days, seconds = td.days, td.seconds
1042
+ hours = seconds // (60 * 60)
1043
+ seconds -= hours * (60 * 60)
1044
+ minutes = seconds // 60
1045
+ seconds -= minutes * 60
1046
+ return f"{days}-{hours:02}:{minutes:02}:{seconds:02}"
1047
+
1048
+
1049
+ _TD_RE = re.compile(r"(\d+)-(\d+):(\d+):(\d+)")
1050
+
1051
+
1052
+ def timedelta_parse(td_str: str) -> timedelta:
1053
+ """
1054
+ Parse a string in standard form as a time delta.
1055
+ """
1056
+ if not (m := _TD_RE.fullmatch(td_str)):
1057
+ raise ValueError("not a supported timedelta form")
1058
+ days, hours, mins, secs = map(int, m.groups())
1059
+ return timedelta(days=days, hours=hours, minutes=mins, seconds=secs)
1060
+
1061
+
1062
+ def open_text_resource(package: ModuleType | str, resource: str) -> IO[str]:
1063
+ """
1064
+ Open a file in a package.
1065
+ """
1066
+ try:
1067
+ return resources.files(package).joinpath(resource).open("r")
1068
+ except AttributeError:
1069
+ # < python 3.9; `resource.open_text` deprecated since 3.11
1070
+ return resources.open_text(package, resource)
1071
+
1072
+
1073
+ def get_file_context(
1074
+ package: ModuleType | str, src: str | None = None
1075
+ ) -> AbstractContextManager[Path]:
1076
+ """
1077
+ Find a file or directory in a package.
1078
+ """
1079
+ try:
1080
+ files = resources.files(package)
1081
+ return resources.as_file(files.joinpath(src) if src else files)
1082
+ # raises ModuleNotFoundError
1083
+ except AttributeError:
1084
+ # < python 3.9
1085
+ return resources.path(package, src or "")