hpcflow-new2 0.2.0a190__py3-none-any.whl → 0.2.0a200__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hpcflow/__pyinstaller/hook-hpcflow.py +1 -0
- hpcflow/_version.py +1 -1
- hpcflow/data/scripts/bad_script.py +2 -0
- hpcflow/data/scripts/do_nothing.py +2 -0
- hpcflow/data/scripts/env_specifier_test/input_file_generator_pass_env_spec.py +4 -0
- hpcflow/data/scripts/env_specifier_test/main_script_test_pass_env_spec.py +8 -0
- hpcflow/data/scripts/env_specifier_test/output_file_parser_pass_env_spec.py +4 -0
- hpcflow/data/scripts/env_specifier_test/v1/input_file_generator_basic.py +4 -0
- hpcflow/data/scripts/env_specifier_test/v1/main_script_test_direct_in_direct_out.py +7 -0
- hpcflow/data/scripts/env_specifier_test/v1/output_file_parser_basic.py +4 -0
- hpcflow/data/scripts/env_specifier_test/v2/main_script_test_direct_in_direct_out.py +7 -0
- hpcflow/data/scripts/input_file_generator_basic.py +3 -0
- hpcflow/data/scripts/input_file_generator_basic_FAIL.py +3 -0
- hpcflow/data/scripts/input_file_generator_test_stdout_stderr.py +8 -0
- hpcflow/data/scripts/main_script_test_direct_in.py +3 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_2.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed_group.py +7 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_3.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_group_direct_out_3.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_group_one_fail_direct_out_3.py +6 -0
- hpcflow/data/scripts/main_script_test_hdf5_in_obj_2.py +12 -0
- hpcflow/data/scripts/main_script_test_json_out_FAIL.py +3 -0
- hpcflow/data/scripts/main_script_test_shell_env_vars.py +12 -0
- hpcflow/data/scripts/main_script_test_std_out_std_err.py +6 -0
- hpcflow/data/scripts/output_file_parser_basic.py +3 -0
- hpcflow/data/scripts/output_file_parser_basic_FAIL.py +7 -0
- hpcflow/data/scripts/output_file_parser_test_stdout_stderr.py +8 -0
- hpcflow/data/scripts/script_exit_test.py +5 -0
- hpcflow/data/template_components/environments.yaml +1 -1
- hpcflow/sdk/__init__.py +5 -0
- hpcflow/sdk/app.py +166 -92
- hpcflow/sdk/cli.py +263 -84
- hpcflow/sdk/cli_common.py +99 -5
- hpcflow/sdk/config/callbacks.py +38 -1
- hpcflow/sdk/config/config.py +102 -13
- hpcflow/sdk/config/errors.py +19 -5
- hpcflow/sdk/config/types.py +3 -0
- hpcflow/sdk/core/__init__.py +25 -1
- hpcflow/sdk/core/actions.py +914 -262
- hpcflow/sdk/core/cache.py +76 -34
- hpcflow/sdk/core/command_files.py +14 -128
- hpcflow/sdk/core/commands.py +35 -6
- hpcflow/sdk/core/element.py +122 -50
- hpcflow/sdk/core/errors.py +58 -2
- hpcflow/sdk/core/execute.py +207 -0
- hpcflow/sdk/core/loop.py +408 -50
- hpcflow/sdk/core/loop_cache.py +4 -4
- hpcflow/sdk/core/parameters.py +382 -37
- hpcflow/sdk/core/run_dir_files.py +13 -40
- hpcflow/sdk/core/skip_reason.py +7 -0
- hpcflow/sdk/core/task.py +119 -30
- hpcflow/sdk/core/task_schema.py +68 -0
- hpcflow/sdk/core/test_utils.py +66 -27
- hpcflow/sdk/core/types.py +54 -1
- hpcflow/sdk/core/utils.py +136 -19
- hpcflow/sdk/core/workflow.py +1587 -356
- hpcflow/sdk/data/workflow_spec_schema.yaml +2 -0
- hpcflow/sdk/demo/cli.py +7 -0
- hpcflow/sdk/helper/cli.py +1 -0
- hpcflow/sdk/log.py +42 -15
- hpcflow/sdk/persistence/base.py +405 -53
- hpcflow/sdk/persistence/json.py +177 -52
- hpcflow/sdk/persistence/pending.py +237 -69
- hpcflow/sdk/persistence/store_resource.py +3 -2
- hpcflow/sdk/persistence/types.py +15 -4
- hpcflow/sdk/persistence/zarr.py +928 -81
- hpcflow/sdk/submission/jobscript.py +1408 -489
- hpcflow/sdk/submission/schedulers/__init__.py +40 -5
- hpcflow/sdk/submission/schedulers/direct.py +33 -19
- hpcflow/sdk/submission/schedulers/sge.py +51 -16
- hpcflow/sdk/submission/schedulers/slurm.py +44 -16
- hpcflow/sdk/submission/schedulers/utils.py +7 -2
- hpcflow/sdk/submission/shells/base.py +68 -20
- hpcflow/sdk/submission/shells/bash.py +222 -129
- hpcflow/sdk/submission/shells/powershell.py +200 -150
- hpcflow/sdk/submission/submission.py +852 -119
- hpcflow/sdk/submission/types.py +18 -21
- hpcflow/sdk/typing.py +24 -5
- hpcflow/sdk/utils/arrays.py +71 -0
- hpcflow/sdk/utils/deferred_file.py +55 -0
- hpcflow/sdk/utils/hashing.py +16 -0
- hpcflow/sdk/utils/patches.py +12 -0
- hpcflow/sdk/utils/strings.py +33 -0
- hpcflow/tests/api/test_api.py +32 -0
- hpcflow/tests/conftest.py +19 -0
- hpcflow/tests/data/benchmark_script_runner.yaml +26 -0
- hpcflow/tests/data/multi_path_sequences.yaml +29 -0
- hpcflow/tests/data/workflow_test_run_abort.yaml +34 -35
- hpcflow/tests/schedulers/sge/test_sge_submission.py +36 -0
- hpcflow/tests/scripts/test_input_file_generators.py +282 -0
- hpcflow/tests/scripts/test_main_scripts.py +821 -70
- hpcflow/tests/scripts/test_non_snippet_script.py +46 -0
- hpcflow/tests/scripts/test_ouput_file_parsers.py +353 -0
- hpcflow/tests/shells/wsl/test_wsl_submission.py +6 -0
- hpcflow/tests/unit/test_action.py +176 -0
- hpcflow/tests/unit/test_app.py +20 -0
- hpcflow/tests/unit/test_cache.py +46 -0
- hpcflow/tests/unit/test_cli.py +133 -0
- hpcflow/tests/unit/test_config.py +122 -1
- hpcflow/tests/unit/test_element_iteration.py +47 -0
- hpcflow/tests/unit/test_jobscript_unit.py +757 -0
- hpcflow/tests/unit/test_loop.py +1332 -27
- hpcflow/tests/unit/test_meta_task.py +325 -0
- hpcflow/tests/unit/test_multi_path_sequences.py +229 -0
- hpcflow/tests/unit/test_parameter.py +13 -0
- hpcflow/tests/unit/test_persistence.py +190 -8
- hpcflow/tests/unit/test_run.py +109 -3
- hpcflow/tests/unit/test_run_directories.py +29 -0
- hpcflow/tests/unit/test_shell.py +20 -0
- hpcflow/tests/unit/test_submission.py +5 -76
- hpcflow/tests/unit/test_workflow_template.py +31 -0
- hpcflow/tests/unit/utils/test_arrays.py +40 -0
- hpcflow/tests/unit/utils/test_deferred_file_writer.py +34 -0
- hpcflow/tests/unit/utils/test_hashing.py +65 -0
- hpcflow/tests/unit/utils/test_patches.py +5 -0
- hpcflow/tests/unit/utils/test_redirect_std.py +50 -0
- hpcflow/tests/workflows/__init__.py +0 -0
- hpcflow/tests/workflows/test_directory_structure.py +31 -0
- hpcflow/tests/workflows/test_jobscript.py +332 -0
- hpcflow/tests/workflows/test_run_status.py +198 -0
- hpcflow/tests/workflows/test_skip_downstream.py +696 -0
- hpcflow/tests/workflows/test_submission.py +140 -0
- hpcflow/tests/workflows/test_workflows.py +142 -2
- hpcflow/tests/workflows/test_zip.py +18 -0
- hpcflow/viz_demo.ipynb +6587 -3
- {hpcflow_new2-0.2.0a190.dist-info → hpcflow_new2-0.2.0a200.dist-info}/METADATA +7 -4
- hpcflow_new2-0.2.0a200.dist-info/RECORD +222 -0
- hpcflow_new2-0.2.0a190.dist-info/RECORD +0 -165
- {hpcflow_new2-0.2.0a190.dist-info → hpcflow_new2-0.2.0a200.dist-info}/LICENSE +0 -0
- {hpcflow_new2-0.2.0a190.dist-info → hpcflow_new2-0.2.0a200.dist-info}/WHEEL +0 -0
- {hpcflow_new2-0.2.0a190.dist-info → hpcflow_new2-0.2.0a200.dist-info}/entry_points.txt +0 -0
hpcflow/sdk/core/loop.py
CHANGED
@@ -6,24 +6,33 @@ notably looping over a set of values or until a condition holds.
|
|
6
6
|
|
7
7
|
from __future__ import annotations
|
8
8
|
|
9
|
+
from collections import defaultdict
|
9
10
|
import copy
|
11
|
+
from pprint import pp
|
12
|
+
import pprint
|
13
|
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
14
|
+
from warnings import warn
|
10
15
|
from collections import defaultdict
|
11
16
|
from itertools import chain
|
12
|
-
from typing import TYPE_CHECKING
|
17
|
+
from typing import cast, TYPE_CHECKING
|
13
18
|
from typing_extensions import override
|
14
19
|
|
15
20
|
from hpcflow.sdk.core.app_aware import AppAware
|
21
|
+
from hpcflow.sdk.core.actions import EARStatus
|
22
|
+
from hpcflow.sdk.core.skip_reason import SkipReason
|
16
23
|
from hpcflow.sdk.core.errors import LoopTaskSubsetError
|
17
24
|
from hpcflow.sdk.core.json_like import ChildObjectSpec, JSONLike
|
18
25
|
from hpcflow.sdk.core.loop_cache import LoopCache, LoopIndex
|
19
|
-
from hpcflow.sdk.core.enums import InputSourceType
|
26
|
+
from hpcflow.sdk.core.enums import InputSourceType, TaskSourceType
|
20
27
|
from hpcflow.sdk.core.utils import check_valid_py_identifier, nth_key, nth_value
|
28
|
+
from hpcflow.sdk.utils.strings import shorten_list_str
|
21
29
|
from hpcflow.sdk.log import TimeIt
|
22
30
|
|
23
31
|
if TYPE_CHECKING:
|
24
32
|
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
25
33
|
from typing import Any, ClassVar
|
26
34
|
from typing_extensions import Self, TypeIs
|
35
|
+
from rich.status import Status
|
27
36
|
from ..typing import DataIndex, ParamSource
|
28
37
|
from .parameters import SchemaInput, InputSource
|
29
38
|
from .rule import Rule
|
@@ -62,6 +71,8 @@ class Loop(JSONLike):
|
|
62
71
|
Specify input parameters that should not iterate.
|
63
72
|
termination: v~hpcflow.app.Rule
|
64
73
|
Stopping criterion, expressed as a rule.
|
74
|
+
termination_task: int | ~hpcflow.app.WorkflowTask
|
75
|
+
Task at which to evaluate the termination condition.
|
65
76
|
"""
|
66
77
|
|
67
78
|
_child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
|
@@ -79,6 +90,7 @@ class Loop(JSONLike):
|
|
79
90
|
name: str | None = None,
|
80
91
|
non_iterable_parameters: list[str] | None = None,
|
81
92
|
termination: Rule | None = None,
|
93
|
+
termination_task: int | WorkflowTask | None = None,
|
82
94
|
) -> None:
|
83
95
|
_task_insert_IDs: list[int] = []
|
84
96
|
for task in tasks:
|
@@ -89,14 +101,34 @@ class Loop(JSONLike):
|
|
89
101
|
else:
|
90
102
|
raise TypeError(
|
91
103
|
f"`tasks` must be a list whose elements are either task insert IDs "
|
92
|
-
f"or WorkflowTask objects, but received the following: {tasks!r}."
|
104
|
+
f"or `WorkflowTask` objects, but received the following: {tasks!r}."
|
93
105
|
)
|
94
106
|
|
107
|
+
if termination_task is None:
|
108
|
+
_term_task_iID = _task_insert_IDs[-1] # terminate on final task by default
|
109
|
+
elif self.__is_WorkflowTask(termination_task):
|
110
|
+
_term_task_iID = termination_task.insert_ID
|
111
|
+
elif isinstance(task, int):
|
112
|
+
_term_task_iID = termination_task
|
113
|
+
else:
|
114
|
+
raise TypeError(
|
115
|
+
f"`termination_task` must be a task insert ID or a `WorkflowTask` "
|
116
|
+
f"object, but received the following: {termination_task!r}."
|
117
|
+
)
|
118
|
+
|
119
|
+
if _term_task_iID not in _task_insert_IDs:
|
120
|
+
raise ValueError(
|
121
|
+
f"If specified, `termination_task` (provided: {termination_task!r}) must "
|
122
|
+
f"refer to a task that is part of the loop. Available task insert IDs "
|
123
|
+
f"are: {_task_insert_IDs!r}."
|
124
|
+
)
|
125
|
+
|
95
126
|
self._task_insert_IDs = _task_insert_IDs
|
96
127
|
self._num_iterations = num_iterations
|
97
128
|
self._name = check_valid_py_identifier(name) if name else name
|
98
129
|
self._non_iterable_parameters = non_iterable_parameters or []
|
99
130
|
self._termination = termination
|
131
|
+
self._termination_task_insert_ID = _term_task_iID
|
100
132
|
|
101
133
|
self._workflow_template: WorkflowTemplate | None = (
|
102
134
|
None # assigned by parent WorkflowTemplate
|
@@ -114,7 +146,15 @@ class Loop(JSONLike):
|
|
114
146
|
insert_IDs = json_like.pop("task_insert_IDs")
|
115
147
|
else:
|
116
148
|
insert_IDs = json_like.pop("tasks")
|
117
|
-
|
149
|
+
|
150
|
+
if "termination_task_insert_ID" in json_like:
|
151
|
+
tt_iID = json_like.pop("termination_task_insert_ID")
|
152
|
+
elif "termination_task" in json_like:
|
153
|
+
tt_iID = json_like.pop("termination_task")
|
154
|
+
else:
|
155
|
+
tt_iID = None
|
156
|
+
|
157
|
+
return cls(tasks=insert_IDs, termination_task=tt_iID, **json_like)
|
118
158
|
|
119
159
|
@property
|
120
160
|
def task_insert_IDs(self) -> tuple[int, ...]:
|
@@ -149,6 +189,25 @@ class Loop(JSONLike):
|
|
149
189
|
"""
|
150
190
|
return self._termination
|
151
191
|
|
192
|
+
@property
|
193
|
+
def termination_task_insert_ID(self) -> int:
|
194
|
+
"""
|
195
|
+
The insert ID of the task at which the loop will terminate.
|
196
|
+
"""
|
197
|
+
return self._termination_task_insert_ID
|
198
|
+
|
199
|
+
@property
|
200
|
+
def termination_task(self) -> WorkflowTask:
|
201
|
+
"""
|
202
|
+
The task at which the loop will terminate.
|
203
|
+
"""
|
204
|
+
if (wt := self.workflow_template) is None:
|
205
|
+
raise RuntimeError(
|
206
|
+
"Workflow template must be assigned to retrieve task objects of the loop."
|
207
|
+
)
|
208
|
+
assert wt.workflow
|
209
|
+
return wt.workflow.tasks.get(insert_ID=self.termination_task_insert_ID)
|
210
|
+
|
152
211
|
@property
|
153
212
|
def workflow_template(self) -> WorkflowTemplate | None:
|
154
213
|
"""
|
@@ -212,6 +271,7 @@ class Loop(JSONLike):
|
|
212
271
|
def __deepcopy__(self, memo: dict[int, Any]) -> Self:
|
213
272
|
kwargs = self.to_dict()
|
214
273
|
kwargs["tasks"] = kwargs.pop("task_insert_IDs")
|
274
|
+
kwargs["termination_task"] = kwargs.pop("termination_task_insert_ID")
|
215
275
|
obj = self.__class__(**copy.deepcopy(kwargs, memo))
|
216
276
|
obj._workflow_template = self._workflow_template
|
217
277
|
return obj
|
@@ -234,6 +294,9 @@ class WorkflowLoop(AppAware):
|
|
234
294
|
Description of what iterations have been added.
|
235
295
|
iterable_parameters:
|
236
296
|
Description of what parameters are being iterated over.
|
297
|
+
output_parameters:
|
298
|
+
Decription of what parameter are output from this loop, and the final task insert
|
299
|
+
ID from which they are output.
|
237
300
|
parents: list[str]
|
238
301
|
The paths to the parent entities of this loop.
|
239
302
|
"""
|
@@ -245,6 +308,7 @@ class WorkflowLoop(AppAware):
|
|
245
308
|
template: Loop,
|
246
309
|
num_added_iterations: dict[tuple[int, ...], int],
|
247
310
|
iterable_parameters: dict[str, IterableParam],
|
311
|
+
output_parameters: dict[str, int],
|
248
312
|
parents: list[str],
|
249
313
|
) -> None:
|
250
314
|
self._index = index
|
@@ -252,10 +316,11 @@ class WorkflowLoop(AppAware):
|
|
252
316
|
self._template = template
|
253
317
|
self._num_added_iterations = num_added_iterations
|
254
318
|
self._iterable_parameters = iterable_parameters
|
319
|
+
self._output_parameters = output_parameters
|
255
320
|
self._parents = parents
|
256
321
|
|
257
|
-
# appended to
|
258
|
-
# reset and added to `self._parents` on dump to disk:
|
322
|
+
# appended to when adding an empty loop to the workflow that is a parent of this
|
323
|
+
# loop; reset and added to `self._parents` on dump to disk:
|
259
324
|
self._pending_parents: list[str] = []
|
260
325
|
|
261
326
|
# used for `num_added_iterations` when a new loop iteration is added, or when
|
@@ -273,16 +338,6 @@ class WorkflowLoop(AppAware):
|
|
273
338
|
if task_indices != tuple(range(task_min, task_max + 1)):
|
274
339
|
raise LoopTaskSubsetError(self.name, self.task_indices)
|
275
340
|
|
276
|
-
for task in self.downstream_tasks:
|
277
|
-
for param in self.iterable_parameters:
|
278
|
-
if param in task.template.all_schema_input_types:
|
279
|
-
raise NotImplementedError(
|
280
|
-
f"Downstream task {task.unique_name!r} of loop {self.name!r} "
|
281
|
-
f"has as one of its input parameters this loop's iterable "
|
282
|
-
f"parameter {param!r}. This parameter cannot be sourced "
|
283
|
-
f"correctly."
|
284
|
-
)
|
285
|
-
|
286
341
|
def __repr__(self) -> str:
|
287
342
|
return (
|
288
343
|
f"{self.__class__.__name__}(template={self.template!r}, "
|
@@ -404,12 +459,20 @@ class WorkflowLoop(AppAware):
|
|
404
459
|
return self.template.name
|
405
460
|
|
406
461
|
@property
|
407
|
-
def iterable_parameters(self) ->
|
462
|
+
def iterable_parameters(self) -> dict[str, IterableParam]:
|
408
463
|
"""
|
409
464
|
The parameters that are being iterated over.
|
410
465
|
"""
|
411
466
|
return self._iterable_parameters
|
412
467
|
|
468
|
+
@property
|
469
|
+
def output_parameters(self) -> dict[str, int]:
|
470
|
+
"""
|
471
|
+
The parameters that are outputs of this loop, and the final task insert ID from
|
472
|
+
which each parameter is output.
|
473
|
+
"""
|
474
|
+
return self._output_parameters
|
475
|
+
|
413
476
|
@property
|
414
477
|
def num_iterations(self) -> int:
|
415
478
|
"""
|
@@ -433,7 +496,9 @@ class WorkflowLoop(AppAware):
|
|
433
496
|
|
434
497
|
@staticmethod
|
435
498
|
@TimeIt.decorator
|
436
|
-
def
|
499
|
+
def _find_iterable_and_output_parameters(
|
500
|
+
loop_template: Loop,
|
501
|
+
) -> tuple[dict[str, IterableParam], dict[str, int]]:
|
437
502
|
all_inputs_first_idx: dict[str, int] = {}
|
438
503
|
all_outputs_idx: dict[str, list[int]] = defaultdict(list)
|
439
504
|
for task in loop_template.task_objects:
|
@@ -442,6 +507,7 @@ class WorkflowLoop(AppAware):
|
|
442
507
|
for typ in task.template.all_schema_output_types:
|
443
508
|
all_outputs_idx[typ].append(task.insert_ID)
|
444
509
|
|
510
|
+
# find input parameters that are also output parameters at a later/same task:
|
445
511
|
iterable_params: dict[str, IterableParam] = {}
|
446
512
|
for typ, first_idx in all_inputs_first_idx.items():
|
447
513
|
if typ in all_outputs_idx and first_idx <= all_outputs_idx[typ][0]:
|
@@ -453,7 +519,9 @@ class WorkflowLoop(AppAware):
|
|
453
519
|
for non_iter in loop_template.non_iterable_parameters:
|
454
520
|
iterable_params.pop(non_iter, None)
|
455
521
|
|
456
|
-
|
522
|
+
final_out_tasks = {k: v[-1] for k, v in all_outputs_idx.items()}
|
523
|
+
|
524
|
+
return iterable_params, final_out_tasks
|
457
525
|
|
458
526
|
@classmethod
|
459
527
|
@TimeIt.decorator
|
@@ -478,21 +546,20 @@ class WorkflowLoop(AppAware):
|
|
478
546
|
iter_loop_idx: list[dict]
|
479
547
|
Iteration information from parent loops.
|
480
548
|
"""
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
num_added_iters = {
|
487
|
-
tuple(l_idx[nm] for nm in parent_names): 1 for l_idx in iter_loop_idx
|
488
|
-
}
|
549
|
+
parent_loops = cls._get_parent_loops(index, workflow, template)
|
550
|
+
parent_names = [i.name for i in parent_loops]
|
551
|
+
num_added_iters: dict[tuple[int, ...], int] = {}
|
552
|
+
for i in iter_loop_idx:
|
553
|
+
num_added_iters[tuple([i[j] for j in parent_names])] = 1
|
489
554
|
|
555
|
+
iter_params, out_params = cls._find_iterable_and_output_parameters(template)
|
490
556
|
return cls(
|
491
557
|
index=index,
|
492
558
|
workflow=workflow,
|
493
559
|
template=template,
|
494
560
|
num_added_iterations=num_added_iters,
|
495
|
-
iterable_parameters=
|
561
|
+
iterable_parameters=iter_params,
|
562
|
+
output_parameters=out_params,
|
496
563
|
parents=parent_names,
|
497
564
|
)
|
498
565
|
|
@@ -551,6 +618,7 @@ class WorkflowLoop(AppAware):
|
|
551
618
|
self,
|
552
619
|
parent_loop_indices: Mapping[str, int] | None = None,
|
553
620
|
cache: LoopCache | None = None,
|
621
|
+
status: Status | None = None,
|
554
622
|
) -> None:
|
555
623
|
"""
|
556
624
|
Add an iteration to this loop.
|
@@ -588,6 +656,13 @@ class WorkflowLoop(AppAware):
|
|
588
656
|
iters_key_dct.get(j, 0) for j in child.parents
|
589
657
|
)
|
590
658
|
|
659
|
+
# needed for the case where an inner loop has only one iteration, meaning
|
660
|
+
# `add_iteration` will not be called recursively on it:
|
661
|
+
self.workflow._store.update_loop_num_iters(
|
662
|
+
index=child.index,
|
663
|
+
num_added_iters=child.num_added_iterations,
|
664
|
+
)
|
665
|
+
|
591
666
|
for task in self.task_objects:
|
592
667
|
new_loop_idx = LoopIndex(iters_key_dct) + {
|
593
668
|
child.name: 0
|
@@ -723,21 +798,40 @@ class WorkflowLoop(AppAware):
|
|
723
798
|
# add iterations to fixed-number-iteration children only:
|
724
799
|
for child in child_loops[::-1]:
|
725
800
|
if child.num_iterations is not None:
|
726
|
-
|
801
|
+
if status:
|
802
|
+
status_prev = str(status.status).rstrip(".")
|
803
|
+
for iter_idx in range(child.num_iterations - 1):
|
804
|
+
if status:
|
805
|
+
status.update(
|
806
|
+
f"{status_prev} --> ({child.name!r}): iteration "
|
807
|
+
f"{iter_idx + 2}/{child.num_iterations}."
|
808
|
+
)
|
727
809
|
par_idx = {parent_name: 0 for parent_name in child.parents}
|
728
810
|
if parent_loop_indices:
|
729
811
|
par_idx.update(parent_loop_indices)
|
730
812
|
par_idx[self.name] = cur_loop_idx + 1
|
731
813
|
child.add_iteration(parent_loop_indices=par_idx, cache=cache)
|
732
814
|
|
815
|
+
self.__update_loop_downstream_data_idx(parent_loop_indices_)
|
816
|
+
|
733
817
|
def __get_src_ID_and_groups(
|
734
|
-
self,
|
818
|
+
self,
|
819
|
+
elem_ID: int,
|
820
|
+
iter_dat: IterableParam,
|
821
|
+
inp: SchemaInput,
|
822
|
+
cache: LoopCache,
|
823
|
+
task: WorkflowTask,
|
735
824
|
) -> tuple[int, Sequence[int]]:
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
825
|
+
# `cache.elements` contains only elements that are part of the
|
826
|
+
# loop, so indexing a dependent element may raise:
|
827
|
+
src_elem_IDs = {}
|
828
|
+
for k, v in cache.element_dependents[elem_ID].items():
|
829
|
+
try:
|
830
|
+
if cache.elements[k]["task_insert_ID"] == iter_dat["output_tasks"][-1]:
|
831
|
+
src_elem_IDs[k] = v
|
832
|
+
except KeyError:
|
833
|
+
continue
|
834
|
+
|
741
835
|
# consider groups
|
742
836
|
single_data = inp.single_labelled_data
|
743
837
|
assert single_data is not None
|
@@ -753,7 +847,8 @@ class WorkflowLoop(AppAware):
|
|
753
847
|
f"Multiple elements found in the iterable parameter "
|
754
848
|
f"{inp!r}'s latest output task (insert ID: "
|
755
849
|
f"{iter_dat['output_tasks'][-1]}) that can be used "
|
756
|
-
f"to parametrise the next iteration
|
850
|
+
f"to parametrise the next iteration of task "
|
851
|
+
f"{task.unique_name!r}: "
|
757
852
|
f"{list(src_elem_IDs)!r}."
|
758
853
|
)
|
759
854
|
|
@@ -787,11 +882,13 @@ class WorkflowLoop(AppAware):
|
|
787
882
|
# identify element(s) from which this iterable input should be
|
788
883
|
# parametrised:
|
789
884
|
if task.insert_ID == iter_dat["output_tasks"][-1]:
|
885
|
+
# single-task loop
|
790
886
|
src_elem_ID = elem_ID
|
791
887
|
grouped_elems: Sequence[int] = []
|
792
888
|
else:
|
889
|
+
# multi-task loop
|
793
890
|
src_elem_ID, grouped_elems = self.__get_src_ID_and_groups(
|
794
|
-
elem_ID, iter_dat, inp, cache
|
891
|
+
elem_ID, iter_dat, inp, cache, task
|
795
892
|
)
|
796
893
|
|
797
894
|
child_loop_max_iters: dict[str, int] = {}
|
@@ -803,12 +900,13 @@ class WorkflowLoop(AppAware):
|
|
803
900
|
self.name: cur_loop_idx,
|
804
901
|
}
|
805
902
|
for loop in child_loops:
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
903
|
+
if iter_dat["output_tasks"][-1] in loop.task_insert_IDs:
|
904
|
+
i_num_iters = loop.num_added_iterations[
|
905
|
+
tuple(child_iter_parents[j] for j in loop.parents)
|
906
|
+
]
|
907
|
+
i_max = i_num_iters - 1
|
908
|
+
child_iter_parents[loop.name] = i_max
|
909
|
+
child_loop_max_iters[loop.name] = i_max
|
812
910
|
|
813
911
|
loop_idx_key = LoopIndex(child_loop_max_iters)
|
814
912
|
loop_idx_key.update(parent_loop_same_iters)
|
@@ -872,13 +970,20 @@ class WorkflowLoop(AppAware):
|
|
872
970
|
src_elem_IDs = cache.element_dependents[
|
873
971
|
self.workflow.tasks.get(insert_ID=tiID).element_IDs[e_idx]
|
874
972
|
]
|
875
|
-
#
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
973
|
+
# `cache.elements` contains only elements that are part of the loop, so
|
974
|
+
# indexing a dependent element may raise:
|
975
|
+
src_elem_IDs_i = []
|
976
|
+
for k, _v in src_elem_IDs.items():
|
977
|
+
try:
|
978
|
+
if (
|
979
|
+
cache.elements[k]["task_insert_ID"] == task.insert_ID
|
980
|
+
and k == elem_ID
|
981
|
+
# filter src_elem_IDs_i for matching element IDs
|
982
|
+
):
|
983
|
+
|
984
|
+
src_elem_IDs_i.append(k)
|
985
|
+
except KeyError:
|
986
|
+
continue
|
882
987
|
|
883
988
|
if len(src_elem_IDs_i) == 1:
|
884
989
|
new_sources.append((tiID, e_idx))
|
@@ -902,9 +1007,262 @@ class WorkflowLoop(AppAware):
|
|
902
1007
|
else:
|
903
1008
|
yield from seq
|
904
1009
|
|
1010
|
+
def __update_loop_downstream_data_idx(
|
1011
|
+
self,
|
1012
|
+
parent_loop_indices: Mapping[str, int],
|
1013
|
+
):
|
1014
|
+
# update data indices of loop-downstream tasks that depend on task outputs from
|
1015
|
+
# this loop:
|
1016
|
+
|
1017
|
+
# keys: iter or run ID, values: dict of param type and new parameter index
|
1018
|
+
iter_new_data_idx: dict[int, DataIndex] = defaultdict(dict)
|
1019
|
+
run_new_data_idx: dict[int, DataIndex] = defaultdict(dict)
|
1020
|
+
|
1021
|
+
param_sources = self.workflow.get_all_parameter_sources()
|
1022
|
+
|
1023
|
+
# keys are parameter type, then task insert ID, then data index keys mapping to
|
1024
|
+
# their updated values:
|
1025
|
+
all_updates: dict[str, dict[int, dict[int, int]]] = defaultdict(
|
1026
|
+
lambda: defaultdict(dict)
|
1027
|
+
)
|
1028
|
+
|
1029
|
+
for task in self.downstream_tasks:
|
1030
|
+
for elem in task.elements:
|
1031
|
+
for param_typ, param_out_task_iID in self.output_parameters.items():
|
1032
|
+
if param_typ in task.template.all_schema_input_types:
|
1033
|
+
# this element's input *might* need updating, only if it has a
|
1034
|
+
# task input source type that is this loop's output task for this
|
1035
|
+
# parameter:
|
1036
|
+
elem_src = elem.input_sources[f"inputs.{param_typ}"]
|
1037
|
+
if (
|
1038
|
+
elem_src.source_type is InputSourceType.TASK
|
1039
|
+
and elem_src.task_source_type is TaskSourceType.OUTPUT
|
1040
|
+
and elem_src.task_ref == param_out_task_iID
|
1041
|
+
):
|
1042
|
+
for iter_i in elem.iterations:
|
1043
|
+
|
1044
|
+
# do not modify element-iterations of previous iterations
|
1045
|
+
# of the current loop:
|
1046
|
+
skip_iter = False
|
1047
|
+
for k, v in parent_loop_indices.items():
|
1048
|
+
if iter_i.loop_idx.get(k) != v:
|
1049
|
+
skip_iter = True
|
1050
|
+
break
|
1051
|
+
|
1052
|
+
if skip_iter:
|
1053
|
+
continue
|
1054
|
+
|
1055
|
+
# update the iteration data index and any pending runs:
|
1056
|
+
iter_old_di = iter_i.data_idx[f"inputs.{param_typ}"]
|
1057
|
+
|
1058
|
+
is_group = True
|
1059
|
+
if not isinstance(iter_old_di, list):
|
1060
|
+
is_group = False
|
1061
|
+
iter_old_di = [iter_old_di]
|
1062
|
+
|
1063
|
+
iter_old_run_source = [
|
1064
|
+
param_sources[i]["EAR_ID"] for i in iter_old_di
|
1065
|
+
]
|
1066
|
+
iter_old_run_objs = self.workflow.get_EARs_from_IDs(
|
1067
|
+
iter_old_run_source
|
1068
|
+
) # TODO: use cache
|
1069
|
+
|
1070
|
+
# need to check the run source is actually from the loop
|
1071
|
+
# output task (it could be from a previous iteration of a
|
1072
|
+
# separate loop in this task):
|
1073
|
+
if any(
|
1074
|
+
i.task.insert_ID != param_out_task_iID
|
1075
|
+
for i in iter_old_run_objs
|
1076
|
+
):
|
1077
|
+
continue
|
1078
|
+
|
1079
|
+
iter_new_iters = [
|
1080
|
+
i.element.iterations[-1] for i in iter_old_run_objs
|
1081
|
+
]
|
1082
|
+
|
1083
|
+
# note: we can cast to int, because output keys never
|
1084
|
+
# have multiple data indices (unlike input keys):
|
1085
|
+
iter_new_dis = [
|
1086
|
+
cast("int", i.get_data_idx()[f"outputs.{param_typ}"])
|
1087
|
+
for i in iter_new_iters
|
1088
|
+
]
|
1089
|
+
|
1090
|
+
# keep track of updates so we can also update task-input
|
1091
|
+
# type sources:
|
1092
|
+
all_updates[param_typ][task.insert_ID].update(
|
1093
|
+
dict(zip(iter_old_di, iter_new_dis))
|
1094
|
+
)
|
1095
|
+
|
1096
|
+
iter_new_data_idx[iter_i.id_][f"inputs.{param_typ}"] = (
|
1097
|
+
iter_new_dis if is_group else iter_new_dis[0]
|
1098
|
+
)
|
1099
|
+
|
1100
|
+
for run_j in iter_i.action_runs:
|
1101
|
+
if run_j.status is EARStatus.pending:
|
1102
|
+
try:
|
1103
|
+
old_di = run_j.data_idx[f"inputs.{param_typ}"]
|
1104
|
+
except KeyError:
|
1105
|
+
# not all actions will include this input
|
1106
|
+
continue
|
1107
|
+
|
1108
|
+
is_group = True
|
1109
|
+
if not isinstance(old_di, list):
|
1110
|
+
is_group = False
|
1111
|
+
old_di = [old_di]
|
1112
|
+
|
1113
|
+
old_run_source = [
|
1114
|
+
param_sources[i]["EAR_ID"] for i in old_di
|
1115
|
+
]
|
1116
|
+
old_run_objs = self.workflow.get_EARs_from_IDs(
|
1117
|
+
old_run_source
|
1118
|
+
) # TODO: use cache
|
1119
|
+
|
1120
|
+
# need to check the run source is actually from the loop
|
1121
|
+
# output task (it could be from a previous action in this
|
1122
|
+
# element-iteration):
|
1123
|
+
if any(
|
1124
|
+
i.task.insert_ID != param_out_task_iID
|
1125
|
+
for i in old_run_objs
|
1126
|
+
):
|
1127
|
+
continue
|
1128
|
+
|
1129
|
+
new_iters = [
|
1130
|
+
i.element.iterations[-1] for i in old_run_objs
|
1131
|
+
]
|
1132
|
+
|
1133
|
+
# note: we can cast to int, because output keys
|
1134
|
+
# never have multiple data indices (unlike input
|
1135
|
+
# keys):
|
1136
|
+
new_dis = [
|
1137
|
+
cast(
|
1138
|
+
"int",
|
1139
|
+
i.get_data_idx()[f"outputs.{param_typ}"],
|
1140
|
+
)
|
1141
|
+
for i in new_iters
|
1142
|
+
]
|
1143
|
+
|
1144
|
+
run_new_data_idx[run_j.id_][
|
1145
|
+
f"inputs.{param_typ}"
|
1146
|
+
] = (new_dis if is_group else new_dis[0])
|
1147
|
+
|
1148
|
+
elif (
|
1149
|
+
elem_src.source_type is InputSourceType.TASK
|
1150
|
+
and elem_src.task_source_type is TaskSourceType.INPUT
|
1151
|
+
):
|
1152
|
+
# parameters are that sourced from inputs of other tasks,
|
1153
|
+
# might need to be updated if those other tasks have
|
1154
|
+
# themselves had their data indices updated:
|
1155
|
+
assert elem_src.task_ref
|
1156
|
+
ups_i = all_updates.get(param_typ, {}).get(elem_src.task_ref)
|
1157
|
+
|
1158
|
+
if ups_i:
|
1159
|
+
# if a further-downstream task has a task-input source
|
1160
|
+
# that points to this task, this will also need updating:
|
1161
|
+
all_updates[param_typ][task.insert_ID].update(ups_i)
|
1162
|
+
|
1163
|
+
else:
|
1164
|
+
continue
|
1165
|
+
|
1166
|
+
for iter_i in elem.iterations:
|
1167
|
+
|
1168
|
+
# update the iteration data index and any pending runs:
|
1169
|
+
iter_old_di = iter_i.data_idx[f"inputs.{param_typ}"]
|
1170
|
+
|
1171
|
+
is_group = True
|
1172
|
+
if not isinstance(iter_old_di, list):
|
1173
|
+
is_group = False
|
1174
|
+
iter_old_di = [iter_old_di]
|
1175
|
+
|
1176
|
+
iter_new_dis = [ups_i.get(i, i) for i in iter_old_di]
|
1177
|
+
|
1178
|
+
if iter_new_dis != iter_old_di:
|
1179
|
+
iter_new_data_idx[iter_i.id_][
|
1180
|
+
f"inputs.{param_typ}"
|
1181
|
+
] = (iter_new_dis if is_group else iter_new_dis[0])
|
1182
|
+
|
1183
|
+
for run_j in iter_i.action_runs:
|
1184
|
+
if run_j.status is EARStatus.pending:
|
1185
|
+
try:
|
1186
|
+
old_di = run_j.data_idx[f"inputs.{param_typ}"]
|
1187
|
+
except KeyError:
|
1188
|
+
# not all actions will include this input
|
1189
|
+
continue
|
1190
|
+
|
1191
|
+
is_group = True
|
1192
|
+
if not isinstance(old_di, list):
|
1193
|
+
is_group = False
|
1194
|
+
old_di = [old_di]
|
1195
|
+
|
1196
|
+
new_dis = [ups_i.get(i, i) for i in old_di]
|
1197
|
+
|
1198
|
+
if new_dis != old_di:
|
1199
|
+
run_new_data_idx[run_j.id_][
|
1200
|
+
f"inputs.{param_typ}"
|
1201
|
+
] = (new_dis if is_group else new_dis[0])
|
1202
|
+
|
1203
|
+
# now update data indices (TODO: including in cache!)
|
1204
|
+
if iter_new_data_idx:
|
1205
|
+
self.workflow._store.update_iter_data_indices(iter_new_data_idx)
|
1206
|
+
|
1207
|
+
if run_new_data_idx:
|
1208
|
+
self.workflow._store.update_run_data_indices(run_new_data_idx)
|
1209
|
+
|
905
1210
|
def test_termination(self, element_iter) -> bool:
|
906
1211
|
"""Check if a loop should terminate, given the specified completed element
|
907
1212
|
iteration."""
|
908
1213
|
if self.template.termination:
|
909
1214
|
return self.template.termination.test(element_iter)
|
910
1215
|
return False
|
1216
|
+
|
1217
|
+
@TimeIt.decorator
|
1218
|
+
def get_element_IDs(self):
|
1219
|
+
elem_IDs = [
|
1220
|
+
j
|
1221
|
+
for i in self.task_insert_IDs
|
1222
|
+
for j in self.workflow.tasks.get(insert_ID=i).element_IDs
|
1223
|
+
]
|
1224
|
+
return elem_IDs
|
1225
|
+
|
1226
|
+
@TimeIt.decorator
|
1227
|
+
def get_elements(self):
|
1228
|
+
return self.workflow.get_elements_from_IDs(self.get_element_IDs())
|
1229
|
+
|
1230
|
+
@TimeIt.decorator
|
1231
|
+
def skip_downstream_iterations(self, elem_iter) -> list[int]:
|
1232
|
+
"""
|
1233
|
+
Parameters
|
1234
|
+
----------
|
1235
|
+
elem_iter
|
1236
|
+
The element iteration whose subsequent iterations should be skipped.
|
1237
|
+
dep_element_IDs
|
1238
|
+
List of elements that are dependent (recursively) on the element
|
1239
|
+
of `elem_iter`.
|
1240
|
+
"""
|
1241
|
+
current_iter_idx = elem_iter.loop_idx[self.name]
|
1242
|
+
current_task_iID = elem_iter.task.insert_ID
|
1243
|
+
self._app.logger.info(
|
1244
|
+
f"setting loop {self.name!r} iterations downstream of current iteration "
|
1245
|
+
f"index {current_iter_idx} to skip"
|
1246
|
+
)
|
1247
|
+
elements = self.get_elements()
|
1248
|
+
|
1249
|
+
# TODO: fix for multiple loop cycles
|
1250
|
+
warn(
|
1251
|
+
"skip downstream iterations does not work correctly for multiple loop cycles!"
|
1252
|
+
)
|
1253
|
+
|
1254
|
+
to_skip = []
|
1255
|
+
for elem in elements:
|
1256
|
+
for iter_i in elem.iterations:
|
1257
|
+
if iter_i.loop_idx[self.name] > current_iter_idx or (
|
1258
|
+
iter_i.loop_idx[self.name] == current_iter_idx
|
1259
|
+
and iter_i.task.insert_ID > current_task_iID
|
1260
|
+
):
|
1261
|
+
to_skip.extend(iter_i.EAR_IDs_flat)
|
1262
|
+
|
1263
|
+
self._app.logger.info(
|
1264
|
+
f"{len(to_skip)} runs will be set to skip: {shorten_list_str(to_skip)}"
|
1265
|
+
)
|
1266
|
+
self.workflow.set_EAR_skip({k: SkipReason.LOOP_TERMINATION for k in to_skip})
|
1267
|
+
|
1268
|
+
return to_skip
|