hpcflow 0.1.9__py3-none-any.whl → 0.2.0a271__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hpcflow/__init__.py +2 -11
- hpcflow/__pyinstaller/__init__.py +5 -0
- hpcflow/__pyinstaller/hook-hpcflow.py +40 -0
- hpcflow/_version.py +1 -1
- hpcflow/app.py +43 -0
- hpcflow/cli.py +2 -462
- hpcflow/data/demo_data_manifest/__init__.py +3 -0
- hpcflow/data/demo_data_manifest/demo_data_manifest.json +6 -0
- hpcflow/data/jinja_templates/test/test_template.txt +8 -0
- hpcflow/data/programs/hello_world/README.md +1 -0
- hpcflow/data/programs/hello_world/hello_world.c +87 -0
- hpcflow/data/programs/hello_world/linux/hello_world +0 -0
- hpcflow/data/programs/hello_world/macos/hello_world +0 -0
- hpcflow/data/programs/hello_world/win/hello_world.exe +0 -0
- hpcflow/data/scripts/__init__.py +1 -0
- hpcflow/data/scripts/bad_script.py +2 -0
- hpcflow/data/scripts/demo_task_1_generate_t1_infile_1.py +8 -0
- hpcflow/data/scripts/demo_task_1_generate_t1_infile_2.py +8 -0
- hpcflow/data/scripts/demo_task_1_parse_p3.py +7 -0
- hpcflow/data/scripts/do_nothing.py +2 -0
- hpcflow/data/scripts/env_specifier_test/input_file_generator_pass_env_spec.py +4 -0
- hpcflow/data/scripts/env_specifier_test/main_script_test_pass_env_spec.py +8 -0
- hpcflow/data/scripts/env_specifier_test/output_file_parser_pass_env_spec.py +4 -0
- hpcflow/data/scripts/env_specifier_test/v1/input_file_generator_basic.py +4 -0
- hpcflow/data/scripts/env_specifier_test/v1/main_script_test_direct_in_direct_out.py +7 -0
- hpcflow/data/scripts/env_specifier_test/v1/output_file_parser_basic.py +4 -0
- hpcflow/data/scripts/env_specifier_test/v2/main_script_test_direct_in_direct_out.py +7 -0
- hpcflow/data/scripts/generate_t1_file_01.py +7 -0
- hpcflow/data/scripts/import_future_script.py +7 -0
- hpcflow/data/scripts/input_file_generator_basic.py +3 -0
- hpcflow/data/scripts/input_file_generator_basic_FAIL.py +3 -0
- hpcflow/data/scripts/input_file_generator_test_stdout_stderr.py +8 -0
- hpcflow/data/scripts/main_script_test_direct_in.py +3 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_2.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed_group.py +7 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_3.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_all_iters_test.py +15 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_env_spec.py +7 -0
- hpcflow/data/scripts/main_script_test_direct_in_direct_out_labels.py +8 -0
- hpcflow/data/scripts/main_script_test_direct_in_group_direct_out_3.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_in_group_one_fail_direct_out_3.py +6 -0
- hpcflow/data/scripts/main_script_test_direct_sub_param_in_direct_out.py +6 -0
- hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +12 -0
- hpcflow/data/scripts/main_script_test_hdf5_in_obj_2.py +12 -0
- hpcflow/data/scripts/main_script_test_hdf5_in_obj_group.py +12 -0
- hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +11 -0
- hpcflow/data/scripts/main_script_test_json_and_direct_in_json_out.py +14 -0
- hpcflow/data/scripts/main_script_test_json_in_json_and_direct_out.py +17 -0
- hpcflow/data/scripts/main_script_test_json_in_json_out.py +14 -0
- hpcflow/data/scripts/main_script_test_json_in_json_out_labels.py +16 -0
- hpcflow/data/scripts/main_script_test_json_in_obj.py +12 -0
- hpcflow/data/scripts/main_script_test_json_out_FAIL.py +3 -0
- hpcflow/data/scripts/main_script_test_json_out_obj.py +10 -0
- hpcflow/data/scripts/main_script_test_json_sub_param_in_json_out_labels.py +16 -0
- hpcflow/data/scripts/main_script_test_shell_env_vars.py +12 -0
- hpcflow/data/scripts/main_script_test_std_out_std_err.py +6 -0
- hpcflow/data/scripts/output_file_parser_basic.py +3 -0
- hpcflow/data/scripts/output_file_parser_basic_FAIL.py +7 -0
- hpcflow/data/scripts/output_file_parser_test_stdout_stderr.py +8 -0
- hpcflow/data/scripts/parse_t1_file_01.py +4 -0
- hpcflow/data/scripts/script_exit_test.py +5 -0
- hpcflow/data/template_components/__init__.py +1 -0
- hpcflow/data/template_components/command_files.yaml +26 -0
- hpcflow/data/template_components/environments.yaml +13 -0
- hpcflow/data/template_components/parameters.yaml +14 -0
- hpcflow/data/template_components/task_schemas.yaml +139 -0
- hpcflow/data/workflows/workflow_1.yaml +5 -0
- hpcflow/examples.ipynb +1037 -0
- hpcflow/sdk/__init__.py +149 -0
- hpcflow/sdk/app.py +4266 -0
- hpcflow/sdk/cli.py +1479 -0
- hpcflow/sdk/cli_common.py +385 -0
- hpcflow/sdk/config/__init__.py +5 -0
- hpcflow/sdk/config/callbacks.py +246 -0
- hpcflow/sdk/config/cli.py +388 -0
- hpcflow/sdk/config/config.py +1410 -0
- hpcflow/sdk/config/config_file.py +501 -0
- hpcflow/sdk/config/errors.py +272 -0
- hpcflow/sdk/config/types.py +150 -0
- hpcflow/sdk/core/__init__.py +38 -0
- hpcflow/sdk/core/actions.py +3857 -0
- hpcflow/sdk/core/app_aware.py +25 -0
- hpcflow/sdk/core/cache.py +224 -0
- hpcflow/sdk/core/command_files.py +814 -0
- hpcflow/sdk/core/commands.py +424 -0
- hpcflow/sdk/core/element.py +2071 -0
- hpcflow/sdk/core/enums.py +221 -0
- hpcflow/sdk/core/environment.py +256 -0
- hpcflow/sdk/core/errors.py +1043 -0
- hpcflow/sdk/core/execute.py +207 -0
- hpcflow/sdk/core/json_like.py +809 -0
- hpcflow/sdk/core/loop.py +1320 -0
- hpcflow/sdk/core/loop_cache.py +282 -0
- hpcflow/sdk/core/object_list.py +933 -0
- hpcflow/sdk/core/parameters.py +3371 -0
- hpcflow/sdk/core/rule.py +196 -0
- hpcflow/sdk/core/run_dir_files.py +57 -0
- hpcflow/sdk/core/skip_reason.py +7 -0
- hpcflow/sdk/core/task.py +3792 -0
- hpcflow/sdk/core/task_schema.py +993 -0
- hpcflow/sdk/core/test_utils.py +538 -0
- hpcflow/sdk/core/types.py +447 -0
- hpcflow/sdk/core/utils.py +1207 -0
- hpcflow/sdk/core/validation.py +87 -0
- hpcflow/sdk/core/values.py +477 -0
- hpcflow/sdk/core/workflow.py +4820 -0
- hpcflow/sdk/core/zarr_io.py +206 -0
- hpcflow/sdk/data/__init__.py +13 -0
- hpcflow/sdk/data/config_file_schema.yaml +34 -0
- hpcflow/sdk/data/config_schema.yaml +260 -0
- hpcflow/sdk/data/environments_spec_schema.yaml +21 -0
- hpcflow/sdk/data/files_spec_schema.yaml +5 -0
- hpcflow/sdk/data/parameters_spec_schema.yaml +7 -0
- hpcflow/sdk/data/task_schema_spec_schema.yaml +3 -0
- hpcflow/sdk/data/workflow_spec_schema.yaml +22 -0
- hpcflow/sdk/demo/__init__.py +3 -0
- hpcflow/sdk/demo/cli.py +242 -0
- hpcflow/sdk/helper/__init__.py +3 -0
- hpcflow/sdk/helper/cli.py +137 -0
- hpcflow/sdk/helper/helper.py +300 -0
- hpcflow/sdk/helper/watcher.py +192 -0
- hpcflow/sdk/log.py +288 -0
- hpcflow/sdk/persistence/__init__.py +18 -0
- hpcflow/sdk/persistence/base.py +2817 -0
- hpcflow/sdk/persistence/defaults.py +6 -0
- hpcflow/sdk/persistence/discovery.py +39 -0
- hpcflow/sdk/persistence/json.py +954 -0
- hpcflow/sdk/persistence/pending.py +948 -0
- hpcflow/sdk/persistence/store_resource.py +203 -0
- hpcflow/sdk/persistence/types.py +309 -0
- hpcflow/sdk/persistence/utils.py +73 -0
- hpcflow/sdk/persistence/zarr.py +2388 -0
- hpcflow/sdk/runtime.py +320 -0
- hpcflow/sdk/submission/__init__.py +3 -0
- hpcflow/sdk/submission/enums.py +70 -0
- hpcflow/sdk/submission/jobscript.py +2379 -0
- hpcflow/sdk/submission/schedulers/__init__.py +281 -0
- hpcflow/sdk/submission/schedulers/direct.py +233 -0
- hpcflow/sdk/submission/schedulers/sge.py +376 -0
- hpcflow/sdk/submission/schedulers/slurm.py +598 -0
- hpcflow/sdk/submission/schedulers/utils.py +25 -0
- hpcflow/sdk/submission/shells/__init__.py +52 -0
- hpcflow/sdk/submission/shells/base.py +229 -0
- hpcflow/sdk/submission/shells/bash.py +504 -0
- hpcflow/sdk/submission/shells/os_version.py +115 -0
- hpcflow/sdk/submission/shells/powershell.py +352 -0
- hpcflow/sdk/submission/submission.py +1402 -0
- hpcflow/sdk/submission/types.py +140 -0
- hpcflow/sdk/typing.py +194 -0
- hpcflow/sdk/utils/arrays.py +69 -0
- hpcflow/sdk/utils/deferred_file.py +55 -0
- hpcflow/sdk/utils/hashing.py +16 -0
- hpcflow/sdk/utils/patches.py +31 -0
- hpcflow/sdk/utils/strings.py +69 -0
- hpcflow/tests/api/test_api.py +32 -0
- hpcflow/tests/conftest.py +123 -0
- hpcflow/tests/data/__init__.py +0 -0
- hpcflow/tests/data/benchmark_N_elements.yaml +6 -0
- hpcflow/tests/data/benchmark_script_runner.yaml +26 -0
- hpcflow/tests/data/multi_path_sequences.yaml +29 -0
- hpcflow/tests/data/workflow_1.json +10 -0
- hpcflow/tests/data/workflow_1.yaml +5 -0
- hpcflow/tests/data/workflow_1_slurm.yaml +8 -0
- hpcflow/tests/data/workflow_1_wsl.yaml +8 -0
- hpcflow/tests/data/workflow_test_run_abort.yaml +42 -0
- hpcflow/tests/jinja_templates/test_jinja_templates.py +161 -0
- hpcflow/tests/programs/test_programs.py +180 -0
- hpcflow/tests/schedulers/direct_linux/test_direct_linux_submission.py +12 -0
- hpcflow/tests/schedulers/sge/test_sge_submission.py +36 -0
- hpcflow/tests/schedulers/slurm/test_slurm_submission.py +14 -0
- hpcflow/tests/scripts/test_input_file_generators.py +282 -0
- hpcflow/tests/scripts/test_main_scripts.py +1361 -0
- hpcflow/tests/scripts/test_non_snippet_script.py +46 -0
- hpcflow/tests/scripts/test_ouput_file_parsers.py +353 -0
- hpcflow/tests/shells/wsl/test_wsl_submission.py +14 -0
- hpcflow/tests/unit/test_action.py +1066 -0
- hpcflow/tests/unit/test_action_rule.py +24 -0
- hpcflow/tests/unit/test_app.py +132 -0
- hpcflow/tests/unit/test_cache.py +46 -0
- hpcflow/tests/unit/test_cli.py +172 -0
- hpcflow/tests/unit/test_command.py +377 -0
- hpcflow/tests/unit/test_config.py +195 -0
- hpcflow/tests/unit/test_config_file.py +162 -0
- hpcflow/tests/unit/test_element.py +666 -0
- hpcflow/tests/unit/test_element_iteration.py +88 -0
- hpcflow/tests/unit/test_element_set.py +158 -0
- hpcflow/tests/unit/test_group.py +115 -0
- hpcflow/tests/unit/test_input_source.py +1479 -0
- hpcflow/tests/unit/test_input_value.py +398 -0
- hpcflow/tests/unit/test_jobscript_unit.py +757 -0
- hpcflow/tests/unit/test_json_like.py +1247 -0
- hpcflow/tests/unit/test_loop.py +2674 -0
- hpcflow/tests/unit/test_meta_task.py +325 -0
- hpcflow/tests/unit/test_multi_path_sequences.py +259 -0
- hpcflow/tests/unit/test_object_list.py +116 -0
- hpcflow/tests/unit/test_parameter.py +243 -0
- hpcflow/tests/unit/test_persistence.py +664 -0
- hpcflow/tests/unit/test_resources.py +243 -0
- hpcflow/tests/unit/test_run.py +286 -0
- hpcflow/tests/unit/test_run_directories.py +29 -0
- hpcflow/tests/unit/test_runtime.py +9 -0
- hpcflow/tests/unit/test_schema_input.py +372 -0
- hpcflow/tests/unit/test_shell.py +129 -0
- hpcflow/tests/unit/test_slurm.py +39 -0
- hpcflow/tests/unit/test_submission.py +502 -0
- hpcflow/tests/unit/test_task.py +2560 -0
- hpcflow/tests/unit/test_task_schema.py +182 -0
- hpcflow/tests/unit/test_utils.py +616 -0
- hpcflow/tests/unit/test_value_sequence.py +549 -0
- hpcflow/tests/unit/test_values.py +91 -0
- hpcflow/tests/unit/test_workflow.py +827 -0
- hpcflow/tests/unit/test_workflow_template.py +186 -0
- hpcflow/tests/unit/utils/test_arrays.py +40 -0
- hpcflow/tests/unit/utils/test_deferred_file_writer.py +34 -0
- hpcflow/tests/unit/utils/test_hashing.py +65 -0
- hpcflow/tests/unit/utils/test_patches.py +5 -0
- hpcflow/tests/unit/utils/test_redirect_std.py +50 -0
- hpcflow/tests/unit/utils/test_strings.py +97 -0
- hpcflow/tests/workflows/__init__.py +0 -0
- hpcflow/tests/workflows/test_directory_structure.py +31 -0
- hpcflow/tests/workflows/test_jobscript.py +355 -0
- hpcflow/tests/workflows/test_run_status.py +198 -0
- hpcflow/tests/workflows/test_skip_downstream.py +696 -0
- hpcflow/tests/workflows/test_submission.py +140 -0
- hpcflow/tests/workflows/test_workflows.py +564 -0
- hpcflow/tests/workflows/test_zip.py +18 -0
- hpcflow/viz_demo.ipynb +6794 -0
- hpcflow-0.2.0a271.dist-info/LICENSE +375 -0
- hpcflow-0.2.0a271.dist-info/METADATA +65 -0
- hpcflow-0.2.0a271.dist-info/RECORD +237 -0
- {hpcflow-0.1.9.dist-info → hpcflow-0.2.0a271.dist-info}/WHEEL +4 -5
- hpcflow-0.2.0a271.dist-info/entry_points.txt +6 -0
- hpcflow/api.py +0 -458
- hpcflow/archive/archive.py +0 -308
- hpcflow/archive/cloud/cloud.py +0 -47
- hpcflow/archive/cloud/errors.py +0 -9
- hpcflow/archive/cloud/providers/dropbox.py +0 -432
- hpcflow/archive/errors.py +0 -5
- hpcflow/base_db.py +0 -4
- hpcflow/config.py +0 -232
- hpcflow/copytree.py +0 -66
- hpcflow/data/examples/_config.yml +0 -14
- hpcflow/data/examples/damask/demo/1.run.yml +0 -4
- hpcflow/data/examples/damask/demo/2.process.yml +0 -29
- hpcflow/data/examples/damask/demo/geom.geom +0 -2052
- hpcflow/data/examples/damask/demo/load.load +0 -1
- hpcflow/data/examples/damask/demo/material.config +0 -185
- hpcflow/data/examples/damask/inputs/geom.geom +0 -2052
- hpcflow/data/examples/damask/inputs/load.load +0 -1
- hpcflow/data/examples/damask/inputs/material.config +0 -185
- hpcflow/data/examples/damask/profiles/_variable_lookup.yml +0 -21
- hpcflow/data/examples/damask/profiles/damask.yml +0 -4
- hpcflow/data/examples/damask/profiles/damask_process.yml +0 -8
- hpcflow/data/examples/damask/profiles/damask_run.yml +0 -5
- hpcflow/data/examples/damask/profiles/default.yml +0 -6
- hpcflow/data/examples/thinking.yml +0 -177
- hpcflow/errors.py +0 -2
- hpcflow/init_db.py +0 -37
- hpcflow/models.py +0 -2549
- hpcflow/nesting.py +0 -9
- hpcflow/profiles.py +0 -455
- hpcflow/project.py +0 -81
- hpcflow/scheduler.py +0 -323
- hpcflow/utils.py +0 -103
- hpcflow/validation.py +0 -167
- hpcflow/variables.py +0 -544
- hpcflow-0.1.9.dist-info/METADATA +0 -168
- hpcflow-0.1.9.dist-info/RECORD +0 -45
- hpcflow-0.1.9.dist-info/entry_points.txt +0 -8
- hpcflow-0.1.9.dist-info/top_level.txt +0 -1
- /hpcflow/{archive → data/jinja_templates}/__init__.py +0 -0
- /hpcflow/{archive/cloud → data/programs}/__init__.py +0 -0
- /hpcflow/{archive/cloud/providers → data/workflows}/__init__.py +0 -0
hpcflow/sdk/core/loop.py
ADDED
|
@@ -0,0 +1,1320 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A looping construct for a workflow.
|
|
3
|
+
There are multiple types of loop,
|
|
4
|
+
notably looping over a set of values or until a condition holds.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
import copy
|
|
11
|
+
from pprint import pp
|
|
12
|
+
import pprint
|
|
13
|
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
|
14
|
+
from warnings import warn
|
|
15
|
+
from collections import defaultdict
|
|
16
|
+
from itertools import chain
|
|
17
|
+
from typing import cast, TYPE_CHECKING
|
|
18
|
+
from typing_extensions import override
|
|
19
|
+
|
|
20
|
+
from hpcflow.sdk.core.app_aware import AppAware
|
|
21
|
+
from hpcflow.sdk.core.actions import EARStatus
|
|
22
|
+
from hpcflow.sdk.core.skip_reason import SkipReason
|
|
23
|
+
from hpcflow.sdk.core.errors import LoopTaskSubsetError
|
|
24
|
+
from hpcflow.sdk.core.json_like import ChildObjectSpec, JSONLike
|
|
25
|
+
from hpcflow.sdk.core.loop_cache import LoopCache, LoopIndex
|
|
26
|
+
from hpcflow.sdk.core.enums import InputSourceType, TaskSourceType
|
|
27
|
+
from hpcflow.sdk.core.utils import check_valid_py_identifier, nth_key, nth_value
|
|
28
|
+
from hpcflow.sdk.utils.strings import shorten_list_str
|
|
29
|
+
from hpcflow.sdk.log import TimeIt
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
|
33
|
+
from typing import Any, ClassVar
|
|
34
|
+
from typing_extensions import Self, TypeIs
|
|
35
|
+
from rich.status import Status
|
|
36
|
+
from ..typing import DataIndex, ParamSource
|
|
37
|
+
from .parameters import SchemaInput, InputSource
|
|
38
|
+
from .rule import Rule
|
|
39
|
+
from .task import WorkflowTask
|
|
40
|
+
from .types import IterableParam
|
|
41
|
+
from .workflow import Workflow, WorkflowTemplate
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# @dataclass
|
|
45
|
+
# class StoppingCriterion:
|
|
46
|
+
# parameter: Parameter
|
|
47
|
+
# condition: ConditionLike
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# @dataclass
|
|
51
|
+
# class Loop:
|
|
52
|
+
# parameter: Parameter
|
|
53
|
+
# stopping_criteria: StoppingCriterion
|
|
54
|
+
# # TODO: should be a logical combination of these (maybe provide a superclass in valida to re-use some logic there?)
|
|
55
|
+
# maximum_iterations: int
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Loop(JSONLike):
|
|
59
|
+
"""
|
|
60
|
+
A loop in a workflow template.
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
tasks: list[int | ~hpcflow.app.WorkflowTask]
|
|
65
|
+
List of task insert IDs or workflow tasks.
|
|
66
|
+
num_iterations:
|
|
67
|
+
Number of iterations to perform.
|
|
68
|
+
name: str
|
|
69
|
+
Loop name.
|
|
70
|
+
non_iterable_parameters: list[str]
|
|
71
|
+
Specify input parameters that should not iterate.
|
|
72
|
+
termination: v~hpcflow.app.Rule
|
|
73
|
+
Stopping criterion, expressed as a rule.
|
|
74
|
+
termination_task: int | ~hpcflow.app.WorkflowTask
|
|
75
|
+
Task at which to evaluate the termination condition.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
_child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
|
|
79
|
+
ChildObjectSpec(name="termination", class_name="Rule"),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def __is_WorkflowTask(cls, value) -> TypeIs[WorkflowTask]:
|
|
84
|
+
return isinstance(value, cls._app.WorkflowTask)
|
|
85
|
+
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
tasks: Iterable[int | str | WorkflowTask],
|
|
89
|
+
num_iterations: int,
|
|
90
|
+
name: str | None = None,
|
|
91
|
+
non_iterable_parameters: list[str] | None = None,
|
|
92
|
+
termination: Rule | None = None,
|
|
93
|
+
termination_task: int | str | WorkflowTask | None = None,
|
|
94
|
+
) -> None:
|
|
95
|
+
_task_refs: list[int | str] = [
|
|
96
|
+
ref_i.insert_ID if self.__is_WorkflowTask(ref_i) else ref_i for ref_i in tasks
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
if termination_task is None:
|
|
100
|
+
_term_task_ref = _task_refs[-1] # terminate on final task by default
|
|
101
|
+
elif self.__is_WorkflowTask(termination_task):
|
|
102
|
+
_term_task_ref = termination_task.insert_ID
|
|
103
|
+
else:
|
|
104
|
+
_term_task_ref = termination_task
|
|
105
|
+
|
|
106
|
+
if _term_task_ref not in _task_refs:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"If specified, `termination_task` (provided: {termination_task!r}) must "
|
|
109
|
+
f"refer to a task that is part of the loop. Available task references "
|
|
110
|
+
f"are: {_task_refs!r}."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
self._task_refs = _task_refs
|
|
114
|
+
self._task_insert_IDs = (
|
|
115
|
+
cast("list[int]", _task_refs)
|
|
116
|
+
if all(isinstance(ref_i, int) for ref_i in _task_refs)
|
|
117
|
+
else None
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self._num_iterations = num_iterations
|
|
121
|
+
self._name = check_valid_py_identifier(name) if name else name
|
|
122
|
+
self._non_iterable_parameters = non_iterable_parameters or []
|
|
123
|
+
self._termination = termination
|
|
124
|
+
self._termination_task_ref = _term_task_ref
|
|
125
|
+
self._termination_task_insert_ID = (
|
|
126
|
+
_term_task_ref if isinstance(_term_task_ref, int) else None
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
self._workflow_template: WorkflowTemplate | None = (
|
|
130
|
+
None # assigned by parent WorkflowTemplate
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
@override
|
|
134
|
+
def _postprocess_to_dict(self, d: dict[str, Any]) -> dict[str, Any]:
|
|
135
|
+
out = super()._postprocess_to_dict(d)
|
|
136
|
+
return {k.lstrip("_"): v for k, v in out.items()}
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def _json_like_constructor(cls, json_like: dict) -> Self:
|
|
140
|
+
"""Invoked by `JSONLike.from_json_like` instead of `__init__`."""
|
|
141
|
+
if "task_insert_IDs" in json_like:
|
|
142
|
+
insert_IDs = json_like.pop("task_insert_IDs")
|
|
143
|
+
else:
|
|
144
|
+
insert_IDs = json_like.pop("tasks") # e.g. from YAML
|
|
145
|
+
|
|
146
|
+
if "termination_task_insert_ID" in json_like:
|
|
147
|
+
tt_iID = json_like.pop("termination_task_insert_ID")
|
|
148
|
+
elif "termination_task" in json_like:
|
|
149
|
+
tt_iID = json_like.pop("termination_task") # e.g. from YAML
|
|
150
|
+
else:
|
|
151
|
+
tt_iID = None
|
|
152
|
+
|
|
153
|
+
task_refs = None
|
|
154
|
+
term_task_ref = None
|
|
155
|
+
if "task_refs" in json_like:
|
|
156
|
+
task_refs = json_like.pop("task_refs")
|
|
157
|
+
if "termination_task_ref" in json_like:
|
|
158
|
+
term_task_ref = json_like.pop("termination_task_ref")
|
|
159
|
+
|
|
160
|
+
obj = cls(tasks=insert_IDs, termination_task=tt_iID, **json_like)
|
|
161
|
+
if task_refs:
|
|
162
|
+
obj._task_refs = task_refs
|
|
163
|
+
if term_task_ref is not None:
|
|
164
|
+
obj._termination_task_ref = term_task_ref
|
|
165
|
+
|
|
166
|
+
return obj
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def task_refs(self) -> tuple[int | str, ...]:
|
|
170
|
+
"""Get the list of task references (insert IDs or task unique names) that define
|
|
171
|
+
the extent of the loop."""
|
|
172
|
+
return tuple(self._task_refs)
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def task_insert_IDs(self) -> tuple[int, ...]:
|
|
176
|
+
"""Get the list of task insert_IDs that define the extent of the loop."""
|
|
177
|
+
assert self._task_insert_IDs
|
|
178
|
+
return tuple(self._task_insert_IDs)
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
def name(self) -> str | None:
|
|
182
|
+
"""
|
|
183
|
+
The name of the loop, if one was provided.
|
|
184
|
+
"""
|
|
185
|
+
return self._name
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def num_iterations(self) -> int:
|
|
189
|
+
"""
|
|
190
|
+
The number of loop iterations to do.
|
|
191
|
+
"""
|
|
192
|
+
return self._num_iterations
|
|
193
|
+
|
|
194
|
+
@property
|
|
195
|
+
def non_iterable_parameters(self) -> Sequence[str]:
|
|
196
|
+
"""
|
|
197
|
+
Which parameters are not iterable.
|
|
198
|
+
"""
|
|
199
|
+
return self._non_iterable_parameters
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def termination(self) -> Rule | None:
|
|
203
|
+
"""
|
|
204
|
+
A termination rule for the loop, if one is provided.
|
|
205
|
+
"""
|
|
206
|
+
return self._termination
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def termination_task_ref(self) -> int | str:
|
|
210
|
+
"""
|
|
211
|
+
The unique name of the task at which the loop will terminate.
|
|
212
|
+
"""
|
|
213
|
+
return self._termination_task_ref
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def termination_task_insert_ID(self) -> int:
|
|
217
|
+
"""
|
|
218
|
+
The insert ID of the task at which the loop will terminate.
|
|
219
|
+
"""
|
|
220
|
+
assert self._termination_task_insert_ID is not None
|
|
221
|
+
return self._termination_task_insert_ID
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
def termination_task(self) -> WorkflowTask:
|
|
225
|
+
"""
|
|
226
|
+
The task at which the loop will terminate.
|
|
227
|
+
"""
|
|
228
|
+
if (wt := self.workflow_template) is None:
|
|
229
|
+
raise RuntimeError(
|
|
230
|
+
"Workflow template must be assigned to retrieve task objects of the loop."
|
|
231
|
+
)
|
|
232
|
+
assert wt.workflow
|
|
233
|
+
return wt.workflow.tasks.get(insert_ID=self.termination_task_insert_ID)
|
|
234
|
+
|
|
235
|
+
@property
|
|
236
|
+
def workflow_template(self) -> WorkflowTemplate | None:
|
|
237
|
+
"""
|
|
238
|
+
The workflow template that contains this loop.
|
|
239
|
+
"""
|
|
240
|
+
return self._workflow_template
|
|
241
|
+
|
|
242
|
+
@workflow_template.setter
|
|
243
|
+
def workflow_template(self, template: WorkflowTemplate):
|
|
244
|
+
self._workflow_template = template
|
|
245
|
+
|
|
246
|
+
def __workflow(self) -> None | Workflow:
|
|
247
|
+
if (wt := self.workflow_template) is None:
|
|
248
|
+
return None
|
|
249
|
+
return wt.workflow
|
|
250
|
+
|
|
251
|
+
@property
|
|
252
|
+
def task_objects(self) -> tuple[WorkflowTask, ...]:
|
|
253
|
+
"""
|
|
254
|
+
The tasks in the loop.
|
|
255
|
+
"""
|
|
256
|
+
if not (wf := self.__workflow()):
|
|
257
|
+
raise RuntimeError(
|
|
258
|
+
"Workflow template must be assigned to retrieve task objects of the loop."
|
|
259
|
+
)
|
|
260
|
+
return tuple(wf.tasks.get(insert_ID=t_id) for t_id in self.task_insert_IDs)
|
|
261
|
+
|
|
262
|
+
def _validate_against_workflow(self, workflow: Workflow) -> None:
|
|
263
|
+
"""Validate the loop parameters against the associated workflow."""
|
|
264
|
+
|
|
265
|
+
names = workflow.get_task_unique_names(map_to_insert_ID=True)
|
|
266
|
+
if self._task_insert_IDs is None:
|
|
267
|
+
err = lambda ref_i: (
|
|
268
|
+
f"Loop {self.name + ' ' if self.name else ''}has an invalid "
|
|
269
|
+
f"task reference {ref_i!r}. Such a task does not exist in "
|
|
270
|
+
f"the associated workflow, which has task names/insert IDs: "
|
|
271
|
+
f"{names!r}."
|
|
272
|
+
)
|
|
273
|
+
self._task_insert_IDs = []
|
|
274
|
+
for ref_i in self._task_refs:
|
|
275
|
+
if isinstance(ref_i, str):
|
|
276
|
+
# reference is task unique name, so retrieve the insert ID:
|
|
277
|
+
try:
|
|
278
|
+
iID = names[ref_i]
|
|
279
|
+
except KeyError:
|
|
280
|
+
raise ValueError(err(ref_i)) from None
|
|
281
|
+
else:
|
|
282
|
+
try:
|
|
283
|
+
workflow.tasks.get(insert_ID=ref_i)
|
|
284
|
+
except ValueError:
|
|
285
|
+
raise ValueError(err(ref_i)) from None
|
|
286
|
+
iID = ref_i
|
|
287
|
+
self._task_insert_IDs.append(iID)
|
|
288
|
+
|
|
289
|
+
if self._termination_task_insert_ID is None:
|
|
290
|
+
tt_ref = self._termination_task_ref
|
|
291
|
+
if isinstance(tt_ref, str):
|
|
292
|
+
tt_iID = names[tt_ref]
|
|
293
|
+
else:
|
|
294
|
+
tt_iID = tt_ref
|
|
295
|
+
self._termination_task_insert_ID = tt_iID
|
|
296
|
+
|
|
297
|
+
def __repr__(self) -> str:
|
|
298
|
+
num_iterations_str = ""
|
|
299
|
+
if self.num_iterations is not None:
|
|
300
|
+
num_iterations_str = f", num_iterations={self.num_iterations!r}"
|
|
301
|
+
|
|
302
|
+
name_str = ""
|
|
303
|
+
if self.name:
|
|
304
|
+
name_str = f", name={self.name!r}"
|
|
305
|
+
|
|
306
|
+
return (
|
|
307
|
+
f"{self.__class__.__name__}("
|
|
308
|
+
f"task_insert_IDs={self.task_insert_IDs!r}{num_iterations_str}{name_str}"
|
|
309
|
+
f")"
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
def __deepcopy__(self, memo: dict[int, Any]) -> Self:
|
|
313
|
+
kwargs = self.to_dict()
|
|
314
|
+
kwargs["tasks"] = kwargs.pop("task_insert_IDs")
|
|
315
|
+
kwargs["termination_task"] = kwargs.pop("termination_task_insert_ID")
|
|
316
|
+
task_refs = kwargs.pop("task_refs", None)
|
|
317
|
+
term_task_ref = kwargs.pop("termination_task_ref", None)
|
|
318
|
+
obj = self.__class__(**copy.deepcopy(kwargs, memo))
|
|
319
|
+
obj._workflow_template = self._workflow_template
|
|
320
|
+
obj._task_refs = task_refs
|
|
321
|
+
obj._termination_task_ref = term_task_ref
|
|
322
|
+
return obj
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class WorkflowLoop(AppAware):
|
|
326
|
+
"""
|
|
327
|
+
Class to represent a :py:class:`.Loop` that is bound to a
|
|
328
|
+
:py:class:`~hpcflow.app.Workflow`.
|
|
329
|
+
|
|
330
|
+
Parameters
|
|
331
|
+
----------
|
|
332
|
+
index: int
|
|
333
|
+
The index of this loop in the workflow.
|
|
334
|
+
workflow: ~hpcflow.app.Workflow
|
|
335
|
+
The workflow containing this loop.
|
|
336
|
+
template: Loop
|
|
337
|
+
The loop that this was generated from.
|
|
338
|
+
num_added_iterations:
|
|
339
|
+
Description of what iterations have been added.
|
|
340
|
+
iterable_parameters:
|
|
341
|
+
Description of what parameters are being iterated over.
|
|
342
|
+
output_parameters:
|
|
343
|
+
Description of what parameter are output from this loop, and the final task insert
|
|
344
|
+
ID from which they are output.
|
|
345
|
+
parents: list[str]
|
|
346
|
+
The paths to the parent entities of this loop.
|
|
347
|
+
"""
|
|
348
|
+
|
|
349
|
+
def __init__(
|
|
350
|
+
self,
|
|
351
|
+
index: int,
|
|
352
|
+
workflow: Workflow,
|
|
353
|
+
template: Loop,
|
|
354
|
+
num_added_iterations: dict[tuple[int, ...], int],
|
|
355
|
+
iterable_parameters: dict[str, IterableParam],
|
|
356
|
+
output_parameters: dict[str, int],
|
|
357
|
+
parents: list[str],
|
|
358
|
+
) -> None:
|
|
359
|
+
self._index = index
|
|
360
|
+
self._workflow = workflow
|
|
361
|
+
self._template = template
|
|
362
|
+
self._num_added_iterations = num_added_iterations
|
|
363
|
+
self._iterable_parameters = iterable_parameters
|
|
364
|
+
self._output_parameters = output_parameters
|
|
365
|
+
self._parents = parents
|
|
366
|
+
|
|
367
|
+
# appended to when adding an empty loop to the workflow that is a parent of this
|
|
368
|
+
# loop; reset and added to `self._parents` on dump to disk:
|
|
369
|
+
self._pending_parents: list[str] = []
|
|
370
|
+
|
|
371
|
+
# used for `num_added_iterations` when a new loop iteration is added, or when
|
|
372
|
+
# parents are append to; reset to None on dump to disk. Each key is a tuple of
|
|
373
|
+
# parent loop indices and each value is the number of pending new iterations:
|
|
374
|
+
self._pending_num_added_iterations: dict[tuple[int, ...], int] | None = None
|
|
375
|
+
|
|
376
|
+
self._validate()
|
|
377
|
+
|
|
378
|
+
@TimeIt.decorator
|
|
379
|
+
def _validate(self) -> None:
|
|
380
|
+
# task subset must be a contiguous range of task indices:
|
|
381
|
+
task_indices = self.task_indices
|
|
382
|
+
task_min, task_max = task_indices[0], task_indices[-1]
|
|
383
|
+
if task_indices != tuple(range(task_min, task_max + 1)):
|
|
384
|
+
raise LoopTaskSubsetError(self.name, self.task_indices)
|
|
385
|
+
|
|
386
|
+
def __repr__(self) -> str:
|
|
387
|
+
return (
|
|
388
|
+
f"{self.__class__.__name__}(template={self.template!r}, "
|
|
389
|
+
f"num_added_iterations={self.num_added_iterations!r})"
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
@property
|
|
393
|
+
def num_added_iterations(self) -> Mapping[tuple[int, ...], int]:
|
|
394
|
+
"""
|
|
395
|
+
The number of added iterations.
|
|
396
|
+
"""
|
|
397
|
+
if self._pending_num_added_iterations:
|
|
398
|
+
return self._pending_num_added_iterations
|
|
399
|
+
else:
|
|
400
|
+
return self._num_added_iterations
|
|
401
|
+
|
|
402
|
+
@property
|
|
403
|
+
def __pending(self) -> dict[tuple[int, ...], int]:
|
|
404
|
+
if not self._pending_num_added_iterations:
|
|
405
|
+
self._pending_num_added_iterations = dict(self._num_added_iterations)
|
|
406
|
+
return self._pending_num_added_iterations
|
|
407
|
+
|
|
408
|
+
def _initialise_pending_added_iters(self, added_iters: Iterable[int]):
|
|
409
|
+
if not self._pending_num_added_iterations:
|
|
410
|
+
self._pending_num_added_iterations = dict(self._num_added_iterations)
|
|
411
|
+
if (added_iters_key := tuple(added_iters)) not in (pending := self.__pending):
|
|
412
|
+
pending[added_iters_key] = 1
|
|
413
|
+
|
|
414
|
+
def _increment_pending_added_iters(self, added_iters_key: Iterable[int]):
|
|
415
|
+
self.__pending[tuple(added_iters_key)] += 1
|
|
416
|
+
|
|
417
|
+
def _update_parents(self, parent: WorkflowLoop):
|
|
418
|
+
assert parent.name
|
|
419
|
+
self._pending_parents.append(parent.name)
|
|
420
|
+
|
|
421
|
+
self._pending_num_added_iterations = {
|
|
422
|
+
(*k, 0): v
|
|
423
|
+
for k, v in (
|
|
424
|
+
self._pending_num_added_iterations or self._num_added_iterations
|
|
425
|
+
).items()
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
self.workflow._store.update_loop_parents(
|
|
429
|
+
index=self.index,
|
|
430
|
+
num_added_iters=self.num_added_iterations,
|
|
431
|
+
parents=self.parents,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
def _reset_pending_num_added_iters(self) -> None:
|
|
435
|
+
self._pending_num_added_iterations = None
|
|
436
|
+
|
|
437
|
+
def _accept_pending_num_added_iters(self) -> None:
|
|
438
|
+
if self._pending_num_added_iterations:
|
|
439
|
+
self._num_added_iterations = dict(self._pending_num_added_iterations)
|
|
440
|
+
self._reset_pending_num_added_iters()
|
|
441
|
+
|
|
442
|
+
def _reset_pending_parents(self) -> None:
|
|
443
|
+
self._pending_parents = []
|
|
444
|
+
|
|
445
|
+
def _accept_pending_parents(self) -> None:
|
|
446
|
+
self._parents += self._pending_parents
|
|
447
|
+
self._reset_pending_parents()
|
|
448
|
+
|
|
449
|
+
@property
|
|
450
|
+
def index(self) -> int:
|
|
451
|
+
"""
|
|
452
|
+
The index of this loop within its workflow.
|
|
453
|
+
"""
|
|
454
|
+
return self._index
|
|
455
|
+
|
|
456
|
+
@property
|
|
457
|
+
def task_refs(self) -> tuple[int | str, ...]:
|
|
458
|
+
"""
|
|
459
|
+
The list of task references (insert IDs or task unique names) that define the
|
|
460
|
+
extent of the loop."""
|
|
461
|
+
return self.template.task_refs
|
|
462
|
+
|
|
463
|
+
@property
|
|
464
|
+
def task_insert_IDs(self) -> tuple[int, ...]:
|
|
465
|
+
"""
|
|
466
|
+
The insertion IDs of the tasks inside this loop.
|
|
467
|
+
"""
|
|
468
|
+
return self.template.task_insert_IDs
|
|
469
|
+
|
|
470
|
+
@property
|
|
471
|
+
def task_objects(self) -> tuple[WorkflowTask, ...]:
|
|
472
|
+
"""
|
|
473
|
+
The tasks in this loop.
|
|
474
|
+
"""
|
|
475
|
+
return self.template.task_objects
|
|
476
|
+
|
|
477
|
+
@property
|
|
478
|
+
def task_indices(self) -> tuple[int, ...]:
|
|
479
|
+
"""
|
|
480
|
+
The list of task indices that define the extent of the loop.
|
|
481
|
+
"""
|
|
482
|
+
return tuple(task.index for task in self.task_objects)
|
|
483
|
+
|
|
484
|
+
@property
|
|
485
|
+
def workflow(self) -> Workflow:
|
|
486
|
+
"""
|
|
487
|
+
The workflow containing this loop.
|
|
488
|
+
"""
|
|
489
|
+
return self._workflow
|
|
490
|
+
|
|
491
|
+
@property
|
|
492
|
+
def template(self) -> Loop:
|
|
493
|
+
"""
|
|
494
|
+
The loop template for this loop.
|
|
495
|
+
"""
|
|
496
|
+
return self._template
|
|
497
|
+
|
|
498
|
+
@property
|
|
499
|
+
def parents(self) -> Sequence[str]:
|
|
500
|
+
"""
|
|
501
|
+
The parents of this loop.
|
|
502
|
+
"""
|
|
503
|
+
return self._parents + self._pending_parents
|
|
504
|
+
|
|
505
|
+
@property
|
|
506
|
+
def name(self) -> str:
|
|
507
|
+
"""
|
|
508
|
+
The name of this loop, if one is defined.
|
|
509
|
+
"""
|
|
510
|
+
assert self.template.name
|
|
511
|
+
return self.template.name
|
|
512
|
+
|
|
513
|
+
@property
|
|
514
|
+
def iterable_parameters(self) -> dict[str, IterableParam]:
|
|
515
|
+
"""
|
|
516
|
+
The parameters that are being iterated over.
|
|
517
|
+
"""
|
|
518
|
+
return self._iterable_parameters
|
|
519
|
+
|
|
520
|
+
@property
|
|
521
|
+
def output_parameters(self) -> dict[str, int]:
|
|
522
|
+
"""
|
|
523
|
+
The parameters that are outputs of this loop, and the final task insert ID from
|
|
524
|
+
which each parameter is output.
|
|
525
|
+
"""
|
|
526
|
+
return self._output_parameters
|
|
527
|
+
|
|
528
|
+
@property
|
|
529
|
+
def num_iterations(self) -> int:
|
|
530
|
+
"""
|
|
531
|
+
The number of iterations.
|
|
532
|
+
"""
|
|
533
|
+
return self.template.num_iterations
|
|
534
|
+
|
|
535
|
+
@property
|
|
536
|
+
def downstream_tasks(self) -> Iterator[WorkflowTask]:
|
|
537
|
+
"""Tasks that are not part of the loop, and downstream from this loop."""
|
|
538
|
+
tasks = self.workflow.tasks
|
|
539
|
+
for idx in range(self.task_objects[-1].index + 1, len(tasks)):
|
|
540
|
+
yield tasks[idx]
|
|
541
|
+
|
|
542
|
+
@property
|
|
543
|
+
def upstream_tasks(self) -> Iterator[WorkflowTask]:
|
|
544
|
+
"""Tasks that are not part of the loop, and upstream from this loop."""
|
|
545
|
+
tasks = self.workflow.tasks
|
|
546
|
+
for idx in range(0, self.task_objects[0].index):
|
|
547
|
+
yield tasks[idx]
|
|
548
|
+
|
|
549
|
+
@staticmethod
|
|
550
|
+
@TimeIt.decorator
|
|
551
|
+
def _find_iterable_and_output_parameters(
|
|
552
|
+
loop_template: Loop,
|
|
553
|
+
) -> tuple[dict[str, IterableParam], dict[str, int]]:
|
|
554
|
+
all_inputs_first_idx: dict[str, int] = {}
|
|
555
|
+
all_outputs_idx: dict[str, list[int]] = defaultdict(list)
|
|
556
|
+
for task in loop_template.task_objects:
|
|
557
|
+
for typ in task.template.all_schema_input_types:
|
|
558
|
+
all_inputs_first_idx.setdefault(typ, task.insert_ID)
|
|
559
|
+
for typ in task.template.all_schema_output_types:
|
|
560
|
+
all_outputs_idx[typ].append(task.insert_ID)
|
|
561
|
+
|
|
562
|
+
# find input parameters that are also output parameters at a later/same task:
|
|
563
|
+
iterable_params: dict[str, IterableParam] = {}
|
|
564
|
+
for typ, first_idx in all_inputs_first_idx.items():
|
|
565
|
+
if typ in all_outputs_idx and first_idx <= all_outputs_idx[typ][0]:
|
|
566
|
+
iterable_params[typ] = {
|
|
567
|
+
"input_task": first_idx,
|
|
568
|
+
"output_tasks": all_outputs_idx[typ],
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
for non_iter in loop_template.non_iterable_parameters:
|
|
572
|
+
iterable_params.pop(non_iter, None)
|
|
573
|
+
|
|
574
|
+
final_out_tasks = {k: v[-1] for k, v in all_outputs_idx.items()}
|
|
575
|
+
|
|
576
|
+
return iterable_params, final_out_tasks
|
|
577
|
+
|
|
578
|
+
@classmethod
|
|
579
|
+
@TimeIt.decorator
|
|
580
|
+
def new_empty_loop(
|
|
581
|
+
cls,
|
|
582
|
+
index: int,
|
|
583
|
+
workflow: Workflow,
|
|
584
|
+
template: Loop,
|
|
585
|
+
iter_loop_idx: Sequence[Mapping[str, int]],
|
|
586
|
+
) -> WorkflowLoop:
|
|
587
|
+
"""
|
|
588
|
+
Make a new empty loop.
|
|
589
|
+
|
|
590
|
+
Parameters
|
|
591
|
+
----------
|
|
592
|
+
index: int
|
|
593
|
+
The index of the loop to create.
|
|
594
|
+
workflow: ~hpcflow.app.Workflow
|
|
595
|
+
The workflow that will contain the loop.
|
|
596
|
+
template: Loop
|
|
597
|
+
The template for the loop.
|
|
598
|
+
iter_loop_idx: list[dict]
|
|
599
|
+
Iteration information from parent loops.
|
|
600
|
+
"""
|
|
601
|
+
parent_loops = cls._get_parent_loops(index, workflow, template)
|
|
602
|
+
parent_names = [i.name for i in parent_loops]
|
|
603
|
+
num_added_iters: dict[tuple[int, ...], int] = {}
|
|
604
|
+
for i in iter_loop_idx:
|
|
605
|
+
num_added_iters[tuple([i[j] for j in parent_names])] = 1
|
|
606
|
+
|
|
607
|
+
iter_params, out_params = cls._find_iterable_and_output_parameters(template)
|
|
608
|
+
return cls(
|
|
609
|
+
index=index,
|
|
610
|
+
workflow=workflow,
|
|
611
|
+
template=template,
|
|
612
|
+
num_added_iterations=num_added_iters,
|
|
613
|
+
iterable_parameters=iter_params,
|
|
614
|
+
output_parameters=out_params,
|
|
615
|
+
parents=parent_names,
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
@classmethod
|
|
619
|
+
@TimeIt.decorator
|
|
620
|
+
def _get_parent_loops(
|
|
621
|
+
cls,
|
|
622
|
+
index: int,
|
|
623
|
+
workflow: Workflow,
|
|
624
|
+
template: Loop,
|
|
625
|
+
) -> list[WorkflowLoop]:
|
|
626
|
+
parents: list[WorkflowLoop] = []
|
|
627
|
+
passed_self = False
|
|
628
|
+
self_tasks = set(template.task_insert_IDs)
|
|
629
|
+
for loop_i in workflow.loops:
|
|
630
|
+
if loop_i.index == index:
|
|
631
|
+
passed_self = True
|
|
632
|
+
continue
|
|
633
|
+
other_tasks = set(loop_i.task_insert_IDs)
|
|
634
|
+
if self_tasks.issubset(other_tasks):
|
|
635
|
+
if (self_tasks == other_tasks) and not passed_self:
|
|
636
|
+
continue
|
|
637
|
+
parents.append(loop_i)
|
|
638
|
+
return parents
|
|
639
|
+
|
|
640
|
+
@TimeIt.decorator
|
|
641
|
+
def get_parent_loops(self) -> list[WorkflowLoop]:
|
|
642
|
+
"""Get loops whose task subset is a superset of this loop's task subset. If two
|
|
643
|
+
loops have identical task subsets, the first loop in the workflow loop list is
|
|
644
|
+
considered the child."""
|
|
645
|
+
return self._get_parent_loops(self.index, self.workflow, self.template)
|
|
646
|
+
|
|
647
|
+
@TimeIt.decorator
|
|
648
|
+
def get_child_loops(self) -> list[WorkflowLoop]:
|
|
649
|
+
"""Get loops whose task subset is a subset of this loop's task subset. If two
|
|
650
|
+
loops have identical task subsets, the first loop in the workflow loop list is
|
|
651
|
+
considered the child."""
|
|
652
|
+
children: list[WorkflowLoop] = []
|
|
653
|
+
passed_self = False
|
|
654
|
+
self_tasks = set(self.task_insert_IDs)
|
|
655
|
+
for loop_i in self.workflow.loops:
|
|
656
|
+
if loop_i.index == self.index:
|
|
657
|
+
passed_self = True
|
|
658
|
+
continue
|
|
659
|
+
other_tasks = set(loop_i.task_insert_IDs)
|
|
660
|
+
if self_tasks.issuperset(other_tasks):
|
|
661
|
+
if (self_tasks == other_tasks) and passed_self:
|
|
662
|
+
continue
|
|
663
|
+
children.append(loop_i)
|
|
664
|
+
|
|
665
|
+
# order by depth, so direct child is first:
|
|
666
|
+
return sorted(children, key=lambda x: len(next(iter(x.num_added_iterations))))
|
|
667
|
+
|
|
668
|
+
@TimeIt.decorator
|
|
669
|
+
def add_iteration(
|
|
670
|
+
self,
|
|
671
|
+
parent_loop_indices: Mapping[str, int] | None = None,
|
|
672
|
+
cache: LoopCache | None = None,
|
|
673
|
+
status: Status | None = None,
|
|
674
|
+
) -> None:
|
|
675
|
+
"""
|
|
676
|
+
Add an iteration to this loop.
|
|
677
|
+
|
|
678
|
+
Parameters
|
|
679
|
+
----------
|
|
680
|
+
parent_loop_indices:
|
|
681
|
+
Where have any parent loops got up to?
|
|
682
|
+
cache:
|
|
683
|
+
A cache used to make adding the iteration more efficient.
|
|
684
|
+
One will be created if it is not supplied.
|
|
685
|
+
"""
|
|
686
|
+
if not cache:
|
|
687
|
+
cache = LoopCache.build(self.workflow)
|
|
688
|
+
assert cache is not None
|
|
689
|
+
parent_loops = self.get_parent_loops()
|
|
690
|
+
child_loops = self.get_child_loops()
|
|
691
|
+
parent_loop_indices_ = parent_loop_indices or {
|
|
692
|
+
loop.name: 0 for loop in parent_loops
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
iters_key = tuple(parent_loop_indices_[p_nm] for p_nm in self.parents)
|
|
696
|
+
cur_loop_idx = self.num_added_iterations[iters_key] - 1
|
|
697
|
+
|
|
698
|
+
# keys are (task.insert_ID and element.index)
|
|
699
|
+
all_new_data_idx: dict[tuple[int, int], DataIndex] = {}
|
|
700
|
+
|
|
701
|
+
# initialise a new `num_added_iterations` key on each child loop:
|
|
702
|
+
iters_key_dct = {
|
|
703
|
+
**parent_loop_indices_,
|
|
704
|
+
self.name: cur_loop_idx + 1,
|
|
705
|
+
}
|
|
706
|
+
for child in child_loops:
|
|
707
|
+
child._initialise_pending_added_iters(
|
|
708
|
+
iters_key_dct.get(j, 0) for j in child.parents
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
# needed for the case where an inner loop has only one iteration, meaning
|
|
712
|
+
# `add_iteration` will not be called recursively on it:
|
|
713
|
+
self.workflow._store.update_loop_num_iters(
|
|
714
|
+
index=child.index,
|
|
715
|
+
num_added_iters=child.num_added_iterations,
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
for task in self.task_objects:
|
|
719
|
+
new_loop_idx = LoopIndex(iters_key_dct) + {
|
|
720
|
+
child.name: 0
|
|
721
|
+
for child in child_loops
|
|
722
|
+
if task.insert_ID in child.task_insert_IDs
|
|
723
|
+
}
|
|
724
|
+
added_iter_IDs: list[int] = []
|
|
725
|
+
for elem_idx in range(task.num_elements):
|
|
726
|
+
elem_ID = task.element_IDs[elem_idx]
|
|
727
|
+
|
|
728
|
+
new_data_idx: DataIndex = {}
|
|
729
|
+
|
|
730
|
+
# copy resources from zeroth iteration:
|
|
731
|
+
zeroth_iter_ID, zi_iter_data_idx = cache.zeroth_iters[elem_ID]
|
|
732
|
+
zi_elem_ID, zi_idx = cache.iterations[zeroth_iter_ID]
|
|
733
|
+
zi_data_idx = nth_value(cache.data_idx[zi_elem_ID], zi_idx)
|
|
734
|
+
|
|
735
|
+
for key, val in zi_data_idx.items():
|
|
736
|
+
if key.startswith("resources."):
|
|
737
|
+
new_data_idx[key] = val
|
|
738
|
+
|
|
739
|
+
for inp in task.template.all_schema_inputs:
|
|
740
|
+
is_inp_task = False
|
|
741
|
+
if iter_dat := self.iterable_parameters.get(inp.typ):
|
|
742
|
+
is_inp_task = task.insert_ID == iter_dat["input_task"]
|
|
743
|
+
|
|
744
|
+
inp_key = f"inputs.{inp.typ}"
|
|
745
|
+
|
|
746
|
+
if is_inp_task:
|
|
747
|
+
assert iter_dat is not None
|
|
748
|
+
inp_dat_idx = self.__get_looped_index(
|
|
749
|
+
task,
|
|
750
|
+
elem_ID,
|
|
751
|
+
cache,
|
|
752
|
+
iter_dat,
|
|
753
|
+
inp,
|
|
754
|
+
parent_loops,
|
|
755
|
+
parent_loop_indices_,
|
|
756
|
+
child_loops,
|
|
757
|
+
cur_loop_idx,
|
|
758
|
+
)
|
|
759
|
+
new_data_idx[inp_key] = inp_dat_idx
|
|
760
|
+
else:
|
|
761
|
+
orig_inp_src = cache.elements[elem_ID]["input_sources"][inp_key]
|
|
762
|
+
inp_dat_idx = None
|
|
763
|
+
|
|
764
|
+
if orig_inp_src.source_type is InputSourceType.LOCAL:
|
|
765
|
+
# keep locally defined inputs from original element
|
|
766
|
+
inp_dat_idx = zi_data_idx[inp_key]
|
|
767
|
+
|
|
768
|
+
elif orig_inp_src.source_type is InputSourceType.DEFAULT:
|
|
769
|
+
# keep default value from original element
|
|
770
|
+
try:
|
|
771
|
+
inp_dat_idx = zi_data_idx[inp_key]
|
|
772
|
+
except KeyError:
|
|
773
|
+
# if this input is required by a conditional action, and
|
|
774
|
+
# that condition is not met, then this input will not
|
|
775
|
+
# exist in the action-run data index, so use the initial
|
|
776
|
+
# iteration data index:
|
|
777
|
+
inp_dat_idx = zi_iter_data_idx[inp_key]
|
|
778
|
+
|
|
779
|
+
elif orig_inp_src.source_type is InputSourceType.TASK:
|
|
780
|
+
inp_dat_idx = self.__get_task_index(
|
|
781
|
+
task,
|
|
782
|
+
orig_inp_src,
|
|
783
|
+
cache,
|
|
784
|
+
elem_ID,
|
|
785
|
+
inp,
|
|
786
|
+
inp_key,
|
|
787
|
+
parent_loop_indices_,
|
|
788
|
+
all_new_data_idx,
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
if inp_dat_idx is None:
|
|
792
|
+
raise RuntimeError(
|
|
793
|
+
f"Could not find a source for parameter {inp.typ} "
|
|
794
|
+
f"when adding a new iteration for task {task!r}."
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
new_data_idx[inp_key] = inp_dat_idx
|
|
798
|
+
|
|
799
|
+
# add any locally defined sub-parameters:
|
|
800
|
+
inp_statuses = cache.elements[elem_ID]["input_statuses"]
|
|
801
|
+
inp_status_inps = set(f"inputs.{inp}" for inp in inp_statuses)
|
|
802
|
+
for sub_param_i in inp_status_inps.difference(new_data_idx):
|
|
803
|
+
sub_param_data_idx_iter_0 = zi_data_idx
|
|
804
|
+
try:
|
|
805
|
+
sub_param_data_idx = sub_param_data_idx_iter_0[sub_param_i]
|
|
806
|
+
except KeyError:
|
|
807
|
+
# as before, if this input is required by a conditional action,
|
|
808
|
+
# and that condition is not met, then this input will not exist in
|
|
809
|
+
# the action-run data index, so use the initial iteration data
|
|
810
|
+
# index:
|
|
811
|
+
sub_param_data_idx = zi_data_idx[sub_param_i]
|
|
812
|
+
|
|
813
|
+
new_data_idx[sub_param_i] = sub_param_data_idx
|
|
814
|
+
|
|
815
|
+
for out in task.template.all_schema_outputs:
|
|
816
|
+
path_i = f"outputs.{out.typ}"
|
|
817
|
+
p_src: ParamSource = {"type": "EAR_output"}
|
|
818
|
+
new_data_idx[path_i] = self.workflow._add_unset_parameter_data(p_src)
|
|
819
|
+
|
|
820
|
+
schema_params = set(i for i in new_data_idx if len(i.split(".")) == 2)
|
|
821
|
+
all_new_data_idx[task.insert_ID, elem_idx] = new_data_idx
|
|
822
|
+
|
|
823
|
+
iter_ID_i = self.workflow._store.add_element_iteration(
|
|
824
|
+
element_ID=elem_ID,
|
|
825
|
+
data_idx=new_data_idx,
|
|
826
|
+
schema_parameters=list(schema_params),
|
|
827
|
+
loop_idx=new_loop_idx,
|
|
828
|
+
)
|
|
829
|
+
if cache:
|
|
830
|
+
cache.add_iteration(
|
|
831
|
+
iter_ID=iter_ID_i,
|
|
832
|
+
task_insert_ID=task.insert_ID,
|
|
833
|
+
element_ID=elem_ID,
|
|
834
|
+
loop_idx=new_loop_idx,
|
|
835
|
+
data_idx=new_data_idx,
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
added_iter_IDs.append(iter_ID_i)
|
|
839
|
+
|
|
840
|
+
task.initialise_EARs(iter_IDs=added_iter_IDs)
|
|
841
|
+
|
|
842
|
+
self._increment_pending_added_iters(
|
|
843
|
+
parent_loop_indices_[p_nm] for p_nm in self.parents
|
|
844
|
+
)
|
|
845
|
+
self.workflow._store.update_loop_num_iters(
|
|
846
|
+
index=self.index,
|
|
847
|
+
num_added_iters=self.num_added_iterations,
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
# add iterations to fixed-number-iteration children only:
|
|
851
|
+
for child in child_loops[::-1]:
|
|
852
|
+
if child.num_iterations is not None:
|
|
853
|
+
if status:
|
|
854
|
+
status_prev = str(status.status).rstrip(".")
|
|
855
|
+
for iter_idx in range(child.num_iterations - 1):
|
|
856
|
+
if status:
|
|
857
|
+
status.update(
|
|
858
|
+
f"{status_prev} --> ({child.name!r}): iteration "
|
|
859
|
+
f"{iter_idx + 2}/{child.num_iterations}."
|
|
860
|
+
)
|
|
861
|
+
par_idx = {parent_name: 0 for parent_name in child.parents}
|
|
862
|
+
if parent_loop_indices:
|
|
863
|
+
par_idx.update(parent_loop_indices)
|
|
864
|
+
par_idx[self.name] = cur_loop_idx + 1
|
|
865
|
+
child.add_iteration(parent_loop_indices=par_idx, cache=cache)
|
|
866
|
+
|
|
867
|
+
self.__update_loop_downstream_data_idx(parent_loop_indices_)
|
|
868
|
+
|
|
869
|
+
def __get_src_ID_and_groups(
|
|
870
|
+
self,
|
|
871
|
+
elem_ID: int,
|
|
872
|
+
iter_dat: IterableParam,
|
|
873
|
+
inp: SchemaInput,
|
|
874
|
+
cache: LoopCache,
|
|
875
|
+
task: WorkflowTask,
|
|
876
|
+
) -> tuple[int, Sequence[int]]:
|
|
877
|
+
# `cache.elements` contains only elements that are part of the
|
|
878
|
+
# loop, so indexing a dependent element may raise:
|
|
879
|
+
src_elem_IDs = {}
|
|
880
|
+
for k, v in cache.element_dependents[elem_ID].items():
|
|
881
|
+
try:
|
|
882
|
+
if cache.elements[k]["task_insert_ID"] == iter_dat["output_tasks"][-1]:
|
|
883
|
+
src_elem_IDs[k] = v
|
|
884
|
+
except KeyError:
|
|
885
|
+
continue
|
|
886
|
+
|
|
887
|
+
# consider groups
|
|
888
|
+
single_data = inp.single_labelled_data
|
|
889
|
+
assert single_data is not None
|
|
890
|
+
inp_group_name = single_data.get("group")
|
|
891
|
+
grouped_elems = [
|
|
892
|
+
src_elem_j_ID
|
|
893
|
+
for src_elem_j_ID, src_elem_j_dat in src_elem_IDs.items()
|
|
894
|
+
if any(nm == inp_group_name for nm in src_elem_j_dat["group_names"])
|
|
895
|
+
]
|
|
896
|
+
|
|
897
|
+
if not grouped_elems and len(src_elem_IDs) > 1:
|
|
898
|
+
raise NotImplementedError(
|
|
899
|
+
f"Multiple elements found in the iterable parameter "
|
|
900
|
+
f"{inp!r}'s latest output task (insert ID: "
|
|
901
|
+
f"{iter_dat['output_tasks'][-1]}) that can be used "
|
|
902
|
+
f"to parametrise the next iteration of task "
|
|
903
|
+
f"{task.unique_name!r}: "
|
|
904
|
+
f"{list(src_elem_IDs)!r}."
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
elif not src_elem_IDs:
|
|
908
|
+
# TODO: maybe OK?
|
|
909
|
+
raise NotImplementedError(
|
|
910
|
+
f"No elements found in the iterable parameter "
|
|
911
|
+
f"{inp!r}'s latest output task (insert ID: "
|
|
912
|
+
f"{iter_dat['output_tasks'][-1]}) that can be used "
|
|
913
|
+
f"to parametrise the next iteration."
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
return nth_key(src_elem_IDs, 0), grouped_elems
|
|
917
|
+
|
|
918
|
+
def __get_looped_index(
|
|
919
|
+
self,
|
|
920
|
+
task: WorkflowTask,
|
|
921
|
+
elem_ID: int,
|
|
922
|
+
cache: LoopCache,
|
|
923
|
+
iter_dat: IterableParam,
|
|
924
|
+
inp: SchemaInput,
|
|
925
|
+
parent_loops: list[WorkflowLoop],
|
|
926
|
+
parent_loop_indices: Mapping[str, int],
|
|
927
|
+
child_loops: list[WorkflowLoop],
|
|
928
|
+
cur_loop_idx: int,
|
|
929
|
+
):
|
|
930
|
+
# source from final output task of previous iteration, with all parent
|
|
931
|
+
# loop indices the same as previous iteration, and all child loop indices
|
|
932
|
+
# maximised:
|
|
933
|
+
|
|
934
|
+
# identify element(s) from which this iterable input should be
|
|
935
|
+
# parametrised:
|
|
936
|
+
if task.insert_ID == iter_dat["output_tasks"][-1]:
|
|
937
|
+
# single-task loop
|
|
938
|
+
src_elem_ID = elem_ID
|
|
939
|
+
grouped_elems: Sequence[int] = []
|
|
940
|
+
else:
|
|
941
|
+
# multi-task loop
|
|
942
|
+
src_elem_ID, grouped_elems = self.__get_src_ID_and_groups(
|
|
943
|
+
elem_ID, iter_dat, inp, cache, task
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
child_loop_max_iters: dict[str, int] = {}
|
|
947
|
+
parent_loop_same_iters = {
|
|
948
|
+
loop.name: parent_loop_indices[loop.name] for loop in parent_loops
|
|
949
|
+
}
|
|
950
|
+
child_iter_parents = {
|
|
951
|
+
**parent_loop_same_iters,
|
|
952
|
+
self.name: cur_loop_idx,
|
|
953
|
+
}
|
|
954
|
+
for loop in child_loops:
|
|
955
|
+
if iter_dat["output_tasks"][-1] in loop.task_insert_IDs:
|
|
956
|
+
i_num_iters = loop.num_added_iterations[
|
|
957
|
+
tuple(child_iter_parents[j] for j in loop.parents)
|
|
958
|
+
]
|
|
959
|
+
i_max = i_num_iters - 1
|
|
960
|
+
child_iter_parents[loop.name] = i_max
|
|
961
|
+
child_loop_max_iters[loop.name] = i_max
|
|
962
|
+
|
|
963
|
+
loop_idx_key = LoopIndex(child_loop_max_iters)
|
|
964
|
+
loop_idx_key.update(parent_loop_same_iters)
|
|
965
|
+
loop_idx_key[self.name] = cur_loop_idx
|
|
966
|
+
|
|
967
|
+
# identify the ElementIteration from which this input should be
|
|
968
|
+
# parametrised:
|
|
969
|
+
if grouped_elems:
|
|
970
|
+
src_data_idx = [
|
|
971
|
+
cache.data_idx[src_elem_ID][loop_idx_key] for src_elem_ID in grouped_elems
|
|
972
|
+
]
|
|
973
|
+
if not src_data_idx:
|
|
974
|
+
raise RuntimeError(
|
|
975
|
+
f"Could not find a source iteration with loop_idx: "
|
|
976
|
+
f"{loop_idx_key!r}."
|
|
977
|
+
)
|
|
978
|
+
return [i[f"outputs.{inp.typ}"] for i in src_data_idx]
|
|
979
|
+
else:
|
|
980
|
+
return cache.data_idx[src_elem_ID][loop_idx_key][f"outputs.{inp.typ}"]
|
|
981
|
+
|
|
982
|
+
def __get_task_index(
|
|
983
|
+
self,
|
|
984
|
+
task: WorkflowTask,
|
|
985
|
+
orig_inp_src: InputSource,
|
|
986
|
+
cache: LoopCache,
|
|
987
|
+
elem_ID: int,
|
|
988
|
+
inp: SchemaInput,
|
|
989
|
+
inp_key: str,
|
|
990
|
+
parent_loop_indices: Mapping[str, int],
|
|
991
|
+
all_new_data_idx: Mapping[tuple[int, int], DataIndex],
|
|
992
|
+
) -> int | list[int]:
|
|
993
|
+
if orig_inp_src.task_ref not in self.task_insert_IDs:
|
|
994
|
+
# source the data_idx from the iteration with same parent
|
|
995
|
+
# loop indices as the new iteration to add:
|
|
996
|
+
src_data_idx = next(
|
|
997
|
+
di_k
|
|
998
|
+
for li_k, di_k in cache.data_idx[elem_ID].items()
|
|
999
|
+
if all(li_k.get(p_k) == p_v for p_k, p_v in parent_loop_indices.items())
|
|
1000
|
+
)
|
|
1001
|
+
|
|
1002
|
+
# could be multiple, but they should all have the same
|
|
1003
|
+
# data index for this parameter:
|
|
1004
|
+
return src_data_idx[inp_key]
|
|
1005
|
+
|
|
1006
|
+
is_group = (
|
|
1007
|
+
inp.single_labelled_data is not None
|
|
1008
|
+
and "group" in inp.single_labelled_data
|
|
1009
|
+
# this input is a group, assume for now all elements
|
|
1010
|
+
)
|
|
1011
|
+
|
|
1012
|
+
# same task/element, but update iteration to the just-added
|
|
1013
|
+
# iteration:
|
|
1014
|
+
assert orig_inp_src.task_source_type is not None
|
|
1015
|
+
key_prefix = orig_inp_src.task_source_type.name.lower()
|
|
1016
|
+
prev_dat_idx_key = f"{key_prefix}s.{inp.typ}"
|
|
1017
|
+
new_sources: list[tuple[int, int]] = []
|
|
1018
|
+
for (tiID, e_idx), _ in all_new_data_idx.items():
|
|
1019
|
+
if tiID == orig_inp_src.task_ref:
|
|
1020
|
+
# find which element in that task `element`
|
|
1021
|
+
# depends on:
|
|
1022
|
+
src_elem_IDs = cache.element_dependents[
|
|
1023
|
+
self.workflow.tasks.get(insert_ID=tiID).element_IDs[e_idx]
|
|
1024
|
+
]
|
|
1025
|
+
# `cache.elements` contains only elements that are part of the loop, so
|
|
1026
|
+
# indexing a dependent element may raise:
|
|
1027
|
+
src_elem_IDs_i = []
|
|
1028
|
+
for k, _v in src_elem_IDs.items():
|
|
1029
|
+
try:
|
|
1030
|
+
if (
|
|
1031
|
+
cache.elements[k]["task_insert_ID"] == task.insert_ID
|
|
1032
|
+
and k == elem_ID
|
|
1033
|
+
# filter src_elem_IDs_i for matching element IDs
|
|
1034
|
+
):
|
|
1035
|
+
|
|
1036
|
+
src_elem_IDs_i.append(k)
|
|
1037
|
+
except KeyError:
|
|
1038
|
+
continue
|
|
1039
|
+
|
|
1040
|
+
if len(src_elem_IDs_i) == 1:
|
|
1041
|
+
new_sources.append((tiID, e_idx))
|
|
1042
|
+
|
|
1043
|
+
if is_group:
|
|
1044
|
+
# Convert into simple list of indices
|
|
1045
|
+
return list(
|
|
1046
|
+
chain.from_iterable(
|
|
1047
|
+
self.__as_sequence(all_new_data_idx[src][prev_dat_idx_key])
|
|
1048
|
+
for src in new_sources
|
|
1049
|
+
)
|
|
1050
|
+
)
|
|
1051
|
+
else:
|
|
1052
|
+
assert len(new_sources) == 1
|
|
1053
|
+
return all_new_data_idx[new_sources[0]][prev_dat_idx_key]
|
|
1054
|
+
|
|
1055
|
+
@staticmethod
|
|
1056
|
+
def __as_sequence(seq: int | Iterable[int]) -> Iterable[int]:
|
|
1057
|
+
if isinstance(seq, int):
|
|
1058
|
+
yield seq
|
|
1059
|
+
else:
|
|
1060
|
+
yield from seq
|
|
1061
|
+
|
|
1062
|
+
def __update_loop_downstream_data_idx(
|
|
1063
|
+
self,
|
|
1064
|
+
parent_loop_indices: Mapping[str, int],
|
|
1065
|
+
):
|
|
1066
|
+
# update data indices of loop-downstream tasks that depend on task outputs from
|
|
1067
|
+
# this loop:
|
|
1068
|
+
|
|
1069
|
+
# keys: iter or run ID, values: dict of param type and new parameter index
|
|
1070
|
+
iter_new_data_idx: dict[int, DataIndex] = defaultdict(dict)
|
|
1071
|
+
run_new_data_idx: dict[int, DataIndex] = defaultdict(dict)
|
|
1072
|
+
|
|
1073
|
+
param_sources = self.workflow.get_all_parameter_sources()
|
|
1074
|
+
|
|
1075
|
+
# keys are parameter type, then task insert ID, then data index keys mapping to
|
|
1076
|
+
# their updated values:
|
|
1077
|
+
all_updates: dict[str, dict[int, dict[int, int]]] = defaultdict(
|
|
1078
|
+
lambda: defaultdict(dict)
|
|
1079
|
+
)
|
|
1080
|
+
|
|
1081
|
+
for task in self.downstream_tasks:
|
|
1082
|
+
for elem in task.elements:
|
|
1083
|
+
for param_typ, param_out_task_iID in self.output_parameters.items():
|
|
1084
|
+
if param_typ in task.template.all_schema_input_types:
|
|
1085
|
+
# this element's input *might* need updating, only if it has a
|
|
1086
|
+
# task input source type that is this loop's output task for this
|
|
1087
|
+
# parameter:
|
|
1088
|
+
elem_src = elem.input_sources[f"inputs.{param_typ}"]
|
|
1089
|
+
if (
|
|
1090
|
+
elem_src.source_type is InputSourceType.TASK
|
|
1091
|
+
and elem_src.task_source_type is TaskSourceType.OUTPUT
|
|
1092
|
+
and elem_src.task_ref == param_out_task_iID
|
|
1093
|
+
):
|
|
1094
|
+
for iter_i in elem.iterations:
|
|
1095
|
+
|
|
1096
|
+
# do not modify element-iterations of previous iterations
|
|
1097
|
+
# of the current loop:
|
|
1098
|
+
skip_iter = False
|
|
1099
|
+
for k, v in parent_loop_indices.items():
|
|
1100
|
+
if iter_i.loop_idx.get(k) != v:
|
|
1101
|
+
skip_iter = True
|
|
1102
|
+
break
|
|
1103
|
+
|
|
1104
|
+
if skip_iter:
|
|
1105
|
+
continue
|
|
1106
|
+
|
|
1107
|
+
# update the iteration data index and any pending runs:
|
|
1108
|
+
iter_old_di = iter_i.data_idx[f"inputs.{param_typ}"]
|
|
1109
|
+
|
|
1110
|
+
is_group = True
|
|
1111
|
+
if not isinstance(iter_old_di, list):
|
|
1112
|
+
is_group = False
|
|
1113
|
+
iter_old_di = [iter_old_di]
|
|
1114
|
+
|
|
1115
|
+
iter_old_run_source = [
|
|
1116
|
+
param_sources[i]["EAR_ID"] for i in iter_old_di
|
|
1117
|
+
]
|
|
1118
|
+
iter_old_run_objs = self.workflow.get_EARs_from_IDs(
|
|
1119
|
+
iter_old_run_source
|
|
1120
|
+
) # TODO: use cache
|
|
1121
|
+
|
|
1122
|
+
# need to check the run source is actually from the loop
|
|
1123
|
+
# output task (it could be from a previous iteration of a
|
|
1124
|
+
# separate loop in this task):
|
|
1125
|
+
if any(
|
|
1126
|
+
i.task.insert_ID != param_out_task_iID
|
|
1127
|
+
for i in iter_old_run_objs
|
|
1128
|
+
):
|
|
1129
|
+
continue
|
|
1130
|
+
|
|
1131
|
+
iter_new_iters = [
|
|
1132
|
+
i.element.iterations[-1] for i in iter_old_run_objs
|
|
1133
|
+
]
|
|
1134
|
+
|
|
1135
|
+
# note: we can cast to int, because output keys never
|
|
1136
|
+
# have multiple data indices (unlike input keys):
|
|
1137
|
+
iter_new_dis = [
|
|
1138
|
+
cast("int", i.get_data_idx()[f"outputs.{param_typ}"])
|
|
1139
|
+
for i in iter_new_iters
|
|
1140
|
+
]
|
|
1141
|
+
|
|
1142
|
+
# keep track of updates so we can also update task-input
|
|
1143
|
+
# type sources:
|
|
1144
|
+
all_updates[param_typ][task.insert_ID].update(
|
|
1145
|
+
dict(zip(iter_old_di, iter_new_dis))
|
|
1146
|
+
)
|
|
1147
|
+
|
|
1148
|
+
iter_new_data_idx[iter_i.id_][f"inputs.{param_typ}"] = (
|
|
1149
|
+
iter_new_dis if is_group else iter_new_dis[0]
|
|
1150
|
+
)
|
|
1151
|
+
|
|
1152
|
+
for run_j in iter_i.action_runs:
|
|
1153
|
+
if run_j.status is EARStatus.pending:
|
|
1154
|
+
try:
|
|
1155
|
+
old_di = run_j.data_idx[f"inputs.{param_typ}"]
|
|
1156
|
+
except KeyError:
|
|
1157
|
+
# not all actions will include this input
|
|
1158
|
+
continue
|
|
1159
|
+
|
|
1160
|
+
is_group = True
|
|
1161
|
+
if not isinstance(old_di, list):
|
|
1162
|
+
is_group = False
|
|
1163
|
+
old_di = [old_di]
|
|
1164
|
+
|
|
1165
|
+
old_run_source = [
|
|
1166
|
+
param_sources[i]["EAR_ID"] for i in old_di
|
|
1167
|
+
]
|
|
1168
|
+
old_run_objs = self.workflow.get_EARs_from_IDs(
|
|
1169
|
+
old_run_source
|
|
1170
|
+
) # TODO: use cache
|
|
1171
|
+
|
|
1172
|
+
# need to check the run source is actually from the loop
|
|
1173
|
+
# output task (it could be from a previous action in this
|
|
1174
|
+
# element-iteration):
|
|
1175
|
+
if any(
|
|
1176
|
+
i.task.insert_ID != param_out_task_iID
|
|
1177
|
+
for i in old_run_objs
|
|
1178
|
+
):
|
|
1179
|
+
continue
|
|
1180
|
+
|
|
1181
|
+
new_iters = [
|
|
1182
|
+
i.element.iterations[-1] for i in old_run_objs
|
|
1183
|
+
]
|
|
1184
|
+
|
|
1185
|
+
# note: we can cast to int, because output keys
|
|
1186
|
+
# never have multiple data indices (unlike input
|
|
1187
|
+
# keys):
|
|
1188
|
+
new_dis = [
|
|
1189
|
+
cast(
|
|
1190
|
+
"int",
|
|
1191
|
+
i.get_data_idx()[f"outputs.{param_typ}"],
|
|
1192
|
+
)
|
|
1193
|
+
for i in new_iters
|
|
1194
|
+
]
|
|
1195
|
+
|
|
1196
|
+
run_new_data_idx[run_j.id_][
|
|
1197
|
+
f"inputs.{param_typ}"
|
|
1198
|
+
] = (new_dis if is_group else new_dis[0])
|
|
1199
|
+
|
|
1200
|
+
elif (
|
|
1201
|
+
elem_src.source_type is InputSourceType.TASK
|
|
1202
|
+
and elem_src.task_source_type is TaskSourceType.INPUT
|
|
1203
|
+
):
|
|
1204
|
+
# parameters are that sourced from inputs of other tasks,
|
|
1205
|
+
# might need to be updated if those other tasks have
|
|
1206
|
+
# themselves had their data indices updated:
|
|
1207
|
+
assert elem_src.task_ref
|
|
1208
|
+
ups_i = all_updates.get(param_typ, {}).get(elem_src.task_ref)
|
|
1209
|
+
|
|
1210
|
+
if ups_i:
|
|
1211
|
+
# if a further-downstream task has a task-input source
|
|
1212
|
+
# that points to this task, this will also need updating:
|
|
1213
|
+
all_updates[param_typ][task.insert_ID].update(ups_i)
|
|
1214
|
+
|
|
1215
|
+
else:
|
|
1216
|
+
continue
|
|
1217
|
+
|
|
1218
|
+
for iter_i in elem.iterations:
|
|
1219
|
+
|
|
1220
|
+
# update the iteration data index and any pending runs:
|
|
1221
|
+
iter_old_di = iter_i.data_idx[f"inputs.{param_typ}"]
|
|
1222
|
+
|
|
1223
|
+
is_group = True
|
|
1224
|
+
if not isinstance(iter_old_di, list):
|
|
1225
|
+
is_group = False
|
|
1226
|
+
iter_old_di = [iter_old_di]
|
|
1227
|
+
|
|
1228
|
+
iter_new_dis = [ups_i.get(i, i) for i in iter_old_di]
|
|
1229
|
+
|
|
1230
|
+
if iter_new_dis != iter_old_di:
|
|
1231
|
+
iter_new_data_idx[iter_i.id_][
|
|
1232
|
+
f"inputs.{param_typ}"
|
|
1233
|
+
] = (iter_new_dis if is_group else iter_new_dis[0])
|
|
1234
|
+
|
|
1235
|
+
for run_j in iter_i.action_runs:
|
|
1236
|
+
if run_j.status is EARStatus.pending:
|
|
1237
|
+
try:
|
|
1238
|
+
old_di = run_j.data_idx[f"inputs.{param_typ}"]
|
|
1239
|
+
except KeyError:
|
|
1240
|
+
# not all actions will include this input
|
|
1241
|
+
continue
|
|
1242
|
+
|
|
1243
|
+
is_group = True
|
|
1244
|
+
if not isinstance(old_di, list):
|
|
1245
|
+
is_group = False
|
|
1246
|
+
old_di = [old_di]
|
|
1247
|
+
|
|
1248
|
+
new_dis = [ups_i.get(i, i) for i in old_di]
|
|
1249
|
+
|
|
1250
|
+
if new_dis != old_di:
|
|
1251
|
+
run_new_data_idx[run_j.id_][
|
|
1252
|
+
f"inputs.{param_typ}"
|
|
1253
|
+
] = (new_dis if is_group else new_dis[0])
|
|
1254
|
+
|
|
1255
|
+
# now update data indices (TODO: including in cache!)
|
|
1256
|
+
if iter_new_data_idx:
|
|
1257
|
+
self.workflow._store.update_iter_data_indices(iter_new_data_idx)
|
|
1258
|
+
|
|
1259
|
+
if run_new_data_idx:
|
|
1260
|
+
self.workflow._store.update_run_data_indices(run_new_data_idx)
|
|
1261
|
+
|
|
1262
|
+
def test_termination(self, element_iter) -> bool:
|
|
1263
|
+
"""Check if a loop should terminate, given the specified completed element
|
|
1264
|
+
iteration."""
|
|
1265
|
+
if self.template.termination:
|
|
1266
|
+
return self.template.termination.test(element_iter)
|
|
1267
|
+
return False
|
|
1268
|
+
|
|
1269
|
+
@TimeIt.decorator
|
|
1270
|
+
def get_element_IDs(self):
|
|
1271
|
+
elem_IDs = [
|
|
1272
|
+
j
|
|
1273
|
+
for i in self.task_insert_IDs
|
|
1274
|
+
for j in self.workflow.tasks.get(insert_ID=i).element_IDs
|
|
1275
|
+
]
|
|
1276
|
+
return elem_IDs
|
|
1277
|
+
|
|
1278
|
+
@TimeIt.decorator
|
|
1279
|
+
def get_elements(self):
|
|
1280
|
+
return self.workflow.get_elements_from_IDs(self.get_element_IDs())
|
|
1281
|
+
|
|
1282
|
+
@TimeIt.decorator
|
|
1283
|
+
def skip_downstream_iterations(self, elem_iter) -> list[int]:
|
|
1284
|
+
"""
|
|
1285
|
+
Parameters
|
|
1286
|
+
----------
|
|
1287
|
+
elem_iter
|
|
1288
|
+
The element iteration whose subsequent iterations should be skipped.
|
|
1289
|
+
dep_element_IDs
|
|
1290
|
+
List of elements that are dependent (recursively) on the element
|
|
1291
|
+
of `elem_iter`.
|
|
1292
|
+
"""
|
|
1293
|
+
current_iter_idx = elem_iter.loop_idx[self.name]
|
|
1294
|
+
current_task_iID = elem_iter.task.insert_ID
|
|
1295
|
+
self._app.logger.info(
|
|
1296
|
+
f"setting loop {self.name!r} iterations downstream of current iteration "
|
|
1297
|
+
f"index {current_iter_idx} to skip"
|
|
1298
|
+
)
|
|
1299
|
+
elements = self.get_elements()
|
|
1300
|
+
|
|
1301
|
+
# TODO: fix for multiple loop cycles
|
|
1302
|
+
warn(
|
|
1303
|
+
"skip downstream iterations does not work correctly for multiple loop cycles!"
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
to_skip = []
|
|
1307
|
+
for elem in elements:
|
|
1308
|
+
for iter_i in elem.iterations:
|
|
1309
|
+
if iter_i.loop_idx[self.name] > current_iter_idx or (
|
|
1310
|
+
iter_i.loop_idx[self.name] == current_iter_idx
|
|
1311
|
+
and iter_i.task.insert_ID > current_task_iID
|
|
1312
|
+
):
|
|
1313
|
+
to_skip.extend(iter_i.EAR_IDs_flat)
|
|
1314
|
+
|
|
1315
|
+
self._app.logger.info(
|
|
1316
|
+
f"{len(to_skip)} runs will be set to skip: {shorten_list_str(to_skip)}"
|
|
1317
|
+
)
|
|
1318
|
+
self.workflow.set_EAR_skip({k: SkipReason.LOOP_TERMINATION for k in to_skip})
|
|
1319
|
+
|
|
1320
|
+
return to_skip
|