hpcflow-new2 0.2.0a189__py3-none-any.whl → 0.2.0a199__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hpcflow/__pyinstaller/hook-hpcflow.py +9 -6
- hpcflow/_version.py +1 -1
- hpcflow/app.py +1 -0
- hpcflow/data/scripts/bad_script.py +2 -0
- hpcflow/data/scripts/do_nothing.py +2 -0
- hpcflow/data/scripts/env_specifier_test/input_file_generator_pass_env_spec.py +4 -0
- hpcflow/data/scripts/env_specifier_test/main_script_test_pass_env_spec.py +8 -0
- hpcflow/data/scripts/env_specifier_test/output_file_parser_pass_env_spec.py +4 -0
- hpcflow/data/scripts/env_specifier_test/v1/input_file_generator_basic.py +4 -0
- hpcflow/data/scripts/env_specifier_test/v1/main_script_test_direct_in_direct_out.py +7 -0
- hpcflow/data/scripts/env_specifier_test/v1/output_file_parser_basic.py +4 -0
- hpcflow/data/scripts/env_specifier_test/v2/main_script_test_direct_in_direct_out.py +7 -0
- hpcflow/data/scripts/input_file_generator_basic.py +3 -0
- hpcflow/data/scripts/input_file_generator_basic_FAIL.py +3 -0
- hpcflow/data/scripts/input_file_generator_test_stdout_stderr.py +8 -0
- hpcflow/data/scripts/main_script_test_direct_in.py +3 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_2.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed_group.py +7 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_3.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_group_direct_out_3.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_group_one_fail_direct_out_3.py +6 -0
- hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +1 -1
- hpcflow/data/scripts/main_script_test_hdf5_in_obj_2.py +12 -0
- hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +1 -1
- hpcflow/data/scripts/main_script_test_json_out_FAIL.py +3 -0
- hpcflow/data/scripts/main_script_test_shell_env_vars.py +12 -0
- hpcflow/data/scripts/main_script_test_std_out_std_err.py +6 -0
- hpcflow/data/scripts/output_file_parser_basic.py +3 -0
- hpcflow/data/scripts/output_file_parser_basic_FAIL.py +7 -0
- hpcflow/data/scripts/output_file_parser_test_stdout_stderr.py +8 -0
- hpcflow/data/scripts/script_exit_test.py +5 -0
- hpcflow/data/template_components/environments.yaml +1 -1
- hpcflow/sdk/__init__.py +26 -15
- hpcflow/sdk/app.py +2192 -768
- hpcflow/sdk/cli.py +506 -296
- hpcflow/sdk/cli_common.py +105 -7
- hpcflow/sdk/config/__init__.py +1 -1
- hpcflow/sdk/config/callbacks.py +115 -43
- hpcflow/sdk/config/cli.py +126 -103
- hpcflow/sdk/config/config.py +674 -318
- hpcflow/sdk/config/config_file.py +131 -95
- hpcflow/sdk/config/errors.py +125 -84
- hpcflow/sdk/config/types.py +148 -0
- hpcflow/sdk/core/__init__.py +25 -1
- hpcflow/sdk/core/actions.py +1771 -1059
- hpcflow/sdk/core/app_aware.py +24 -0
- hpcflow/sdk/core/cache.py +139 -79
- hpcflow/sdk/core/command_files.py +263 -287
- hpcflow/sdk/core/commands.py +145 -112
- hpcflow/sdk/core/element.py +828 -535
- hpcflow/sdk/core/enums.py +192 -0
- hpcflow/sdk/core/environment.py +74 -93
- hpcflow/sdk/core/errors.py +455 -52
- hpcflow/sdk/core/execute.py +207 -0
- hpcflow/sdk/core/json_like.py +540 -272
- hpcflow/sdk/core/loop.py +751 -347
- hpcflow/sdk/core/loop_cache.py +164 -47
- hpcflow/sdk/core/object_list.py +370 -207
- hpcflow/sdk/core/parameters.py +1100 -627
- hpcflow/sdk/core/rule.py +59 -41
- hpcflow/sdk/core/run_dir_files.py +21 -37
- hpcflow/sdk/core/skip_reason.py +7 -0
- hpcflow/sdk/core/task.py +1649 -1339
- hpcflow/sdk/core/task_schema.py +308 -196
- hpcflow/sdk/core/test_utils.py +191 -114
- hpcflow/sdk/core/types.py +440 -0
- hpcflow/sdk/core/utils.py +485 -309
- hpcflow/sdk/core/validation.py +82 -9
- hpcflow/sdk/core/workflow.py +2544 -1178
- hpcflow/sdk/core/zarr_io.py +98 -137
- hpcflow/sdk/data/workflow_spec_schema.yaml +2 -0
- hpcflow/sdk/demo/cli.py +53 -33
- hpcflow/sdk/helper/cli.py +18 -15
- hpcflow/sdk/helper/helper.py +75 -63
- hpcflow/sdk/helper/watcher.py +61 -28
- hpcflow/sdk/log.py +122 -71
- hpcflow/sdk/persistence/__init__.py +8 -31
- hpcflow/sdk/persistence/base.py +1360 -606
- hpcflow/sdk/persistence/defaults.py +6 -0
- hpcflow/sdk/persistence/discovery.py +38 -0
- hpcflow/sdk/persistence/json.py +568 -188
- hpcflow/sdk/persistence/pending.py +382 -179
- hpcflow/sdk/persistence/store_resource.py +39 -23
- hpcflow/sdk/persistence/types.py +318 -0
- hpcflow/sdk/persistence/utils.py +14 -11
- hpcflow/sdk/persistence/zarr.py +1337 -433
- hpcflow/sdk/runtime.py +44 -41
- hpcflow/sdk/submission/{jobscript_info.py → enums.py} +39 -12
- hpcflow/sdk/submission/jobscript.py +1651 -692
- hpcflow/sdk/submission/schedulers/__init__.py +167 -39
- hpcflow/sdk/submission/schedulers/direct.py +121 -81
- hpcflow/sdk/submission/schedulers/sge.py +170 -129
- hpcflow/sdk/submission/schedulers/slurm.py +291 -268
- hpcflow/sdk/submission/schedulers/utils.py +12 -2
- hpcflow/sdk/submission/shells/__init__.py +14 -15
- hpcflow/sdk/submission/shells/base.py +150 -29
- hpcflow/sdk/submission/shells/bash.py +283 -173
- hpcflow/sdk/submission/shells/os_version.py +31 -30
- hpcflow/sdk/submission/shells/powershell.py +228 -170
- hpcflow/sdk/submission/submission.py +1014 -335
- hpcflow/sdk/submission/types.py +140 -0
- hpcflow/sdk/typing.py +182 -12
- hpcflow/sdk/utils/arrays.py +71 -0
- hpcflow/sdk/utils/deferred_file.py +55 -0
- hpcflow/sdk/utils/hashing.py +16 -0
- hpcflow/sdk/utils/patches.py +12 -0
- hpcflow/sdk/utils/strings.py +33 -0
- hpcflow/tests/api/test_api.py +32 -0
- hpcflow/tests/conftest.py +27 -6
- hpcflow/tests/data/multi_path_sequences.yaml +29 -0
- hpcflow/tests/data/workflow_test_run_abort.yaml +34 -35
- hpcflow/tests/schedulers/sge/test_sge_submission.py +36 -0
- hpcflow/tests/schedulers/slurm/test_slurm_submission.py +5 -2
- hpcflow/tests/scripts/test_input_file_generators.py +282 -0
- hpcflow/tests/scripts/test_main_scripts.py +866 -85
- hpcflow/tests/scripts/test_non_snippet_script.py +46 -0
- hpcflow/tests/scripts/test_ouput_file_parsers.py +353 -0
- hpcflow/tests/shells/wsl/test_wsl_submission.py +12 -4
- hpcflow/tests/unit/test_action.py +262 -75
- hpcflow/tests/unit/test_action_rule.py +9 -4
- hpcflow/tests/unit/test_app.py +33 -6
- hpcflow/tests/unit/test_cache.py +46 -0
- hpcflow/tests/unit/test_cli.py +134 -1
- hpcflow/tests/unit/test_command.py +71 -54
- hpcflow/tests/unit/test_config.py +142 -16
- hpcflow/tests/unit/test_config_file.py +21 -18
- hpcflow/tests/unit/test_element.py +58 -62
- hpcflow/tests/unit/test_element_iteration.py +50 -1
- hpcflow/tests/unit/test_element_set.py +29 -19
- hpcflow/tests/unit/test_group.py +4 -2
- hpcflow/tests/unit/test_input_source.py +116 -93
- hpcflow/tests/unit/test_input_value.py +29 -24
- hpcflow/tests/unit/test_jobscript_unit.py +757 -0
- hpcflow/tests/unit/test_json_like.py +44 -35
- hpcflow/tests/unit/test_loop.py +1396 -84
- hpcflow/tests/unit/test_meta_task.py +325 -0
- hpcflow/tests/unit/test_multi_path_sequences.py +229 -0
- hpcflow/tests/unit/test_object_list.py +17 -12
- hpcflow/tests/unit/test_parameter.py +29 -7
- hpcflow/tests/unit/test_persistence.py +237 -42
- hpcflow/tests/unit/test_resources.py +20 -18
- hpcflow/tests/unit/test_run.py +117 -6
- hpcflow/tests/unit/test_run_directories.py +29 -0
- hpcflow/tests/unit/test_runtime.py +2 -1
- hpcflow/tests/unit/test_schema_input.py +23 -15
- hpcflow/tests/unit/test_shell.py +23 -2
- hpcflow/tests/unit/test_slurm.py +8 -7
- hpcflow/tests/unit/test_submission.py +38 -89
- hpcflow/tests/unit/test_task.py +352 -247
- hpcflow/tests/unit/test_task_schema.py +33 -20
- hpcflow/tests/unit/test_utils.py +9 -11
- hpcflow/tests/unit/test_value_sequence.py +15 -12
- hpcflow/tests/unit/test_workflow.py +114 -83
- hpcflow/tests/unit/test_workflow_template.py +0 -1
- hpcflow/tests/unit/utils/test_arrays.py +40 -0
- hpcflow/tests/unit/utils/test_deferred_file_writer.py +34 -0
- hpcflow/tests/unit/utils/test_hashing.py +65 -0
- hpcflow/tests/unit/utils/test_patches.py +5 -0
- hpcflow/tests/unit/utils/test_redirect_std.py +50 -0
- hpcflow/tests/workflows/__init__.py +0 -0
- hpcflow/tests/workflows/test_directory_structure.py +31 -0
- hpcflow/tests/workflows/test_jobscript.py +334 -1
- hpcflow/tests/workflows/test_run_status.py +198 -0
- hpcflow/tests/workflows/test_skip_downstream.py +696 -0
- hpcflow/tests/workflows/test_submission.py +140 -0
- hpcflow/tests/workflows/test_workflows.py +160 -15
- hpcflow/tests/workflows/test_zip.py +18 -0
- hpcflow/viz_demo.ipynb +6587 -3
- {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/METADATA +8 -4
- hpcflow_new2-0.2.0a199.dist-info/RECORD +221 -0
- hpcflow/sdk/core/parallel.py +0 -21
- hpcflow_new2-0.2.0a189.dist-info/RECORD +0 -158
- {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/LICENSE +0 -0
- {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/WHEEL +0 -0
- {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/entry_points.txt +0 -0
hpcflow/sdk/core/loop.py
CHANGED
@@ -6,21 +6,39 @@ notably looping over a set of values or until a condition holds.
|
|
6
6
|
|
7
7
|
from __future__ import annotations
|
8
8
|
|
9
|
+
from collections import defaultdict
|
9
10
|
import copy
|
10
|
-
from
|
11
|
-
|
12
|
-
from
|
11
|
+
from pprint import pp
|
12
|
+
import pprint
|
13
|
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
14
|
+
from warnings import warn
|
15
|
+
from collections import defaultdict
|
16
|
+
from itertools import chain
|
17
|
+
from typing import cast, TYPE_CHECKING
|
18
|
+
from typing_extensions import override
|
19
|
+
|
20
|
+
from hpcflow.sdk.core.app_aware import AppAware
|
21
|
+
from hpcflow.sdk.core.actions import EARStatus
|
22
|
+
from hpcflow.sdk.core.skip_reason import SkipReason
|
13
23
|
from hpcflow.sdk.core.errors import LoopTaskSubsetError
|
14
24
|
from hpcflow.sdk.core.json_like import ChildObjectSpec, JSONLike
|
15
|
-
from hpcflow.sdk.core.loop_cache import LoopCache
|
16
|
-
from hpcflow.sdk.core.
|
17
|
-
from hpcflow.sdk.core.task import WorkflowTask
|
25
|
+
from hpcflow.sdk.core.loop_cache import LoopCache, LoopIndex
|
26
|
+
from hpcflow.sdk.core.enums import InputSourceType, TaskSourceType
|
18
27
|
from hpcflow.sdk.core.utils import check_valid_py_identifier, nth_key, nth_value
|
28
|
+
from hpcflow.sdk.utils.strings import shorten_list_str
|
19
29
|
from hpcflow.sdk.log import TimeIt
|
20
30
|
|
21
|
-
|
22
|
-
|
23
|
-
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
33
|
+
from typing import Any, ClassVar
|
34
|
+
from typing_extensions import Self, TypeIs
|
35
|
+
from rich.status import Status
|
36
|
+
from ..typing import DataIndex, ParamSource
|
37
|
+
from .parameters import SchemaInput, InputSource
|
38
|
+
from .rule import Rule
|
39
|
+
from .task import WorkflowTask
|
40
|
+
from .types import IterableParam
|
41
|
+
from .workflow import Workflow, WorkflowTemplate
|
24
42
|
|
25
43
|
|
26
44
|
# @dataclass
|
@@ -53,126 +71,189 @@ class Loop(JSONLike):
|
|
53
71
|
Specify input parameters that should not iterate.
|
54
72
|
termination: v~hpcflow.app.Rule
|
55
73
|
Stopping criterion, expressed as a rule.
|
74
|
+
termination_task: int | ~hpcflow.app.WorkflowTask
|
75
|
+
Task at which to evaluate the termination condition.
|
56
76
|
"""
|
57
77
|
|
58
|
-
|
59
|
-
|
78
|
+
_child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
|
79
|
+
ChildObjectSpec(name="termination", class_name="Rule"),
|
80
|
+
)
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def __is_WorkflowTask(cls, value) -> TypeIs[WorkflowTask]:
|
84
|
+
return isinstance(value, cls._app.WorkflowTask)
|
60
85
|
|
61
86
|
def __init__(
|
62
87
|
self,
|
63
|
-
tasks:
|
88
|
+
tasks: Iterable[int | WorkflowTask],
|
64
89
|
num_iterations: int,
|
65
|
-
name:
|
66
|
-
non_iterable_parameters:
|
67
|
-
termination:
|
90
|
+
name: str | None = None,
|
91
|
+
non_iterable_parameters: list[str] | None = None,
|
92
|
+
termination: Rule | None = None,
|
93
|
+
termination_task: int | WorkflowTask | None = None,
|
68
94
|
) -> None:
|
69
|
-
_task_insert_IDs = []
|
95
|
+
_task_insert_IDs: list[int] = []
|
70
96
|
for task in tasks:
|
71
|
-
if
|
97
|
+
if self.__is_WorkflowTask(task):
|
72
98
|
_task_insert_IDs.append(task.insert_ID)
|
73
99
|
elif isinstance(task, int):
|
74
100
|
_task_insert_IDs.append(task)
|
75
101
|
else:
|
76
102
|
raise TypeError(
|
77
103
|
f"`tasks` must be a list whose elements are either task insert IDs "
|
78
|
-
f"or WorkflowTask objects, but received the following: {tasks!r}."
|
104
|
+
f"or `WorkflowTask` objects, but received the following: {tasks!r}."
|
79
105
|
)
|
80
106
|
|
107
|
+
if termination_task is None:
|
108
|
+
_term_task_iID = _task_insert_IDs[-1] # terminate on final task by default
|
109
|
+
elif self.__is_WorkflowTask(termination_task):
|
110
|
+
_term_task_iID = termination_task.insert_ID
|
111
|
+
elif isinstance(task, int):
|
112
|
+
_term_task_iID = termination_task
|
113
|
+
else:
|
114
|
+
raise TypeError(
|
115
|
+
f"`termination_task` must be a task insert ID or a `WorkflowTask` "
|
116
|
+
f"object, but received the following: {termination_task!r}."
|
117
|
+
)
|
118
|
+
|
119
|
+
if _term_task_iID not in _task_insert_IDs:
|
120
|
+
raise ValueError(
|
121
|
+
f"If specified, `termination_task` (provided: {termination_task!r}) must "
|
122
|
+
f"refer to a task that is part of the loop. Available task insert IDs "
|
123
|
+
f"are: {_task_insert_IDs!r}."
|
124
|
+
)
|
125
|
+
|
81
126
|
self._task_insert_IDs = _task_insert_IDs
|
82
127
|
self._num_iterations = num_iterations
|
83
128
|
self._name = check_valid_py_identifier(name) if name else name
|
84
129
|
self._non_iterable_parameters = non_iterable_parameters or []
|
85
130
|
self._termination = termination
|
131
|
+
self._termination_task_insert_ID = _term_task_iID
|
86
132
|
|
87
|
-
self._workflow_template
|
133
|
+
self._workflow_template: WorkflowTemplate | None = (
|
134
|
+
None # assigned by parent WorkflowTemplate
|
135
|
+
)
|
88
136
|
|
89
|
-
|
90
|
-
|
137
|
+
@override
|
138
|
+
def _postprocess_to_dict(self, d: dict[str, Any]) -> dict[str, Any]:
|
139
|
+
out = super()._postprocess_to_dict(d)
|
91
140
|
return {k.lstrip("_"): v for k, v in out.items()}
|
92
141
|
|
93
142
|
@classmethod
|
94
|
-
def _json_like_constructor(cls, json_like):
|
143
|
+
def _json_like_constructor(cls, json_like: dict) -> Self:
|
95
144
|
"""Invoked by `JSONLike.from_json_like` instead of `__init__`."""
|
96
145
|
if "task_insert_IDs" in json_like:
|
97
146
|
insert_IDs = json_like.pop("task_insert_IDs")
|
98
147
|
else:
|
99
148
|
insert_IDs = json_like.pop("tasks")
|
100
|
-
|
101
|
-
|
149
|
+
|
150
|
+
if "termination_task_insert_ID" in json_like:
|
151
|
+
tt_iID = json_like.pop("termination_task_insert_ID")
|
152
|
+
elif "termination_task" in json_like:
|
153
|
+
tt_iID = json_like.pop("termination_task")
|
154
|
+
else:
|
155
|
+
tt_iID = None
|
156
|
+
|
157
|
+
return cls(tasks=insert_IDs, termination_task=tt_iID, **json_like)
|
102
158
|
|
103
159
|
@property
|
104
|
-
def task_insert_IDs(self) ->
|
160
|
+
def task_insert_IDs(self) -> tuple[int, ...]:
|
105
161
|
"""Get the list of task insert_IDs that define the extent of the loop."""
|
106
162
|
return tuple(self._task_insert_IDs)
|
107
163
|
|
108
164
|
@property
|
109
|
-
def name(self):
|
165
|
+
def name(self) -> str | None:
|
110
166
|
"""
|
111
167
|
The name of the loop, if one was provided.
|
112
168
|
"""
|
113
169
|
return self._name
|
114
170
|
|
115
171
|
@property
|
116
|
-
def num_iterations(self):
|
172
|
+
def num_iterations(self) -> int:
|
117
173
|
"""
|
118
174
|
The number of loop iterations to do.
|
119
175
|
"""
|
120
176
|
return self._num_iterations
|
121
177
|
|
122
178
|
@property
|
123
|
-
def non_iterable_parameters(self):
|
179
|
+
def non_iterable_parameters(self) -> Sequence[str]:
|
124
180
|
"""
|
125
181
|
Which parameters are not iterable.
|
126
182
|
"""
|
127
183
|
return self._non_iterable_parameters
|
128
184
|
|
129
185
|
@property
|
130
|
-
def termination(self):
|
186
|
+
def termination(self) -> Rule | None:
|
131
187
|
"""
|
132
188
|
A termination rule for the loop, if one is provided.
|
133
189
|
"""
|
134
190
|
return self._termination
|
135
191
|
|
136
192
|
@property
|
137
|
-
def
|
193
|
+
def termination_task_insert_ID(self) -> int:
|
194
|
+
"""
|
195
|
+
The insert ID of the task at which the loop will terminate.
|
196
|
+
"""
|
197
|
+
return self._termination_task_insert_ID
|
198
|
+
|
199
|
+
@property
|
200
|
+
def termination_task(self) -> WorkflowTask:
|
201
|
+
"""
|
202
|
+
The task at which the loop will terminate.
|
203
|
+
"""
|
204
|
+
if (wt := self.workflow_template) is None:
|
205
|
+
raise RuntimeError(
|
206
|
+
"Workflow template must be assigned to retrieve task objects of the loop."
|
207
|
+
)
|
208
|
+
assert wt.workflow
|
209
|
+
return wt.workflow.tasks.get(insert_ID=self.termination_task_insert_ID)
|
210
|
+
|
211
|
+
@property
|
212
|
+
def workflow_template(self) -> WorkflowTemplate | None:
|
138
213
|
"""
|
139
214
|
The workflow template that contains this loop.
|
140
215
|
"""
|
141
216
|
return self._workflow_template
|
142
217
|
|
143
218
|
@workflow_template.setter
|
144
|
-
def workflow_template(self, template:
|
219
|
+
def workflow_template(self, template: WorkflowTemplate):
|
145
220
|
self._workflow_template = template
|
146
|
-
self.
|
221
|
+
self.__validate_against_template()
|
222
|
+
|
223
|
+
def __workflow(self) -> None | Workflow:
|
224
|
+
if (wt := self.workflow_template) is None:
|
225
|
+
return None
|
226
|
+
return wt.workflow
|
147
227
|
|
148
228
|
@property
|
149
|
-
def task_objects(self) ->
|
229
|
+
def task_objects(self) -> tuple[WorkflowTask, ...]:
|
150
230
|
"""
|
151
231
|
The tasks in the loop.
|
152
232
|
"""
|
153
|
-
if not self.
|
233
|
+
if not (wf := self.__workflow()):
|
154
234
|
raise RuntimeError(
|
155
235
|
"Workflow template must be assigned to retrieve task objects of the loop."
|
156
236
|
)
|
157
|
-
return tuple(
|
158
|
-
self.workflow_template.workflow.tasks.get(insert_ID=i)
|
159
|
-
for i in self.task_insert_IDs
|
160
|
-
)
|
237
|
+
return tuple(wf.tasks.get(insert_ID=t_id) for t_id in self.task_insert_IDs)
|
161
238
|
|
162
|
-
def
|
239
|
+
def __validate_against_template(self) -> None:
|
163
240
|
"""Validate the loop parameters against the associated workflow."""
|
164
241
|
|
165
242
|
# insert IDs must exist:
|
243
|
+
if not (wf := self.__workflow()):
|
244
|
+
raise RuntimeError(
|
245
|
+
"workflow cannot be validated against as it is not assigned"
|
246
|
+
)
|
166
247
|
for insert_ID in self.task_insert_IDs:
|
167
248
|
try:
|
168
|
-
|
249
|
+
wf.tasks.get(insert_ID=insert_ID)
|
169
250
|
except ValueError:
|
170
251
|
raise ValueError(
|
171
252
|
f"Loop {self.name!r} has an invalid task insert ID {insert_ID!r}. "
|
172
253
|
f"Such as task does not exist in the associated workflow."
|
173
254
|
)
|
174
255
|
|
175
|
-
def __repr__(self):
|
256
|
+
def __repr__(self) -> str:
|
176
257
|
num_iterations_str = ""
|
177
258
|
if self.num_iterations is not None:
|
178
259
|
num_iterations_str = f", num_iterations={self.num_iterations!r}"
|
@@ -187,15 +268,16 @@ class Loop(JSONLike):
|
|
187
268
|
f")"
|
188
269
|
)
|
189
270
|
|
190
|
-
def __deepcopy__(self, memo):
|
271
|
+
def __deepcopy__(self, memo: dict[int, Any]) -> Self:
|
191
272
|
kwargs = self.to_dict()
|
192
273
|
kwargs["tasks"] = kwargs.pop("task_insert_IDs")
|
274
|
+
kwargs["termination_task"] = kwargs.pop("termination_task_insert_ID")
|
193
275
|
obj = self.__class__(**copy.deepcopy(kwargs, memo))
|
194
276
|
obj._workflow_template = self._workflow_template
|
195
277
|
return obj
|
196
278
|
|
197
279
|
|
198
|
-
class WorkflowLoop:
|
280
|
+
class WorkflowLoop(AppAware):
|
199
281
|
"""
|
200
282
|
Class to represent a :py:class:`.Loop` that is bound to a
|
201
283
|
:py:class:`~hpcflow.app.Workflow`.
|
@@ -212,59 +294,49 @@ class WorkflowLoop:
|
|
212
294
|
Description of what iterations have been added.
|
213
295
|
iterable_parameters:
|
214
296
|
Description of what parameters are being iterated over.
|
297
|
+
output_parameters:
|
298
|
+
Decription of what parameter are output from this loop, and the final task insert
|
299
|
+
ID from which they are output.
|
215
300
|
parents: list[str]
|
216
301
|
The paths to the parent entities of this loop.
|
217
302
|
"""
|
218
303
|
|
219
|
-
_app_attr = "app"
|
220
|
-
|
221
304
|
def __init__(
|
222
305
|
self,
|
223
306
|
index: int,
|
224
|
-
workflow:
|
225
|
-
template:
|
226
|
-
num_added_iterations:
|
227
|
-
iterable_parameters:
|
228
|
-
|
229
|
-
|
307
|
+
workflow: Workflow,
|
308
|
+
template: Loop,
|
309
|
+
num_added_iterations: dict[tuple[int, ...], int],
|
310
|
+
iterable_parameters: dict[str, IterableParam],
|
311
|
+
output_parameters: dict[str, int],
|
312
|
+
parents: list[str],
|
313
|
+
) -> None:
|
230
314
|
self._index = index
|
231
315
|
self._workflow = workflow
|
232
316
|
self._template = template
|
233
317
|
self._num_added_iterations = num_added_iterations
|
234
318
|
self._iterable_parameters = iterable_parameters
|
319
|
+
self._output_parameters = output_parameters
|
235
320
|
self._parents = parents
|
236
321
|
|
237
|
-
# appended to
|
238
|
-
# reset and added to `self._parents` on dump to disk:
|
239
|
-
self._pending_parents = []
|
322
|
+
# appended to when adding an empty loop to the workflow that is a parent of this
|
323
|
+
# loop; reset and added to `self._parents` on dump to disk:
|
324
|
+
self._pending_parents: list[str] = []
|
240
325
|
|
241
326
|
# used for `num_added_iterations` when a new loop iteration is added, or when
|
242
327
|
# parents are append to; reset to None on dump to disk. Each key is a tuple of
|
243
328
|
# parent loop indices and each value is the number of pending new iterations:
|
244
|
-
self._pending_num_added_iterations = None
|
329
|
+
self._pending_num_added_iterations: dict[tuple[int, ...], int] | None = None
|
245
330
|
|
246
331
|
self._validate()
|
247
332
|
|
248
333
|
@TimeIt.decorator
|
249
|
-
def _validate(self):
|
334
|
+
def _validate(self) -> None:
|
250
335
|
# task subset must be a contiguous range of task indices:
|
251
336
|
task_indices = self.task_indices
|
252
337
|
task_min, task_max = task_indices[0], task_indices[-1]
|
253
338
|
if task_indices != tuple(range(task_min, task_max + 1)):
|
254
|
-
raise LoopTaskSubsetError(
|
255
|
-
f"Loop {self.name!r}: task subset must be an ascending contiguous range, "
|
256
|
-
f"but specified task indices were: {self.task_indices!r}."
|
257
|
-
)
|
258
|
-
|
259
|
-
for task in self.downstream_tasks:
|
260
|
-
for param in self.iterable_parameters:
|
261
|
-
if param in task.template.all_schema_input_types:
|
262
|
-
raise NotImplementedError(
|
263
|
-
f"Downstream task {task.unique_name!r} of loop {self.name!r} "
|
264
|
-
f"has as one of its input parameters this loop's iterable "
|
265
|
-
f"parameter {param!r}. This parameter cannot be sourced "
|
266
|
-
f"correctly."
|
267
|
-
)
|
339
|
+
raise LoopTaskSubsetError(self.name, self.task_indices)
|
268
340
|
|
269
341
|
def __repr__(self) -> str:
|
270
342
|
return (
|
@@ -273,7 +345,7 @@ class WorkflowLoop:
|
|
273
345
|
)
|
274
346
|
|
275
347
|
@property
|
276
|
-
def num_added_iterations(self):
|
348
|
+
def num_added_iterations(self) -> Mapping[tuple[int, ...], int]:
|
277
349
|
"""
|
278
350
|
The number of added iterations.
|
279
351
|
"""
|
@@ -282,27 +354,30 @@ class WorkflowLoop:
|
|
282
354
|
else:
|
283
355
|
return self._num_added_iterations
|
284
356
|
|
285
|
-
|
357
|
+
@property
|
358
|
+
def __pending(self) -> dict[tuple[int, ...], int]:
|
286
359
|
if not self._pending_num_added_iterations:
|
287
|
-
self._pending_num_added_iterations =
|
360
|
+
self._pending_num_added_iterations = dict(self._num_added_iterations)
|
361
|
+
return self._pending_num_added_iterations
|
288
362
|
|
289
|
-
|
290
|
-
self._pending_num_added_iterations[added_iters_key] = 1
|
291
|
-
|
292
|
-
def _increment_pending_added_iters(self, added_iters_key):
|
363
|
+
def _initialise_pending_added_iters(self, added_iters: Iterable[int]):
|
293
364
|
if not self._pending_num_added_iterations:
|
294
|
-
self._pending_num_added_iterations =
|
365
|
+
self._pending_num_added_iterations = dict(self._num_added_iterations)
|
366
|
+
if (added_iters_key := tuple(added_iters)) not in (pending := self.__pending):
|
367
|
+
pending[added_iters_key] = 1
|
295
368
|
|
296
|
-
|
369
|
+
def _increment_pending_added_iters(self, added_iters_key: Iterable[int]):
|
370
|
+
self.__pending[tuple(added_iters_key)] += 1
|
297
371
|
|
298
|
-
def _update_parents(self, parent:
|
372
|
+
def _update_parents(self, parent: WorkflowLoop):
|
373
|
+
assert parent.name
|
299
374
|
self._pending_parents.append(parent.name)
|
300
375
|
|
301
|
-
if not self._pending_num_added_iterations:
|
302
|
-
self._pending_num_added_iterations = copy.deepcopy(self._num_added_iterations)
|
303
|
-
|
304
376
|
self._pending_num_added_iterations = {
|
305
|
-
|
377
|
+
(*k, 0): v
|
378
|
+
for k, v in (
|
379
|
+
self._pending_num_added_iterations or self._num_added_iterations
|
380
|
+
).items()
|
306
381
|
}
|
307
382
|
|
308
383
|
self.workflow._store.update_loop_parents(
|
@@ -311,116 +386,129 @@ class WorkflowLoop:
|
|
311
386
|
parents=self.parents,
|
312
387
|
)
|
313
388
|
|
314
|
-
def _reset_pending_num_added_iters(self):
|
389
|
+
def _reset_pending_num_added_iters(self) -> None:
|
315
390
|
self._pending_num_added_iterations = None
|
316
391
|
|
317
|
-
def _accept_pending_num_added_iters(self):
|
392
|
+
def _accept_pending_num_added_iters(self) -> None:
|
318
393
|
if self._pending_num_added_iterations:
|
319
|
-
self._num_added_iterations =
|
394
|
+
self._num_added_iterations = dict(self._pending_num_added_iterations)
|
320
395
|
self._reset_pending_num_added_iters()
|
321
396
|
|
322
|
-
def _reset_pending_parents(self):
|
397
|
+
def _reset_pending_parents(self) -> None:
|
323
398
|
self._pending_parents = []
|
324
399
|
|
325
|
-
def _accept_pending_parents(self):
|
400
|
+
def _accept_pending_parents(self) -> None:
|
326
401
|
self._parents += self._pending_parents
|
327
402
|
self._reset_pending_parents()
|
328
403
|
|
329
404
|
@property
|
330
|
-
def index(self):
|
405
|
+
def index(self) -> int:
|
331
406
|
"""
|
332
407
|
The index of this loop within its workflow.
|
333
408
|
"""
|
334
409
|
return self._index
|
335
410
|
|
336
411
|
@property
|
337
|
-
def task_insert_IDs(self):
|
412
|
+
def task_insert_IDs(self) -> tuple[int, ...]:
|
338
413
|
"""
|
339
414
|
The insertion IDs of the tasks inside this loop.
|
340
415
|
"""
|
341
416
|
return self.template.task_insert_IDs
|
342
417
|
|
343
418
|
@property
|
344
|
-
def task_objects(self):
|
419
|
+
def task_objects(self) -> tuple[WorkflowTask, ...]:
|
345
420
|
"""
|
346
421
|
The tasks in this loop.
|
347
422
|
"""
|
348
423
|
return self.template.task_objects
|
349
424
|
|
350
425
|
@property
|
351
|
-
def task_indices(self) ->
|
426
|
+
def task_indices(self) -> tuple[int, ...]:
|
352
427
|
"""
|
353
428
|
The list of task indices that define the extent of the loop.
|
354
429
|
"""
|
355
|
-
return tuple(
|
430
|
+
return tuple(task.index for task in self.task_objects)
|
356
431
|
|
357
432
|
@property
|
358
|
-
def workflow(self):
|
433
|
+
def workflow(self) -> Workflow:
|
359
434
|
"""
|
360
435
|
The workflow containing this loop.
|
361
436
|
"""
|
362
437
|
return self._workflow
|
363
438
|
|
364
439
|
@property
|
365
|
-
def template(self):
|
440
|
+
def template(self) -> Loop:
|
366
441
|
"""
|
367
442
|
The loop template for this loop.
|
368
443
|
"""
|
369
444
|
return self._template
|
370
445
|
|
371
446
|
@property
|
372
|
-
def parents(self) ->
|
447
|
+
def parents(self) -> Sequence[str]:
|
373
448
|
"""
|
374
449
|
The parents of this loop.
|
375
450
|
"""
|
376
451
|
return self._parents + self._pending_parents
|
377
452
|
|
378
453
|
@property
|
379
|
-
def name(self):
|
454
|
+
def name(self) -> str:
|
380
455
|
"""
|
381
456
|
The name of this loop, if one is defined.
|
382
457
|
"""
|
458
|
+
assert self.template.name
|
383
459
|
return self.template.name
|
384
460
|
|
385
461
|
@property
|
386
|
-
def iterable_parameters(self):
|
462
|
+
def iterable_parameters(self) -> dict[str, IterableParam]:
|
387
463
|
"""
|
388
464
|
The parameters that are being iterated over.
|
389
465
|
"""
|
390
466
|
return self._iterable_parameters
|
391
467
|
|
392
468
|
@property
|
393
|
-
def
|
469
|
+
def output_parameters(self) -> dict[str, int]:
|
470
|
+
"""
|
471
|
+
The parameters that are outputs of this loop, and the final task insert ID from
|
472
|
+
which each parameter is output.
|
473
|
+
"""
|
474
|
+
return self._output_parameters
|
475
|
+
|
476
|
+
@property
|
477
|
+
def num_iterations(self) -> int:
|
394
478
|
"""
|
395
479
|
The number of iterations.
|
396
480
|
"""
|
397
481
|
return self.template.num_iterations
|
398
482
|
|
399
483
|
@property
|
400
|
-
def downstream_tasks(self) ->
|
484
|
+
def downstream_tasks(self) -> Iterator[WorkflowTask]:
|
401
485
|
"""Tasks that are not part of the loop, and downstream from this loop."""
|
402
|
-
|
486
|
+
tasks = self.workflow.tasks
|
487
|
+
for idx in range(self.task_objects[-1].index + 1, len(tasks)):
|
488
|
+
yield tasks[idx]
|
403
489
|
|
404
490
|
@property
|
405
|
-
def upstream_tasks(self) ->
|
491
|
+
def upstream_tasks(self) -> Iterator[WorkflowTask]:
|
406
492
|
"""Tasks that are not part of the loop, and upstream from this loop."""
|
407
|
-
|
493
|
+
tasks = self.workflow.tasks
|
494
|
+
for idx in range(0, self.task_objects[0].index):
|
495
|
+
yield tasks[idx]
|
408
496
|
|
409
497
|
@staticmethod
|
410
498
|
@TimeIt.decorator
|
411
|
-
def
|
412
|
-
|
413
|
-
|
499
|
+
def _find_iterable_and_output_parameters(
|
500
|
+
loop_template: Loop,
|
501
|
+
) -> tuple[dict[str, IterableParam], dict[str, int]]:
|
502
|
+
all_inputs_first_idx: dict[str, int] = {}
|
503
|
+
all_outputs_idx: dict[str, list[int]] = defaultdict(list)
|
414
504
|
for task in loop_template.task_objects:
|
415
505
|
for typ in task.template.all_schema_input_types:
|
416
|
-
|
417
|
-
all_inputs_first_idx[typ] = task.insert_ID
|
506
|
+
all_inputs_first_idx.setdefault(typ, task.insert_ID)
|
418
507
|
for typ in task.template.all_schema_output_types:
|
419
|
-
if typ not in all_outputs_idx:
|
420
|
-
all_outputs_idx[typ] = []
|
421
508
|
all_outputs_idx[typ].append(task.insert_ID)
|
422
509
|
|
423
|
-
|
510
|
+
# find input parameters that are also output parameters at a later/same task:
|
511
|
+
iterable_params: dict[str, IterableParam] = {}
|
424
512
|
for typ, first_idx in all_inputs_first_idx.items():
|
425
513
|
if typ in all_outputs_idx and first_idx <= all_outputs_idx[typ][0]:
|
426
514
|
iterable_params[typ] = {
|
@@ -429,20 +517,21 @@ class WorkflowLoop:
|
|
429
517
|
}
|
430
518
|
|
431
519
|
for non_iter in loop_template.non_iterable_parameters:
|
432
|
-
|
433
|
-
|
520
|
+
iterable_params.pop(non_iter, None)
|
521
|
+
|
522
|
+
final_out_tasks = {k: v[-1] for k, v in all_outputs_idx.items()}
|
434
523
|
|
435
|
-
return iterable_params
|
524
|
+
return iterable_params, final_out_tasks
|
436
525
|
|
437
526
|
@classmethod
|
438
527
|
@TimeIt.decorator
|
439
528
|
def new_empty_loop(
|
440
529
|
cls,
|
441
530
|
index: int,
|
442
|
-
workflow:
|
443
|
-
template:
|
444
|
-
iter_loop_idx:
|
445
|
-
) ->
|
531
|
+
workflow: Workflow,
|
532
|
+
template: Loop,
|
533
|
+
iter_loop_idx: Sequence[Mapping[str, int]],
|
534
|
+
) -> WorkflowLoop:
|
446
535
|
"""
|
447
536
|
Make a new empty loop.
|
448
537
|
|
@@ -459,29 +548,30 @@ class WorkflowLoop:
|
|
459
548
|
"""
|
460
549
|
parent_loops = cls._get_parent_loops(index, workflow, template)
|
461
550
|
parent_names = [i.name for i in parent_loops]
|
462
|
-
num_added_iters = {}
|
551
|
+
num_added_iters: dict[tuple[int, ...], int] = {}
|
463
552
|
for i in iter_loop_idx:
|
464
553
|
num_added_iters[tuple([i[j] for j in parent_names])] = 1
|
465
554
|
|
466
|
-
|
555
|
+
iter_params, out_params = cls._find_iterable_and_output_parameters(template)
|
556
|
+
return cls(
|
467
557
|
index=index,
|
468
558
|
workflow=workflow,
|
469
559
|
template=template,
|
470
560
|
num_added_iterations=num_added_iters,
|
471
|
-
iterable_parameters=
|
561
|
+
iterable_parameters=iter_params,
|
562
|
+
output_parameters=out_params,
|
472
563
|
parents=parent_names,
|
473
564
|
)
|
474
|
-
return obj
|
475
565
|
|
476
566
|
@classmethod
|
477
567
|
@TimeIt.decorator
|
478
568
|
def _get_parent_loops(
|
479
569
|
cls,
|
480
570
|
index: int,
|
481
|
-
workflow:
|
482
|
-
template:
|
483
|
-
) ->
|
484
|
-
parents = []
|
571
|
+
workflow: Workflow,
|
572
|
+
template: Loop,
|
573
|
+
) -> list[WorkflowLoop]:
|
574
|
+
parents: list[WorkflowLoop] = []
|
485
575
|
passed_self = False
|
486
576
|
self_tasks = set(template.task_insert_IDs)
|
487
577
|
for loop_i in workflow.loops:
|
@@ -496,18 +586,18 @@ class WorkflowLoop:
|
|
496
586
|
return parents
|
497
587
|
|
498
588
|
@TimeIt.decorator
|
499
|
-
def get_parent_loops(self) ->
|
589
|
+
def get_parent_loops(self) -> list[WorkflowLoop]:
|
500
590
|
"""Get loops whose task subset is a superset of this loop's task subset. If two
|
501
591
|
loops have identical task subsets, the first loop in the workflow loop list is
|
502
592
|
considered the child."""
|
503
593
|
return self._get_parent_loops(self.index, self.workflow, self.template)
|
504
594
|
|
505
595
|
@TimeIt.decorator
|
506
|
-
def get_child_loops(self) ->
|
596
|
+
def get_child_loops(self) -> list[WorkflowLoop]:
|
507
597
|
"""Get loops whose task subset is a subset of this loop's task subset. If two
|
508
598
|
loops have identical task subsets, the first loop in the workflow loop list is
|
509
599
|
considered the child."""
|
510
|
-
children = []
|
600
|
+
children: list[WorkflowLoop] = []
|
511
601
|
passed_self = False
|
512
602
|
self_tasks = set(self.task_insert_IDs)
|
513
603
|
for loop_i in self.workflow.loops:
|
@@ -521,11 +611,15 @@ class WorkflowLoop:
|
|
521
611
|
children.append(loop_i)
|
522
612
|
|
523
613
|
# order by depth, so direct child is first:
|
524
|
-
|
525
|
-
return children
|
614
|
+
return sorted(children, key=lambda x: len(next(iter(x.num_added_iterations))))
|
526
615
|
|
527
616
|
@TimeIt.decorator
|
528
|
-
def add_iteration(
|
617
|
+
def add_iteration(
|
618
|
+
self,
|
619
|
+
parent_loop_indices: Mapping[str, int] | None = None,
|
620
|
+
cache: LoopCache | None = None,
|
621
|
+
status: Status | None = None,
|
622
|
+
) -> None:
|
529
623
|
"""
|
530
624
|
Add an iteration to this loop.
|
531
625
|
|
@@ -539,42 +633,47 @@ class WorkflowLoop:
|
|
539
633
|
"""
|
540
634
|
if not cache:
|
541
635
|
cache = LoopCache.build(self.workflow)
|
636
|
+
assert cache is not None
|
542
637
|
parent_loops = self.get_parent_loops()
|
543
638
|
child_loops = self.get_child_loops()
|
544
|
-
|
545
|
-
|
546
|
-
|
639
|
+
parent_loop_indices_ = parent_loop_indices or {
|
640
|
+
loop.name: 0 for loop in parent_loops
|
641
|
+
}
|
547
642
|
|
548
|
-
iters_key = tuple([
|
643
|
+
iters_key = tuple(parent_loop_indices_[p_nm] for p_nm in self.parents)
|
549
644
|
cur_loop_idx = self.num_added_iterations[iters_key] - 1
|
550
|
-
|
645
|
+
|
646
|
+
# keys are (task.insert_ID and element.index)
|
647
|
+
all_new_data_idx: dict[tuple[int, int], DataIndex] = {}
|
551
648
|
|
552
649
|
# initialise a new `num_added_iterations` key on each child loop:
|
650
|
+
iters_key_dct = {
|
651
|
+
**parent_loop_indices_,
|
652
|
+
self.name: cur_loop_idx + 1,
|
653
|
+
}
|
553
654
|
for child in child_loops:
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
}
|
558
|
-
added_iters_key_chd = tuple([iters_key_dct.get(j, 0) for j in child.parents])
|
559
|
-
child._initialise_pending_added_iters(added_iters_key_chd)
|
655
|
+
child._initialise_pending_added_iters(
|
656
|
+
iters_key_dct.get(j, 0) for j in child.parents
|
657
|
+
)
|
560
658
|
|
561
|
-
|
659
|
+
# needed for the case where an inner loop has only one iteration, meaning
|
660
|
+
# `add_iteration` will not be called recursively on it:
|
661
|
+
self.workflow._store.update_loop_num_iters(
|
662
|
+
index=child.index,
|
663
|
+
num_added_iters=child.num_added_iterations,
|
664
|
+
)
|
562
665
|
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
for child in child_loops
|
569
|
-
if task.insert_ID in child.task_insert_IDs
|
570
|
-
},
|
666
|
+
for task in self.task_objects:
|
667
|
+
new_loop_idx = LoopIndex(iters_key_dct) + {
|
668
|
+
child.name: 0
|
669
|
+
for child in child_loops
|
670
|
+
if task.insert_ID in child.task_insert_IDs
|
571
671
|
}
|
572
|
-
added_iter_IDs = []
|
672
|
+
added_iter_IDs: list[int] = []
|
573
673
|
for elem_idx in range(task.num_elements):
|
574
|
-
|
575
674
|
elem_ID = task.element_IDs[elem_idx]
|
576
675
|
|
577
|
-
new_data_idx = {}
|
676
|
+
new_data_idx: DataIndex = {}
|
578
677
|
|
579
678
|
# copy resources from zeroth iteration:
|
580
679
|
zeroth_iter_ID, zi_iter_data_idx = cache.zeroth_iters[elem_ID]
|
@@ -587,109 +686,26 @@ class WorkflowLoop:
|
|
587
686
|
|
588
687
|
for inp in task.template.all_schema_inputs:
|
589
688
|
is_inp_task = False
|
590
|
-
iter_dat
|
591
|
-
if iter_dat:
|
689
|
+
if iter_dat := self.iterable_parameters.get(inp.typ):
|
592
690
|
is_inp_task = task.insert_ID == iter_dat["input_task"]
|
593
691
|
|
594
|
-
|
595
|
-
# source from final output task of previous iteration, with all parent
|
596
|
-
# loop indices the same as previous iteration, and all child loop indices
|
597
|
-
# maximised:
|
598
|
-
|
599
|
-
# identify element(s) from which this iterable input should be
|
600
|
-
# parametrised:
|
601
|
-
if task.insert_ID == iter_dat["output_tasks"][-1]:
|
602
|
-
src_elem_ID = elem_ID
|
603
|
-
grouped_elems = None
|
604
|
-
else:
|
605
|
-
src_elem_IDs_all = cache.element_dependents[elem_ID]
|
606
|
-
src_elem_IDs = {
|
607
|
-
k: v
|
608
|
-
for k, v in src_elem_IDs_all.items()
|
609
|
-
if cache.elements[k]["task_insert_ID"]
|
610
|
-
== iter_dat["output_tasks"][-1]
|
611
|
-
}
|
612
|
-
# consider groups
|
613
|
-
inp_group_name = inp.single_labelled_data.get("group")
|
614
|
-
grouped_elems = []
|
615
|
-
for src_elem_j_ID, src_elem_j_dat in src_elem_IDs.items():
|
616
|
-
i_in_group = any(
|
617
|
-
k == inp_group_name
|
618
|
-
for k in src_elem_j_dat["group_names"]
|
619
|
-
)
|
620
|
-
if i_in_group:
|
621
|
-
grouped_elems.append(src_elem_j_ID)
|
622
|
-
|
623
|
-
if not grouped_elems and len(src_elem_IDs) > 1:
|
624
|
-
raise NotImplementedError(
|
625
|
-
f"Multiple elements found in the iterable parameter "
|
626
|
-
f"{inp!r}'s latest output task (insert ID: "
|
627
|
-
f"{iter_dat['output_tasks'][-1]}) that can be used "
|
628
|
-
f"to parametrise the next iteration: "
|
629
|
-
f"{list(src_elem_IDs.keys())!r}."
|
630
|
-
)
|
631
|
-
|
632
|
-
elif not src_elem_IDs:
|
633
|
-
# TODO: maybe OK?
|
634
|
-
raise NotImplementedError(
|
635
|
-
f"No elements found in the iterable parameter "
|
636
|
-
f"{inp!r}'s latest output task (insert ID: "
|
637
|
-
f"{iter_dat['output_tasks'][-1]}) that can be used "
|
638
|
-
f"to parametrise the next iteration."
|
639
|
-
)
|
640
|
-
|
641
|
-
else:
|
642
|
-
src_elem_ID = nth_key(src_elem_IDs, 0)
|
643
|
-
|
644
|
-
child_loop_max_iters = {}
|
645
|
-
parent_loop_same_iters = {
|
646
|
-
i.name: parent_loop_indices[i.name] for i in parent_loops
|
647
|
-
}
|
648
|
-
child_iter_parents = {
|
649
|
-
**parent_loop_same_iters,
|
650
|
-
self.name: cur_loop_idx,
|
651
|
-
}
|
652
|
-
for i in child_loops:
|
653
|
-
i_num_iters = i.num_added_iterations[
|
654
|
-
tuple(child_iter_parents[j] for j in i.parents)
|
655
|
-
]
|
656
|
-
i_max = i_num_iters - 1
|
657
|
-
child_iter_parents[i.name] = i_max
|
658
|
-
child_loop_max_iters[i.name] = i_max
|
659
|
-
|
660
|
-
source_iter_loop_idx = {
|
661
|
-
**child_loop_max_iters,
|
662
|
-
**parent_loop_same_iters,
|
663
|
-
self.name: cur_loop_idx,
|
664
|
-
}
|
665
|
-
|
666
|
-
# identify the ElementIteration from which this input should be
|
667
|
-
# parametrised:
|
668
|
-
loop_idx_key = tuple(sorted(source_iter_loop_idx.items()))
|
669
|
-
if grouped_elems:
|
670
|
-
src_data_idx = []
|
671
|
-
for src_elem_ID in grouped_elems:
|
672
|
-
src_data_idx.append(
|
673
|
-
cache.data_idx[src_elem_ID][loop_idx_key]
|
674
|
-
)
|
675
|
-
else:
|
676
|
-
src_data_idx = cache.data_idx[src_elem_ID][loop_idx_key]
|
677
|
-
|
678
|
-
if not src_data_idx:
|
679
|
-
raise RuntimeError(
|
680
|
-
f"Could not find a source iteration with loop_idx: "
|
681
|
-
f"{source_iter_loop_idx!r}."
|
682
|
-
)
|
683
|
-
|
684
|
-
if grouped_elems:
|
685
|
-
inp_dat_idx = [i[f"outputs.{inp.typ}"] for i in src_data_idx]
|
686
|
-
else:
|
687
|
-
inp_dat_idx = src_data_idx[f"outputs.{inp.typ}"]
|
688
|
-
new_data_idx[f"inputs.{inp.typ}"] = inp_dat_idx
|
692
|
+
inp_key = f"inputs.{inp.typ}"
|
689
693
|
|
694
|
+
if is_inp_task:
|
695
|
+
assert iter_dat is not None
|
696
|
+
inp_dat_idx = self.__get_looped_index(
|
697
|
+
task,
|
698
|
+
elem_ID,
|
699
|
+
cache,
|
700
|
+
iter_dat,
|
701
|
+
inp,
|
702
|
+
parent_loops,
|
703
|
+
parent_loop_indices_,
|
704
|
+
child_loops,
|
705
|
+
cur_loop_idx,
|
706
|
+
)
|
707
|
+
new_data_idx[inp_key] = inp_dat_idx
|
690
708
|
else:
|
691
|
-
inp_key = f"inputs.{inp.typ}"
|
692
|
-
|
693
709
|
orig_inp_src = cache.elements[elem_ID]["input_sources"][inp_key]
|
694
710
|
inp_dat_idx = None
|
695
711
|
|
@@ -709,77 +725,16 @@ class WorkflowLoop:
|
|
709
725
|
inp_dat_idx = zi_iter_data_idx[inp_key]
|
710
726
|
|
711
727
|
elif orig_inp_src.source_type is InputSourceType.TASK:
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
skip_iter = True
|
723
|
-
break
|
724
|
-
if not skip_iter:
|
725
|
-
src_data_idx.append(di_k)
|
726
|
-
|
727
|
-
# could be multiple, but they should all have the same
|
728
|
-
# data index for this parameter:
|
729
|
-
src_data_idx = src_data_idx[0]
|
730
|
-
inp_dat_idx = src_data_idx[inp_key]
|
731
|
-
else:
|
732
|
-
is_group = False
|
733
|
-
if (
|
734
|
-
not inp.multiple
|
735
|
-
and "group" in inp.single_labelled_data
|
736
|
-
):
|
737
|
-
# this input is a group, assume for now all elements:
|
738
|
-
is_group = True
|
739
|
-
|
740
|
-
# same task/element, but update iteration to the just-added
|
741
|
-
# iteration:
|
742
|
-
key_prefix = orig_inp_src.task_source_type.name.lower()
|
743
|
-
prev_dat_idx_key = f"{key_prefix}s.{inp.typ}"
|
744
|
-
new_sources = []
|
745
|
-
for (
|
746
|
-
tiID,
|
747
|
-
e_idx,
|
748
|
-
), prev_dat_idx in all_new_data_idx.items():
|
749
|
-
if tiID == orig_inp_src.task_ref:
|
750
|
-
# find which element in that task `element`
|
751
|
-
# depends on:
|
752
|
-
task_i = self.workflow.tasks.get(insert_ID=tiID)
|
753
|
-
elem_i_ID = task_i.element_IDs[e_idx]
|
754
|
-
src_elem_IDs_all = cache.element_dependents[
|
755
|
-
elem_i_ID
|
756
|
-
]
|
757
|
-
src_elem_IDs_i = {
|
758
|
-
k: v
|
759
|
-
for k, v in src_elem_IDs_all.items()
|
760
|
-
if cache.elements[k]["task_insert_ID"]
|
761
|
-
== task.insert_ID
|
762
|
-
}
|
763
|
-
|
764
|
-
# filter src_elem_IDs_i for matching element IDs:
|
765
|
-
src_elem_IDs_i = [
|
766
|
-
i for i in src_elem_IDs_i if i == elem_ID
|
767
|
-
]
|
768
|
-
if (
|
769
|
-
len(src_elem_IDs_i) == 1
|
770
|
-
and src_elem_IDs_i[0] == elem_ID
|
771
|
-
):
|
772
|
-
new_sources.append((tiID, e_idx))
|
773
|
-
|
774
|
-
if is_group:
|
775
|
-
inp_dat_idx = [
|
776
|
-
all_new_data_idx[i][prev_dat_idx_key]
|
777
|
-
for i in new_sources
|
778
|
-
]
|
779
|
-
else:
|
780
|
-
assert len(new_sources) == 1
|
781
|
-
prev_dat_idx = all_new_data_idx[new_sources[0]]
|
782
|
-
inp_dat_idx = prev_dat_idx[prev_dat_idx_key]
|
728
|
+
inp_dat_idx = self.__get_task_index(
|
729
|
+
task,
|
730
|
+
orig_inp_src,
|
731
|
+
cache,
|
732
|
+
elem_ID,
|
733
|
+
inp,
|
734
|
+
inp_key,
|
735
|
+
parent_loop_indices_,
|
736
|
+
all_new_data_idx,
|
737
|
+
)
|
783
738
|
|
784
739
|
if inp_dat_idx is None:
|
785
740
|
raise RuntimeError(
|
@@ -791,9 +746,8 @@ class WorkflowLoop:
|
|
791
746
|
|
792
747
|
# add any locally defined sub-parameters:
|
793
748
|
inp_statuses = cache.elements[elem_ID]["input_statuses"]
|
794
|
-
inp_status_inps = set(
|
795
|
-
|
796
|
-
for sub_param_i in sub_params:
|
749
|
+
inp_status_inps = set(f"inputs.{inp}" for inp in inp_statuses)
|
750
|
+
for sub_param_i in inp_status_inps.difference(new_data_idx):
|
797
751
|
sub_param_data_idx_iter_0 = zi_data_idx
|
798
752
|
try:
|
799
753
|
sub_param_data_idx = sub_param_data_idx_iter_0[sub_param_i]
|
@@ -808,13 +762,11 @@ class WorkflowLoop:
|
|
808
762
|
|
809
763
|
for out in task.template.all_schema_outputs:
|
810
764
|
path_i = f"outputs.{out.typ}"
|
811
|
-
p_src = {"type": "EAR_output"}
|
765
|
+
p_src: ParamSource = {"type": "EAR_output"}
|
812
766
|
new_data_idx[path_i] = self.workflow._add_unset_parameter_data(p_src)
|
813
767
|
|
814
|
-
schema_params = set(
|
815
|
-
|
816
|
-
)
|
817
|
-
all_new_data_idx[(task.insert_ID, elem_idx)] = new_data_idx
|
768
|
+
schema_params = set(i for i in new_data_idx if len(i.split(".")) == 2)
|
769
|
+
all_new_data_idx[task.insert_ID, elem_idx] = new_data_idx
|
818
770
|
|
819
771
|
iter_ID_i = self.workflow._store.add_element_iteration(
|
820
772
|
element_ID=elem_ID,
|
@@ -835,8 +787,9 @@ class WorkflowLoop:
|
|
835
787
|
|
836
788
|
task.initialise_EARs(iter_IDs=added_iter_IDs)
|
837
789
|
|
838
|
-
|
839
|
-
|
790
|
+
self._increment_pending_added_iters(
|
791
|
+
parent_loop_indices_[p_nm] for p_nm in self.parents
|
792
|
+
)
|
840
793
|
self.workflow._store.update_loop_num_iters(
|
841
794
|
index=self.index,
|
842
795
|
num_added_iters=self.num_added_iterations,
|
@@ -845,20 +798,471 @@ class WorkflowLoop:
|
|
845
798
|
# add iterations to fixed-number-iteration children only:
|
846
799
|
for child in child_loops[::-1]:
|
847
800
|
if child.num_iterations is not None:
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
801
|
+
if status:
|
802
|
+
status_prev = str(status.status).rstrip(".")
|
803
|
+
for iter_idx in range(child.num_iterations - 1):
|
804
|
+
if status:
|
805
|
+
status.update(
|
806
|
+
f"{status_prev} --> ({child.name!r}): iteration "
|
807
|
+
f"{iter_idx + 2}/{child.num_iterations}."
|
808
|
+
)
|
809
|
+
par_idx = {parent_name: 0 for parent_name in child.parents}
|
810
|
+
if parent_loop_indices:
|
811
|
+
par_idx.update(parent_loop_indices)
|
812
|
+
par_idx[self.name] = cur_loop_idx + 1
|
813
|
+
child.add_iteration(parent_loop_indices=par_idx, cache=cache)
|
814
|
+
|
815
|
+
self.__update_loop_downstream_data_idx(parent_loop_indices_)
|
816
|
+
|
817
|
+
def __get_src_ID_and_groups(
|
818
|
+
self,
|
819
|
+
elem_ID: int,
|
820
|
+
iter_dat: IterableParam,
|
821
|
+
inp: SchemaInput,
|
822
|
+
cache: LoopCache,
|
823
|
+
task: WorkflowTask,
|
824
|
+
) -> tuple[int, Sequence[int]]:
|
825
|
+
# `cache.elements` contains only elements that are part of the
|
826
|
+
# loop, so indexing a dependent element may raise:
|
827
|
+
src_elem_IDs = {}
|
828
|
+
for k, v in cache.element_dependents[elem_ID].items():
|
829
|
+
try:
|
830
|
+
if cache.elements[k]["task_insert_ID"] == iter_dat["output_tasks"][-1]:
|
831
|
+
src_elem_IDs[k] = v
|
832
|
+
except KeyError:
|
833
|
+
continue
|
834
|
+
|
835
|
+
# consider groups
|
836
|
+
single_data = inp.single_labelled_data
|
837
|
+
assert single_data is not None
|
838
|
+
inp_group_name = single_data.get("group")
|
839
|
+
grouped_elems = [
|
840
|
+
src_elem_j_ID
|
841
|
+
for src_elem_j_ID, src_elem_j_dat in src_elem_IDs.items()
|
842
|
+
if any(nm == inp_group_name for nm in src_elem_j_dat["group_names"])
|
843
|
+
]
|
844
|
+
|
845
|
+
if not grouped_elems and len(src_elem_IDs) > 1:
|
846
|
+
raise NotImplementedError(
|
847
|
+
f"Multiple elements found in the iterable parameter "
|
848
|
+
f"{inp!r}'s latest output task (insert ID: "
|
849
|
+
f"{iter_dat['output_tasks'][-1]}) that can be used "
|
850
|
+
f"to parametrise the next iteration of task "
|
851
|
+
f"{task.unique_name!r}: "
|
852
|
+
f"{list(src_elem_IDs)!r}."
|
853
|
+
)
|
854
|
+
|
855
|
+
elif not src_elem_IDs:
|
856
|
+
# TODO: maybe OK?
|
857
|
+
raise NotImplementedError(
|
858
|
+
f"No elements found in the iterable parameter "
|
859
|
+
f"{inp!r}'s latest output task (insert ID: "
|
860
|
+
f"{iter_dat['output_tasks'][-1]}) that can be used "
|
861
|
+
f"to parametrise the next iteration."
|
862
|
+
)
|
863
|
+
|
864
|
+
return nth_key(src_elem_IDs, 0), grouped_elems
|
865
|
+
|
866
|
+
def __get_looped_index(
|
867
|
+
self,
|
868
|
+
task: WorkflowTask,
|
869
|
+
elem_ID: int,
|
870
|
+
cache: LoopCache,
|
871
|
+
iter_dat: IterableParam,
|
872
|
+
inp: SchemaInput,
|
873
|
+
parent_loops: list[WorkflowLoop],
|
874
|
+
parent_loop_indices: Mapping[str, int],
|
875
|
+
child_loops: list[WorkflowLoop],
|
876
|
+
cur_loop_idx: int,
|
877
|
+
):
|
878
|
+
# source from final output task of previous iteration, with all parent
|
879
|
+
# loop indices the same as previous iteration, and all child loop indices
|
880
|
+
# maximised:
|
881
|
+
|
882
|
+
# identify element(s) from which this iterable input should be
|
883
|
+
# parametrised:
|
884
|
+
if task.insert_ID == iter_dat["output_tasks"][-1]:
|
885
|
+
# single-task loop
|
886
|
+
src_elem_ID = elem_ID
|
887
|
+
grouped_elems: Sequence[int] = []
|
888
|
+
else:
|
889
|
+
# multi-task loop
|
890
|
+
src_elem_ID, grouped_elems = self.__get_src_ID_and_groups(
|
891
|
+
elem_ID, iter_dat, inp, cache, task
|
892
|
+
)
|
893
|
+
|
894
|
+
child_loop_max_iters: dict[str, int] = {}
|
895
|
+
parent_loop_same_iters = {
|
896
|
+
loop.name: parent_loop_indices[loop.name] for loop in parent_loops
|
897
|
+
}
|
898
|
+
child_iter_parents = {
|
899
|
+
**parent_loop_same_iters,
|
900
|
+
self.name: cur_loop_idx,
|
901
|
+
}
|
902
|
+
for loop in child_loops:
|
903
|
+
if iter_dat["output_tasks"][-1] in loop.task_insert_IDs:
|
904
|
+
i_num_iters = loop.num_added_iterations[
|
905
|
+
tuple(child_iter_parents[j] for j in loop.parents)
|
906
|
+
]
|
907
|
+
i_max = i_num_iters - 1
|
908
|
+
child_iter_parents[loop.name] = i_max
|
909
|
+
child_loop_max_iters[loop.name] = i_max
|
910
|
+
|
911
|
+
loop_idx_key = LoopIndex(child_loop_max_iters)
|
912
|
+
loop_idx_key.update(parent_loop_same_iters)
|
913
|
+
loop_idx_key[self.name] = cur_loop_idx
|
914
|
+
|
915
|
+
# identify the ElementIteration from which this input should be
|
916
|
+
# parametrised:
|
917
|
+
if grouped_elems:
|
918
|
+
src_data_idx = [
|
919
|
+
cache.data_idx[src_elem_ID][loop_idx_key] for src_elem_ID in grouped_elems
|
920
|
+
]
|
921
|
+
if not src_data_idx:
|
922
|
+
raise RuntimeError(
|
923
|
+
f"Could not find a source iteration with loop_idx: "
|
924
|
+
f"{loop_idx_key!r}."
|
925
|
+
)
|
926
|
+
return [i[f"outputs.{inp.typ}"] for i in src_data_idx]
|
927
|
+
else:
|
928
|
+
return cache.data_idx[src_elem_ID][loop_idx_key][f"outputs.{inp.typ}"]
|
929
|
+
|
930
|
+
def __get_task_index(
|
931
|
+
self,
|
932
|
+
task: WorkflowTask,
|
933
|
+
orig_inp_src: InputSource,
|
934
|
+
cache: LoopCache,
|
935
|
+
elem_ID: int,
|
936
|
+
inp: SchemaInput,
|
937
|
+
inp_key: str,
|
938
|
+
parent_loop_indices: Mapping[str, int],
|
939
|
+
all_new_data_idx: Mapping[tuple[int, int], DataIndex],
|
940
|
+
) -> int | list[int]:
|
941
|
+
if orig_inp_src.task_ref not in self.task_insert_IDs:
|
942
|
+
# source the data_idx from the iteration with same parent
|
943
|
+
# loop indices as the new iteration to add:
|
944
|
+
src_data_idx = next(
|
945
|
+
di_k
|
946
|
+
for li_k, di_k in cache.data_idx[elem_ID].items()
|
947
|
+
if all(li_k.get(p_k) == p_v for p_k, p_v in parent_loop_indices.items())
|
948
|
+
)
|
949
|
+
|
950
|
+
# could be multiple, but they should all have the same
|
951
|
+
# data index for this parameter:
|
952
|
+
return src_data_idx[inp_key]
|
953
|
+
|
954
|
+
is_group = (
|
955
|
+
inp.single_labelled_data is not None
|
956
|
+
and "group" in inp.single_labelled_data
|
957
|
+
# this input is a group, assume for now all elements
|
958
|
+
)
|
959
|
+
|
960
|
+
# same task/element, but update iteration to the just-added
|
961
|
+
# iteration:
|
962
|
+
assert orig_inp_src.task_source_type is not None
|
963
|
+
key_prefix = orig_inp_src.task_source_type.name.lower()
|
964
|
+
prev_dat_idx_key = f"{key_prefix}s.{inp.typ}"
|
965
|
+
new_sources: list[tuple[int, int]] = []
|
966
|
+
for (tiID, e_idx), _ in all_new_data_idx.items():
|
967
|
+
if tiID == orig_inp_src.task_ref:
|
968
|
+
# find which element in that task `element`
|
969
|
+
# depends on:
|
970
|
+
src_elem_IDs = cache.element_dependents[
|
971
|
+
self.workflow.tasks.get(insert_ID=tiID).element_IDs[e_idx]
|
972
|
+
]
|
973
|
+
# `cache.elements` contains only elements that are part of the loop, so
|
974
|
+
# indexing a dependent element may raise:
|
975
|
+
src_elem_IDs_i = []
|
976
|
+
for k, _v in src_elem_IDs.items():
|
977
|
+
try:
|
978
|
+
if (
|
979
|
+
cache.elements[k]["task_insert_ID"] == task.insert_ID
|
980
|
+
and k == elem_ID
|
981
|
+
# filter src_elem_IDs_i for matching element IDs
|
982
|
+
):
|
983
|
+
|
984
|
+
src_elem_IDs_i.append(k)
|
985
|
+
except KeyError:
|
986
|
+
continue
|
987
|
+
|
988
|
+
if len(src_elem_IDs_i) == 1:
|
989
|
+
new_sources.append((tiID, e_idx))
|
990
|
+
|
991
|
+
if is_group:
|
992
|
+
# Convert into simple list of indices
|
993
|
+
return list(
|
994
|
+
chain.from_iterable(
|
995
|
+
self.__as_sequence(all_new_data_idx[src][prev_dat_idx_key])
|
996
|
+
for src in new_sources
|
997
|
+
)
|
998
|
+
)
|
999
|
+
else:
|
1000
|
+
assert len(new_sources) == 1
|
1001
|
+
return all_new_data_idx[new_sources[0]][prev_dat_idx_key]
|
1002
|
+
|
1003
|
+
@staticmethod
|
1004
|
+
def __as_sequence(seq: int | Iterable[int]) -> Iterable[int]:
|
1005
|
+
if isinstance(seq, int):
|
1006
|
+
yield seq
|
1007
|
+
else:
|
1008
|
+
yield from seq
|
1009
|
+
|
1010
|
+
def __update_loop_downstream_data_idx(
|
1011
|
+
self,
|
1012
|
+
parent_loop_indices: Mapping[str, int],
|
1013
|
+
):
|
1014
|
+
# update data indices of loop-downstream tasks that depend on task outputs from
|
1015
|
+
# this loop:
|
1016
|
+
|
1017
|
+
# keys: iter or run ID, values: dict of param type and new parameter index
|
1018
|
+
iter_new_data_idx: dict[int, DataIndex] = defaultdict(dict)
|
1019
|
+
run_new_data_idx: dict[int, DataIndex] = defaultdict(dict)
|
1020
|
+
|
1021
|
+
param_sources = self.workflow.get_all_parameter_sources()
|
1022
|
+
|
1023
|
+
# keys are parameter type, then task insert ID, then data index keys mapping to
|
1024
|
+
# their updated values:
|
1025
|
+
all_updates: dict[str, dict[int, dict[int, int]]] = defaultdict(
|
1026
|
+
lambda: defaultdict(dict)
|
1027
|
+
)
|
1028
|
+
|
1029
|
+
for task in self.downstream_tasks:
|
1030
|
+
for elem in task.elements:
|
1031
|
+
for param_typ, param_out_task_iID in self.output_parameters.items():
|
1032
|
+
if param_typ in task.template.all_schema_input_types:
|
1033
|
+
# this element's input *might* need updating, only if it has a
|
1034
|
+
# task input source type that is this loop's output task for this
|
1035
|
+
# parameter:
|
1036
|
+
elem_src = elem.input_sources[f"inputs.{param_typ}"]
|
1037
|
+
if (
|
1038
|
+
elem_src.source_type is InputSourceType.TASK
|
1039
|
+
and elem_src.task_source_type is TaskSourceType.OUTPUT
|
1040
|
+
and elem_src.task_ref == param_out_task_iID
|
1041
|
+
):
|
1042
|
+
for iter_i in elem.iterations:
|
1043
|
+
|
1044
|
+
# do not modify element-iterations of previous iterations
|
1045
|
+
# of the current loop:
|
1046
|
+
skip_iter = False
|
1047
|
+
for k, v in parent_loop_indices.items():
|
1048
|
+
if iter_i.loop_idx.get(k) != v:
|
1049
|
+
skip_iter = True
|
1050
|
+
break
|
1051
|
+
|
1052
|
+
if skip_iter:
|
1053
|
+
continue
|
1054
|
+
|
1055
|
+
# update the iteration data index and any pending runs:
|
1056
|
+
iter_old_di = iter_i.data_idx[f"inputs.{param_typ}"]
|
1057
|
+
|
1058
|
+
is_group = True
|
1059
|
+
if not isinstance(iter_old_di, list):
|
1060
|
+
is_group = False
|
1061
|
+
iter_old_di = [iter_old_di]
|
1062
|
+
|
1063
|
+
iter_old_run_source = [
|
1064
|
+
param_sources[i]["EAR_ID"] for i in iter_old_di
|
1065
|
+
]
|
1066
|
+
iter_old_run_objs = self.workflow.get_EARs_from_IDs(
|
1067
|
+
iter_old_run_source
|
1068
|
+
) # TODO: use cache
|
1069
|
+
|
1070
|
+
# need to check the run source is actually from the loop
|
1071
|
+
# output task (it could be from a previous iteration of a
|
1072
|
+
# separate loop in this task):
|
1073
|
+
if any(
|
1074
|
+
i.task.insert_ID != param_out_task_iID
|
1075
|
+
for i in iter_old_run_objs
|
1076
|
+
):
|
1077
|
+
continue
|
1078
|
+
|
1079
|
+
iter_new_iters = [
|
1080
|
+
i.element.iterations[-1] for i in iter_old_run_objs
|
1081
|
+
]
|
1082
|
+
|
1083
|
+
# note: we can cast to int, because output keys never
|
1084
|
+
# have multiple data indices (unlike input keys):
|
1085
|
+
iter_new_dis = [
|
1086
|
+
cast("int", i.get_data_idx()[f"outputs.{param_typ}"])
|
1087
|
+
for i in iter_new_iters
|
1088
|
+
]
|
1089
|
+
|
1090
|
+
# keep track of updates so we can also update task-input
|
1091
|
+
# type sources:
|
1092
|
+
all_updates[param_typ][task.insert_ID].update(
|
1093
|
+
dict(zip(iter_old_di, iter_new_dis))
|
1094
|
+
)
|
1095
|
+
|
1096
|
+
iter_new_data_idx[iter_i.id_][f"inputs.{param_typ}"] = (
|
1097
|
+
iter_new_dis if is_group else iter_new_dis[0]
|
1098
|
+
)
|
1099
|
+
|
1100
|
+
for run_j in iter_i.action_runs:
|
1101
|
+
if run_j.status is EARStatus.pending:
|
1102
|
+
try:
|
1103
|
+
old_di = run_j.data_idx[f"inputs.{param_typ}"]
|
1104
|
+
except KeyError:
|
1105
|
+
# not all actions will include this input
|
1106
|
+
continue
|
1107
|
+
|
1108
|
+
is_group = True
|
1109
|
+
if not isinstance(old_di, list):
|
1110
|
+
is_group = False
|
1111
|
+
old_di = [old_di]
|
1112
|
+
|
1113
|
+
old_run_source = [
|
1114
|
+
param_sources[i]["EAR_ID"] for i in old_di
|
1115
|
+
]
|
1116
|
+
old_run_objs = self.workflow.get_EARs_from_IDs(
|
1117
|
+
old_run_source
|
1118
|
+
) # TODO: use cache
|
1119
|
+
|
1120
|
+
# need to check the run source is actually from the loop
|
1121
|
+
# output task (it could be from a previous action in this
|
1122
|
+
# element-iteration):
|
1123
|
+
if any(
|
1124
|
+
i.task.insert_ID != param_out_task_iID
|
1125
|
+
for i in old_run_objs
|
1126
|
+
):
|
1127
|
+
continue
|
1128
|
+
|
1129
|
+
new_iters = [
|
1130
|
+
i.element.iterations[-1] for i in old_run_objs
|
1131
|
+
]
|
1132
|
+
|
1133
|
+
# note: we can cast to int, because output keys
|
1134
|
+
# never have multiple data indices (unlike input
|
1135
|
+
# keys):
|
1136
|
+
new_dis = [
|
1137
|
+
cast(
|
1138
|
+
"int",
|
1139
|
+
i.get_data_idx()[f"outputs.{param_typ}"],
|
1140
|
+
)
|
1141
|
+
for i in new_iters
|
1142
|
+
]
|
1143
|
+
|
1144
|
+
run_new_data_idx[run_j.id_][
|
1145
|
+
f"inputs.{param_typ}"
|
1146
|
+
] = (new_dis if is_group else new_dis[0])
|
1147
|
+
|
1148
|
+
elif (
|
1149
|
+
elem_src.source_type is InputSourceType.TASK
|
1150
|
+
and elem_src.task_source_type is TaskSourceType.INPUT
|
1151
|
+
):
|
1152
|
+
# parameters are that sourced from inputs of other tasks,
|
1153
|
+
# might need to be updated if those other tasks have
|
1154
|
+
# themselves had their data indices updated:
|
1155
|
+
assert elem_src.task_ref
|
1156
|
+
ups_i = all_updates.get(param_typ, {}).get(elem_src.task_ref)
|
1157
|
+
|
1158
|
+
if ups_i:
|
1159
|
+
# if a further-downstream task has a task-input source
|
1160
|
+
# that points to this task, this will also need updating:
|
1161
|
+
all_updates[param_typ][task.insert_ID].update(ups_i)
|
1162
|
+
|
1163
|
+
else:
|
1164
|
+
continue
|
1165
|
+
|
1166
|
+
for iter_i in elem.iterations:
|
1167
|
+
|
1168
|
+
# update the iteration data index and any pending runs:
|
1169
|
+
iter_old_di = iter_i.data_idx[f"inputs.{param_typ}"]
|
1170
|
+
|
1171
|
+
is_group = True
|
1172
|
+
if not isinstance(iter_old_di, list):
|
1173
|
+
is_group = False
|
1174
|
+
iter_old_di = [iter_old_di]
|
1175
|
+
|
1176
|
+
iter_new_dis = [ups_i.get(i, i) for i in iter_old_di]
|
1177
|
+
|
1178
|
+
if iter_new_dis != iter_old_di:
|
1179
|
+
iter_new_data_idx[iter_i.id_][
|
1180
|
+
f"inputs.{param_typ}"
|
1181
|
+
] = (iter_new_dis if is_group else iter_new_dis[0])
|
1182
|
+
|
1183
|
+
for run_j in iter_i.action_runs:
|
1184
|
+
if run_j.status is EARStatus.pending:
|
1185
|
+
try:
|
1186
|
+
old_di = run_j.data_idx[f"inputs.{param_typ}"]
|
1187
|
+
except KeyError:
|
1188
|
+
# not all actions will include this input
|
1189
|
+
continue
|
1190
|
+
|
1191
|
+
is_group = True
|
1192
|
+
if not isinstance(old_di, list):
|
1193
|
+
is_group = False
|
1194
|
+
old_di = [old_di]
|
858
1195
|
|
859
|
-
|
1196
|
+
new_dis = [ups_i.get(i, i) for i in old_di]
|
1197
|
+
|
1198
|
+
if new_dis != old_di:
|
1199
|
+
run_new_data_idx[run_j.id_][
|
1200
|
+
f"inputs.{param_typ}"
|
1201
|
+
] = (new_dis if is_group else new_dis[0])
|
1202
|
+
|
1203
|
+
# now update data indices (TODO: including in cache!)
|
1204
|
+
if iter_new_data_idx:
|
1205
|
+
self.workflow._store.update_iter_data_indices(iter_new_data_idx)
|
1206
|
+
|
1207
|
+
if run_new_data_idx:
|
1208
|
+
self.workflow._store.update_run_data_indices(run_new_data_idx)
|
1209
|
+
|
1210
|
+
def test_termination(self, element_iter) -> bool:
|
860
1211
|
"""Check if a loop should terminate, given the specified completed element
|
861
1212
|
iteration."""
|
862
1213
|
if self.template.termination:
|
863
1214
|
return self.template.termination.test(element_iter)
|
864
1215
|
return False
|
1216
|
+
|
1217
|
+
@TimeIt.decorator
|
1218
|
+
def get_element_IDs(self):
|
1219
|
+
elem_IDs = [
|
1220
|
+
j
|
1221
|
+
for i in self.task_insert_IDs
|
1222
|
+
for j in self.workflow.tasks.get(insert_ID=i).element_IDs
|
1223
|
+
]
|
1224
|
+
return elem_IDs
|
1225
|
+
|
1226
|
+
@TimeIt.decorator
|
1227
|
+
def get_elements(self):
|
1228
|
+
return self.workflow.get_elements_from_IDs(self.get_element_IDs())
|
1229
|
+
|
1230
|
+
@TimeIt.decorator
|
1231
|
+
def skip_downstream_iterations(self, elem_iter) -> list[int]:
|
1232
|
+
"""
|
1233
|
+
Parameters
|
1234
|
+
----------
|
1235
|
+
elem_iter
|
1236
|
+
The element iteration whose subsequent iterations should be skipped.
|
1237
|
+
dep_element_IDs
|
1238
|
+
List of elements that are dependent (recursively) on the element
|
1239
|
+
of `elem_iter`.
|
1240
|
+
"""
|
1241
|
+
current_iter_idx = elem_iter.loop_idx[self.name]
|
1242
|
+
current_task_iID = elem_iter.task.insert_ID
|
1243
|
+
self._app.logger.info(
|
1244
|
+
f"setting loop {self.name!r} iterations downstream of current iteration "
|
1245
|
+
f"index {current_iter_idx} to skip"
|
1246
|
+
)
|
1247
|
+
elements = self.get_elements()
|
1248
|
+
|
1249
|
+
# TODO: fix for multiple loop cycles
|
1250
|
+
warn(
|
1251
|
+
"skip downstream iterations does not work correctly for multiple loop cycles!"
|
1252
|
+
)
|
1253
|
+
|
1254
|
+
to_skip = []
|
1255
|
+
for elem in elements:
|
1256
|
+
for iter_i in elem.iterations:
|
1257
|
+
if iter_i.loop_idx[self.name] > current_iter_idx or (
|
1258
|
+
iter_i.loop_idx[self.name] == current_iter_idx
|
1259
|
+
and iter_i.task.insert_ID > current_task_iID
|
1260
|
+
):
|
1261
|
+
to_skip.extend(iter_i.EAR_IDs_flat)
|
1262
|
+
|
1263
|
+
self._app.logger.info(
|
1264
|
+
f"{len(to_skip)} runs will be set to skip: {shorten_list_str(to_skip)}"
|
1265
|
+
)
|
1266
|
+
self.workflow.set_EAR_skip({k: SkipReason.LOOP_TERMINATION for k in to_skip})
|
1267
|
+
|
1268
|
+
return to_skip
|