hpcflow 0.1.15__py3-none-any.whl → 0.2.0a271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (275) hide show
  1. hpcflow/__init__.py +2 -11
  2. hpcflow/__pyinstaller/__init__.py +5 -0
  3. hpcflow/__pyinstaller/hook-hpcflow.py +40 -0
  4. hpcflow/_version.py +1 -1
  5. hpcflow/app.py +43 -0
  6. hpcflow/cli.py +2 -461
  7. hpcflow/data/demo_data_manifest/__init__.py +3 -0
  8. hpcflow/data/demo_data_manifest/demo_data_manifest.json +6 -0
  9. hpcflow/data/jinja_templates/test/test_template.txt +8 -0
  10. hpcflow/data/programs/hello_world/README.md +1 -0
  11. hpcflow/data/programs/hello_world/hello_world.c +87 -0
  12. hpcflow/data/programs/hello_world/linux/hello_world +0 -0
  13. hpcflow/data/programs/hello_world/macos/hello_world +0 -0
  14. hpcflow/data/programs/hello_world/win/hello_world.exe +0 -0
  15. hpcflow/data/scripts/__init__.py +1 -0
  16. hpcflow/data/scripts/bad_script.py +2 -0
  17. hpcflow/data/scripts/demo_task_1_generate_t1_infile_1.py +8 -0
  18. hpcflow/data/scripts/demo_task_1_generate_t1_infile_2.py +8 -0
  19. hpcflow/data/scripts/demo_task_1_parse_p3.py +7 -0
  20. hpcflow/data/scripts/do_nothing.py +2 -0
  21. hpcflow/data/scripts/env_specifier_test/input_file_generator_pass_env_spec.py +4 -0
  22. hpcflow/data/scripts/env_specifier_test/main_script_test_pass_env_spec.py +8 -0
  23. hpcflow/data/scripts/env_specifier_test/output_file_parser_pass_env_spec.py +4 -0
  24. hpcflow/data/scripts/env_specifier_test/v1/input_file_generator_basic.py +4 -0
  25. hpcflow/data/scripts/env_specifier_test/v1/main_script_test_direct_in_direct_out.py +7 -0
  26. hpcflow/data/scripts/env_specifier_test/v1/output_file_parser_basic.py +4 -0
  27. hpcflow/data/scripts/env_specifier_test/v2/main_script_test_direct_in_direct_out.py +7 -0
  28. hpcflow/data/scripts/generate_t1_file_01.py +7 -0
  29. hpcflow/data/scripts/import_future_script.py +7 -0
  30. hpcflow/data/scripts/input_file_generator_basic.py +3 -0
  31. hpcflow/data/scripts/input_file_generator_basic_FAIL.py +3 -0
  32. hpcflow/data/scripts/input_file_generator_test_stdout_stderr.py +8 -0
  33. hpcflow/data/scripts/main_script_test_direct_in.py +3 -0
  34. hpcflow/data/scripts/main_script_test_direct_in_direct_out.py +6 -0
  35. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2.py +6 -0
  36. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed.py +6 -0
  37. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed_group.py +7 -0
  38. hpcflow/data/scripts/main_script_test_direct_in_direct_out_3.py +6 -0
  39. hpcflow/data/scripts/main_script_test_direct_in_direct_out_all_iters_test.py +15 -0
  40. hpcflow/data/scripts/main_script_test_direct_in_direct_out_env_spec.py +7 -0
  41. hpcflow/data/scripts/main_script_test_direct_in_direct_out_labels.py +8 -0
  42. hpcflow/data/scripts/main_script_test_direct_in_group_direct_out_3.py +6 -0
  43. hpcflow/data/scripts/main_script_test_direct_in_group_one_fail_direct_out_3.py +6 -0
  44. hpcflow/data/scripts/main_script_test_direct_sub_param_in_direct_out.py +6 -0
  45. hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +12 -0
  46. hpcflow/data/scripts/main_script_test_hdf5_in_obj_2.py +12 -0
  47. hpcflow/data/scripts/main_script_test_hdf5_in_obj_group.py +12 -0
  48. hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +11 -0
  49. hpcflow/data/scripts/main_script_test_json_and_direct_in_json_out.py +14 -0
  50. hpcflow/data/scripts/main_script_test_json_in_json_and_direct_out.py +17 -0
  51. hpcflow/data/scripts/main_script_test_json_in_json_out.py +14 -0
  52. hpcflow/data/scripts/main_script_test_json_in_json_out_labels.py +16 -0
  53. hpcflow/data/scripts/main_script_test_json_in_obj.py +12 -0
  54. hpcflow/data/scripts/main_script_test_json_out_FAIL.py +3 -0
  55. hpcflow/data/scripts/main_script_test_json_out_obj.py +10 -0
  56. hpcflow/data/scripts/main_script_test_json_sub_param_in_json_out_labels.py +16 -0
  57. hpcflow/data/scripts/main_script_test_shell_env_vars.py +12 -0
  58. hpcflow/data/scripts/main_script_test_std_out_std_err.py +6 -0
  59. hpcflow/data/scripts/output_file_parser_basic.py +3 -0
  60. hpcflow/data/scripts/output_file_parser_basic_FAIL.py +7 -0
  61. hpcflow/data/scripts/output_file_parser_test_stdout_stderr.py +8 -0
  62. hpcflow/data/scripts/parse_t1_file_01.py +4 -0
  63. hpcflow/data/scripts/script_exit_test.py +5 -0
  64. hpcflow/data/template_components/__init__.py +1 -0
  65. hpcflow/data/template_components/command_files.yaml +26 -0
  66. hpcflow/data/template_components/environments.yaml +13 -0
  67. hpcflow/data/template_components/parameters.yaml +14 -0
  68. hpcflow/data/template_components/task_schemas.yaml +139 -0
  69. hpcflow/data/workflows/workflow_1.yaml +5 -0
  70. hpcflow/examples.ipynb +1037 -0
  71. hpcflow/sdk/__init__.py +149 -0
  72. hpcflow/sdk/app.py +4266 -0
  73. hpcflow/sdk/cli.py +1479 -0
  74. hpcflow/sdk/cli_common.py +385 -0
  75. hpcflow/sdk/config/__init__.py +5 -0
  76. hpcflow/sdk/config/callbacks.py +246 -0
  77. hpcflow/sdk/config/cli.py +388 -0
  78. hpcflow/sdk/config/config.py +1410 -0
  79. hpcflow/sdk/config/config_file.py +501 -0
  80. hpcflow/sdk/config/errors.py +272 -0
  81. hpcflow/sdk/config/types.py +150 -0
  82. hpcflow/sdk/core/__init__.py +38 -0
  83. hpcflow/sdk/core/actions.py +3857 -0
  84. hpcflow/sdk/core/app_aware.py +25 -0
  85. hpcflow/sdk/core/cache.py +224 -0
  86. hpcflow/sdk/core/command_files.py +814 -0
  87. hpcflow/sdk/core/commands.py +424 -0
  88. hpcflow/sdk/core/element.py +2071 -0
  89. hpcflow/sdk/core/enums.py +221 -0
  90. hpcflow/sdk/core/environment.py +256 -0
  91. hpcflow/sdk/core/errors.py +1043 -0
  92. hpcflow/sdk/core/execute.py +207 -0
  93. hpcflow/sdk/core/json_like.py +809 -0
  94. hpcflow/sdk/core/loop.py +1320 -0
  95. hpcflow/sdk/core/loop_cache.py +282 -0
  96. hpcflow/sdk/core/object_list.py +933 -0
  97. hpcflow/sdk/core/parameters.py +3371 -0
  98. hpcflow/sdk/core/rule.py +196 -0
  99. hpcflow/sdk/core/run_dir_files.py +57 -0
  100. hpcflow/sdk/core/skip_reason.py +7 -0
  101. hpcflow/sdk/core/task.py +3792 -0
  102. hpcflow/sdk/core/task_schema.py +993 -0
  103. hpcflow/sdk/core/test_utils.py +538 -0
  104. hpcflow/sdk/core/types.py +447 -0
  105. hpcflow/sdk/core/utils.py +1207 -0
  106. hpcflow/sdk/core/validation.py +87 -0
  107. hpcflow/sdk/core/values.py +477 -0
  108. hpcflow/sdk/core/workflow.py +4820 -0
  109. hpcflow/sdk/core/zarr_io.py +206 -0
  110. hpcflow/sdk/data/__init__.py +13 -0
  111. hpcflow/sdk/data/config_file_schema.yaml +34 -0
  112. hpcflow/sdk/data/config_schema.yaml +260 -0
  113. hpcflow/sdk/data/environments_spec_schema.yaml +21 -0
  114. hpcflow/sdk/data/files_spec_schema.yaml +5 -0
  115. hpcflow/sdk/data/parameters_spec_schema.yaml +7 -0
  116. hpcflow/sdk/data/task_schema_spec_schema.yaml +3 -0
  117. hpcflow/sdk/data/workflow_spec_schema.yaml +22 -0
  118. hpcflow/sdk/demo/__init__.py +3 -0
  119. hpcflow/sdk/demo/cli.py +242 -0
  120. hpcflow/sdk/helper/__init__.py +3 -0
  121. hpcflow/sdk/helper/cli.py +137 -0
  122. hpcflow/sdk/helper/helper.py +300 -0
  123. hpcflow/sdk/helper/watcher.py +192 -0
  124. hpcflow/sdk/log.py +288 -0
  125. hpcflow/sdk/persistence/__init__.py +18 -0
  126. hpcflow/sdk/persistence/base.py +2817 -0
  127. hpcflow/sdk/persistence/defaults.py +6 -0
  128. hpcflow/sdk/persistence/discovery.py +39 -0
  129. hpcflow/sdk/persistence/json.py +954 -0
  130. hpcflow/sdk/persistence/pending.py +948 -0
  131. hpcflow/sdk/persistence/store_resource.py +203 -0
  132. hpcflow/sdk/persistence/types.py +309 -0
  133. hpcflow/sdk/persistence/utils.py +73 -0
  134. hpcflow/sdk/persistence/zarr.py +2388 -0
  135. hpcflow/sdk/runtime.py +320 -0
  136. hpcflow/sdk/submission/__init__.py +3 -0
  137. hpcflow/sdk/submission/enums.py +70 -0
  138. hpcflow/sdk/submission/jobscript.py +2379 -0
  139. hpcflow/sdk/submission/schedulers/__init__.py +281 -0
  140. hpcflow/sdk/submission/schedulers/direct.py +233 -0
  141. hpcflow/sdk/submission/schedulers/sge.py +376 -0
  142. hpcflow/sdk/submission/schedulers/slurm.py +598 -0
  143. hpcflow/sdk/submission/schedulers/utils.py +25 -0
  144. hpcflow/sdk/submission/shells/__init__.py +52 -0
  145. hpcflow/sdk/submission/shells/base.py +229 -0
  146. hpcflow/sdk/submission/shells/bash.py +504 -0
  147. hpcflow/sdk/submission/shells/os_version.py +115 -0
  148. hpcflow/sdk/submission/shells/powershell.py +352 -0
  149. hpcflow/sdk/submission/submission.py +1402 -0
  150. hpcflow/sdk/submission/types.py +140 -0
  151. hpcflow/sdk/typing.py +194 -0
  152. hpcflow/sdk/utils/arrays.py +69 -0
  153. hpcflow/sdk/utils/deferred_file.py +55 -0
  154. hpcflow/sdk/utils/hashing.py +16 -0
  155. hpcflow/sdk/utils/patches.py +31 -0
  156. hpcflow/sdk/utils/strings.py +69 -0
  157. hpcflow/tests/api/test_api.py +32 -0
  158. hpcflow/tests/conftest.py +123 -0
  159. hpcflow/tests/data/__init__.py +0 -0
  160. hpcflow/tests/data/benchmark_N_elements.yaml +6 -0
  161. hpcflow/tests/data/benchmark_script_runner.yaml +26 -0
  162. hpcflow/tests/data/multi_path_sequences.yaml +29 -0
  163. hpcflow/tests/data/workflow_1.json +10 -0
  164. hpcflow/tests/data/workflow_1.yaml +5 -0
  165. hpcflow/tests/data/workflow_1_slurm.yaml +8 -0
  166. hpcflow/tests/data/workflow_1_wsl.yaml +8 -0
  167. hpcflow/tests/data/workflow_test_run_abort.yaml +42 -0
  168. hpcflow/tests/jinja_templates/test_jinja_templates.py +161 -0
  169. hpcflow/tests/programs/test_programs.py +180 -0
  170. hpcflow/tests/schedulers/direct_linux/test_direct_linux_submission.py +12 -0
  171. hpcflow/tests/schedulers/sge/test_sge_submission.py +36 -0
  172. hpcflow/tests/schedulers/slurm/test_slurm_submission.py +14 -0
  173. hpcflow/tests/scripts/test_input_file_generators.py +282 -0
  174. hpcflow/tests/scripts/test_main_scripts.py +1361 -0
  175. hpcflow/tests/scripts/test_non_snippet_script.py +46 -0
  176. hpcflow/tests/scripts/test_ouput_file_parsers.py +353 -0
  177. hpcflow/tests/shells/wsl/test_wsl_submission.py +14 -0
  178. hpcflow/tests/unit/test_action.py +1066 -0
  179. hpcflow/tests/unit/test_action_rule.py +24 -0
  180. hpcflow/tests/unit/test_app.py +132 -0
  181. hpcflow/tests/unit/test_cache.py +46 -0
  182. hpcflow/tests/unit/test_cli.py +172 -0
  183. hpcflow/tests/unit/test_command.py +377 -0
  184. hpcflow/tests/unit/test_config.py +195 -0
  185. hpcflow/tests/unit/test_config_file.py +162 -0
  186. hpcflow/tests/unit/test_element.py +666 -0
  187. hpcflow/tests/unit/test_element_iteration.py +88 -0
  188. hpcflow/tests/unit/test_element_set.py +158 -0
  189. hpcflow/tests/unit/test_group.py +115 -0
  190. hpcflow/tests/unit/test_input_source.py +1479 -0
  191. hpcflow/tests/unit/test_input_value.py +398 -0
  192. hpcflow/tests/unit/test_jobscript_unit.py +757 -0
  193. hpcflow/tests/unit/test_json_like.py +1247 -0
  194. hpcflow/tests/unit/test_loop.py +2674 -0
  195. hpcflow/tests/unit/test_meta_task.py +325 -0
  196. hpcflow/tests/unit/test_multi_path_sequences.py +259 -0
  197. hpcflow/tests/unit/test_object_list.py +116 -0
  198. hpcflow/tests/unit/test_parameter.py +243 -0
  199. hpcflow/tests/unit/test_persistence.py +664 -0
  200. hpcflow/tests/unit/test_resources.py +243 -0
  201. hpcflow/tests/unit/test_run.py +286 -0
  202. hpcflow/tests/unit/test_run_directories.py +29 -0
  203. hpcflow/tests/unit/test_runtime.py +9 -0
  204. hpcflow/tests/unit/test_schema_input.py +372 -0
  205. hpcflow/tests/unit/test_shell.py +129 -0
  206. hpcflow/tests/unit/test_slurm.py +39 -0
  207. hpcflow/tests/unit/test_submission.py +502 -0
  208. hpcflow/tests/unit/test_task.py +2560 -0
  209. hpcflow/tests/unit/test_task_schema.py +182 -0
  210. hpcflow/tests/unit/test_utils.py +616 -0
  211. hpcflow/tests/unit/test_value_sequence.py +549 -0
  212. hpcflow/tests/unit/test_values.py +91 -0
  213. hpcflow/tests/unit/test_workflow.py +827 -0
  214. hpcflow/tests/unit/test_workflow_template.py +186 -0
  215. hpcflow/tests/unit/utils/test_arrays.py +40 -0
  216. hpcflow/tests/unit/utils/test_deferred_file_writer.py +34 -0
  217. hpcflow/tests/unit/utils/test_hashing.py +65 -0
  218. hpcflow/tests/unit/utils/test_patches.py +5 -0
  219. hpcflow/tests/unit/utils/test_redirect_std.py +50 -0
  220. hpcflow/tests/unit/utils/test_strings.py +97 -0
  221. hpcflow/tests/workflows/__init__.py +0 -0
  222. hpcflow/tests/workflows/test_directory_structure.py +31 -0
  223. hpcflow/tests/workflows/test_jobscript.py +355 -0
  224. hpcflow/tests/workflows/test_run_status.py +198 -0
  225. hpcflow/tests/workflows/test_skip_downstream.py +696 -0
  226. hpcflow/tests/workflows/test_submission.py +140 -0
  227. hpcflow/tests/workflows/test_workflows.py +564 -0
  228. hpcflow/tests/workflows/test_zip.py +18 -0
  229. hpcflow/viz_demo.ipynb +6794 -0
  230. hpcflow-0.2.0a271.dist-info/LICENSE +375 -0
  231. hpcflow-0.2.0a271.dist-info/METADATA +65 -0
  232. hpcflow-0.2.0a271.dist-info/RECORD +237 -0
  233. {hpcflow-0.1.15.dist-info → hpcflow-0.2.0a271.dist-info}/WHEEL +4 -5
  234. hpcflow-0.2.0a271.dist-info/entry_points.txt +6 -0
  235. hpcflow/api.py +0 -490
  236. hpcflow/archive/archive.py +0 -307
  237. hpcflow/archive/cloud/cloud.py +0 -45
  238. hpcflow/archive/cloud/errors.py +0 -9
  239. hpcflow/archive/cloud/providers/dropbox.py +0 -427
  240. hpcflow/archive/errors.py +0 -5
  241. hpcflow/base_db.py +0 -4
  242. hpcflow/config.py +0 -233
  243. hpcflow/copytree.py +0 -66
  244. hpcflow/data/examples/_config.yml +0 -14
  245. hpcflow/data/examples/damask/demo/1.run.yml +0 -4
  246. hpcflow/data/examples/damask/demo/2.process.yml +0 -29
  247. hpcflow/data/examples/damask/demo/geom.geom +0 -2052
  248. hpcflow/data/examples/damask/demo/load.load +0 -1
  249. hpcflow/data/examples/damask/demo/material.config +0 -185
  250. hpcflow/data/examples/damask/inputs/geom.geom +0 -2052
  251. hpcflow/data/examples/damask/inputs/load.load +0 -1
  252. hpcflow/data/examples/damask/inputs/material.config +0 -185
  253. hpcflow/data/examples/damask/profiles/_variable_lookup.yml +0 -21
  254. hpcflow/data/examples/damask/profiles/damask.yml +0 -4
  255. hpcflow/data/examples/damask/profiles/damask_process.yml +0 -8
  256. hpcflow/data/examples/damask/profiles/damask_run.yml +0 -5
  257. hpcflow/data/examples/damask/profiles/default.yml +0 -6
  258. hpcflow/data/examples/thinking.yml +0 -177
  259. hpcflow/errors.py +0 -2
  260. hpcflow/init_db.py +0 -37
  261. hpcflow/models.py +0 -2595
  262. hpcflow/nesting.py +0 -9
  263. hpcflow/profiles.py +0 -455
  264. hpcflow/project.py +0 -81
  265. hpcflow/scheduler.py +0 -322
  266. hpcflow/utils.py +0 -103
  267. hpcflow/validation.py +0 -166
  268. hpcflow/variables.py +0 -543
  269. hpcflow-0.1.15.dist-info/METADATA +0 -168
  270. hpcflow-0.1.15.dist-info/RECORD +0 -45
  271. hpcflow-0.1.15.dist-info/entry_points.txt +0 -8
  272. hpcflow-0.1.15.dist-info/top_level.txt +0 -1
  273. /hpcflow/{archive → data/jinja_templates}/__init__.py +0 -0
  274. /hpcflow/{archive/cloud → data/programs}/__init__.py +0 -0
  275. /hpcflow/{archive/cloud/providers → data/workflows}/__init__.py +0 -0
@@ -0,0 +1,1320 @@
1
+ """
2
+ A looping construct for a workflow.
3
+ There are multiple types of loop,
4
+ notably looping over a set of values or until a condition holds.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from collections import defaultdict
10
+ import copy
11
+ from pprint import pp
12
+ import pprint
13
+ from typing import Dict, List, Optional, Tuple, Union, Any
14
+ from warnings import warn
15
+ from collections import defaultdict
16
+ from itertools import chain
17
+ from typing import cast, TYPE_CHECKING
18
+ from typing_extensions import override
19
+
20
+ from hpcflow.sdk.core.app_aware import AppAware
21
+ from hpcflow.sdk.core.actions import EARStatus
22
+ from hpcflow.sdk.core.skip_reason import SkipReason
23
+ from hpcflow.sdk.core.errors import LoopTaskSubsetError
24
+ from hpcflow.sdk.core.json_like import ChildObjectSpec, JSONLike
25
+ from hpcflow.sdk.core.loop_cache import LoopCache, LoopIndex
26
+ from hpcflow.sdk.core.enums import InputSourceType, TaskSourceType
27
+ from hpcflow.sdk.core.utils import check_valid_py_identifier, nth_key, nth_value
28
+ from hpcflow.sdk.utils.strings import shorten_list_str
29
+ from hpcflow.sdk.log import TimeIt
30
+
31
+ if TYPE_CHECKING:
32
+ from collections.abc import Iterable, Iterator, Mapping, Sequence
33
+ from typing import Any, ClassVar
34
+ from typing_extensions import Self, TypeIs
35
+ from rich.status import Status
36
+ from ..typing import DataIndex, ParamSource
37
+ from .parameters import SchemaInput, InputSource
38
+ from .rule import Rule
39
+ from .task import WorkflowTask
40
+ from .types import IterableParam
41
+ from .workflow import Workflow, WorkflowTemplate
42
+
43
+
44
+ # @dataclass
45
+ # class StoppingCriterion:
46
+ # parameter: Parameter
47
+ # condition: ConditionLike
48
+
49
+
50
+ # @dataclass
51
+ # class Loop:
52
+ # parameter: Parameter
53
+ # stopping_criteria: StoppingCriterion
54
+ # # TODO: should be a logical combination of these (maybe provide a superclass in valida to re-use some logic there?)
55
+ # maximum_iterations: int
56
+
57
+
58
+ class Loop(JSONLike):
59
+ """
60
+ A loop in a workflow template.
61
+
62
+ Parameters
63
+ ----------
64
+ tasks: list[int | ~hpcflow.app.WorkflowTask]
65
+ List of task insert IDs or workflow tasks.
66
+ num_iterations:
67
+ Number of iterations to perform.
68
+ name: str
69
+ Loop name.
70
+ non_iterable_parameters: list[str]
71
+ Specify input parameters that should not iterate.
72
+ termination: v~hpcflow.app.Rule
73
+ Stopping criterion, expressed as a rule.
74
+ termination_task: int | ~hpcflow.app.WorkflowTask
75
+ Task at which to evaluate the termination condition.
76
+ """
77
+
78
+ _child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
79
+ ChildObjectSpec(name="termination", class_name="Rule"),
80
+ )
81
+
82
+ @classmethod
83
+ def __is_WorkflowTask(cls, value) -> TypeIs[WorkflowTask]:
84
+ return isinstance(value, cls._app.WorkflowTask)
85
+
86
+ def __init__(
87
+ self,
88
+ tasks: Iterable[int | str | WorkflowTask],
89
+ num_iterations: int,
90
+ name: str | None = None,
91
+ non_iterable_parameters: list[str] | None = None,
92
+ termination: Rule | None = None,
93
+ termination_task: int | str | WorkflowTask | None = None,
94
+ ) -> None:
95
+ _task_refs: list[int | str] = [
96
+ ref_i.insert_ID if self.__is_WorkflowTask(ref_i) else ref_i for ref_i in tasks
97
+ ]
98
+
99
+ if termination_task is None:
100
+ _term_task_ref = _task_refs[-1] # terminate on final task by default
101
+ elif self.__is_WorkflowTask(termination_task):
102
+ _term_task_ref = termination_task.insert_ID
103
+ else:
104
+ _term_task_ref = termination_task
105
+
106
+ if _term_task_ref not in _task_refs:
107
+ raise ValueError(
108
+ f"If specified, `termination_task` (provided: {termination_task!r}) must "
109
+ f"refer to a task that is part of the loop. Available task references "
110
+ f"are: {_task_refs!r}."
111
+ )
112
+
113
+ self._task_refs = _task_refs
114
+ self._task_insert_IDs = (
115
+ cast("list[int]", _task_refs)
116
+ if all(isinstance(ref_i, int) for ref_i in _task_refs)
117
+ else None
118
+ )
119
+
120
+ self._num_iterations = num_iterations
121
+ self._name = check_valid_py_identifier(name) if name else name
122
+ self._non_iterable_parameters = non_iterable_parameters or []
123
+ self._termination = termination
124
+ self._termination_task_ref = _term_task_ref
125
+ self._termination_task_insert_ID = (
126
+ _term_task_ref if isinstance(_term_task_ref, int) else None
127
+ )
128
+
129
+ self._workflow_template: WorkflowTemplate | None = (
130
+ None # assigned by parent WorkflowTemplate
131
+ )
132
+
133
+ @override
134
+ def _postprocess_to_dict(self, d: dict[str, Any]) -> dict[str, Any]:
135
+ out = super()._postprocess_to_dict(d)
136
+ return {k.lstrip("_"): v for k, v in out.items()}
137
+
138
+ @classmethod
139
+ def _json_like_constructor(cls, json_like: dict) -> Self:
140
+ """Invoked by `JSONLike.from_json_like` instead of `__init__`."""
141
+ if "task_insert_IDs" in json_like:
142
+ insert_IDs = json_like.pop("task_insert_IDs")
143
+ else:
144
+ insert_IDs = json_like.pop("tasks") # e.g. from YAML
145
+
146
+ if "termination_task_insert_ID" in json_like:
147
+ tt_iID = json_like.pop("termination_task_insert_ID")
148
+ elif "termination_task" in json_like:
149
+ tt_iID = json_like.pop("termination_task") # e.g. from YAML
150
+ else:
151
+ tt_iID = None
152
+
153
+ task_refs = None
154
+ term_task_ref = None
155
+ if "task_refs" in json_like:
156
+ task_refs = json_like.pop("task_refs")
157
+ if "termination_task_ref" in json_like:
158
+ term_task_ref = json_like.pop("termination_task_ref")
159
+
160
+ obj = cls(tasks=insert_IDs, termination_task=tt_iID, **json_like)
161
+ if task_refs:
162
+ obj._task_refs = task_refs
163
+ if term_task_ref is not None:
164
+ obj._termination_task_ref = term_task_ref
165
+
166
+ return obj
167
+
168
+ @property
169
+ def task_refs(self) -> tuple[int | str, ...]:
170
+ """Get the list of task references (insert IDs or task unique names) that define
171
+ the extent of the loop."""
172
+ return tuple(self._task_refs)
173
+
174
+ @property
175
+ def task_insert_IDs(self) -> tuple[int, ...]:
176
+ """Get the list of task insert_IDs that define the extent of the loop."""
177
+ assert self._task_insert_IDs
178
+ return tuple(self._task_insert_IDs)
179
+
180
+ @property
181
+ def name(self) -> str | None:
182
+ """
183
+ The name of the loop, if one was provided.
184
+ """
185
+ return self._name
186
+
187
+ @property
188
+ def num_iterations(self) -> int:
189
+ """
190
+ The number of loop iterations to do.
191
+ """
192
+ return self._num_iterations
193
+
194
+ @property
195
+ def non_iterable_parameters(self) -> Sequence[str]:
196
+ """
197
+ Which parameters are not iterable.
198
+ """
199
+ return self._non_iterable_parameters
200
+
201
+ @property
202
+ def termination(self) -> Rule | None:
203
+ """
204
+ A termination rule for the loop, if one is provided.
205
+ """
206
+ return self._termination
207
+
208
+ @property
209
+ def termination_task_ref(self) -> int | str:
210
+ """
211
+ The unique name of the task at which the loop will terminate.
212
+ """
213
+ return self._termination_task_ref
214
+
215
+ @property
216
+ def termination_task_insert_ID(self) -> int:
217
+ """
218
+ The insert ID of the task at which the loop will terminate.
219
+ """
220
+ assert self._termination_task_insert_ID is not None
221
+ return self._termination_task_insert_ID
222
+
223
+ @property
224
+ def termination_task(self) -> WorkflowTask:
225
+ """
226
+ The task at which the loop will terminate.
227
+ """
228
+ if (wt := self.workflow_template) is None:
229
+ raise RuntimeError(
230
+ "Workflow template must be assigned to retrieve task objects of the loop."
231
+ )
232
+ assert wt.workflow
233
+ return wt.workflow.tasks.get(insert_ID=self.termination_task_insert_ID)
234
+
235
+ @property
236
+ def workflow_template(self) -> WorkflowTemplate | None:
237
+ """
238
+ The workflow template that contains this loop.
239
+ """
240
+ return self._workflow_template
241
+
242
+ @workflow_template.setter
243
+ def workflow_template(self, template: WorkflowTemplate):
244
+ self._workflow_template = template
245
+
246
+ def __workflow(self) -> None | Workflow:
247
+ if (wt := self.workflow_template) is None:
248
+ return None
249
+ return wt.workflow
250
+
251
+ @property
252
+ def task_objects(self) -> tuple[WorkflowTask, ...]:
253
+ """
254
+ The tasks in the loop.
255
+ """
256
+ if not (wf := self.__workflow()):
257
+ raise RuntimeError(
258
+ "Workflow template must be assigned to retrieve task objects of the loop."
259
+ )
260
+ return tuple(wf.tasks.get(insert_ID=t_id) for t_id in self.task_insert_IDs)
261
+
262
+ def _validate_against_workflow(self, workflow: Workflow) -> None:
263
+ """Validate the loop parameters against the associated workflow."""
264
+
265
+ names = workflow.get_task_unique_names(map_to_insert_ID=True)
266
+ if self._task_insert_IDs is None:
267
+ err = lambda ref_i: (
268
+ f"Loop {self.name + ' ' if self.name else ''}has an invalid "
269
+ f"task reference {ref_i!r}. Such a task does not exist in "
270
+ f"the associated workflow, which has task names/insert IDs: "
271
+ f"{names!r}."
272
+ )
273
+ self._task_insert_IDs = []
274
+ for ref_i in self._task_refs:
275
+ if isinstance(ref_i, str):
276
+ # reference is task unique name, so retrieve the insert ID:
277
+ try:
278
+ iID = names[ref_i]
279
+ except KeyError:
280
+ raise ValueError(err(ref_i)) from None
281
+ else:
282
+ try:
283
+ workflow.tasks.get(insert_ID=ref_i)
284
+ except ValueError:
285
+ raise ValueError(err(ref_i)) from None
286
+ iID = ref_i
287
+ self._task_insert_IDs.append(iID)
288
+
289
+ if self._termination_task_insert_ID is None:
290
+ tt_ref = self._termination_task_ref
291
+ if isinstance(tt_ref, str):
292
+ tt_iID = names[tt_ref]
293
+ else:
294
+ tt_iID = tt_ref
295
+ self._termination_task_insert_ID = tt_iID
296
+
297
+ def __repr__(self) -> str:
298
+ num_iterations_str = ""
299
+ if self.num_iterations is not None:
300
+ num_iterations_str = f", num_iterations={self.num_iterations!r}"
301
+
302
+ name_str = ""
303
+ if self.name:
304
+ name_str = f", name={self.name!r}"
305
+
306
+ return (
307
+ f"{self.__class__.__name__}("
308
+ f"task_insert_IDs={self.task_insert_IDs!r}{num_iterations_str}{name_str}"
309
+ f")"
310
+ )
311
+
312
+ def __deepcopy__(self, memo: dict[int, Any]) -> Self:
313
+ kwargs = self.to_dict()
314
+ kwargs["tasks"] = kwargs.pop("task_insert_IDs")
315
+ kwargs["termination_task"] = kwargs.pop("termination_task_insert_ID")
316
+ task_refs = kwargs.pop("task_refs", None)
317
+ term_task_ref = kwargs.pop("termination_task_ref", None)
318
+ obj = self.__class__(**copy.deepcopy(kwargs, memo))
319
+ obj._workflow_template = self._workflow_template
320
+ obj._task_refs = task_refs
321
+ obj._termination_task_ref = term_task_ref
322
+ return obj
323
+
324
+
325
+ class WorkflowLoop(AppAware):
326
+ """
327
+ Class to represent a :py:class:`.Loop` that is bound to a
328
+ :py:class:`~hpcflow.app.Workflow`.
329
+
330
+ Parameters
331
+ ----------
332
+ index: int
333
+ The index of this loop in the workflow.
334
+ workflow: ~hpcflow.app.Workflow
335
+ The workflow containing this loop.
336
+ template: Loop
337
+ The loop that this was generated from.
338
+ num_added_iterations:
339
+ Description of what iterations have been added.
340
+ iterable_parameters:
341
+ Description of what parameters are being iterated over.
342
+ output_parameters:
343
+ Description of what parameter are output from this loop, and the final task insert
344
+ ID from which they are output.
345
+ parents: list[str]
346
+ The paths to the parent entities of this loop.
347
+ """
348
+
349
+ def __init__(
350
+ self,
351
+ index: int,
352
+ workflow: Workflow,
353
+ template: Loop,
354
+ num_added_iterations: dict[tuple[int, ...], int],
355
+ iterable_parameters: dict[str, IterableParam],
356
+ output_parameters: dict[str, int],
357
+ parents: list[str],
358
+ ) -> None:
359
+ self._index = index
360
+ self._workflow = workflow
361
+ self._template = template
362
+ self._num_added_iterations = num_added_iterations
363
+ self._iterable_parameters = iterable_parameters
364
+ self._output_parameters = output_parameters
365
+ self._parents = parents
366
+
367
+ # appended to when adding an empty loop to the workflow that is a parent of this
368
+ # loop; reset and added to `self._parents` on dump to disk:
369
+ self._pending_parents: list[str] = []
370
+
371
+ # used for `num_added_iterations` when a new loop iteration is added, or when
372
+ # parents are append to; reset to None on dump to disk. Each key is a tuple of
373
+ # parent loop indices and each value is the number of pending new iterations:
374
+ self._pending_num_added_iterations: dict[tuple[int, ...], int] | None = None
375
+
376
+ self._validate()
377
+
378
+ @TimeIt.decorator
379
+ def _validate(self) -> None:
380
+ # task subset must be a contiguous range of task indices:
381
+ task_indices = self.task_indices
382
+ task_min, task_max = task_indices[0], task_indices[-1]
383
+ if task_indices != tuple(range(task_min, task_max + 1)):
384
+ raise LoopTaskSubsetError(self.name, self.task_indices)
385
+
386
+ def __repr__(self) -> str:
387
+ return (
388
+ f"{self.__class__.__name__}(template={self.template!r}, "
389
+ f"num_added_iterations={self.num_added_iterations!r})"
390
+ )
391
+
392
+ @property
393
+ def num_added_iterations(self) -> Mapping[tuple[int, ...], int]:
394
+ """
395
+ The number of added iterations.
396
+ """
397
+ if self._pending_num_added_iterations:
398
+ return self._pending_num_added_iterations
399
+ else:
400
+ return self._num_added_iterations
401
+
402
+ @property
403
+ def __pending(self) -> dict[tuple[int, ...], int]:
404
+ if not self._pending_num_added_iterations:
405
+ self._pending_num_added_iterations = dict(self._num_added_iterations)
406
+ return self._pending_num_added_iterations
407
+
408
+ def _initialise_pending_added_iters(self, added_iters: Iterable[int]):
409
+ if not self._pending_num_added_iterations:
410
+ self._pending_num_added_iterations = dict(self._num_added_iterations)
411
+ if (added_iters_key := tuple(added_iters)) not in (pending := self.__pending):
412
+ pending[added_iters_key] = 1
413
+
414
+ def _increment_pending_added_iters(self, added_iters_key: Iterable[int]):
415
+ self.__pending[tuple(added_iters_key)] += 1
416
+
417
+ def _update_parents(self, parent: WorkflowLoop):
418
+ assert parent.name
419
+ self._pending_parents.append(parent.name)
420
+
421
+ self._pending_num_added_iterations = {
422
+ (*k, 0): v
423
+ for k, v in (
424
+ self._pending_num_added_iterations or self._num_added_iterations
425
+ ).items()
426
+ }
427
+
428
+ self.workflow._store.update_loop_parents(
429
+ index=self.index,
430
+ num_added_iters=self.num_added_iterations,
431
+ parents=self.parents,
432
+ )
433
+
434
+ def _reset_pending_num_added_iters(self) -> None:
435
+ self._pending_num_added_iterations = None
436
+
437
+ def _accept_pending_num_added_iters(self) -> None:
438
+ if self._pending_num_added_iterations:
439
+ self._num_added_iterations = dict(self._pending_num_added_iterations)
440
+ self._reset_pending_num_added_iters()
441
+
442
+ def _reset_pending_parents(self) -> None:
443
+ self._pending_parents = []
444
+
445
+ def _accept_pending_parents(self) -> None:
446
+ self._parents += self._pending_parents
447
+ self._reset_pending_parents()
448
+
449
+ @property
450
+ def index(self) -> int:
451
+ """
452
+ The index of this loop within its workflow.
453
+ """
454
+ return self._index
455
+
456
+ @property
457
+ def task_refs(self) -> tuple[int | str, ...]:
458
+ """
459
+ The list of task references (insert IDs or task unique names) that define the
460
+ extent of the loop."""
461
+ return self.template.task_refs
462
+
463
+ @property
464
+ def task_insert_IDs(self) -> tuple[int, ...]:
465
+ """
466
+ The insertion IDs of the tasks inside this loop.
467
+ """
468
+ return self.template.task_insert_IDs
469
+
470
+ @property
471
+ def task_objects(self) -> tuple[WorkflowTask, ...]:
472
+ """
473
+ The tasks in this loop.
474
+ """
475
+ return self.template.task_objects
476
+
477
+ @property
478
+ def task_indices(self) -> tuple[int, ...]:
479
+ """
480
+ The list of task indices that define the extent of the loop.
481
+ """
482
+ return tuple(task.index for task in self.task_objects)
483
+
484
+ @property
485
+ def workflow(self) -> Workflow:
486
+ """
487
+ The workflow containing this loop.
488
+ """
489
+ return self._workflow
490
+
491
+ @property
492
+ def template(self) -> Loop:
493
+ """
494
+ The loop template for this loop.
495
+ """
496
+ return self._template
497
+
498
+ @property
499
+ def parents(self) -> Sequence[str]:
500
+ """
501
+ The parents of this loop.
502
+ """
503
+ return self._parents + self._pending_parents
504
+
505
+ @property
506
+ def name(self) -> str:
507
+ """
508
+ The name of this loop, if one is defined.
509
+ """
510
+ assert self.template.name
511
+ return self.template.name
512
+
513
+ @property
514
+ def iterable_parameters(self) -> dict[str, IterableParam]:
515
+ """
516
+ The parameters that are being iterated over.
517
+ """
518
+ return self._iterable_parameters
519
+
520
+ @property
521
+ def output_parameters(self) -> dict[str, int]:
522
+ """
523
+ The parameters that are outputs of this loop, and the final task insert ID from
524
+ which each parameter is output.
525
+ """
526
+ return self._output_parameters
527
+
528
+ @property
529
+ def num_iterations(self) -> int:
530
+ """
531
+ The number of iterations.
532
+ """
533
+ return self.template.num_iterations
534
+
535
+ @property
536
+ def downstream_tasks(self) -> Iterator[WorkflowTask]:
537
+ """Tasks that are not part of the loop, and downstream from this loop."""
538
+ tasks = self.workflow.tasks
539
+ for idx in range(self.task_objects[-1].index + 1, len(tasks)):
540
+ yield tasks[idx]
541
+
542
+ @property
543
+ def upstream_tasks(self) -> Iterator[WorkflowTask]:
544
+ """Tasks that are not part of the loop, and upstream from this loop."""
545
+ tasks = self.workflow.tasks
546
+ for idx in range(0, self.task_objects[0].index):
547
+ yield tasks[idx]
548
+
549
+ @staticmethod
550
+ @TimeIt.decorator
551
+ def _find_iterable_and_output_parameters(
552
+ loop_template: Loop,
553
+ ) -> tuple[dict[str, IterableParam], dict[str, int]]:
554
+ all_inputs_first_idx: dict[str, int] = {}
555
+ all_outputs_idx: dict[str, list[int]] = defaultdict(list)
556
+ for task in loop_template.task_objects:
557
+ for typ in task.template.all_schema_input_types:
558
+ all_inputs_first_idx.setdefault(typ, task.insert_ID)
559
+ for typ in task.template.all_schema_output_types:
560
+ all_outputs_idx[typ].append(task.insert_ID)
561
+
562
+ # find input parameters that are also output parameters at a later/same task:
563
+ iterable_params: dict[str, IterableParam] = {}
564
+ for typ, first_idx in all_inputs_first_idx.items():
565
+ if typ in all_outputs_idx and first_idx <= all_outputs_idx[typ][0]:
566
+ iterable_params[typ] = {
567
+ "input_task": first_idx,
568
+ "output_tasks": all_outputs_idx[typ],
569
+ }
570
+
571
+ for non_iter in loop_template.non_iterable_parameters:
572
+ iterable_params.pop(non_iter, None)
573
+
574
+ final_out_tasks = {k: v[-1] for k, v in all_outputs_idx.items()}
575
+
576
+ return iterable_params, final_out_tasks
577
+
578
+ @classmethod
579
+ @TimeIt.decorator
580
+ def new_empty_loop(
581
+ cls,
582
+ index: int,
583
+ workflow: Workflow,
584
+ template: Loop,
585
+ iter_loop_idx: Sequence[Mapping[str, int]],
586
+ ) -> WorkflowLoop:
587
+ """
588
+ Make a new empty loop.
589
+
590
+ Parameters
591
+ ----------
592
+ index: int
593
+ The index of the loop to create.
594
+ workflow: ~hpcflow.app.Workflow
595
+ The workflow that will contain the loop.
596
+ template: Loop
597
+ The template for the loop.
598
+ iter_loop_idx: list[dict]
599
+ Iteration information from parent loops.
600
+ """
601
+ parent_loops = cls._get_parent_loops(index, workflow, template)
602
+ parent_names = [i.name for i in parent_loops]
603
+ num_added_iters: dict[tuple[int, ...], int] = {}
604
+ for i in iter_loop_idx:
605
+ num_added_iters[tuple([i[j] for j in parent_names])] = 1
606
+
607
+ iter_params, out_params = cls._find_iterable_and_output_parameters(template)
608
+ return cls(
609
+ index=index,
610
+ workflow=workflow,
611
+ template=template,
612
+ num_added_iterations=num_added_iters,
613
+ iterable_parameters=iter_params,
614
+ output_parameters=out_params,
615
+ parents=parent_names,
616
+ )
617
+
618
+ @classmethod
619
+ @TimeIt.decorator
620
+ def _get_parent_loops(
621
+ cls,
622
+ index: int,
623
+ workflow: Workflow,
624
+ template: Loop,
625
+ ) -> list[WorkflowLoop]:
626
+ parents: list[WorkflowLoop] = []
627
+ passed_self = False
628
+ self_tasks = set(template.task_insert_IDs)
629
+ for loop_i in workflow.loops:
630
+ if loop_i.index == index:
631
+ passed_self = True
632
+ continue
633
+ other_tasks = set(loop_i.task_insert_IDs)
634
+ if self_tasks.issubset(other_tasks):
635
+ if (self_tasks == other_tasks) and not passed_self:
636
+ continue
637
+ parents.append(loop_i)
638
+ return parents
639
+
640
+ @TimeIt.decorator
641
+ def get_parent_loops(self) -> list[WorkflowLoop]:
642
+ """Get loops whose task subset is a superset of this loop's task subset. If two
643
+ loops have identical task subsets, the first loop in the workflow loop list is
644
+ considered the child."""
645
+ return self._get_parent_loops(self.index, self.workflow, self.template)
646
+
647
+ @TimeIt.decorator
648
+ def get_child_loops(self) -> list[WorkflowLoop]:
649
+ """Get loops whose task subset is a subset of this loop's task subset. If two
650
+ loops have identical task subsets, the first loop in the workflow loop list is
651
+ considered the child."""
652
+ children: list[WorkflowLoop] = []
653
+ passed_self = False
654
+ self_tasks = set(self.task_insert_IDs)
655
+ for loop_i in self.workflow.loops:
656
+ if loop_i.index == self.index:
657
+ passed_self = True
658
+ continue
659
+ other_tasks = set(loop_i.task_insert_IDs)
660
+ if self_tasks.issuperset(other_tasks):
661
+ if (self_tasks == other_tasks) and passed_self:
662
+ continue
663
+ children.append(loop_i)
664
+
665
+ # order by depth, so direct child is first:
666
+ return sorted(children, key=lambda x: len(next(iter(x.num_added_iterations))))
667
+
668
+ @TimeIt.decorator
669
+ def add_iteration(
670
+ self,
671
+ parent_loop_indices: Mapping[str, int] | None = None,
672
+ cache: LoopCache | None = None,
673
+ status: Status | None = None,
674
+ ) -> None:
675
+ """
676
+ Add an iteration to this loop.
677
+
678
+ Parameters
679
+ ----------
680
+ parent_loop_indices:
681
+ Where have any parent loops got up to?
682
+ cache:
683
+ A cache used to make adding the iteration more efficient.
684
+ One will be created if it is not supplied.
685
+ """
686
+ if not cache:
687
+ cache = LoopCache.build(self.workflow)
688
+ assert cache is not None
689
+ parent_loops = self.get_parent_loops()
690
+ child_loops = self.get_child_loops()
691
+ parent_loop_indices_ = parent_loop_indices or {
692
+ loop.name: 0 for loop in parent_loops
693
+ }
694
+
695
+ iters_key = tuple(parent_loop_indices_[p_nm] for p_nm in self.parents)
696
+ cur_loop_idx = self.num_added_iterations[iters_key] - 1
697
+
698
+ # keys are (task.insert_ID and element.index)
699
+ all_new_data_idx: dict[tuple[int, int], DataIndex] = {}
700
+
701
+ # initialise a new `num_added_iterations` key on each child loop:
702
+ iters_key_dct = {
703
+ **parent_loop_indices_,
704
+ self.name: cur_loop_idx + 1,
705
+ }
706
+ for child in child_loops:
707
+ child._initialise_pending_added_iters(
708
+ iters_key_dct.get(j, 0) for j in child.parents
709
+ )
710
+
711
+ # needed for the case where an inner loop has only one iteration, meaning
712
+ # `add_iteration` will not be called recursively on it:
713
+ self.workflow._store.update_loop_num_iters(
714
+ index=child.index,
715
+ num_added_iters=child.num_added_iterations,
716
+ )
717
+
718
+ for task in self.task_objects:
719
+ new_loop_idx = LoopIndex(iters_key_dct) + {
720
+ child.name: 0
721
+ for child in child_loops
722
+ if task.insert_ID in child.task_insert_IDs
723
+ }
724
+ added_iter_IDs: list[int] = []
725
+ for elem_idx in range(task.num_elements):
726
+ elem_ID = task.element_IDs[elem_idx]
727
+
728
+ new_data_idx: DataIndex = {}
729
+
730
+ # copy resources from zeroth iteration:
731
+ zeroth_iter_ID, zi_iter_data_idx = cache.zeroth_iters[elem_ID]
732
+ zi_elem_ID, zi_idx = cache.iterations[zeroth_iter_ID]
733
+ zi_data_idx = nth_value(cache.data_idx[zi_elem_ID], zi_idx)
734
+
735
+ for key, val in zi_data_idx.items():
736
+ if key.startswith("resources."):
737
+ new_data_idx[key] = val
738
+
739
+ for inp in task.template.all_schema_inputs:
740
+ is_inp_task = False
741
+ if iter_dat := self.iterable_parameters.get(inp.typ):
742
+ is_inp_task = task.insert_ID == iter_dat["input_task"]
743
+
744
+ inp_key = f"inputs.{inp.typ}"
745
+
746
+ if is_inp_task:
747
+ assert iter_dat is not None
748
+ inp_dat_idx = self.__get_looped_index(
749
+ task,
750
+ elem_ID,
751
+ cache,
752
+ iter_dat,
753
+ inp,
754
+ parent_loops,
755
+ parent_loop_indices_,
756
+ child_loops,
757
+ cur_loop_idx,
758
+ )
759
+ new_data_idx[inp_key] = inp_dat_idx
760
+ else:
761
+ orig_inp_src = cache.elements[elem_ID]["input_sources"][inp_key]
762
+ inp_dat_idx = None
763
+
764
+ if orig_inp_src.source_type is InputSourceType.LOCAL:
765
+ # keep locally defined inputs from original element
766
+ inp_dat_idx = zi_data_idx[inp_key]
767
+
768
+ elif orig_inp_src.source_type is InputSourceType.DEFAULT:
769
+ # keep default value from original element
770
+ try:
771
+ inp_dat_idx = zi_data_idx[inp_key]
772
+ except KeyError:
773
+ # if this input is required by a conditional action, and
774
+ # that condition is not met, then this input will not
775
+ # exist in the action-run data index, so use the initial
776
+ # iteration data index:
777
+ inp_dat_idx = zi_iter_data_idx[inp_key]
778
+
779
+ elif orig_inp_src.source_type is InputSourceType.TASK:
780
+ inp_dat_idx = self.__get_task_index(
781
+ task,
782
+ orig_inp_src,
783
+ cache,
784
+ elem_ID,
785
+ inp,
786
+ inp_key,
787
+ parent_loop_indices_,
788
+ all_new_data_idx,
789
+ )
790
+
791
+ if inp_dat_idx is None:
792
+ raise RuntimeError(
793
+ f"Could not find a source for parameter {inp.typ} "
794
+ f"when adding a new iteration for task {task!r}."
795
+ )
796
+
797
+ new_data_idx[inp_key] = inp_dat_idx
798
+
799
+ # add any locally defined sub-parameters:
800
+ inp_statuses = cache.elements[elem_ID]["input_statuses"]
801
+ inp_status_inps = set(f"inputs.{inp}" for inp in inp_statuses)
802
+ for sub_param_i in inp_status_inps.difference(new_data_idx):
803
+ sub_param_data_idx_iter_0 = zi_data_idx
804
+ try:
805
+ sub_param_data_idx = sub_param_data_idx_iter_0[sub_param_i]
806
+ except KeyError:
807
+ # as before, if this input is required by a conditional action,
808
+ # and that condition is not met, then this input will not exist in
809
+ # the action-run data index, so use the initial iteration data
810
+ # index:
811
+ sub_param_data_idx = zi_data_idx[sub_param_i]
812
+
813
+ new_data_idx[sub_param_i] = sub_param_data_idx
814
+
815
+ for out in task.template.all_schema_outputs:
816
+ path_i = f"outputs.{out.typ}"
817
+ p_src: ParamSource = {"type": "EAR_output"}
818
+ new_data_idx[path_i] = self.workflow._add_unset_parameter_data(p_src)
819
+
820
+ schema_params = set(i for i in new_data_idx if len(i.split(".")) == 2)
821
+ all_new_data_idx[task.insert_ID, elem_idx] = new_data_idx
822
+
823
+ iter_ID_i = self.workflow._store.add_element_iteration(
824
+ element_ID=elem_ID,
825
+ data_idx=new_data_idx,
826
+ schema_parameters=list(schema_params),
827
+ loop_idx=new_loop_idx,
828
+ )
829
+ if cache:
830
+ cache.add_iteration(
831
+ iter_ID=iter_ID_i,
832
+ task_insert_ID=task.insert_ID,
833
+ element_ID=elem_ID,
834
+ loop_idx=new_loop_idx,
835
+ data_idx=new_data_idx,
836
+ )
837
+
838
+ added_iter_IDs.append(iter_ID_i)
839
+
840
+ task.initialise_EARs(iter_IDs=added_iter_IDs)
841
+
842
+ self._increment_pending_added_iters(
843
+ parent_loop_indices_[p_nm] for p_nm in self.parents
844
+ )
845
+ self.workflow._store.update_loop_num_iters(
846
+ index=self.index,
847
+ num_added_iters=self.num_added_iterations,
848
+ )
849
+
850
+ # add iterations to fixed-number-iteration children only:
851
+ for child in child_loops[::-1]:
852
+ if child.num_iterations is not None:
853
+ if status:
854
+ status_prev = str(status.status).rstrip(".")
855
+ for iter_idx in range(child.num_iterations - 1):
856
+ if status:
857
+ status.update(
858
+ f"{status_prev} --> ({child.name!r}): iteration "
859
+ f"{iter_idx + 2}/{child.num_iterations}."
860
+ )
861
+ par_idx = {parent_name: 0 for parent_name in child.parents}
862
+ if parent_loop_indices:
863
+ par_idx.update(parent_loop_indices)
864
+ par_idx[self.name] = cur_loop_idx + 1
865
+ child.add_iteration(parent_loop_indices=par_idx, cache=cache)
866
+
867
+ self.__update_loop_downstream_data_idx(parent_loop_indices_)
868
+
869
+ def __get_src_ID_and_groups(
870
+ self,
871
+ elem_ID: int,
872
+ iter_dat: IterableParam,
873
+ inp: SchemaInput,
874
+ cache: LoopCache,
875
+ task: WorkflowTask,
876
+ ) -> tuple[int, Sequence[int]]:
877
+ # `cache.elements` contains only elements that are part of the
878
+ # loop, so indexing a dependent element may raise:
879
+ src_elem_IDs = {}
880
+ for k, v in cache.element_dependents[elem_ID].items():
881
+ try:
882
+ if cache.elements[k]["task_insert_ID"] == iter_dat["output_tasks"][-1]:
883
+ src_elem_IDs[k] = v
884
+ except KeyError:
885
+ continue
886
+
887
+ # consider groups
888
+ single_data = inp.single_labelled_data
889
+ assert single_data is not None
890
+ inp_group_name = single_data.get("group")
891
+ grouped_elems = [
892
+ src_elem_j_ID
893
+ for src_elem_j_ID, src_elem_j_dat in src_elem_IDs.items()
894
+ if any(nm == inp_group_name for nm in src_elem_j_dat["group_names"])
895
+ ]
896
+
897
+ if not grouped_elems and len(src_elem_IDs) > 1:
898
+ raise NotImplementedError(
899
+ f"Multiple elements found in the iterable parameter "
900
+ f"{inp!r}'s latest output task (insert ID: "
901
+ f"{iter_dat['output_tasks'][-1]}) that can be used "
902
+ f"to parametrise the next iteration of task "
903
+ f"{task.unique_name!r}: "
904
+ f"{list(src_elem_IDs)!r}."
905
+ )
906
+
907
+ elif not src_elem_IDs:
908
+ # TODO: maybe OK?
909
+ raise NotImplementedError(
910
+ f"No elements found in the iterable parameter "
911
+ f"{inp!r}'s latest output task (insert ID: "
912
+ f"{iter_dat['output_tasks'][-1]}) that can be used "
913
+ f"to parametrise the next iteration."
914
+ )
915
+
916
+ return nth_key(src_elem_IDs, 0), grouped_elems
917
+
918
+ def __get_looped_index(
919
+ self,
920
+ task: WorkflowTask,
921
+ elem_ID: int,
922
+ cache: LoopCache,
923
+ iter_dat: IterableParam,
924
+ inp: SchemaInput,
925
+ parent_loops: list[WorkflowLoop],
926
+ parent_loop_indices: Mapping[str, int],
927
+ child_loops: list[WorkflowLoop],
928
+ cur_loop_idx: int,
929
+ ):
930
+ # source from final output task of previous iteration, with all parent
931
+ # loop indices the same as previous iteration, and all child loop indices
932
+ # maximised:
933
+
934
+ # identify element(s) from which this iterable input should be
935
+ # parametrised:
936
+ if task.insert_ID == iter_dat["output_tasks"][-1]:
937
+ # single-task loop
938
+ src_elem_ID = elem_ID
939
+ grouped_elems: Sequence[int] = []
940
+ else:
941
+ # multi-task loop
942
+ src_elem_ID, grouped_elems = self.__get_src_ID_and_groups(
943
+ elem_ID, iter_dat, inp, cache, task
944
+ )
945
+
946
+ child_loop_max_iters: dict[str, int] = {}
947
+ parent_loop_same_iters = {
948
+ loop.name: parent_loop_indices[loop.name] for loop in parent_loops
949
+ }
950
+ child_iter_parents = {
951
+ **parent_loop_same_iters,
952
+ self.name: cur_loop_idx,
953
+ }
954
+ for loop in child_loops:
955
+ if iter_dat["output_tasks"][-1] in loop.task_insert_IDs:
956
+ i_num_iters = loop.num_added_iterations[
957
+ tuple(child_iter_parents[j] for j in loop.parents)
958
+ ]
959
+ i_max = i_num_iters - 1
960
+ child_iter_parents[loop.name] = i_max
961
+ child_loop_max_iters[loop.name] = i_max
962
+
963
+ loop_idx_key = LoopIndex(child_loop_max_iters)
964
+ loop_idx_key.update(parent_loop_same_iters)
965
+ loop_idx_key[self.name] = cur_loop_idx
966
+
967
+ # identify the ElementIteration from which this input should be
968
+ # parametrised:
969
+ if grouped_elems:
970
+ src_data_idx = [
971
+ cache.data_idx[src_elem_ID][loop_idx_key] for src_elem_ID in grouped_elems
972
+ ]
973
+ if not src_data_idx:
974
+ raise RuntimeError(
975
+ f"Could not find a source iteration with loop_idx: "
976
+ f"{loop_idx_key!r}."
977
+ )
978
+ return [i[f"outputs.{inp.typ}"] for i in src_data_idx]
979
+ else:
980
+ return cache.data_idx[src_elem_ID][loop_idx_key][f"outputs.{inp.typ}"]
981
+
982
+ def __get_task_index(
983
+ self,
984
+ task: WorkflowTask,
985
+ orig_inp_src: InputSource,
986
+ cache: LoopCache,
987
+ elem_ID: int,
988
+ inp: SchemaInput,
989
+ inp_key: str,
990
+ parent_loop_indices: Mapping[str, int],
991
+ all_new_data_idx: Mapping[tuple[int, int], DataIndex],
992
+ ) -> int | list[int]:
993
+ if orig_inp_src.task_ref not in self.task_insert_IDs:
994
+ # source the data_idx from the iteration with same parent
995
+ # loop indices as the new iteration to add:
996
+ src_data_idx = next(
997
+ di_k
998
+ for li_k, di_k in cache.data_idx[elem_ID].items()
999
+ if all(li_k.get(p_k) == p_v for p_k, p_v in parent_loop_indices.items())
1000
+ )
1001
+
1002
+ # could be multiple, but they should all have the same
1003
+ # data index for this parameter:
1004
+ return src_data_idx[inp_key]
1005
+
1006
+ is_group = (
1007
+ inp.single_labelled_data is not None
1008
+ and "group" in inp.single_labelled_data
1009
+ # this input is a group, assume for now all elements
1010
+ )
1011
+
1012
+ # same task/element, but update iteration to the just-added
1013
+ # iteration:
1014
+ assert orig_inp_src.task_source_type is not None
1015
+ key_prefix = orig_inp_src.task_source_type.name.lower()
1016
+ prev_dat_idx_key = f"{key_prefix}s.{inp.typ}"
1017
+ new_sources: list[tuple[int, int]] = []
1018
+ for (tiID, e_idx), _ in all_new_data_idx.items():
1019
+ if tiID == orig_inp_src.task_ref:
1020
+ # find which element in that task `element`
1021
+ # depends on:
1022
+ src_elem_IDs = cache.element_dependents[
1023
+ self.workflow.tasks.get(insert_ID=tiID).element_IDs[e_idx]
1024
+ ]
1025
+ # `cache.elements` contains only elements that are part of the loop, so
1026
+ # indexing a dependent element may raise:
1027
+ src_elem_IDs_i = []
1028
+ for k, _v in src_elem_IDs.items():
1029
+ try:
1030
+ if (
1031
+ cache.elements[k]["task_insert_ID"] == task.insert_ID
1032
+ and k == elem_ID
1033
+ # filter src_elem_IDs_i for matching element IDs
1034
+ ):
1035
+
1036
+ src_elem_IDs_i.append(k)
1037
+ except KeyError:
1038
+ continue
1039
+
1040
+ if len(src_elem_IDs_i) == 1:
1041
+ new_sources.append((tiID, e_idx))
1042
+
1043
+ if is_group:
1044
+ # Convert into simple list of indices
1045
+ return list(
1046
+ chain.from_iterable(
1047
+ self.__as_sequence(all_new_data_idx[src][prev_dat_idx_key])
1048
+ for src in new_sources
1049
+ )
1050
+ )
1051
+ else:
1052
+ assert len(new_sources) == 1
1053
+ return all_new_data_idx[new_sources[0]][prev_dat_idx_key]
1054
+
1055
+ @staticmethod
1056
+ def __as_sequence(seq: int | Iterable[int]) -> Iterable[int]:
1057
+ if isinstance(seq, int):
1058
+ yield seq
1059
+ else:
1060
+ yield from seq
1061
+
1062
+ def __update_loop_downstream_data_idx(
1063
+ self,
1064
+ parent_loop_indices: Mapping[str, int],
1065
+ ):
1066
+ # update data indices of loop-downstream tasks that depend on task outputs from
1067
+ # this loop:
1068
+
1069
+ # keys: iter or run ID, values: dict of param type and new parameter index
1070
+ iter_new_data_idx: dict[int, DataIndex] = defaultdict(dict)
1071
+ run_new_data_idx: dict[int, DataIndex] = defaultdict(dict)
1072
+
1073
+ param_sources = self.workflow.get_all_parameter_sources()
1074
+
1075
+ # keys are parameter type, then task insert ID, then data index keys mapping to
1076
+ # their updated values:
1077
+ all_updates: dict[str, dict[int, dict[int, int]]] = defaultdict(
1078
+ lambda: defaultdict(dict)
1079
+ )
1080
+
1081
+ for task in self.downstream_tasks:
1082
+ for elem in task.elements:
1083
+ for param_typ, param_out_task_iID in self.output_parameters.items():
1084
+ if param_typ in task.template.all_schema_input_types:
1085
+ # this element's input *might* need updating, only if it has a
1086
+ # task input source type that is this loop's output task for this
1087
+ # parameter:
1088
+ elem_src = elem.input_sources[f"inputs.{param_typ}"]
1089
+ if (
1090
+ elem_src.source_type is InputSourceType.TASK
1091
+ and elem_src.task_source_type is TaskSourceType.OUTPUT
1092
+ and elem_src.task_ref == param_out_task_iID
1093
+ ):
1094
+ for iter_i in elem.iterations:
1095
+
1096
+ # do not modify element-iterations of previous iterations
1097
+ # of the current loop:
1098
+ skip_iter = False
1099
+ for k, v in parent_loop_indices.items():
1100
+ if iter_i.loop_idx.get(k) != v:
1101
+ skip_iter = True
1102
+ break
1103
+
1104
+ if skip_iter:
1105
+ continue
1106
+
1107
+ # update the iteration data index and any pending runs:
1108
+ iter_old_di = iter_i.data_idx[f"inputs.{param_typ}"]
1109
+
1110
+ is_group = True
1111
+ if not isinstance(iter_old_di, list):
1112
+ is_group = False
1113
+ iter_old_di = [iter_old_di]
1114
+
1115
+ iter_old_run_source = [
1116
+ param_sources[i]["EAR_ID"] for i in iter_old_di
1117
+ ]
1118
+ iter_old_run_objs = self.workflow.get_EARs_from_IDs(
1119
+ iter_old_run_source
1120
+ ) # TODO: use cache
1121
+
1122
+ # need to check the run source is actually from the loop
1123
+ # output task (it could be from a previous iteration of a
1124
+ # separate loop in this task):
1125
+ if any(
1126
+ i.task.insert_ID != param_out_task_iID
1127
+ for i in iter_old_run_objs
1128
+ ):
1129
+ continue
1130
+
1131
+ iter_new_iters = [
1132
+ i.element.iterations[-1] for i in iter_old_run_objs
1133
+ ]
1134
+
1135
+ # note: we can cast to int, because output keys never
1136
+ # have multiple data indices (unlike input keys):
1137
+ iter_new_dis = [
1138
+ cast("int", i.get_data_idx()[f"outputs.{param_typ}"])
1139
+ for i in iter_new_iters
1140
+ ]
1141
+
1142
+ # keep track of updates so we can also update task-input
1143
+ # type sources:
1144
+ all_updates[param_typ][task.insert_ID].update(
1145
+ dict(zip(iter_old_di, iter_new_dis))
1146
+ )
1147
+
1148
+ iter_new_data_idx[iter_i.id_][f"inputs.{param_typ}"] = (
1149
+ iter_new_dis if is_group else iter_new_dis[0]
1150
+ )
1151
+
1152
+ for run_j in iter_i.action_runs:
1153
+ if run_j.status is EARStatus.pending:
1154
+ try:
1155
+ old_di = run_j.data_idx[f"inputs.{param_typ}"]
1156
+ except KeyError:
1157
+ # not all actions will include this input
1158
+ continue
1159
+
1160
+ is_group = True
1161
+ if not isinstance(old_di, list):
1162
+ is_group = False
1163
+ old_di = [old_di]
1164
+
1165
+ old_run_source = [
1166
+ param_sources[i]["EAR_ID"] for i in old_di
1167
+ ]
1168
+ old_run_objs = self.workflow.get_EARs_from_IDs(
1169
+ old_run_source
1170
+ ) # TODO: use cache
1171
+
1172
+ # need to check the run source is actually from the loop
1173
+ # output task (it could be from a previous action in this
1174
+ # element-iteration):
1175
+ if any(
1176
+ i.task.insert_ID != param_out_task_iID
1177
+ for i in old_run_objs
1178
+ ):
1179
+ continue
1180
+
1181
+ new_iters = [
1182
+ i.element.iterations[-1] for i in old_run_objs
1183
+ ]
1184
+
1185
+ # note: we can cast to int, because output keys
1186
+ # never have multiple data indices (unlike input
1187
+ # keys):
1188
+ new_dis = [
1189
+ cast(
1190
+ "int",
1191
+ i.get_data_idx()[f"outputs.{param_typ}"],
1192
+ )
1193
+ for i in new_iters
1194
+ ]
1195
+
1196
+ run_new_data_idx[run_j.id_][
1197
+ f"inputs.{param_typ}"
1198
+ ] = (new_dis if is_group else new_dis[0])
1199
+
1200
+ elif (
1201
+ elem_src.source_type is InputSourceType.TASK
1202
+ and elem_src.task_source_type is TaskSourceType.INPUT
1203
+ ):
1204
+ # parameters are that sourced from inputs of other tasks,
1205
+ # might need to be updated if those other tasks have
1206
+ # themselves had their data indices updated:
1207
+ assert elem_src.task_ref
1208
+ ups_i = all_updates.get(param_typ, {}).get(elem_src.task_ref)
1209
+
1210
+ if ups_i:
1211
+ # if a further-downstream task has a task-input source
1212
+ # that points to this task, this will also need updating:
1213
+ all_updates[param_typ][task.insert_ID].update(ups_i)
1214
+
1215
+ else:
1216
+ continue
1217
+
1218
+ for iter_i in elem.iterations:
1219
+
1220
+ # update the iteration data index and any pending runs:
1221
+ iter_old_di = iter_i.data_idx[f"inputs.{param_typ}"]
1222
+
1223
+ is_group = True
1224
+ if not isinstance(iter_old_di, list):
1225
+ is_group = False
1226
+ iter_old_di = [iter_old_di]
1227
+
1228
+ iter_new_dis = [ups_i.get(i, i) for i in iter_old_di]
1229
+
1230
+ if iter_new_dis != iter_old_di:
1231
+ iter_new_data_idx[iter_i.id_][
1232
+ f"inputs.{param_typ}"
1233
+ ] = (iter_new_dis if is_group else iter_new_dis[0])
1234
+
1235
+ for run_j in iter_i.action_runs:
1236
+ if run_j.status is EARStatus.pending:
1237
+ try:
1238
+ old_di = run_j.data_idx[f"inputs.{param_typ}"]
1239
+ except KeyError:
1240
+ # not all actions will include this input
1241
+ continue
1242
+
1243
+ is_group = True
1244
+ if not isinstance(old_di, list):
1245
+ is_group = False
1246
+ old_di = [old_di]
1247
+
1248
+ new_dis = [ups_i.get(i, i) for i in old_di]
1249
+
1250
+ if new_dis != old_di:
1251
+ run_new_data_idx[run_j.id_][
1252
+ f"inputs.{param_typ}"
1253
+ ] = (new_dis if is_group else new_dis[0])
1254
+
1255
+ # now update data indices (TODO: including in cache!)
1256
+ if iter_new_data_idx:
1257
+ self.workflow._store.update_iter_data_indices(iter_new_data_idx)
1258
+
1259
+ if run_new_data_idx:
1260
+ self.workflow._store.update_run_data_indices(run_new_data_idx)
1261
+
1262
+ def test_termination(self, element_iter) -> bool:
1263
+ """Check if a loop should terminate, given the specified completed element
1264
+ iteration."""
1265
+ if self.template.termination:
1266
+ return self.template.termination.test(element_iter)
1267
+ return False
1268
+
1269
+ @TimeIt.decorator
1270
+ def get_element_IDs(self):
1271
+ elem_IDs = [
1272
+ j
1273
+ for i in self.task_insert_IDs
1274
+ for j in self.workflow.tasks.get(insert_ID=i).element_IDs
1275
+ ]
1276
+ return elem_IDs
1277
+
1278
+ @TimeIt.decorator
1279
+ def get_elements(self):
1280
+ return self.workflow.get_elements_from_IDs(self.get_element_IDs())
1281
+
1282
+ @TimeIt.decorator
1283
+ def skip_downstream_iterations(self, elem_iter) -> list[int]:
1284
+ """
1285
+ Parameters
1286
+ ----------
1287
+ elem_iter
1288
+ The element iteration whose subsequent iterations should be skipped.
1289
+ dep_element_IDs
1290
+ List of elements that are dependent (recursively) on the element
1291
+ of `elem_iter`.
1292
+ """
1293
+ current_iter_idx = elem_iter.loop_idx[self.name]
1294
+ current_task_iID = elem_iter.task.insert_ID
1295
+ self._app.logger.info(
1296
+ f"setting loop {self.name!r} iterations downstream of current iteration "
1297
+ f"index {current_iter_idx} to skip"
1298
+ )
1299
+ elements = self.get_elements()
1300
+
1301
+ # TODO: fix for multiple loop cycles
1302
+ warn(
1303
+ "skip downstream iterations does not work correctly for multiple loop cycles!"
1304
+ )
1305
+
1306
+ to_skip = []
1307
+ for elem in elements:
1308
+ for iter_i in elem.iterations:
1309
+ if iter_i.loop_idx[self.name] > current_iter_idx or (
1310
+ iter_i.loop_idx[self.name] == current_iter_idx
1311
+ and iter_i.task.insert_ID > current_task_iID
1312
+ ):
1313
+ to_skip.extend(iter_i.EAR_IDs_flat)
1314
+
1315
+ self._app.logger.info(
1316
+ f"{len(to_skip)} runs will be set to skip: {shorten_list_str(to_skip)}"
1317
+ )
1318
+ self.workflow.set_EAR_skip({k: SkipReason.LOOP_TERMINATION for k in to_skip})
1319
+
1320
+ return to_skip