hpcflow 0.1.15__py3-none-any.whl → 0.2.0a271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (275) hide show
  1. hpcflow/__init__.py +2 -11
  2. hpcflow/__pyinstaller/__init__.py +5 -0
  3. hpcflow/__pyinstaller/hook-hpcflow.py +40 -0
  4. hpcflow/_version.py +1 -1
  5. hpcflow/app.py +43 -0
  6. hpcflow/cli.py +2 -461
  7. hpcflow/data/demo_data_manifest/__init__.py +3 -0
  8. hpcflow/data/demo_data_manifest/demo_data_manifest.json +6 -0
  9. hpcflow/data/jinja_templates/test/test_template.txt +8 -0
  10. hpcflow/data/programs/hello_world/README.md +1 -0
  11. hpcflow/data/programs/hello_world/hello_world.c +87 -0
  12. hpcflow/data/programs/hello_world/linux/hello_world +0 -0
  13. hpcflow/data/programs/hello_world/macos/hello_world +0 -0
  14. hpcflow/data/programs/hello_world/win/hello_world.exe +0 -0
  15. hpcflow/data/scripts/__init__.py +1 -0
  16. hpcflow/data/scripts/bad_script.py +2 -0
  17. hpcflow/data/scripts/demo_task_1_generate_t1_infile_1.py +8 -0
  18. hpcflow/data/scripts/demo_task_1_generate_t1_infile_2.py +8 -0
  19. hpcflow/data/scripts/demo_task_1_parse_p3.py +7 -0
  20. hpcflow/data/scripts/do_nothing.py +2 -0
  21. hpcflow/data/scripts/env_specifier_test/input_file_generator_pass_env_spec.py +4 -0
  22. hpcflow/data/scripts/env_specifier_test/main_script_test_pass_env_spec.py +8 -0
  23. hpcflow/data/scripts/env_specifier_test/output_file_parser_pass_env_spec.py +4 -0
  24. hpcflow/data/scripts/env_specifier_test/v1/input_file_generator_basic.py +4 -0
  25. hpcflow/data/scripts/env_specifier_test/v1/main_script_test_direct_in_direct_out.py +7 -0
  26. hpcflow/data/scripts/env_specifier_test/v1/output_file_parser_basic.py +4 -0
  27. hpcflow/data/scripts/env_specifier_test/v2/main_script_test_direct_in_direct_out.py +7 -0
  28. hpcflow/data/scripts/generate_t1_file_01.py +7 -0
  29. hpcflow/data/scripts/import_future_script.py +7 -0
  30. hpcflow/data/scripts/input_file_generator_basic.py +3 -0
  31. hpcflow/data/scripts/input_file_generator_basic_FAIL.py +3 -0
  32. hpcflow/data/scripts/input_file_generator_test_stdout_stderr.py +8 -0
  33. hpcflow/data/scripts/main_script_test_direct_in.py +3 -0
  34. hpcflow/data/scripts/main_script_test_direct_in_direct_out.py +6 -0
  35. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2.py +6 -0
  36. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed.py +6 -0
  37. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed_group.py +7 -0
  38. hpcflow/data/scripts/main_script_test_direct_in_direct_out_3.py +6 -0
  39. hpcflow/data/scripts/main_script_test_direct_in_direct_out_all_iters_test.py +15 -0
  40. hpcflow/data/scripts/main_script_test_direct_in_direct_out_env_spec.py +7 -0
  41. hpcflow/data/scripts/main_script_test_direct_in_direct_out_labels.py +8 -0
  42. hpcflow/data/scripts/main_script_test_direct_in_group_direct_out_3.py +6 -0
  43. hpcflow/data/scripts/main_script_test_direct_in_group_one_fail_direct_out_3.py +6 -0
  44. hpcflow/data/scripts/main_script_test_direct_sub_param_in_direct_out.py +6 -0
  45. hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +12 -0
  46. hpcflow/data/scripts/main_script_test_hdf5_in_obj_2.py +12 -0
  47. hpcflow/data/scripts/main_script_test_hdf5_in_obj_group.py +12 -0
  48. hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +11 -0
  49. hpcflow/data/scripts/main_script_test_json_and_direct_in_json_out.py +14 -0
  50. hpcflow/data/scripts/main_script_test_json_in_json_and_direct_out.py +17 -0
  51. hpcflow/data/scripts/main_script_test_json_in_json_out.py +14 -0
  52. hpcflow/data/scripts/main_script_test_json_in_json_out_labels.py +16 -0
  53. hpcflow/data/scripts/main_script_test_json_in_obj.py +12 -0
  54. hpcflow/data/scripts/main_script_test_json_out_FAIL.py +3 -0
  55. hpcflow/data/scripts/main_script_test_json_out_obj.py +10 -0
  56. hpcflow/data/scripts/main_script_test_json_sub_param_in_json_out_labels.py +16 -0
  57. hpcflow/data/scripts/main_script_test_shell_env_vars.py +12 -0
  58. hpcflow/data/scripts/main_script_test_std_out_std_err.py +6 -0
  59. hpcflow/data/scripts/output_file_parser_basic.py +3 -0
  60. hpcflow/data/scripts/output_file_parser_basic_FAIL.py +7 -0
  61. hpcflow/data/scripts/output_file_parser_test_stdout_stderr.py +8 -0
  62. hpcflow/data/scripts/parse_t1_file_01.py +4 -0
  63. hpcflow/data/scripts/script_exit_test.py +5 -0
  64. hpcflow/data/template_components/__init__.py +1 -0
  65. hpcflow/data/template_components/command_files.yaml +26 -0
  66. hpcflow/data/template_components/environments.yaml +13 -0
  67. hpcflow/data/template_components/parameters.yaml +14 -0
  68. hpcflow/data/template_components/task_schemas.yaml +139 -0
  69. hpcflow/data/workflows/workflow_1.yaml +5 -0
  70. hpcflow/examples.ipynb +1037 -0
  71. hpcflow/sdk/__init__.py +149 -0
  72. hpcflow/sdk/app.py +4266 -0
  73. hpcflow/sdk/cli.py +1479 -0
  74. hpcflow/sdk/cli_common.py +385 -0
  75. hpcflow/sdk/config/__init__.py +5 -0
  76. hpcflow/sdk/config/callbacks.py +246 -0
  77. hpcflow/sdk/config/cli.py +388 -0
  78. hpcflow/sdk/config/config.py +1410 -0
  79. hpcflow/sdk/config/config_file.py +501 -0
  80. hpcflow/sdk/config/errors.py +272 -0
  81. hpcflow/sdk/config/types.py +150 -0
  82. hpcflow/sdk/core/__init__.py +38 -0
  83. hpcflow/sdk/core/actions.py +3857 -0
  84. hpcflow/sdk/core/app_aware.py +25 -0
  85. hpcflow/sdk/core/cache.py +224 -0
  86. hpcflow/sdk/core/command_files.py +814 -0
  87. hpcflow/sdk/core/commands.py +424 -0
  88. hpcflow/sdk/core/element.py +2071 -0
  89. hpcflow/sdk/core/enums.py +221 -0
  90. hpcflow/sdk/core/environment.py +256 -0
  91. hpcflow/sdk/core/errors.py +1043 -0
  92. hpcflow/sdk/core/execute.py +207 -0
  93. hpcflow/sdk/core/json_like.py +809 -0
  94. hpcflow/sdk/core/loop.py +1320 -0
  95. hpcflow/sdk/core/loop_cache.py +282 -0
  96. hpcflow/sdk/core/object_list.py +933 -0
  97. hpcflow/sdk/core/parameters.py +3371 -0
  98. hpcflow/sdk/core/rule.py +196 -0
  99. hpcflow/sdk/core/run_dir_files.py +57 -0
  100. hpcflow/sdk/core/skip_reason.py +7 -0
  101. hpcflow/sdk/core/task.py +3792 -0
  102. hpcflow/sdk/core/task_schema.py +993 -0
  103. hpcflow/sdk/core/test_utils.py +538 -0
  104. hpcflow/sdk/core/types.py +447 -0
  105. hpcflow/sdk/core/utils.py +1207 -0
  106. hpcflow/sdk/core/validation.py +87 -0
  107. hpcflow/sdk/core/values.py +477 -0
  108. hpcflow/sdk/core/workflow.py +4820 -0
  109. hpcflow/sdk/core/zarr_io.py +206 -0
  110. hpcflow/sdk/data/__init__.py +13 -0
  111. hpcflow/sdk/data/config_file_schema.yaml +34 -0
  112. hpcflow/sdk/data/config_schema.yaml +260 -0
  113. hpcflow/sdk/data/environments_spec_schema.yaml +21 -0
  114. hpcflow/sdk/data/files_spec_schema.yaml +5 -0
  115. hpcflow/sdk/data/parameters_spec_schema.yaml +7 -0
  116. hpcflow/sdk/data/task_schema_spec_schema.yaml +3 -0
  117. hpcflow/sdk/data/workflow_spec_schema.yaml +22 -0
  118. hpcflow/sdk/demo/__init__.py +3 -0
  119. hpcflow/sdk/demo/cli.py +242 -0
  120. hpcflow/sdk/helper/__init__.py +3 -0
  121. hpcflow/sdk/helper/cli.py +137 -0
  122. hpcflow/sdk/helper/helper.py +300 -0
  123. hpcflow/sdk/helper/watcher.py +192 -0
  124. hpcflow/sdk/log.py +288 -0
  125. hpcflow/sdk/persistence/__init__.py +18 -0
  126. hpcflow/sdk/persistence/base.py +2817 -0
  127. hpcflow/sdk/persistence/defaults.py +6 -0
  128. hpcflow/sdk/persistence/discovery.py +39 -0
  129. hpcflow/sdk/persistence/json.py +954 -0
  130. hpcflow/sdk/persistence/pending.py +948 -0
  131. hpcflow/sdk/persistence/store_resource.py +203 -0
  132. hpcflow/sdk/persistence/types.py +309 -0
  133. hpcflow/sdk/persistence/utils.py +73 -0
  134. hpcflow/sdk/persistence/zarr.py +2388 -0
  135. hpcflow/sdk/runtime.py +320 -0
  136. hpcflow/sdk/submission/__init__.py +3 -0
  137. hpcflow/sdk/submission/enums.py +70 -0
  138. hpcflow/sdk/submission/jobscript.py +2379 -0
  139. hpcflow/sdk/submission/schedulers/__init__.py +281 -0
  140. hpcflow/sdk/submission/schedulers/direct.py +233 -0
  141. hpcflow/sdk/submission/schedulers/sge.py +376 -0
  142. hpcflow/sdk/submission/schedulers/slurm.py +598 -0
  143. hpcflow/sdk/submission/schedulers/utils.py +25 -0
  144. hpcflow/sdk/submission/shells/__init__.py +52 -0
  145. hpcflow/sdk/submission/shells/base.py +229 -0
  146. hpcflow/sdk/submission/shells/bash.py +504 -0
  147. hpcflow/sdk/submission/shells/os_version.py +115 -0
  148. hpcflow/sdk/submission/shells/powershell.py +352 -0
  149. hpcflow/sdk/submission/submission.py +1402 -0
  150. hpcflow/sdk/submission/types.py +140 -0
  151. hpcflow/sdk/typing.py +194 -0
  152. hpcflow/sdk/utils/arrays.py +69 -0
  153. hpcflow/sdk/utils/deferred_file.py +55 -0
  154. hpcflow/sdk/utils/hashing.py +16 -0
  155. hpcflow/sdk/utils/patches.py +31 -0
  156. hpcflow/sdk/utils/strings.py +69 -0
  157. hpcflow/tests/api/test_api.py +32 -0
  158. hpcflow/tests/conftest.py +123 -0
  159. hpcflow/tests/data/__init__.py +0 -0
  160. hpcflow/tests/data/benchmark_N_elements.yaml +6 -0
  161. hpcflow/tests/data/benchmark_script_runner.yaml +26 -0
  162. hpcflow/tests/data/multi_path_sequences.yaml +29 -0
  163. hpcflow/tests/data/workflow_1.json +10 -0
  164. hpcflow/tests/data/workflow_1.yaml +5 -0
  165. hpcflow/tests/data/workflow_1_slurm.yaml +8 -0
  166. hpcflow/tests/data/workflow_1_wsl.yaml +8 -0
  167. hpcflow/tests/data/workflow_test_run_abort.yaml +42 -0
  168. hpcflow/tests/jinja_templates/test_jinja_templates.py +161 -0
  169. hpcflow/tests/programs/test_programs.py +180 -0
  170. hpcflow/tests/schedulers/direct_linux/test_direct_linux_submission.py +12 -0
  171. hpcflow/tests/schedulers/sge/test_sge_submission.py +36 -0
  172. hpcflow/tests/schedulers/slurm/test_slurm_submission.py +14 -0
  173. hpcflow/tests/scripts/test_input_file_generators.py +282 -0
  174. hpcflow/tests/scripts/test_main_scripts.py +1361 -0
  175. hpcflow/tests/scripts/test_non_snippet_script.py +46 -0
  176. hpcflow/tests/scripts/test_ouput_file_parsers.py +353 -0
  177. hpcflow/tests/shells/wsl/test_wsl_submission.py +14 -0
  178. hpcflow/tests/unit/test_action.py +1066 -0
  179. hpcflow/tests/unit/test_action_rule.py +24 -0
  180. hpcflow/tests/unit/test_app.py +132 -0
  181. hpcflow/tests/unit/test_cache.py +46 -0
  182. hpcflow/tests/unit/test_cli.py +172 -0
  183. hpcflow/tests/unit/test_command.py +377 -0
  184. hpcflow/tests/unit/test_config.py +195 -0
  185. hpcflow/tests/unit/test_config_file.py +162 -0
  186. hpcflow/tests/unit/test_element.py +666 -0
  187. hpcflow/tests/unit/test_element_iteration.py +88 -0
  188. hpcflow/tests/unit/test_element_set.py +158 -0
  189. hpcflow/tests/unit/test_group.py +115 -0
  190. hpcflow/tests/unit/test_input_source.py +1479 -0
  191. hpcflow/tests/unit/test_input_value.py +398 -0
  192. hpcflow/tests/unit/test_jobscript_unit.py +757 -0
  193. hpcflow/tests/unit/test_json_like.py +1247 -0
  194. hpcflow/tests/unit/test_loop.py +2674 -0
  195. hpcflow/tests/unit/test_meta_task.py +325 -0
  196. hpcflow/tests/unit/test_multi_path_sequences.py +259 -0
  197. hpcflow/tests/unit/test_object_list.py +116 -0
  198. hpcflow/tests/unit/test_parameter.py +243 -0
  199. hpcflow/tests/unit/test_persistence.py +664 -0
  200. hpcflow/tests/unit/test_resources.py +243 -0
  201. hpcflow/tests/unit/test_run.py +286 -0
  202. hpcflow/tests/unit/test_run_directories.py +29 -0
  203. hpcflow/tests/unit/test_runtime.py +9 -0
  204. hpcflow/tests/unit/test_schema_input.py +372 -0
  205. hpcflow/tests/unit/test_shell.py +129 -0
  206. hpcflow/tests/unit/test_slurm.py +39 -0
  207. hpcflow/tests/unit/test_submission.py +502 -0
  208. hpcflow/tests/unit/test_task.py +2560 -0
  209. hpcflow/tests/unit/test_task_schema.py +182 -0
  210. hpcflow/tests/unit/test_utils.py +616 -0
  211. hpcflow/tests/unit/test_value_sequence.py +549 -0
  212. hpcflow/tests/unit/test_values.py +91 -0
  213. hpcflow/tests/unit/test_workflow.py +827 -0
  214. hpcflow/tests/unit/test_workflow_template.py +186 -0
  215. hpcflow/tests/unit/utils/test_arrays.py +40 -0
  216. hpcflow/tests/unit/utils/test_deferred_file_writer.py +34 -0
  217. hpcflow/tests/unit/utils/test_hashing.py +65 -0
  218. hpcflow/tests/unit/utils/test_patches.py +5 -0
  219. hpcflow/tests/unit/utils/test_redirect_std.py +50 -0
  220. hpcflow/tests/unit/utils/test_strings.py +97 -0
  221. hpcflow/tests/workflows/__init__.py +0 -0
  222. hpcflow/tests/workflows/test_directory_structure.py +31 -0
  223. hpcflow/tests/workflows/test_jobscript.py +355 -0
  224. hpcflow/tests/workflows/test_run_status.py +198 -0
  225. hpcflow/tests/workflows/test_skip_downstream.py +696 -0
  226. hpcflow/tests/workflows/test_submission.py +140 -0
  227. hpcflow/tests/workflows/test_workflows.py +564 -0
  228. hpcflow/tests/workflows/test_zip.py +18 -0
  229. hpcflow/viz_demo.ipynb +6794 -0
  230. hpcflow-0.2.0a271.dist-info/LICENSE +375 -0
  231. hpcflow-0.2.0a271.dist-info/METADATA +65 -0
  232. hpcflow-0.2.0a271.dist-info/RECORD +237 -0
  233. {hpcflow-0.1.15.dist-info → hpcflow-0.2.0a271.dist-info}/WHEEL +4 -5
  234. hpcflow-0.2.0a271.dist-info/entry_points.txt +6 -0
  235. hpcflow/api.py +0 -490
  236. hpcflow/archive/archive.py +0 -307
  237. hpcflow/archive/cloud/cloud.py +0 -45
  238. hpcflow/archive/cloud/errors.py +0 -9
  239. hpcflow/archive/cloud/providers/dropbox.py +0 -427
  240. hpcflow/archive/errors.py +0 -5
  241. hpcflow/base_db.py +0 -4
  242. hpcflow/config.py +0 -233
  243. hpcflow/copytree.py +0 -66
  244. hpcflow/data/examples/_config.yml +0 -14
  245. hpcflow/data/examples/damask/demo/1.run.yml +0 -4
  246. hpcflow/data/examples/damask/demo/2.process.yml +0 -29
  247. hpcflow/data/examples/damask/demo/geom.geom +0 -2052
  248. hpcflow/data/examples/damask/demo/load.load +0 -1
  249. hpcflow/data/examples/damask/demo/material.config +0 -185
  250. hpcflow/data/examples/damask/inputs/geom.geom +0 -2052
  251. hpcflow/data/examples/damask/inputs/load.load +0 -1
  252. hpcflow/data/examples/damask/inputs/material.config +0 -185
  253. hpcflow/data/examples/damask/profiles/_variable_lookup.yml +0 -21
  254. hpcflow/data/examples/damask/profiles/damask.yml +0 -4
  255. hpcflow/data/examples/damask/profiles/damask_process.yml +0 -8
  256. hpcflow/data/examples/damask/profiles/damask_run.yml +0 -5
  257. hpcflow/data/examples/damask/profiles/default.yml +0 -6
  258. hpcflow/data/examples/thinking.yml +0 -177
  259. hpcflow/errors.py +0 -2
  260. hpcflow/init_db.py +0 -37
  261. hpcflow/models.py +0 -2595
  262. hpcflow/nesting.py +0 -9
  263. hpcflow/profiles.py +0 -455
  264. hpcflow/project.py +0 -81
  265. hpcflow/scheduler.py +0 -322
  266. hpcflow/utils.py +0 -103
  267. hpcflow/validation.py +0 -166
  268. hpcflow/variables.py +0 -543
  269. hpcflow-0.1.15.dist-info/METADATA +0 -168
  270. hpcflow-0.1.15.dist-info/RECORD +0 -45
  271. hpcflow-0.1.15.dist-info/entry_points.txt +0 -8
  272. hpcflow-0.1.15.dist-info/top_level.txt +0 -1
  273. /hpcflow/{archive → data/jinja_templates}/__init__.py +0 -0
  274. /hpcflow/{archive/cloud → data/programs}/__init__.py +0 -0
  275. /hpcflow/{archive/cloud/providers → data/workflows}/__init__.py +0 -0
@@ -0,0 +1,1207 @@
1
+ """
2
+ Miscellaneous utilities.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from collections import Counter
7
+ from asyncio import events
8
+ import contextvars
9
+ import contextlib
10
+ import copy
11
+ import enum
12
+ import functools
13
+ import hashlib
14
+ from itertools import accumulate, islice
15
+ from importlib import resources
16
+ import json
17
+ import keyword
18
+ import os
19
+ from pathlib import Path, PurePath
20
+ import random
21
+ import re
22
+ import socket
23
+ import string
24
+ import subprocess
25
+ from datetime import datetime, timedelta, timezone
26
+ import sys
27
+ import traceback
28
+ from typing import Literal, cast, overload, TypeVar, TYPE_CHECKING
29
+ import fsspec # type: ignore
30
+ import numpy as np
31
+
32
+ from ruamel.yaml import YAML
33
+ from ruamel.yaml.error import MarkedYAMLError
34
+ from watchdog.utils.dirsnapshot import DirectorySnapshot
35
+
36
+ from hpcflow.sdk.core.errors import (
37
+ ContainerKeyError,
38
+ InvalidIdentifier,
39
+ MissingVariableSubstitutionError,
40
+ YAMLError,
41
+ )
42
+ from hpcflow.sdk.log import TimeIt
43
+ from hpcflow.sdk.utils.deferred_file import DeferredFileWriter
44
+
45
+ if TYPE_CHECKING:
46
+ from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence
47
+ from contextlib import AbstractContextManager
48
+ from types import ModuleType
49
+ from typing import Any, IO, Iterator
50
+ from typing_extensions import TypeAlias
51
+ from numpy.typing import NDArray
52
+ from ..typing import PathLike
53
+
54
+ T = TypeVar("T")
55
+ T2 = TypeVar("T2")
56
+ T3 = TypeVar("T3")
57
+ TList: TypeAlias = "T | list[TList]"
58
+ TD = TypeVar("TD", bound="Mapping[str, Any]")
59
+ E = TypeVar("E", bound=enum.Enum)
60
+
61
+
62
+ def make_workflow_id() -> str:
63
+ """
64
+ Generate a random ID for a workflow.
65
+ """
66
+ length = 12
67
+ chars = string.ascii_letters + "0123456789"
68
+ return "".join(random.choices(chars, k=length))
69
+
70
+
71
+ def get_time_stamp() -> str:
72
+ """
73
+ Get the current time in standard string form.
74
+ """
75
+ return datetime.now(timezone.utc).astimezone().strftime("%Y.%m.%d_%H:%M:%S_%z")
76
+
77
+
78
+ def get_duplicate_items(lst: Iterable[T]) -> list[T]:
79
+ """Get a list of all items in an iterable that appear more than once, assuming items
80
+ are hashable.
81
+
82
+ Examples
83
+ --------
84
+ >>> get_duplicate_items([1, 1, 2, 3])
85
+ [1]
86
+
87
+ >>> get_duplicate_items([1, 2, 3])
88
+ []
89
+
90
+ >>> get_duplicate_items([1, 2, 3, 3, 3, 2])
91
+ [2, 3]
92
+
93
+ """
94
+ return [x for x, y in Counter(lst).items() if y > 1]
95
+
96
+
97
+ def check_valid_py_identifier(name: str) -> str:
98
+ """Check a string is (roughly) a valid Python variable identifier and return it.
99
+
100
+ The rules are:
101
+ 1. `name` must not be empty
102
+ 2. `name` must not be a Python keyword
103
+ 3. `name` must begin with an alphabetic character, and all remaining characters
104
+ must be alphanumeric.
105
+
106
+ Notes
107
+ -----
108
+ The following attributes are passed through this function on object initialisation:
109
+ - `ElementGroup.name`
110
+ - `Executable.label`
111
+ - `Parameter.typ`
112
+ - `TaskObjective.name`
113
+ - `TaskSchema.method`
114
+ - `TaskSchema.implementation`
115
+ - `Loop.name`
116
+
117
+ """
118
+ try:
119
+ trial_name = name[1:].replace("_", "") # "internal" underscores are allowed
120
+ except TypeError:
121
+ raise InvalidIdentifier(name) from None
122
+ except KeyError as e:
123
+ raise KeyError(f"unexpected name type {name}") from e
124
+ if (
125
+ not name
126
+ or not (name[0].isalpha() and ((trial_name[1:] or "a").isalnum()))
127
+ or keyword.iskeyword(name)
128
+ ):
129
+ raise InvalidIdentifier(name)
130
+
131
+ return name
132
+
133
+
134
+ @overload
135
+ def group_by_dict_key_values( # type: ignore[overload-overlap]
136
+ lst: list[dict[T, T2]], key: T
137
+ ) -> list[list[dict[T, T2]]]: ...
138
+
139
+
140
+ @overload
141
+ def group_by_dict_key_values(lst: list[TD], key: str) -> list[list[TD]]: ...
142
+
143
+
144
+ def group_by_dict_key_values(lst: list, key):
145
+ """Group a list of dicts according to specified equivalent key-values.
146
+
147
+ Parameters
148
+ ----------
149
+ lst : list of dict
150
+ The list of dicts to group together.
151
+ key : key value
152
+ Dicts that have identical values for all of these keys will be grouped together
153
+ into a sub-list.
154
+
155
+ Returns
156
+ -------
157
+ grouped : list of list of dict
158
+
159
+ Examples
160
+ --------
161
+ >>> group_by_dict_key_values([{'a': 1}, {'a': 2}, {'a': 1}], 'a')
162
+ [[{'a': 1}, {'a': 1}], [{'a': 2}]]
163
+
164
+ """
165
+
166
+ grouped = [[lst[0]]]
167
+ for lst_item in lst[1:]:
168
+ for group_idx, group in enumerate(grouped):
169
+ try:
170
+ is_vals_equal = lst_item[key] == group[0][key]
171
+
172
+ except KeyError:
173
+ # dicts that do not have the `key` will be in their own group:
174
+ is_vals_equal = False
175
+
176
+ if is_vals_equal:
177
+ grouped[group_idx].append(lst_item)
178
+ break
179
+
180
+ if not is_vals_equal:
181
+ grouped.append([lst_item])
182
+
183
+ return grouped
184
+
185
+
186
+ def swap_nested_dict_keys(dct: dict[T, dict[T2, T3]], inner_key: T2):
187
+ """Return a copy where top-level keys have been swapped with a second-level inner key.
188
+
189
+ Examples:
190
+ ---------
191
+ >>> swap_nested_dict_keys(
192
+ dct={
193
+ 'p1': {'format': 'direct', 'all_iterations': True},
194
+ 'p2': {'format': 'json'},
195
+ 'p3': {'format': 'direct'},
196
+ },
197
+ inner_key="format",
198
+ )
199
+ {
200
+ "direct": {"p1": {"all_iterations": True}, "p3": {}},
201
+ "json": {"p2": {}},
202
+ }
203
+
204
+ """
205
+ out: dict[T3, dict[T, dict[T2, T3]]] = {}
206
+ for k, v in copy.deepcopy(dct or {}).items():
207
+ out.setdefault(v.pop(inner_key), {})[k] = v
208
+ return out
209
+
210
+
211
+ def _ensure_int(path_comp: Any, cur_data: Any, cast_indices: bool) -> int:
212
+ """
213
+ Helper for get_in_container() and set_in_container()
214
+ """
215
+ if isinstance(path_comp, int):
216
+ return path_comp
217
+ if not cast_indices:
218
+ raise TypeError(
219
+ f"Path component {path_comp!r} must be an integer index "
220
+ f"since data is a sequence: {cur_data!r}."
221
+ )
222
+ try:
223
+ return int(path_comp)
224
+ except (TypeError, ValueError) as e:
225
+ raise TypeError(
226
+ f"Path component {path_comp!r} must be an integer index "
227
+ f"since data is a sequence: {cur_data!r}."
228
+ ) from e
229
+
230
+
231
+ def get_in_container(
232
+ cont, path: Sequence, cast_indices: bool = False, allow_getattr: bool = False
233
+ ):
234
+ """
235
+ Follow a path (sequence of indices of appropriate type) into a container to obtain
236
+ a "leaf" value. Containers can be lists, tuples, dicts,
237
+ or any class (with `getattr()`) if ``allow_getattr`` is True.
238
+ """
239
+ cur_data = cont
240
+ err_msg = (
241
+ "Data at path {path_comps!r} is not a sequence, but is of type "
242
+ "{cur_data_type!r} and so sub-data cannot be extracted."
243
+ )
244
+ for idx, path_comp in enumerate(path):
245
+ if isinstance(cur_data, (list, tuple)):
246
+ cur_data = cur_data[_ensure_int(path_comp, cur_data, cast_indices)]
247
+ elif isinstance(cur_data, dict) or hasattr(cur_data, "__getitem__"):
248
+ try:
249
+ cur_data = cur_data[path_comp]
250
+ except KeyError:
251
+ raise ContainerKeyError(path=cast("list[str]", path[: idx + 1]))
252
+ elif allow_getattr:
253
+ try:
254
+ cur_data = getattr(cur_data, path_comp)
255
+ except AttributeError:
256
+ raise ValueError(
257
+ err_msg.format(cur_data_type=type(cur_data), path_comps=path[:idx])
258
+ )
259
+ else:
260
+ raise ValueError(
261
+ err_msg.format(cur_data_type=type(cur_data), path_comps=path[:idx])
262
+ )
263
+ return cur_data
264
+
265
+
266
+ def set_in_container(
267
+ cont, path: Sequence, value, ensure_path=False, cast_indices=False
268
+ ) -> None:
269
+ """
270
+ Follow a path (sequence of indices of appropriate type) into a container to update
271
+ a "leaf" value. Containers can be lists, tuples or dicts.
272
+ The "branch" holding the leaf to update must be modifiable.
273
+ """
274
+ if ensure_path:
275
+ num_path = len(path)
276
+ for idx in range(1, num_path):
277
+ try:
278
+ get_in_container(cont, path[:idx], cast_indices=cast_indices)
279
+ except (KeyError, ValueError):
280
+ set_in_container(
281
+ cont=cont,
282
+ path=path[:idx],
283
+ value={},
284
+ ensure_path=False,
285
+ cast_indices=cast_indices,
286
+ )
287
+
288
+ sub_data = get_in_container(cont, path[:-1], cast_indices=cast_indices)
289
+ path_comp = path[-1]
290
+ if isinstance(sub_data, (list, tuple)):
291
+ path_comp = _ensure_int(path_comp, sub_data, cast_indices)
292
+ sub_data[path_comp] = value
293
+
294
+
295
+ def get_relative_path(path1: Sequence[T], path2: Sequence[T]) -> Sequence[T]:
296
+ """Get relative path components between two paths.
297
+
298
+ Parameters
299
+ ----------
300
+ path1 : tuple of (str or int or float) of length N
301
+ path2 : tuple of (str or int or float) of length less than or equal to N
302
+
303
+ Returns
304
+ -------
305
+ relative_path : tuple of (str or int or float)
306
+ The path components in `path1` that are not in `path2`.
307
+
308
+ Raises
309
+ ------
310
+ ValueError
311
+ If the two paths do not share a common ancestor of path components, or if `path2`
312
+ is longer than `path1`.
313
+
314
+ Notes
315
+ -----
316
+ This function behaves like a simplified `PurePath(*path1).relative_to(PurePath(*path2))`
317
+ from the `pathlib` module, but where path components can include non-strings.
318
+
319
+ Examples
320
+ --------
321
+ >>> get_relative_path(('A', 'B', 'C'), ('A',))
322
+ ('B', 'C')
323
+
324
+ >>> get_relative_path(('A', 'B'), ('A', 'B'))
325
+ ()
326
+
327
+ """
328
+
329
+ len_path2 = len(path2)
330
+ if len(path1) < len_path2 or any(i != j for i, j in zip(path1[:len_path2], path2)):
331
+ raise ValueError(f"{path1!r} is not in the subpath of {path2!r}.")
332
+
333
+ return path1[len_path2:]
334
+
335
+
336
+ def search_dir_files_by_regex(
337
+ pattern: str | re.Pattern[str], directory: str | os.PathLike = "."
338
+ ) -> list[str]:
339
+ """Search recursively for files in a directory by a regex pattern and return matching
340
+ file paths, relative to the given directory."""
341
+ dir_ = Path(directory)
342
+ return [
343
+ str(entry.relative_to(dir_))
344
+ for entry in dir_.rglob("*")
345
+ if re.search(pattern, entry.name)
346
+ ]
347
+
348
+
349
+ class PrettyPrinter:
350
+ """
351
+ A class that produces a nice readable version of itself with ``str()``.
352
+ Intended to be subclassed.
353
+ """
354
+
355
+ def __str__(self) -> str:
356
+ lines = [self.__class__.__name__ + ":"]
357
+ for key, val in vars(self).items():
358
+ lines.extend(f"{key}: {val}".split("\n"))
359
+ return "\n ".join(lines)
360
+
361
+
362
+ _STRING_VARS_RE = re.compile(r"\<\<var:(.*?)(?:\[(.*)\])?\>\>")
363
+
364
+
365
+ @TimeIt.decorator
366
+ def substitute_string_vars(string: str, variables: dict[str, str]):
367
+ """
368
+ Scan ``string`` and substitute sequences like ``<<var:ABC>>`` with the value
369
+ looked up in the supplied dictionary (with ``ABC`` as the key).
370
+
371
+ Default values for the substitution can be supplied like:
372
+ ``<<var:ABC[default=XYZ]>>``
373
+
374
+ Examples
375
+ --------
376
+ >>> substitute_string_vars("abc <<var:def>> ghi", {"def": "123"})
377
+ "abc 123 def"
378
+ """
379
+
380
+ def var_repl(match_obj: re.Match[str]) -> str:
381
+ kwargs: dict[str, str] = {}
382
+ var_name: str = match_obj[1]
383
+ kwargs_str: str | None = match_obj[2]
384
+ if kwargs_str:
385
+ for i in kwargs_str.split(","):
386
+ k, v = i.split("=")
387
+ kwargs[k.strip()] = v.strip()
388
+ try:
389
+ out = str(variables[var_name])
390
+ except KeyError:
391
+ if "default" in kwargs:
392
+ out = kwargs["default"]
393
+ print(
394
+ f"Using default value ({out!r}) for workflow template string "
395
+ f"variable {var_name!r}."
396
+ )
397
+ else:
398
+ raise MissingVariableSubstitutionError(var_name, variables)
399
+ return out
400
+
401
+ return _STRING_VARS_RE.sub(
402
+ repl=var_repl,
403
+ string=string,
404
+ )
405
+
406
+
407
+ @TimeIt.decorator
408
+ def read_YAML_str(
409
+ yaml_str: str,
410
+ typ="safe",
411
+ variables: dict[str, str] | Literal[False] | None = None,
412
+ source: str | None = None,
413
+ ) -> Any:
414
+ """Load a YAML string. This will produce basic objects.
415
+
416
+ Parameters
417
+ ----------
418
+ yaml_str:
419
+ The YAML string to parse.
420
+ typ:
421
+ Load type passed to the YAML library.
422
+ variables:
423
+ String variables to substitute in `yaml_str`. Substitutions will be attempted if
424
+ the file looks to contain variable references (like "<<var:name>>"). If set to
425
+ `False`, no substitutions will occur.
426
+ source:
427
+ Used to document the source of the YAML string if raising a parsing error.
428
+ Typically, this should be a string that starts with "from ...", e.g.
429
+ "from the file path '/path/to/bad/file'".
430
+ """
431
+ if variables is not False and "<<var:" in yaml_str:
432
+ yaml_str = substitute_string_vars(yaml_str, variables=variables or {})
433
+ yaml = YAML(typ=typ)
434
+ try:
435
+ return yaml.load(yaml_str)
436
+ except MarkedYAMLError as err: # includes `ScannerError` and `ParserError`
437
+ source_str = f"{source} " if source else ""
438
+ raise YAMLError(
439
+ f"The YAML string {source_str}is not formatted correctly."
440
+ ) from err
441
+
442
+
443
+ @TimeIt.decorator
444
+ def read_YAML_file(
445
+ path: PathLike, typ="safe", variables: dict[str, str] | Literal[False] | None = None
446
+ ) -> Any:
447
+ """Load a YAML file. This will produce basic objects.
448
+
449
+ Parameters
450
+ ----------
451
+ path:
452
+ Path to the YAML file to parse.
453
+ typ:
454
+ Load type passed to the YAML library.
455
+ variables:
456
+ String variables to substitute in the file given by `path`. Substitutions will be
457
+ attempted if the file looks to contain variable references (like "<<var:name>>").
458
+ If set to `False`, no substitutions will occur.
459
+ """
460
+ with fsspec.open(path, "rt") as f:
461
+ yaml_str: str = f.read()
462
+ return read_YAML_str(yaml_str, typ=typ, variables=variables, source=f"from {path!r}")
463
+
464
+
465
+ def write_YAML_file(obj, path: str | Path, typ: str = "safe") -> None:
466
+ """Write a basic object to a YAML file."""
467
+ yaml = YAML(typ=typ)
468
+ with Path(path).open("wt") as fp:
469
+ yaml.dump(obj, fp)
470
+
471
+
472
+ def read_JSON_string(
473
+ json_str: str, variables: dict[str, str] | Literal[False] | None = None
474
+ ) -> Any:
475
+ """Load a JSON string. This will produce basic objects.
476
+
477
+ Parameters
478
+ ----------
479
+ json_str:
480
+ The JSON string to parse.
481
+ variables:
482
+ String variables to substitute in `json_str`. Substitutions will be attempted if
483
+ the file looks to contain variable references (like "<<var:name>>"). If set to
484
+ `False`, no substitutions will occur.
485
+ """
486
+ if variables is not False and "<<var:" in json_str:
487
+ json_str = substitute_string_vars(json_str, variables=variables or {})
488
+ return json.loads(json_str)
489
+
490
+
491
+ def read_JSON_file(path, variables: dict[str, str] | Literal[False] | None = None) -> Any:
492
+ """Load a JSON file. This will produce basic objects.
493
+
494
+ Parameters
495
+ ----------
496
+ path:
497
+ Path to the JSON file to parse.
498
+ variables:
499
+ String variables to substitute in the file given by `path`. Substitutions will be
500
+ attempted if the file looks to contain variable references (like "<<var:name>>").
501
+ If set to `False`, no substitutions will occur.
502
+ """
503
+ with fsspec.open(path, "rt") as f:
504
+ json_str: str = f.read()
505
+ return read_JSON_string(json_str, variables=variables)
506
+
507
+
508
+ def write_JSON_file(obj, path: str | Path) -> None:
509
+ """Write a basic object to a JSON file."""
510
+ with Path(path).open("wt") as fp:
511
+ json.dump(obj, fp)
512
+
513
+
514
+ def get_item_repeat_index(
515
+ lst: Sequence[T],
516
+ *,
517
+ distinguish_singular: bool = False,
518
+ item_callable: Callable[[T], Hashable] | None = None,
519
+ ):
520
+ """Get the repeat index for each item in a list.
521
+
522
+ Parameters
523
+ ----------
524
+ lst : list
525
+ Must contain hashable items, or hashable objects that are returned via `callable`
526
+ called on each item.
527
+ distinguish_singular : bool
528
+ If True, items that are not repeated will have a repeat index of 0, and items that
529
+ are repeated will have repeat indices starting from 1.
530
+ item_callable : callable
531
+ If specified, comparisons are made on the output of this callable on each item.
532
+
533
+ Returns
534
+ -------
535
+ repeat_idx : list of int
536
+ Repeat indices of each item (see `distinguish_singular` for details).
537
+
538
+ """
539
+
540
+ idx: dict[Any, list[int]] = {}
541
+ if item_callable:
542
+ for i_idx, item in enumerate(lst):
543
+ idx.setdefault(item_callable(item), []).append(i_idx)
544
+ else:
545
+ for i_idx, item in enumerate(lst):
546
+ idx.setdefault(item, []).append(i_idx)
547
+
548
+ rep_idx = [0] * len(lst)
549
+ for v in idx.values():
550
+ start = len(v) > 1 if distinguish_singular else 0
551
+ for i_idx, i in enumerate(v, start):
552
+ rep_idx[i] = i_idx
553
+
554
+ return rep_idx
555
+
556
+
557
+ def get_process_stamp() -> str:
558
+ """
559
+ Return a globally unique string identifying this process.
560
+
561
+ Note
562
+ ----
563
+ This should only be called once per process.
564
+ """
565
+ return "{} {} {}".format(
566
+ datetime.now(),
567
+ socket.gethostname(),
568
+ os.getpid(),
569
+ )
570
+
571
+
572
+ _ANSI_ESCAPE_RE = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
573
+
574
+
575
+ def remove_ansi_escape_sequences(string: str) -> str:
576
+ """
577
+ Strip ANSI terminal escape codes from a string.
578
+ """
579
+ return _ANSI_ESCAPE_RE.sub("", string)
580
+
581
+
582
+ def get_md5_hash(obj) -> str:
583
+ """
584
+ Compute the MD5 hash of an object.
585
+ This is the hash of the JSON of the object (with sorted keys) as a hex string.
586
+ """
587
+ json_str = json.dumps(obj, sort_keys=True)
588
+ return hashlib.md5(json_str.encode("utf-8")).hexdigest()
589
+
590
+
591
+ def get_nested_indices(
592
+ idx: int, size: int, nest_levels: int, raise_on_rollover: bool = False
593
+ ) -> list[int]:
594
+ """Generate the set of nested indices of length `n` that correspond to a global
595
+ `idx`.
596
+
597
+ Examples
598
+ --------
599
+ >>> for i in range(4**2): print(get_nest_index(i, nest_levels=2, size=4))
600
+ [0, 0]
601
+ [0, 1]
602
+ [0, 2]
603
+ [0, 3]
604
+ [1, 0]
605
+ [1, 1]
606
+ [1, 2]
607
+ [1, 3]
608
+ [2, 0]
609
+ [2, 1]
610
+ [2, 2]
611
+ [2, 3]
612
+ [3, 0]
613
+ [3, 1]
614
+ [3, 2]
615
+ [3, 3]
616
+
617
+ >>> for i in range(4**3): print(get_nested_indices(i, nest_levels=3, size=4))
618
+ [0, 0, 0]
619
+ [0, 0, 1]
620
+ [0, 0, 2]
621
+ [0, 0, 3]
622
+ [0, 1, 0]
623
+ ...
624
+ [3, 2, 3]
625
+ [3, 3, 0]
626
+ [3, 3, 1]
627
+ [3, 3, 2]
628
+ [3, 3, 3]
629
+ """
630
+ if raise_on_rollover and idx >= size**nest_levels:
631
+ raise ValueError(
632
+ f"`idx` ({idx}) is greater than or equal to size**nest_levels` "
633
+ f"({size**nest_levels})."
634
+ )
635
+
636
+ return [(idx // (size ** (nest_levels - (i + 1)))) % size for i in range(nest_levels)]
637
+
638
+
639
+ def ensure_in(item: T, lst: list[T]) -> int:
640
+ """Get the index of an item in a list and append the item if it is not in the
641
+ list."""
642
+ # TODO: add tests
643
+ try:
644
+ return lst.index(item)
645
+ except ValueError:
646
+ lst.append(item)
647
+ return len(lst) - 1
648
+
649
+
650
+ def list_to_dict(
651
+ lst: Sequence[Mapping[T, T2]], exclude: Iterable[T] | None = None
652
+ ) -> dict[T, list[T2]]:
653
+ """
654
+ Convert a list of dicts to a dict of lists.
655
+ """
656
+ # TODO: test
657
+ exc = frozenset(exclude or ())
658
+ dct: dict[T, list[T2]] = {k: [] for k in lst[0] if k not in exc}
659
+ for d in lst:
660
+ for k, v in d.items():
661
+ if k not in exc:
662
+ dct[k].append(v)
663
+ return dct
664
+
665
+
666
+ def bisect_slice(selection: slice, len_A: int) -> tuple[slice, slice]:
667
+ """Given two sequences (the first of which of known length), get the two slices that
668
+ are equivalent to a given slice if the two sequences were combined."""
669
+
670
+ if selection.start < 0 or selection.stop < 0 or selection.step < 0:
671
+ raise NotImplementedError("Can't do negative slices yet.")
672
+
673
+ A_idx = selection.indices(len_A)
674
+ B_start = selection.start - len_A
675
+ if len_A != 0 and B_start < 0:
676
+ B_start = B_start % selection.step
677
+ if len_A > selection.stop:
678
+ B_stop = B_start
679
+ else:
680
+ B_stop = selection.stop - len_A
681
+
682
+ return slice(*A_idx), slice(B_start, B_stop, selection.step)
683
+
684
+
685
+ def replace_items(lst: list[T], start: int, end: int, repl: list[T]) -> list[T]:
686
+ """Replaced a range of items in a list with items in another list."""
687
+ # Convert to actual indices for our safety checks; handles end-relative addressing
688
+ real_start, real_end, _ = slice(start, end).indices(len(lst))
689
+ if real_end <= real_start:
690
+ raise ValueError(
691
+ f"`end` ({end}) must be greater than or equal to `start` ({start})."
692
+ )
693
+ if real_start >= len(lst):
694
+ raise ValueError(f"`start` ({start}) must be less than length ({len(lst)}).")
695
+ if real_end > len(lst):
696
+ raise ValueError(
697
+ f"`end` ({end}) must be less than or equal to length ({len(lst)})."
698
+ )
699
+
700
+ lst = list(lst)
701
+ lst[start:end] = repl
702
+ return lst
703
+
704
+
705
+ def flatten(
706
+ lst: list[int] | list[list[int]] | list[list[list[int]]],
707
+ ) -> tuple[list[int], tuple[list[int], ...]]:
708
+ """Flatten an arbitrarily (but of uniform depth) nested list and return shape
709
+ information to enable un-flattening.
710
+
711
+ Un-flattening can be performed with the :py:func:`reshape` function.
712
+
713
+ lst
714
+ List to be flattened. Each element must contain all lists or otherwise all items
715
+ that are considered to be at the "bottom" of the nested structure (e.g. integers).
716
+ For example, `[[1, 2], [3]]` is permitted and flattens to `[1, 2, 3]`, but
717
+ `[[1, 2], 3]` is not permitted because the first element is a list, but the second
718
+ is not.
719
+
720
+ """
721
+
722
+ def _flatten(
723
+ lst: list[int] | list[list[int]] | list[list[list[int]]], depth=0
724
+ ) -> list[int]:
725
+ out: list[int] = []
726
+ for item in lst:
727
+ if isinstance(item, list):
728
+ out.extend(_flatten(item, depth + 1))
729
+ all_lens[depth].append(len(item))
730
+ else:
731
+ out.append(item)
732
+ return out
733
+
734
+ def _get_max_depth(lst: list[int] | list[list[int]] | list[list[list[int]]]) -> int:
735
+ val: Any = lst
736
+ max_depth = 0
737
+ while isinstance(val, list):
738
+ max_depth += 1
739
+ try:
740
+ val = val[0]
741
+ except IndexError:
742
+ # empty list, assume this is max depth
743
+ break
744
+ return max_depth
745
+
746
+ max_depth = _get_max_depth(lst) - 1
747
+ all_lens: tuple[list[int], ...] = tuple([] for _ in range(max_depth))
748
+
749
+ return _flatten(lst), all_lens
750
+
751
+
752
+ def reshape(lst: Sequence[T], lens: Sequence[Sequence[int]]) -> list[TList[T]]:
753
+ """
754
+ Reverse the destructuring of the :py:func:`flatten` function.
755
+ """
756
+
757
+ def _reshape(lst: list[T2], lens: Sequence[int]) -> list[list[T2]]:
758
+ lens_acc = [0, *accumulate(lens)]
759
+ return [lst[lens_acc[idx] : lens_acc[idx + 1]] for idx in range(len(lens))]
760
+
761
+ result: list[TList[T]] = list(lst)
762
+ for lens_i in lens[::-1]:
763
+ result = cast("list[TList[T]]", _reshape(result, lens_i))
764
+
765
+ return result
766
+
767
+
768
+ @overload
769
+ def remap(
770
+ lst: list[int], mapping_func: Callable[[Sequence[int]], Sequence[T]]
771
+ ) -> list[T]: ...
772
+
773
+
774
+ @overload
775
+ def remap(
776
+ lst: list[list[int]], mapping_func: Callable[[Sequence[int]], Sequence[T]]
777
+ ) -> list[list[T]]: ...
778
+
779
+
780
+ @overload
781
+ def remap(
782
+ lst: list[list[list[int]]], mapping_func: Callable[[Sequence[int]], Sequence[T]]
783
+ ) -> list[list[list[T]]]: ...
784
+
785
+
786
+ def remap(lst, mapping_func):
787
+ """
788
+ Apply a mapping to a structure of lists with ints (typically indices) as leaves to
789
+ get a structure of lists with some objects as leaves.
790
+
791
+ Parameters
792
+ ----------
793
+ lst: list[int] | list[list[int]] | list[list[list[int]]]
794
+ The structure to remap.
795
+ mapping_func: Callable[[Sequence[int]], Sequence[T]]
796
+ The mapping function from sequences of ints to sequences of objects.
797
+
798
+ Returns
799
+ -------
800
+ list[T] | list[list[T]] | list[list[list[T]]]
801
+ Nested list structure in same form as input, with leaves remapped.
802
+ """
803
+ x, y = flatten(lst)
804
+ return reshape(mapping_func(x), y)
805
+
806
+
807
+ _FSSPEC_URL_RE = re.compile(r"(?:[a-z0-9]+:{1,2})+\/\/")
808
+
809
+
810
+ def is_fsspec_url(url: str) -> bool:
811
+ """
812
+ Test if a URL appears to be one that can be understood by fsspec.
813
+ """
814
+ return bool(_FSSPEC_URL_RE.match(url))
815
+
816
+
817
+ class JSONLikeDirSnapShot(DirectorySnapshot):
818
+ """
819
+ Overridden DirectorySnapshot from watchdog to allow saving and loading from JSON.
820
+
821
+ Parameters
822
+ ----------
823
+ root_path: str
824
+ Where to take the snapshot based at.
825
+ data: dict[str, list]
826
+ Serialised snapshot to reload from.
827
+ See :py:meth:`to_json_like`.
828
+ """
829
+
830
+ def __init__(
831
+ self,
832
+ root_path: str | None = None,
833
+ data: dict[str, list] | None = None,
834
+ use_strings: bool = False,
835
+ ):
836
+ """
837
+ Create an empty snapshot or load from JSON-like data.
838
+ """
839
+
840
+ #: Where to take the snapshot based at.
841
+ self.root_path = root_path
842
+ self._stat_info: dict[bytes | str, os.stat_result] = {}
843
+ self._inode_to_path: dict[tuple[int, int], bytes | str] = {}
844
+
845
+ if data:
846
+ assert root_path
847
+ for name, item in data.items():
848
+ # add root path
849
+ full_name = str(PurePath(root_path) / PurePath(name))
850
+ item = [int(i) for i in item] if use_strings else item
851
+ stat_dat, inode_key = item[:-2], item[-2:]
852
+ self._stat_info[full_name] = os.stat_result(stat_dat)
853
+ self._inode_to_path[tuple(inode_key)] = full_name
854
+
855
+ def take(self, *args, **kwargs) -> None:
856
+ """Take the snapshot."""
857
+ super().__init__(*args, **kwargs)
858
+
859
+ def to_json_like(self, use_strings: bool = False) -> dict[str, Any]:
860
+ """Export to a dict that is JSON-compatible and can be later reloaded.
861
+
862
+ The last two integers in `data` for each path are the keys in
863
+ `self._inode_to_path`.
864
+
865
+ """
866
+ # first key is the root path:
867
+ root_path = next(iter(self._stat_info))
868
+
869
+ # store efficiently:
870
+ inode_invert = {v: k for k, v in self._inode_to_path.items()}
871
+ data: dict[str, list] = {
872
+ str(PurePath(cast("str", k)).relative_to(cast("str", root_path))): [
873
+ str(i) if use_strings else i for i in [*v, *inode_invert[k]]
874
+ ]
875
+ for k, v in self._stat_info.items()
876
+ }
877
+
878
+ return {
879
+ "root_path": root_path,
880
+ "data": data,
881
+ "use_strings": use_strings,
882
+ }
883
+
884
+
885
+ def open_file(filename: str | Path):
886
+ """Open a file or directory using the default system application."""
887
+ if sys.platform == "win32":
888
+ os.startfile(filename)
889
+ else:
890
+ opener = "open" if sys.platform == "darwin" else "xdg-open"
891
+ subprocess.call([opener, filename])
892
+
893
+
894
+ @overload
895
+ def get_enum_by_name_or_val(enum_cls: type[E], key: None) -> None: ...
896
+
897
+
898
+ @overload
899
+ def get_enum_by_name_or_val(enum_cls: type[E], key: str | int | float | E) -> E: ...
900
+
901
+
902
+ def get_enum_by_name_or_val(
903
+ enum_cls: type[E], key: str | int | float | E | None
904
+ ) -> E | None:
905
+ """Retrieve an enum by name or value, assuming uppercase names and integer values."""
906
+ if key is None or isinstance(key, enum_cls):
907
+ return key
908
+ elif isinstance(key, (int, float)):
909
+ return enum_cls(int(key)) # retrieve by value
910
+ elif isinstance(key, str):
911
+ try:
912
+ return cast("E", getattr(enum_cls, key.upper())) # retrieve by name
913
+ except AttributeError:
914
+ pass
915
+ raise ValueError(f"Unknown enum key or value {key!r} for class {enum_cls!r}")
916
+
917
+
918
+ _PARAM_SPLIT_RE = re.compile(r"^((?:\w+\.)*)(\w+)(?:\[(\w+)\])?((?:\.\w+)*)$")
919
+
920
+
921
+ def split_param_label(
922
+ param_path: str,
923
+ ) -> tuple[str, str] | tuple[str, None] | tuple[None, None]:
924
+ """Split a parameter path into the path and the label, if present."""
925
+ m = _PARAM_SPLIT_RE.match(param_path)
926
+ if not m:
927
+ return (None, None)
928
+
929
+ clean_path = m.group(1) + m.group(2) + m.group(4)
930
+ bracket_value = m.group(3)
931
+ return (clean_path, bracket_value)
932
+
933
+
934
+ def process_string_nodes(data: T, str_processor: Callable[[str], str]) -> T:
935
+ """Walk through a nested data structure and process string nodes using a provided
936
+ callable."""
937
+
938
+ if isinstance(data, dict):
939
+ return cast(
940
+ "T", {k: process_string_nodes(v, str_processor) for k, v in data.items()}
941
+ )
942
+
943
+ elif isinstance(data, (list, tuple, set, frozenset)):
944
+ _data = (process_string_nodes(i, str_processor) for i in data)
945
+ if isinstance(data, tuple):
946
+ return cast("T", tuple(_data))
947
+ elif isinstance(data, set):
948
+ return cast("T", set(_data))
949
+ elif isinstance(data, frozenset):
950
+ return cast("T", frozenset(_data))
951
+ else:
952
+ return cast("T", list(_data))
953
+
954
+ elif isinstance(data, str):
955
+ return cast("T", str_processor(data))
956
+
957
+ return data
958
+
959
+
960
+ def linspace_rect(
961
+ start: Sequence[float],
962
+ stop: Sequence[float],
963
+ num: Sequence[int],
964
+ include: Sequence[str] | None = None,
965
+ **kwargs,
966
+ ) -> NDArray:
967
+ """Generate a linear space around a rectangle.
968
+
969
+ Parameters
970
+ ----------
971
+ start
972
+ Two start values; one for each dimension of the rectangle.
973
+ stop
974
+ Two stop values; one for each dimension of the rectangle.
975
+ num
976
+ Two number values; one for each dimension of the rectangle.
977
+ include
978
+ If specified, include only the specified edges. Choose from "top", "right",
979
+ "bottom", "left".
980
+
981
+ Returns
982
+ -------
983
+ rect
984
+ Coordinates of the rectangle perimeter.
985
+
986
+ """
987
+
988
+ if num[0] <= 1 or num[1] <= 1:
989
+ raise ValueError("Both values in `num` must be greater than 1.")
990
+
991
+ inc = set(include) if include else {"top", "right", "bottom", "left"}
992
+
993
+ c0_range = np.linspace(start=start[0], stop=stop[0], num=num[0], **kwargs)
994
+ c1_range_all = np.linspace(start=start[1], stop=stop[1], num=num[1], **kwargs)
995
+
996
+ c1_range = c1_range_all
997
+ if "bottom" in inc:
998
+ c1_range = c1_range[1:]
999
+ if "top" in inc:
1000
+ c1_range = c1_range[:-1]
1001
+
1002
+ c0_range_c1_start = np.vstack([c0_range, np.repeat(start[1], num[0])])
1003
+ c0_range_c1_stop = np.vstack([c0_range, np.repeat(c1_range_all[-1], num[0])])
1004
+
1005
+ c1_range_c0_start = np.vstack([np.repeat(start[0], len(c1_range)), c1_range])
1006
+ c1_range_c0_stop = np.vstack([np.repeat(c0_range[-1], len(c1_range)), c1_range])
1007
+
1008
+ stacked = []
1009
+ if "top" in inc:
1010
+ stacked.append(c0_range_c1_stop)
1011
+ if "right" in inc:
1012
+ stacked.append(c1_range_c0_stop)
1013
+ if "bottom" in inc:
1014
+ stacked.append(c0_range_c1_start)
1015
+ if "left" in inc:
1016
+ stacked.append(c1_range_c0_start)
1017
+
1018
+ return np.hstack(stacked)
1019
+
1020
+
1021
+ def dict_values_process_flat(
1022
+ d: Mapping[T, T2 | list[T2]], callable: Callable[[list[T2]], list[T3]]
1023
+ ) -> Mapping[T, T3 | list[T3]]:
1024
+ """
1025
+ Return a copy of a dict, where the values are processed by a callable that is to
1026
+ be called only once, and where the values may be single items or lists of items.
1027
+
1028
+ Examples
1029
+ --------
1030
+ d = {'a': 0, 'b': [1, 2], 'c': 5}
1031
+ >>> dict_values_process_flat(d, callable=lambda x: [i + 1 for i in x])
1032
+ {'a': 1, 'b': [2, 3], 'c': 6}
1033
+
1034
+ """
1035
+ flat: list[T2] = [] # values of `d`, flattened
1036
+ is_multi: list[tuple[bool, int]] = (
1037
+ []
1038
+ ) # whether a list, and the number of items to process
1039
+ for i in d.values():
1040
+ if isinstance(i, list):
1041
+ flat.extend(cast("list[T2]", i))
1042
+ is_multi.append((True, len(i)))
1043
+ else:
1044
+ flat.append(cast("T2", i))
1045
+ is_multi.append((False, 1))
1046
+
1047
+ processed = callable(flat)
1048
+
1049
+ out: dict[T, T3 | list[T3]] = {}
1050
+ for idx_i, (m, k) in enumerate(zip(is_multi, d)):
1051
+ start_idx = sum(i[1] for i in is_multi[:idx_i])
1052
+ end_idx = start_idx + m[1]
1053
+ proc_idx_k = processed[start_idx:end_idx]
1054
+ if not m[0]:
1055
+ out[k] = proc_idx_k[0]
1056
+ else:
1057
+ out[k] = proc_idx_k
1058
+
1059
+ return out
1060
+
1061
+
1062
+ def nth_key(dct: Iterable[T], n: int) -> T:
1063
+ """
1064
+ Given a dict in some order, get the n'th key of that dict.
1065
+ """
1066
+ it = iter(dct)
1067
+ next(islice(it, n, n), None)
1068
+ return next(it)
1069
+
1070
+
1071
+ def nth_value(dct: dict[Any, T], n: int) -> T:
1072
+ """
1073
+ Given a dict in some order, get the n'th value of that dict.
1074
+ """
1075
+ return dct[nth_key(dct, n)]
1076
+
1077
+
1078
+ def normalise_timestamp(timestamp: datetime) -> datetime:
1079
+ """
1080
+ Force a timestamp to have UTC as its timezone,
1081
+ then convert to use the local timezone.
1082
+ """
1083
+ return timestamp.replace(tzinfo=timezone.utc).astimezone()
1084
+
1085
+
1086
+ def parse_timestamp(timestamp: str | datetime, ts_fmt: str) -> datetime:
1087
+ """
1088
+ Standard timestamp parsing.
1089
+ Ensures that timestamps are internally all UTC.
1090
+ """
1091
+ return normalise_timestamp(
1092
+ timestamp
1093
+ if isinstance(timestamp, datetime)
1094
+ else datetime.strptime(timestamp, ts_fmt)
1095
+ )
1096
+
1097
+
1098
+ def current_timestamp() -> datetime:
1099
+ """
1100
+ Get a UTC timestamp for the current time
1101
+ """
1102
+ return datetime.now(timezone.utc)
1103
+
1104
+
1105
+ def timedelta_format(td: timedelta) -> str:
1106
+ """
1107
+ Convert time delta to string in standard form.
1108
+ """
1109
+ days, seconds = td.days, td.seconds
1110
+ hours = seconds // (60 * 60)
1111
+ seconds -= hours * (60 * 60)
1112
+ minutes = seconds // 60
1113
+ seconds -= minutes * 60
1114
+ return f"{days}-{hours:02}:{minutes:02}:{seconds:02}"
1115
+
1116
+
1117
+ _TD_RE = re.compile(r"(\d+)-(\d+):(\d+):(\d+)")
1118
+
1119
+
1120
+ def timedelta_parse(td_str: str) -> timedelta:
1121
+ """
1122
+ Parse a string in standard form as a time delta.
1123
+ """
1124
+ if not (m := _TD_RE.fullmatch(td_str)):
1125
+ raise ValueError("not a supported timedelta form")
1126
+ days, hours, mins, secs = map(int, m.groups())
1127
+ return timedelta(days=days, hours=hours, minutes=mins, seconds=secs)
1128
+
1129
+
1130
+ def open_text_resource(package: ModuleType | str, resource: str) -> IO[str]:
1131
+ """
1132
+ Open a file in a package.
1133
+ """
1134
+ try:
1135
+ return resources.files(package).joinpath(resource).open("r")
1136
+ except AttributeError:
1137
+ # < python 3.9; `resource.open_text` deprecated since 3.11
1138
+ return resources.open_text(package, resource)
1139
+
1140
+
1141
+ def get_file_context(
1142
+ package: ModuleType | str, src: str | None = None
1143
+ ) -> AbstractContextManager[Path]:
1144
+ """
1145
+ Find a file or directory in a package.
1146
+ """
1147
+ try:
1148
+ files = resources.files(package)
1149
+ return resources.as_file(files.joinpath(src) if src else files)
1150
+ # raises ModuleNotFoundError
1151
+ except AttributeError:
1152
+ # < python 3.9
1153
+ return resources.path(package, src or "")
1154
+
1155
+
1156
+ @contextlib.contextmanager
1157
+ def redirect_std_to_file(
1158
+ file,
1159
+ mode: Literal["w", "a"] = "a",
1160
+ ignore: Callable[[BaseException], Literal[True] | int] | None = None,
1161
+ ) -> Iterator[None]:
1162
+ """Temporarily redirect both stdout and stderr to a file, and if an exception is
1163
+ raised, catch it, print the traceback to that file, and exit.
1164
+
1165
+ File creation is deferred until an actual write is required.
1166
+
1167
+ Parameters
1168
+ ----------
1169
+ ignore
1170
+ Callable to test if a given exception should be ignored. If an exception is
1171
+ not ignored, its traceback will be printed to `file` and the program will
1172
+ exit with exit code 1. The callable should accept one parameter, the
1173
+ exception, and should return True if that exception should be ignored, or
1174
+ an integer representing the exit code to exit the program with if that
1175
+ exception should not be ignored. By default, no exceptions are ignored.
1176
+
1177
+ """
1178
+ ignore = ignore or (lambda _: 1)
1179
+ with DeferredFileWriter(file, mode=mode) as fp:
1180
+ with contextlib.redirect_stdout(fp):
1181
+ with contextlib.redirect_stderr(fp):
1182
+ try:
1183
+ yield
1184
+ except BaseException as exc:
1185
+ ignore_ret = ignore(exc)
1186
+ if ignore_ret is not True:
1187
+ traceback.print_exc()
1188
+ sys.exit(ignore_ret)
1189
+
1190
+
1191
+ async def to_thread(func, /, *args, **kwargs):
1192
+ """Copied from https://github.com/python/cpython/blob/4b4227b907a262446b9d276c274feda2590a4e6e/Lib/asyncio/threads.py
1193
+ to support Python 3.8, which does not have `asyncio.to_thread`.
1194
+
1195
+ Asynchronously run function *func* in a separate thread.
1196
+
1197
+ Any *args and **kwargs supplied for this function are directly passed
1198
+ to *func*. Also, the current :class:`contextvars.Context` is propagated,
1199
+ allowing context variables from the main thread to be accessed in the
1200
+ separate thread.
1201
+
1202
+ Return a coroutine that can be awaited to get the eventual result of *func*.
1203
+ """
1204
+ loop = events.get_running_loop()
1205
+ ctx = contextvars.copy_context()
1206
+ func_call = functools.partial(ctx.run, func, *args, **kwargs)
1207
+ return await loop.run_in_executor(None, func_call)