hpcflow 0.1.15__py3-none-any.whl → 0.2.0a271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (275) hide show
  1. hpcflow/__init__.py +2 -11
  2. hpcflow/__pyinstaller/__init__.py +5 -0
  3. hpcflow/__pyinstaller/hook-hpcflow.py +40 -0
  4. hpcflow/_version.py +1 -1
  5. hpcflow/app.py +43 -0
  6. hpcflow/cli.py +2 -461
  7. hpcflow/data/demo_data_manifest/__init__.py +3 -0
  8. hpcflow/data/demo_data_manifest/demo_data_manifest.json +6 -0
  9. hpcflow/data/jinja_templates/test/test_template.txt +8 -0
  10. hpcflow/data/programs/hello_world/README.md +1 -0
  11. hpcflow/data/programs/hello_world/hello_world.c +87 -0
  12. hpcflow/data/programs/hello_world/linux/hello_world +0 -0
  13. hpcflow/data/programs/hello_world/macos/hello_world +0 -0
  14. hpcflow/data/programs/hello_world/win/hello_world.exe +0 -0
  15. hpcflow/data/scripts/__init__.py +1 -0
  16. hpcflow/data/scripts/bad_script.py +2 -0
  17. hpcflow/data/scripts/demo_task_1_generate_t1_infile_1.py +8 -0
  18. hpcflow/data/scripts/demo_task_1_generate_t1_infile_2.py +8 -0
  19. hpcflow/data/scripts/demo_task_1_parse_p3.py +7 -0
  20. hpcflow/data/scripts/do_nothing.py +2 -0
  21. hpcflow/data/scripts/env_specifier_test/input_file_generator_pass_env_spec.py +4 -0
  22. hpcflow/data/scripts/env_specifier_test/main_script_test_pass_env_spec.py +8 -0
  23. hpcflow/data/scripts/env_specifier_test/output_file_parser_pass_env_spec.py +4 -0
  24. hpcflow/data/scripts/env_specifier_test/v1/input_file_generator_basic.py +4 -0
  25. hpcflow/data/scripts/env_specifier_test/v1/main_script_test_direct_in_direct_out.py +7 -0
  26. hpcflow/data/scripts/env_specifier_test/v1/output_file_parser_basic.py +4 -0
  27. hpcflow/data/scripts/env_specifier_test/v2/main_script_test_direct_in_direct_out.py +7 -0
  28. hpcflow/data/scripts/generate_t1_file_01.py +7 -0
  29. hpcflow/data/scripts/import_future_script.py +7 -0
  30. hpcflow/data/scripts/input_file_generator_basic.py +3 -0
  31. hpcflow/data/scripts/input_file_generator_basic_FAIL.py +3 -0
  32. hpcflow/data/scripts/input_file_generator_test_stdout_stderr.py +8 -0
  33. hpcflow/data/scripts/main_script_test_direct_in.py +3 -0
  34. hpcflow/data/scripts/main_script_test_direct_in_direct_out.py +6 -0
  35. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2.py +6 -0
  36. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed.py +6 -0
  37. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed_group.py +7 -0
  38. hpcflow/data/scripts/main_script_test_direct_in_direct_out_3.py +6 -0
  39. hpcflow/data/scripts/main_script_test_direct_in_direct_out_all_iters_test.py +15 -0
  40. hpcflow/data/scripts/main_script_test_direct_in_direct_out_env_spec.py +7 -0
  41. hpcflow/data/scripts/main_script_test_direct_in_direct_out_labels.py +8 -0
  42. hpcflow/data/scripts/main_script_test_direct_in_group_direct_out_3.py +6 -0
  43. hpcflow/data/scripts/main_script_test_direct_in_group_one_fail_direct_out_3.py +6 -0
  44. hpcflow/data/scripts/main_script_test_direct_sub_param_in_direct_out.py +6 -0
  45. hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +12 -0
  46. hpcflow/data/scripts/main_script_test_hdf5_in_obj_2.py +12 -0
  47. hpcflow/data/scripts/main_script_test_hdf5_in_obj_group.py +12 -0
  48. hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +11 -0
  49. hpcflow/data/scripts/main_script_test_json_and_direct_in_json_out.py +14 -0
  50. hpcflow/data/scripts/main_script_test_json_in_json_and_direct_out.py +17 -0
  51. hpcflow/data/scripts/main_script_test_json_in_json_out.py +14 -0
  52. hpcflow/data/scripts/main_script_test_json_in_json_out_labels.py +16 -0
  53. hpcflow/data/scripts/main_script_test_json_in_obj.py +12 -0
  54. hpcflow/data/scripts/main_script_test_json_out_FAIL.py +3 -0
  55. hpcflow/data/scripts/main_script_test_json_out_obj.py +10 -0
  56. hpcflow/data/scripts/main_script_test_json_sub_param_in_json_out_labels.py +16 -0
  57. hpcflow/data/scripts/main_script_test_shell_env_vars.py +12 -0
  58. hpcflow/data/scripts/main_script_test_std_out_std_err.py +6 -0
  59. hpcflow/data/scripts/output_file_parser_basic.py +3 -0
  60. hpcflow/data/scripts/output_file_parser_basic_FAIL.py +7 -0
  61. hpcflow/data/scripts/output_file_parser_test_stdout_stderr.py +8 -0
  62. hpcflow/data/scripts/parse_t1_file_01.py +4 -0
  63. hpcflow/data/scripts/script_exit_test.py +5 -0
  64. hpcflow/data/template_components/__init__.py +1 -0
  65. hpcflow/data/template_components/command_files.yaml +26 -0
  66. hpcflow/data/template_components/environments.yaml +13 -0
  67. hpcflow/data/template_components/parameters.yaml +14 -0
  68. hpcflow/data/template_components/task_schemas.yaml +139 -0
  69. hpcflow/data/workflows/workflow_1.yaml +5 -0
  70. hpcflow/examples.ipynb +1037 -0
  71. hpcflow/sdk/__init__.py +149 -0
  72. hpcflow/sdk/app.py +4266 -0
  73. hpcflow/sdk/cli.py +1479 -0
  74. hpcflow/sdk/cli_common.py +385 -0
  75. hpcflow/sdk/config/__init__.py +5 -0
  76. hpcflow/sdk/config/callbacks.py +246 -0
  77. hpcflow/sdk/config/cli.py +388 -0
  78. hpcflow/sdk/config/config.py +1410 -0
  79. hpcflow/sdk/config/config_file.py +501 -0
  80. hpcflow/sdk/config/errors.py +272 -0
  81. hpcflow/sdk/config/types.py +150 -0
  82. hpcflow/sdk/core/__init__.py +38 -0
  83. hpcflow/sdk/core/actions.py +3857 -0
  84. hpcflow/sdk/core/app_aware.py +25 -0
  85. hpcflow/sdk/core/cache.py +224 -0
  86. hpcflow/sdk/core/command_files.py +814 -0
  87. hpcflow/sdk/core/commands.py +424 -0
  88. hpcflow/sdk/core/element.py +2071 -0
  89. hpcflow/sdk/core/enums.py +221 -0
  90. hpcflow/sdk/core/environment.py +256 -0
  91. hpcflow/sdk/core/errors.py +1043 -0
  92. hpcflow/sdk/core/execute.py +207 -0
  93. hpcflow/sdk/core/json_like.py +809 -0
  94. hpcflow/sdk/core/loop.py +1320 -0
  95. hpcflow/sdk/core/loop_cache.py +282 -0
  96. hpcflow/sdk/core/object_list.py +933 -0
  97. hpcflow/sdk/core/parameters.py +3371 -0
  98. hpcflow/sdk/core/rule.py +196 -0
  99. hpcflow/sdk/core/run_dir_files.py +57 -0
  100. hpcflow/sdk/core/skip_reason.py +7 -0
  101. hpcflow/sdk/core/task.py +3792 -0
  102. hpcflow/sdk/core/task_schema.py +993 -0
  103. hpcflow/sdk/core/test_utils.py +538 -0
  104. hpcflow/sdk/core/types.py +447 -0
  105. hpcflow/sdk/core/utils.py +1207 -0
  106. hpcflow/sdk/core/validation.py +87 -0
  107. hpcflow/sdk/core/values.py +477 -0
  108. hpcflow/sdk/core/workflow.py +4820 -0
  109. hpcflow/sdk/core/zarr_io.py +206 -0
  110. hpcflow/sdk/data/__init__.py +13 -0
  111. hpcflow/sdk/data/config_file_schema.yaml +34 -0
  112. hpcflow/sdk/data/config_schema.yaml +260 -0
  113. hpcflow/sdk/data/environments_spec_schema.yaml +21 -0
  114. hpcflow/sdk/data/files_spec_schema.yaml +5 -0
  115. hpcflow/sdk/data/parameters_spec_schema.yaml +7 -0
  116. hpcflow/sdk/data/task_schema_spec_schema.yaml +3 -0
  117. hpcflow/sdk/data/workflow_spec_schema.yaml +22 -0
  118. hpcflow/sdk/demo/__init__.py +3 -0
  119. hpcflow/sdk/demo/cli.py +242 -0
  120. hpcflow/sdk/helper/__init__.py +3 -0
  121. hpcflow/sdk/helper/cli.py +137 -0
  122. hpcflow/sdk/helper/helper.py +300 -0
  123. hpcflow/sdk/helper/watcher.py +192 -0
  124. hpcflow/sdk/log.py +288 -0
  125. hpcflow/sdk/persistence/__init__.py +18 -0
  126. hpcflow/sdk/persistence/base.py +2817 -0
  127. hpcflow/sdk/persistence/defaults.py +6 -0
  128. hpcflow/sdk/persistence/discovery.py +39 -0
  129. hpcflow/sdk/persistence/json.py +954 -0
  130. hpcflow/sdk/persistence/pending.py +948 -0
  131. hpcflow/sdk/persistence/store_resource.py +203 -0
  132. hpcflow/sdk/persistence/types.py +309 -0
  133. hpcflow/sdk/persistence/utils.py +73 -0
  134. hpcflow/sdk/persistence/zarr.py +2388 -0
  135. hpcflow/sdk/runtime.py +320 -0
  136. hpcflow/sdk/submission/__init__.py +3 -0
  137. hpcflow/sdk/submission/enums.py +70 -0
  138. hpcflow/sdk/submission/jobscript.py +2379 -0
  139. hpcflow/sdk/submission/schedulers/__init__.py +281 -0
  140. hpcflow/sdk/submission/schedulers/direct.py +233 -0
  141. hpcflow/sdk/submission/schedulers/sge.py +376 -0
  142. hpcflow/sdk/submission/schedulers/slurm.py +598 -0
  143. hpcflow/sdk/submission/schedulers/utils.py +25 -0
  144. hpcflow/sdk/submission/shells/__init__.py +52 -0
  145. hpcflow/sdk/submission/shells/base.py +229 -0
  146. hpcflow/sdk/submission/shells/bash.py +504 -0
  147. hpcflow/sdk/submission/shells/os_version.py +115 -0
  148. hpcflow/sdk/submission/shells/powershell.py +352 -0
  149. hpcflow/sdk/submission/submission.py +1402 -0
  150. hpcflow/sdk/submission/types.py +140 -0
  151. hpcflow/sdk/typing.py +194 -0
  152. hpcflow/sdk/utils/arrays.py +69 -0
  153. hpcflow/sdk/utils/deferred_file.py +55 -0
  154. hpcflow/sdk/utils/hashing.py +16 -0
  155. hpcflow/sdk/utils/patches.py +31 -0
  156. hpcflow/sdk/utils/strings.py +69 -0
  157. hpcflow/tests/api/test_api.py +32 -0
  158. hpcflow/tests/conftest.py +123 -0
  159. hpcflow/tests/data/__init__.py +0 -0
  160. hpcflow/tests/data/benchmark_N_elements.yaml +6 -0
  161. hpcflow/tests/data/benchmark_script_runner.yaml +26 -0
  162. hpcflow/tests/data/multi_path_sequences.yaml +29 -0
  163. hpcflow/tests/data/workflow_1.json +10 -0
  164. hpcflow/tests/data/workflow_1.yaml +5 -0
  165. hpcflow/tests/data/workflow_1_slurm.yaml +8 -0
  166. hpcflow/tests/data/workflow_1_wsl.yaml +8 -0
  167. hpcflow/tests/data/workflow_test_run_abort.yaml +42 -0
  168. hpcflow/tests/jinja_templates/test_jinja_templates.py +161 -0
  169. hpcflow/tests/programs/test_programs.py +180 -0
  170. hpcflow/tests/schedulers/direct_linux/test_direct_linux_submission.py +12 -0
  171. hpcflow/tests/schedulers/sge/test_sge_submission.py +36 -0
  172. hpcflow/tests/schedulers/slurm/test_slurm_submission.py +14 -0
  173. hpcflow/tests/scripts/test_input_file_generators.py +282 -0
  174. hpcflow/tests/scripts/test_main_scripts.py +1361 -0
  175. hpcflow/tests/scripts/test_non_snippet_script.py +46 -0
  176. hpcflow/tests/scripts/test_ouput_file_parsers.py +353 -0
  177. hpcflow/tests/shells/wsl/test_wsl_submission.py +14 -0
  178. hpcflow/tests/unit/test_action.py +1066 -0
  179. hpcflow/tests/unit/test_action_rule.py +24 -0
  180. hpcflow/tests/unit/test_app.py +132 -0
  181. hpcflow/tests/unit/test_cache.py +46 -0
  182. hpcflow/tests/unit/test_cli.py +172 -0
  183. hpcflow/tests/unit/test_command.py +377 -0
  184. hpcflow/tests/unit/test_config.py +195 -0
  185. hpcflow/tests/unit/test_config_file.py +162 -0
  186. hpcflow/tests/unit/test_element.py +666 -0
  187. hpcflow/tests/unit/test_element_iteration.py +88 -0
  188. hpcflow/tests/unit/test_element_set.py +158 -0
  189. hpcflow/tests/unit/test_group.py +115 -0
  190. hpcflow/tests/unit/test_input_source.py +1479 -0
  191. hpcflow/tests/unit/test_input_value.py +398 -0
  192. hpcflow/tests/unit/test_jobscript_unit.py +757 -0
  193. hpcflow/tests/unit/test_json_like.py +1247 -0
  194. hpcflow/tests/unit/test_loop.py +2674 -0
  195. hpcflow/tests/unit/test_meta_task.py +325 -0
  196. hpcflow/tests/unit/test_multi_path_sequences.py +259 -0
  197. hpcflow/tests/unit/test_object_list.py +116 -0
  198. hpcflow/tests/unit/test_parameter.py +243 -0
  199. hpcflow/tests/unit/test_persistence.py +664 -0
  200. hpcflow/tests/unit/test_resources.py +243 -0
  201. hpcflow/tests/unit/test_run.py +286 -0
  202. hpcflow/tests/unit/test_run_directories.py +29 -0
  203. hpcflow/tests/unit/test_runtime.py +9 -0
  204. hpcflow/tests/unit/test_schema_input.py +372 -0
  205. hpcflow/tests/unit/test_shell.py +129 -0
  206. hpcflow/tests/unit/test_slurm.py +39 -0
  207. hpcflow/tests/unit/test_submission.py +502 -0
  208. hpcflow/tests/unit/test_task.py +2560 -0
  209. hpcflow/tests/unit/test_task_schema.py +182 -0
  210. hpcflow/tests/unit/test_utils.py +616 -0
  211. hpcflow/tests/unit/test_value_sequence.py +549 -0
  212. hpcflow/tests/unit/test_values.py +91 -0
  213. hpcflow/tests/unit/test_workflow.py +827 -0
  214. hpcflow/tests/unit/test_workflow_template.py +186 -0
  215. hpcflow/tests/unit/utils/test_arrays.py +40 -0
  216. hpcflow/tests/unit/utils/test_deferred_file_writer.py +34 -0
  217. hpcflow/tests/unit/utils/test_hashing.py +65 -0
  218. hpcflow/tests/unit/utils/test_patches.py +5 -0
  219. hpcflow/tests/unit/utils/test_redirect_std.py +50 -0
  220. hpcflow/tests/unit/utils/test_strings.py +97 -0
  221. hpcflow/tests/workflows/__init__.py +0 -0
  222. hpcflow/tests/workflows/test_directory_structure.py +31 -0
  223. hpcflow/tests/workflows/test_jobscript.py +355 -0
  224. hpcflow/tests/workflows/test_run_status.py +198 -0
  225. hpcflow/tests/workflows/test_skip_downstream.py +696 -0
  226. hpcflow/tests/workflows/test_submission.py +140 -0
  227. hpcflow/tests/workflows/test_workflows.py +564 -0
  228. hpcflow/tests/workflows/test_zip.py +18 -0
  229. hpcflow/viz_demo.ipynb +6794 -0
  230. hpcflow-0.2.0a271.dist-info/LICENSE +375 -0
  231. hpcflow-0.2.0a271.dist-info/METADATA +65 -0
  232. hpcflow-0.2.0a271.dist-info/RECORD +237 -0
  233. {hpcflow-0.1.15.dist-info → hpcflow-0.2.0a271.dist-info}/WHEEL +4 -5
  234. hpcflow-0.2.0a271.dist-info/entry_points.txt +6 -0
  235. hpcflow/api.py +0 -490
  236. hpcflow/archive/archive.py +0 -307
  237. hpcflow/archive/cloud/cloud.py +0 -45
  238. hpcflow/archive/cloud/errors.py +0 -9
  239. hpcflow/archive/cloud/providers/dropbox.py +0 -427
  240. hpcflow/archive/errors.py +0 -5
  241. hpcflow/base_db.py +0 -4
  242. hpcflow/config.py +0 -233
  243. hpcflow/copytree.py +0 -66
  244. hpcflow/data/examples/_config.yml +0 -14
  245. hpcflow/data/examples/damask/demo/1.run.yml +0 -4
  246. hpcflow/data/examples/damask/demo/2.process.yml +0 -29
  247. hpcflow/data/examples/damask/demo/geom.geom +0 -2052
  248. hpcflow/data/examples/damask/demo/load.load +0 -1
  249. hpcflow/data/examples/damask/demo/material.config +0 -185
  250. hpcflow/data/examples/damask/inputs/geom.geom +0 -2052
  251. hpcflow/data/examples/damask/inputs/load.load +0 -1
  252. hpcflow/data/examples/damask/inputs/material.config +0 -185
  253. hpcflow/data/examples/damask/profiles/_variable_lookup.yml +0 -21
  254. hpcflow/data/examples/damask/profiles/damask.yml +0 -4
  255. hpcflow/data/examples/damask/profiles/damask_process.yml +0 -8
  256. hpcflow/data/examples/damask/profiles/damask_run.yml +0 -5
  257. hpcflow/data/examples/damask/profiles/default.yml +0 -6
  258. hpcflow/data/examples/thinking.yml +0 -177
  259. hpcflow/errors.py +0 -2
  260. hpcflow/init_db.py +0 -37
  261. hpcflow/models.py +0 -2595
  262. hpcflow/nesting.py +0 -9
  263. hpcflow/profiles.py +0 -455
  264. hpcflow/project.py +0 -81
  265. hpcflow/scheduler.py +0 -322
  266. hpcflow/utils.py +0 -103
  267. hpcflow/validation.py +0 -166
  268. hpcflow/variables.py +0 -543
  269. hpcflow-0.1.15.dist-info/METADATA +0 -168
  270. hpcflow-0.1.15.dist-info/RECORD +0 -45
  271. hpcflow-0.1.15.dist-info/entry_points.txt +0 -8
  272. hpcflow-0.1.15.dist-info/top_level.txt +0 -1
  273. /hpcflow/{archive → data/jinja_templates}/__init__.py +0 -0
  274. /hpcflow/{archive/cloud → data/programs}/__init__.py +0 -0
  275. /hpcflow/{archive/cloud/providers → data/workflows}/__init__.py +0 -0
@@ -0,0 +1,598 @@
1
+ """
2
+ An interface to SLURM.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ import subprocess
7
+ import time
8
+ from typing import cast, TYPE_CHECKING
9
+ from typing_extensions import override
10
+ from hpcflow.sdk.typing import hydrate
11
+ from hpcflow.sdk.core.enums import ParallelMode
12
+ from hpcflow.sdk.core.errors import (
13
+ IncompatibleParallelModeError,
14
+ IncompatibleSLURMArgumentsError,
15
+ IncompatibleSLURMPartitionError,
16
+ UnknownSLURMPartitionError,
17
+ )
18
+ from hpcflow.sdk.log import TimeIt
19
+ from hpcflow.sdk.submission.enums import JobscriptElementState
20
+ from hpcflow.sdk.submission.schedulers import QueuedScheduler
21
+ from hpcflow.sdk.submission.schedulers.utils import run_cmd
22
+
23
+ if TYPE_CHECKING:
24
+ from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
25
+ from typing import Any, ClassVar
26
+ from ...config.types import SchedulerConfigDescriptor, SLURMPartitionsDescriptor
27
+ from ...core.element import ElementResources
28
+ from ..jobscript import Jobscript
29
+ from ..types import VersionInfo
30
+ from ..shells.base import Shell
31
+
32
+
33
+ @hydrate
34
+ class SlurmPosix(QueuedScheduler):
35
+ """
36
+ A scheduler that uses SLURM.
37
+
38
+ Keyword Args
39
+ ------------
40
+ directives: dict
41
+ Scheduler directives. Each item is written verbatim in the jobscript as a
42
+ scheduler directive, and is not processed in any way. If a value is `None`, the
43
+ key is considered a flag-like directive. If a value is a list, multiple directives
44
+ will be printed to the jobscript with the same key, but different values.
45
+
46
+ Notes
47
+ -----
48
+ - runs in current working directory by default [2]
49
+
50
+ Todo
51
+ ----
52
+ - consider getting memory usage like: https://stackoverflow.com/a/44143229/5042280
53
+
54
+ References
55
+ ----------
56
+ [1] https://manpages.org/sbatch
57
+ [2] https://ri.itservices.manchester.ac.uk/csf4/batch/sge-to-slurm/
58
+
59
+ """
60
+
61
+ #: Default submission command.
62
+ DEFAULT_SUBMIT_CMD: ClassVar[str] = "sbatch"
63
+ #: Default command to show the queue state.
64
+ DEFAULT_SHOW_CMD: ClassVar[Sequence[str]] = ("squeue", "--me")
65
+ #: Default cancel command.
66
+ DEFAULT_DEL_CMD: ClassVar[str] = "scancel"
67
+ #: Default job control directive prefix.
68
+ DEFAULT_JS_CMD: ClassVar[str] = "#SBATCH"
69
+ #: Default prefix to enable array processing.
70
+ DEFAULT_ARRAY_SWITCH: ClassVar[str] = "--array"
71
+ #: Default shell variable with array ID.
72
+ DEFAULT_ARRAY_ITEM_VAR: ClassVar[str] = "SLURM_ARRAY_TASK_ID"
73
+ #: Number of times to try when querying the state.
74
+ NUM_STATE_QUERY_TRIES: ClassVar[int] = 5
75
+ #: Delay (in seconds) between attempts to query the state.
76
+ INTER_STATE_QUERY_DELAY: ClassVar[float] = 0.5
77
+
78
+ #: Maps scheduler state codes to :py:class:`JobscriptElementState` values.
79
+ state_lookup: ClassVar[Mapping[str, JobscriptElementState]] = {
80
+ "PENDING": JobscriptElementState.pending,
81
+ "RUNNING": JobscriptElementState.running,
82
+ "COMPLETING": JobscriptElementState.running,
83
+ "CANCELLED": JobscriptElementState.cancelled,
84
+ "COMPLETED": JobscriptElementState.finished,
85
+ "FAILED": JobscriptElementState.errored,
86
+ "OUT_OF_MEMORY": JobscriptElementState.errored,
87
+ "TIMEOUT": JobscriptElementState.errored,
88
+ }
89
+
90
+ @classmethod
91
+ @override
92
+ @TimeIt.decorator
93
+ def process_resources(
94
+ cls, resources: ElementResources, scheduler_config: SchedulerConfigDescriptor
95
+ ) -> None:
96
+ """Perform scheduler-specific processing to the element resources.
97
+
98
+ Note
99
+ ----
100
+ This mutates `resources`.
101
+ """
102
+ if resources.is_parallel:
103
+ if resources.parallel_mode is None:
104
+ # set default parallel mode:
105
+ resources.parallel_mode = ParallelMode.DISTRIBUTED
106
+
107
+ if resources.parallel_mode is ParallelMode.SHARED:
108
+ if (resources.num_nodes and resources.num_nodes > 1) or (
109
+ resources.SLURM_num_nodes and resources.SLURM_num_nodes > 1
110
+ ):
111
+ raise IncompatibleParallelModeError(resources.parallel_mode)
112
+ # consider `num_cores` and `num_threads` synonyms in this case:
113
+ if resources.SLURM_num_tasks and resources.SLURM_num_tasks != 1:
114
+ raise IncompatibleSLURMArgumentsError(
115
+ f"For the {resources.parallel_mode.name.lower()} parallel mode, "
116
+ f"`SLURM_num_tasks` must be set to 1 (to ensure all requested "
117
+ f"cores reside on the same node)."
118
+ )
119
+ resources.SLURM_num_tasks = 1
120
+
121
+ if resources.SLURM_num_cpus_per_task == 1:
122
+ raise IncompatibleSLURMArgumentsError(
123
+ f"For the {resources.parallel_mode.name.lower()} parallel mode, "
124
+ f"if `SLURM_num_cpus_per_task` is set, it must be set to the "
125
+ f"number of threads/cores to use, and so must be greater than 1, "
126
+ f"but {resources.SLURM_num_cpus_per_task!r} was specified."
127
+ )
128
+ resources.num_threads = resources.num_threads or resources.num_cores
129
+ if not resources.num_threads and not resources.SLURM_num_cpus_per_task:
130
+ raise ValueError(
131
+ f"For the {resources.parallel_mode.name.lower()} parallel "
132
+ f"mode, specify `num_threads` (or its synonym for this "
133
+ f"parallel mode: `num_cores`), or the SLURM-specific "
134
+ f"parameter `SLURM_num_cpus_per_task`."
135
+ )
136
+ elif (resources.num_threads and resources.SLURM_num_cpus_per_task) and (
137
+ resources.num_threads != resources.SLURM_num_cpus_per_task
138
+ ):
139
+ raise IncompatibleSLURMArgumentsError(
140
+ f"Incompatible parameters for `num_cores`/`num_threads` "
141
+ f"({resources.num_threads}) and `SLURM_num_cpus_per_task` "
142
+ f"({resources.SLURM_num_cpus_per_task}) for the "
143
+ f"{resources.parallel_mode.name.lower()} parallel mode."
144
+ )
145
+ resources.SLURM_num_cpus_per_task = resources.num_threads
146
+
147
+ elif resources.parallel_mode is ParallelMode.DISTRIBUTED:
148
+ if resources.num_threads:
149
+ raise ValueError(
150
+ f"For the {resources.parallel_mode.name.lower()} parallel "
151
+ f"mode, specifying `num_threads` is not permitted."
152
+ )
153
+ if (
154
+ resources.SLURM_num_tasks
155
+ and resources.num_cores
156
+ and resources.SLURM_num_tasks != resources.num_cores
157
+ ):
158
+ raise IncompatibleSLURMArgumentsError(
159
+ f"Incompatible parameters for `num_cores` ({resources.num_cores})"
160
+ f" and `SLURM_num_tasks` ({resources.SLURM_num_tasks}) for the "
161
+ f"{resources.parallel_mode.name.lower()} parallel mode."
162
+ )
163
+ elif not resources.SLURM_num_tasks and resources.num_cores:
164
+ resources.SLURM_num_tasks = resources.num_cores
165
+ elif (
166
+ resources.SLURM_num_tasks_per_node
167
+ and resources.num_cores_per_node
168
+ and resources.SLURM_num_tasks_per_node != resources.num_cores_per_node
169
+ ):
170
+ raise IncompatibleSLURMArgumentsError(
171
+ f"Incompatible parameters for `num_cores_per_node` "
172
+ f"({resources.num_cores_per_node}) and `SLURM_num_tasks_per_node`"
173
+ f" ({resources.SLURM_num_tasks_per_node}) for the "
174
+ f"{resources.parallel_mode.name.lower()} parallel mode."
175
+ )
176
+ elif (
177
+ not resources.SLURM_num_tasks_per_node
178
+ and resources.num_cores_per_node
179
+ ):
180
+ resources.SLURM_num_tasks_per_node = resources.num_cores_per_node
181
+
182
+ if (
183
+ resources.SLURM_num_nodes
184
+ and resources.num_nodes
185
+ and resources.SLURM_num_nodes != resources.num_nodes
186
+ ):
187
+ raise IncompatibleSLURMArgumentsError(
188
+ f"Incompatible parameters for `num_nodes` ({resources.num_nodes})"
189
+ f" and `SLURM_num_nodes` ({resources.SLURM_num_nodes}) for the "
190
+ f"{resources.parallel_mode.name.lower()} parallel mode."
191
+ )
192
+ elif not resources.SLURM_num_nodes and resources.num_nodes:
193
+ resources.SLURM_num_nodes = resources.num_nodes
194
+
195
+ elif resources.parallel_mode is ParallelMode.HYBRID:
196
+ raise NotImplementedError("hybrid parallel mode not yet supported.")
197
+
198
+ else:
199
+ if resources.SLURM_is_parallel:
200
+ raise IncompatibleSLURMArgumentsError(
201
+ "Some specified SLURM-specific arguments (which indicate a parallel "
202
+ "job) conflict with the scheduler-agnostic arguments (which "
203
+ "indicate a serial job)."
204
+ )
205
+ if not resources.SLURM_num_tasks:
206
+ resources.SLURM_num_tasks = 1
207
+
208
+ if resources.SLURM_num_tasks_per_node:
209
+ resources.SLURM_num_tasks_per_node = None
210
+
211
+ if not resources.SLURM_num_nodes:
212
+ resources.SLURM_num_nodes = 1
213
+
214
+ if not resources.SLURM_num_cpus_per_task:
215
+ resources.SLURM_num_cpus_per_task = 1
216
+
217
+ num_cores = resources.num_cores or resources.SLURM_num_tasks
218
+ num_cores_per_node = (
219
+ resources.num_cores_per_node or resources.SLURM_num_tasks_per_node
220
+ )
221
+ num_nodes = resources.num_nodes or resources.SLURM_num_nodes
222
+ para_mode = resources.parallel_mode
223
+
224
+ # select matching partition if possible:
225
+ all_parts = scheduler_config.get("partitions", {})
226
+ if resources.SLURM_partition is not None:
227
+ # check user-specified partition is valid and compatible with requested
228
+ # cores/nodes:
229
+ try:
230
+ part = all_parts[resources.SLURM_partition]
231
+ except KeyError:
232
+ raise UnknownSLURMPartitionError(resources.SLURM_partition, all_parts)
233
+ # TODO: we when we support ParallelMode.HYBRID, these checks will have to
234
+ # consider the total number of cores requested per node
235
+ # (num_cores_per_node * num_threads)?
236
+ part_num_cores = part.get("num_cores", ())
237
+ part_num_cores_per_node = part.get("num_cores_per_node", ())
238
+ part_num_nodes = part.get("num_nodes", ())
239
+ part_para_modes = part.get("parallel_modes", ())
240
+ if cls.__is_present_unsupported(num_cores, part_num_cores):
241
+ raise IncompatibleSLURMPartitionError(
242
+ resources.SLURM_partition, "number of cores", num_cores
243
+ )
244
+ if cls.__is_present_unsupported(num_cores_per_node, part_num_cores_per_node):
245
+ raise IncompatibleSLURMPartitionError(
246
+ resources.SLURM_partition,
247
+ "number of cores per node",
248
+ num_cores_per_node,
249
+ )
250
+ if cls.__is_present_unsupported(num_nodes, part_num_nodes):
251
+ raise IncompatibleSLURMPartitionError(
252
+ resources.SLURM_partition, "number of nodes", num_nodes
253
+ )
254
+ if para_mode and para_mode.name.lower() not in part_para_modes:
255
+ raise IncompatibleSLURMPartitionError(
256
+ resources.SLURM_partition, "parallel mode", para_mode
257
+ )
258
+ else:
259
+ # find the first compatible partition if one exists:
260
+ # TODO: bug here? not finding correct partition?
261
+ for part_name, part_info in all_parts.items():
262
+ if cls.__partition_matches(
263
+ num_cores, num_cores_per_node, num_nodes, para_mode, part_info
264
+ ):
265
+ resources.SLURM_partition = str(part_name)
266
+ break
267
+
268
+ @classmethod
269
+ def __is_present_unsupported(
270
+ cls, num_req: int | None, part_have: Sequence[int] | None
271
+ ) -> bool:
272
+ """
273
+ Test if information is present on both sides, but doesn't match.
274
+ """
275
+ return bool(
276
+ num_req and part_have and not cls.is_num_cores_supported(num_req, part_have)
277
+ )
278
+
279
+ @classmethod
280
+ def __is_present_supported(
281
+ cls, num_req: int | None, part_have: Sequence[int] | None
282
+ ) -> bool:
283
+ """
284
+ Test if information is present on both sides, and also matches.
285
+ """
286
+ return bool(
287
+ num_req and part_have and cls.is_num_cores_supported(num_req, part_have)
288
+ )
289
+
290
+ @classmethod
291
+ def __partition_matches(
292
+ cls,
293
+ num_cores: int | None,
294
+ num_cores_per_node: int | None,
295
+ num_nodes: int | None,
296
+ para_mode: ParallelMode | None,
297
+ part_info: SLURMPartitionsDescriptor,
298
+ ) -> bool:
299
+ """
300
+ Check whether a partition (part_name, part_info) matches the requested number
301
+ of cores and nodes.
302
+ """
303
+ part_num_cores = part_info.get("num_cores", [])
304
+ part_num_cores_per_node = part_info.get("num_cores_per_node", [])
305
+ part_num_nodes = part_info.get("num_nodes", [])
306
+ part_para_modes = part_info.get("parallel_modes", [])
307
+ if (
308
+ not cls.__is_present_supported(num_cores, part_num_cores)
309
+ or not cls.__is_present_supported(num_cores_per_node, part_num_cores_per_node)
310
+ or not cls.__is_present_supported(num_nodes, part_num_nodes)
311
+ ):
312
+ return False
313
+ # FIXME: Does the next check come above or below the check below?
314
+ # Surely not both!
315
+ part_match = True
316
+ if part_match:
317
+ return True
318
+ if para_mode and para_mode.name.lower() not in part_para_modes:
319
+ return False
320
+ if part_match:
321
+ return True
322
+ return False
323
+
324
+ def __format_core_request_lines(self, resources: ElementResources) -> Iterator[str]:
325
+ if resources.SLURM_partition:
326
+ yield f"{self.js_cmd} --partition {resources.SLURM_partition}"
327
+ if resources.SLURM_num_nodes: # TODO: option for --exclusive ?
328
+ yield f"{self.js_cmd} --nodes {resources.SLURM_num_nodes}"
329
+ if resources.SLURM_num_tasks:
330
+ yield f"{self.js_cmd} --ntasks {resources.SLURM_num_tasks}"
331
+ if resources.SLURM_num_tasks_per_node:
332
+ yield f"{self.js_cmd} --ntasks-per-node {resources.SLURM_num_tasks_per_node}"
333
+ if resources.SLURM_num_cpus_per_task:
334
+ yield f"{self.js_cmd} --cpus-per-task {resources.SLURM_num_cpus_per_task}"
335
+
336
+ def __format_array_request(self, num_elements: int, resources: ElementResources):
337
+ # TODO: Slurm docs start indices at zero, why are we starting at one?
338
+ # https://slurm.schedmd.com/sbatch.html#OPT_array
339
+ max_str = f"%{resources.max_array_items}" if resources.max_array_items else ""
340
+ return f"{self.js_cmd} {self.array_switch} 1-{num_elements}{max_str}"
341
+
342
+ def get_stdout_filename(
343
+ self, js_idx: int, job_ID: str, array_idx: int | None = None
344
+ ) -> str:
345
+ """File name of the standard output stream file."""
346
+ array_idx_str = f".{array_idx}" if array_idx is not None else ""
347
+ return f"js_{js_idx}.sh_{job_ID}{array_idx_str}.out"
348
+
349
+ def get_stderr_filename(
350
+ self, js_idx: int, job_ID: str, array_idx: int | None = None
351
+ ) -> str:
352
+ """File name of the standard error stream file."""
353
+ array_idx_str = f".{array_idx}" if array_idx is not None else ""
354
+ return f"js_{js_idx}.sh_{job_ID}{array_idx_str}.err"
355
+
356
+ def __format_std_stream_file_option_lines(
357
+ self, is_array: bool, sub_idx: int, js_idx: int, combine_std: bool
358
+ ) -> Iterator[str]:
359
+ pattern = R"%x_%A.%a" if is_array else R"%x_%j"
360
+ base = f"./artifacts/submissions/{sub_idx}/js_std/{js_idx}/{pattern}"
361
+ yield f"{self.js_cmd} --output {base}.out"
362
+ if not combine_std:
363
+ yield f"{self.js_cmd} --error {base}.err"
364
+
365
+ @override
366
+ def format_directives(
367
+ self,
368
+ resources: ElementResources,
369
+ num_elements: int,
370
+ is_array: bool,
371
+ sub_idx: int,
372
+ js_idx: int,
373
+ ) -> str:
374
+ """
375
+ Format the directives to the scheduler.
376
+ """
377
+ opts: list[str] = []
378
+ opts.extend(self.__format_core_request_lines(resources))
379
+
380
+ if is_array:
381
+ opts.append(self.__format_array_request(num_elements, resources))
382
+
383
+ opts.extend(
384
+ self.__format_std_stream_file_option_lines(
385
+ is_array, sub_idx, js_idx, resources.combine_jobscript_std
386
+ )
387
+ )
388
+
389
+ for opt_k, opt_v in self.directives.items():
390
+ if isinstance(opt_v, list):
391
+ for i in opt_v:
392
+ opts.append(f"{self.js_cmd} {opt_k} {i}")
393
+ elif opt_v:
394
+ opts.append(f"{self.js_cmd} {opt_k} {opt_v}")
395
+ elif opt_v is None:
396
+ opts.append(f"{self.js_cmd} {opt_k}")
397
+
398
+ return "\n".join(opts) + "\n"
399
+
400
+ @override
401
+ @TimeIt.decorator
402
+ def get_version_info(self) -> VersionInfo:
403
+ vers_cmd = [self.submit_cmd, "--version"]
404
+ proc = subprocess.run(
405
+ args=vers_cmd,
406
+ stdout=subprocess.PIPE,
407
+ stderr=subprocess.PIPE,
408
+ )
409
+ stdout = proc.stdout.decode().strip()
410
+ stderr = proc.stderr.decode().strip()
411
+ if stderr:
412
+ print(stderr)
413
+ name, version = stdout.split()
414
+ return {
415
+ "scheduler_name": name,
416
+ "scheduler_version": version,
417
+ }
418
+
419
+ @override
420
+ def get_submit_command(
421
+ self,
422
+ shell: Shell,
423
+ js_path: str,
424
+ deps: dict[Any, tuple[Any, ...]],
425
+ ) -> list[str]:
426
+ """
427
+ Get the command to use to submit a job to the scheduler.
428
+
429
+ Returns
430
+ -------
431
+ List of argument words.
432
+ """
433
+ cmd = [self.submit_cmd, "--parsable"]
434
+ if deps:
435
+ cmd.append("--dependency")
436
+ cmd.append(",".join(self.__dependency_args(deps)))
437
+ cmd.append(js_path)
438
+ return cmd
439
+
440
+ @staticmethod
441
+ def __dependency_args(deps: dict[Any, tuple[Any, ...]]) -> Iterator[str]:
442
+ for job_ID, is_array_dep in deps.values():
443
+ if is_array_dep: # array dependency
444
+ yield f"aftercorr:{job_ID}"
445
+ else:
446
+ yield f"afterany:{job_ID}"
447
+
448
+ def parse_submission_output(self, stdout: str) -> str:
449
+ """Extract scheduler reference for a newly submitted jobscript"""
450
+ if ";" in stdout:
451
+ return stdout.split(";")[0] # since we submit with "--parsable"
452
+ # Try using the whole thing
453
+ return stdout
454
+
455
+ @staticmethod
456
+ def _parse_job_IDs(job_ID_str: str) -> tuple[str, None | list[int]]:
457
+ """
458
+ Parse the job ID column from the `squeue` command (the `%i` format option).
459
+
460
+ Returns
461
+ -------
462
+ job_id
463
+ The job identifier.
464
+ array_indices
465
+ The indices into the job array.
466
+ """
467
+ base_job_ID, *arr_idx_data = job_ID_str.split("_")
468
+ if not arr_idx_data:
469
+ return base_job_ID, None
470
+ arr_idx = arr_idx_data[0]
471
+ try:
472
+ return base_job_ID, [int(arr_idx) - 1] # zero-index
473
+ except ValueError:
474
+ pass
475
+ # split on commas (e.g. "[5,8-40]")
476
+ _arr_idx: list[int] = []
477
+ for i_range_str in arr_idx.strip("[]").split(","):
478
+ if "-" in i_range_str:
479
+ _from, _to = i_range_str.split("-")
480
+ if "%" in _to:
481
+ # indicates max concurrent array items; not needed
482
+ _to = _to.split("%")[0]
483
+ _arr_idx.extend(range(int(_from) - 1, int(_to)))
484
+ else:
485
+ _arr_idx.append(int(i_range_str) - 1)
486
+ return base_job_ID, _arr_idx
487
+
488
+ def __parse_job_states(
489
+ self, stdout: str
490
+ ) -> dict[str, JobscriptElementState | dict[int, JobscriptElementState]]:
491
+ """Parse output from Slurm `squeue` command with a simple format."""
492
+ info: dict[str, JobscriptElementState | dict[int, JobscriptElementState]] = {}
493
+ for ln in stdout.split("\n"):
494
+ if not ln:
495
+ continue
496
+ job_id, job_state, *_ = ln.split()
497
+ base_job_ID, arr_idx = self._parse_job_IDs(job_id)
498
+ state = self.state_lookup.get(job_state, JobscriptElementState.errored)
499
+
500
+ if arr_idx is not None:
501
+ entry = cast(
502
+ dict[int, JobscriptElementState], info.setdefault(base_job_ID, {})
503
+ )
504
+ for arr_idx_i in arr_idx:
505
+ entry[arr_idx_i] = state
506
+ else:
507
+ info[base_job_ID] = state
508
+
509
+ return info
510
+
511
+ def __query_job_states(self, job_IDs: Iterable[str]) -> tuple[str, str]:
512
+ """Query the state of the specified jobs."""
513
+ cmd = [
514
+ *self.show_cmd,
515
+ "--noheader",
516
+ "--format",
517
+ R"%200i %30T", # job ID (<base_job_id>_<index> for array job) and job state
518
+ "--jobs",
519
+ ",".join(job_IDs),
520
+ ]
521
+ return run_cmd(cmd, logger=self._app.submission_logger)
522
+
523
+ def __get_job_valid_IDs(self, job_IDs: Collection[str] | None = None) -> set[str]:
524
+ """Get a list of job IDs that are known by the scheduler, optionally filtered by
525
+ specified job IDs."""
526
+
527
+ cmd = [*self.show_cmd, "--noheader", "--format", r"%F"]
528
+ stdout, stderr = run_cmd(cmd, logger=self._app.submission_logger)
529
+ if stderr:
530
+ raise ValueError(
531
+ f"Could not get query Slurm jobs. Command was: {cmd!r}; stderr was: "
532
+ f"{stderr}"
533
+ )
534
+ else:
535
+ known_jobs = set(i.strip() for i in stdout.split("\n") if i.strip())
536
+ if job_IDs is None:
537
+ return known_jobs
538
+ return known_jobs.intersection(job_IDs)
539
+
540
+ @override
541
+ def get_job_state_info(
542
+ self, *, js_refs: Sequence[str] | None = None
543
+ ) -> Mapping[str, JobscriptElementState | Mapping[int, JobscriptElementState]]:
544
+ """Query the scheduler to get the states of all of this user's jobs, optionally
545
+ filtering by specified job IDs.
546
+
547
+ Jobs that are not in the scheduler's status output will not appear in the output
548
+ of this method.
549
+ """
550
+
551
+ # if job_IDs are passed, then assume they are existant, otherwise retrieve valid
552
+ # jobs:
553
+ refs: Collection[str] = js_refs or self.__get_job_valid_IDs()
554
+
555
+ count = 0
556
+ while refs:
557
+ stdout, stderr = self.__query_job_states(refs)
558
+ if not stderr:
559
+ return self.__parse_job_states(stdout)
560
+ if (
561
+ "Invalid job id specified" not in stderr
562
+ or count >= self.NUM_STATE_QUERY_TRIES
563
+ ):
564
+ raise ValueError(f"Could not get Slurm job states. Stderr was: {stderr}")
565
+
566
+ # the job might have finished; this only seems to happen if a single
567
+ # non-existant job ID is specified; for multiple non-existant jobs, no
568
+ # error is produced;
569
+ self._app.submission_logger.info(
570
+ "A specified job ID is non-existant; refreshing known job IDs..."
571
+ )
572
+ time.sleep(self.INTER_STATE_QUERY_DELAY)
573
+ refs = self.__get_job_valid_IDs(refs)
574
+ count += 1
575
+ return {}
576
+
577
+ @override
578
+ def cancel_jobs(
579
+ self,
580
+ js_refs: list[str],
581
+ jobscripts: list[Jobscript] | None = None,
582
+ ):
583
+ """
584
+ Cancel submitted jobs.
585
+ """
586
+ cmd = [self.del_cmd, *js_refs]
587
+ self._app.submission_logger.info(
588
+ f"cancelling {self.__class__.__name__} jobscripts with command: {cmd}."
589
+ )
590
+ stdout, stderr = run_cmd(cmd, logger=self._app.submission_logger)
591
+ if stderr:
592
+ raise ValueError(
593
+ f"Could not get query {self.__class__.__name__} jobs. Command was: "
594
+ f"{cmd!r}; stderr was: {stderr}"
595
+ )
596
+ self._app.submission_logger.info(
597
+ f"jobscripts cancel command executed; stdout was: {stdout}."
598
+ )
@@ -0,0 +1,25 @@
1
+ """
2
+ Helper for running a subprocess.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ import subprocess
7
+ from typing import TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from collections.abc import Sequence
11
+ from logging import Logger
12
+
13
+
14
+ def run_cmd(
15
+ cmd: str | Sequence[str], logger: Logger | None = None, **kwargs
16
+ ) -> tuple[str, str]:
17
+ """Execute a command and return stdout, stderr as strings."""
18
+ if logger:
19
+ logger.debug(f"running shell command: {cmd}")
20
+ proc = subprocess.run(
21
+ args=cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, **kwargs
22
+ )
23
+ stdout = proc.stdout.decode()
24
+ stderr = proc.stderr.decode()
25
+ return stdout, stderr
@@ -0,0 +1,52 @@
1
+ """
2
+ Adapters for various shells.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ import os
7
+
8
+ from hpcflow.sdk.core.errors import UnsupportedShellError
9
+
10
+ from hpcflow.sdk.submission.shells.base import Shell
11
+ from hpcflow.sdk.submission.shells.bash import Bash, WSLBash
12
+ from hpcflow.sdk.submission.shells.powershell import WindowsPowerShell
13
+
14
+ #: All supported shells.
15
+ ALL_SHELLS: dict[str, dict[str, type[Shell]]] = {
16
+ "bash": {"posix": Bash},
17
+ "powershell": {"nt": WindowsPowerShell},
18
+ "wsl+bash": {"nt": WSLBash},
19
+ "wsl": {"nt": WSLBash}, # TODO: cast this to wsl+bash in ResourceSpec?
20
+ }
21
+
22
+ #: The default shell in the default config.
23
+ DEFAULT_SHELL_NAMES = {
24
+ "posix": "bash",
25
+ "nt": "powershell",
26
+ }
27
+
28
+
29
+ def get_supported_shells(os_name: str | None = None) -> dict[str, type[Shell]]:
30
+ """
31
+ Get shells supported on the current or given OS.
32
+ """
33
+ os_name_ = os_name or os.name
34
+ return {k: v[os_name_] for k, v in ALL_SHELLS.items() if v.get(os_name_)}
35
+
36
+
37
+ def get_shell(shell_name: str | None, os_name: str | None = None, **kwargs) -> Shell:
38
+ """
39
+ Get a shell interface with the given name for a given OS (or the current one).
40
+ """
41
+ # TODO: apply config default shell args?
42
+
43
+ os_name = os_name or os.name
44
+ shell_name = (
45
+ DEFAULT_SHELL_NAMES[os_name] if shell_name is None else shell_name.lower()
46
+ )
47
+
48
+ supported = get_supported_shells(os_name.lower())
49
+ if not (shell_cls := supported.get(shell_name)):
50
+ raise UnsupportedShellError(shell=shell_name, supported=supported)
51
+
52
+ return shell_cls(**kwargs)