hpcflow 0.1.9__py3-none-any.whl → 0.2.0a271__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (275) hide show
  1. hpcflow/__init__.py +2 -11
  2. hpcflow/__pyinstaller/__init__.py +5 -0
  3. hpcflow/__pyinstaller/hook-hpcflow.py +40 -0
  4. hpcflow/_version.py +1 -1
  5. hpcflow/app.py +43 -0
  6. hpcflow/cli.py +2 -462
  7. hpcflow/data/demo_data_manifest/__init__.py +3 -0
  8. hpcflow/data/demo_data_manifest/demo_data_manifest.json +6 -0
  9. hpcflow/data/jinja_templates/test/test_template.txt +8 -0
  10. hpcflow/data/programs/hello_world/README.md +1 -0
  11. hpcflow/data/programs/hello_world/hello_world.c +87 -0
  12. hpcflow/data/programs/hello_world/linux/hello_world +0 -0
  13. hpcflow/data/programs/hello_world/macos/hello_world +0 -0
  14. hpcflow/data/programs/hello_world/win/hello_world.exe +0 -0
  15. hpcflow/data/scripts/__init__.py +1 -0
  16. hpcflow/data/scripts/bad_script.py +2 -0
  17. hpcflow/data/scripts/demo_task_1_generate_t1_infile_1.py +8 -0
  18. hpcflow/data/scripts/demo_task_1_generate_t1_infile_2.py +8 -0
  19. hpcflow/data/scripts/demo_task_1_parse_p3.py +7 -0
  20. hpcflow/data/scripts/do_nothing.py +2 -0
  21. hpcflow/data/scripts/env_specifier_test/input_file_generator_pass_env_spec.py +4 -0
  22. hpcflow/data/scripts/env_specifier_test/main_script_test_pass_env_spec.py +8 -0
  23. hpcflow/data/scripts/env_specifier_test/output_file_parser_pass_env_spec.py +4 -0
  24. hpcflow/data/scripts/env_specifier_test/v1/input_file_generator_basic.py +4 -0
  25. hpcflow/data/scripts/env_specifier_test/v1/main_script_test_direct_in_direct_out.py +7 -0
  26. hpcflow/data/scripts/env_specifier_test/v1/output_file_parser_basic.py +4 -0
  27. hpcflow/data/scripts/env_specifier_test/v2/main_script_test_direct_in_direct_out.py +7 -0
  28. hpcflow/data/scripts/generate_t1_file_01.py +7 -0
  29. hpcflow/data/scripts/import_future_script.py +7 -0
  30. hpcflow/data/scripts/input_file_generator_basic.py +3 -0
  31. hpcflow/data/scripts/input_file_generator_basic_FAIL.py +3 -0
  32. hpcflow/data/scripts/input_file_generator_test_stdout_stderr.py +8 -0
  33. hpcflow/data/scripts/main_script_test_direct_in.py +3 -0
  34. hpcflow/data/scripts/main_script_test_direct_in_direct_out.py +6 -0
  35. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2.py +6 -0
  36. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed.py +6 -0
  37. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed_group.py +7 -0
  38. hpcflow/data/scripts/main_script_test_direct_in_direct_out_3.py +6 -0
  39. hpcflow/data/scripts/main_script_test_direct_in_direct_out_all_iters_test.py +15 -0
  40. hpcflow/data/scripts/main_script_test_direct_in_direct_out_env_spec.py +7 -0
  41. hpcflow/data/scripts/main_script_test_direct_in_direct_out_labels.py +8 -0
  42. hpcflow/data/scripts/main_script_test_direct_in_group_direct_out_3.py +6 -0
  43. hpcflow/data/scripts/main_script_test_direct_in_group_one_fail_direct_out_3.py +6 -0
  44. hpcflow/data/scripts/main_script_test_direct_sub_param_in_direct_out.py +6 -0
  45. hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +12 -0
  46. hpcflow/data/scripts/main_script_test_hdf5_in_obj_2.py +12 -0
  47. hpcflow/data/scripts/main_script_test_hdf5_in_obj_group.py +12 -0
  48. hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +11 -0
  49. hpcflow/data/scripts/main_script_test_json_and_direct_in_json_out.py +14 -0
  50. hpcflow/data/scripts/main_script_test_json_in_json_and_direct_out.py +17 -0
  51. hpcflow/data/scripts/main_script_test_json_in_json_out.py +14 -0
  52. hpcflow/data/scripts/main_script_test_json_in_json_out_labels.py +16 -0
  53. hpcflow/data/scripts/main_script_test_json_in_obj.py +12 -0
  54. hpcflow/data/scripts/main_script_test_json_out_FAIL.py +3 -0
  55. hpcflow/data/scripts/main_script_test_json_out_obj.py +10 -0
  56. hpcflow/data/scripts/main_script_test_json_sub_param_in_json_out_labels.py +16 -0
  57. hpcflow/data/scripts/main_script_test_shell_env_vars.py +12 -0
  58. hpcflow/data/scripts/main_script_test_std_out_std_err.py +6 -0
  59. hpcflow/data/scripts/output_file_parser_basic.py +3 -0
  60. hpcflow/data/scripts/output_file_parser_basic_FAIL.py +7 -0
  61. hpcflow/data/scripts/output_file_parser_test_stdout_stderr.py +8 -0
  62. hpcflow/data/scripts/parse_t1_file_01.py +4 -0
  63. hpcflow/data/scripts/script_exit_test.py +5 -0
  64. hpcflow/data/template_components/__init__.py +1 -0
  65. hpcflow/data/template_components/command_files.yaml +26 -0
  66. hpcflow/data/template_components/environments.yaml +13 -0
  67. hpcflow/data/template_components/parameters.yaml +14 -0
  68. hpcflow/data/template_components/task_schemas.yaml +139 -0
  69. hpcflow/data/workflows/workflow_1.yaml +5 -0
  70. hpcflow/examples.ipynb +1037 -0
  71. hpcflow/sdk/__init__.py +149 -0
  72. hpcflow/sdk/app.py +4266 -0
  73. hpcflow/sdk/cli.py +1479 -0
  74. hpcflow/sdk/cli_common.py +385 -0
  75. hpcflow/sdk/config/__init__.py +5 -0
  76. hpcflow/sdk/config/callbacks.py +246 -0
  77. hpcflow/sdk/config/cli.py +388 -0
  78. hpcflow/sdk/config/config.py +1410 -0
  79. hpcflow/sdk/config/config_file.py +501 -0
  80. hpcflow/sdk/config/errors.py +272 -0
  81. hpcflow/sdk/config/types.py +150 -0
  82. hpcflow/sdk/core/__init__.py +38 -0
  83. hpcflow/sdk/core/actions.py +3857 -0
  84. hpcflow/sdk/core/app_aware.py +25 -0
  85. hpcflow/sdk/core/cache.py +224 -0
  86. hpcflow/sdk/core/command_files.py +814 -0
  87. hpcflow/sdk/core/commands.py +424 -0
  88. hpcflow/sdk/core/element.py +2071 -0
  89. hpcflow/sdk/core/enums.py +221 -0
  90. hpcflow/sdk/core/environment.py +256 -0
  91. hpcflow/sdk/core/errors.py +1043 -0
  92. hpcflow/sdk/core/execute.py +207 -0
  93. hpcflow/sdk/core/json_like.py +809 -0
  94. hpcflow/sdk/core/loop.py +1320 -0
  95. hpcflow/sdk/core/loop_cache.py +282 -0
  96. hpcflow/sdk/core/object_list.py +933 -0
  97. hpcflow/sdk/core/parameters.py +3371 -0
  98. hpcflow/sdk/core/rule.py +196 -0
  99. hpcflow/sdk/core/run_dir_files.py +57 -0
  100. hpcflow/sdk/core/skip_reason.py +7 -0
  101. hpcflow/sdk/core/task.py +3792 -0
  102. hpcflow/sdk/core/task_schema.py +993 -0
  103. hpcflow/sdk/core/test_utils.py +538 -0
  104. hpcflow/sdk/core/types.py +447 -0
  105. hpcflow/sdk/core/utils.py +1207 -0
  106. hpcflow/sdk/core/validation.py +87 -0
  107. hpcflow/sdk/core/values.py +477 -0
  108. hpcflow/sdk/core/workflow.py +4820 -0
  109. hpcflow/sdk/core/zarr_io.py +206 -0
  110. hpcflow/sdk/data/__init__.py +13 -0
  111. hpcflow/sdk/data/config_file_schema.yaml +34 -0
  112. hpcflow/sdk/data/config_schema.yaml +260 -0
  113. hpcflow/sdk/data/environments_spec_schema.yaml +21 -0
  114. hpcflow/sdk/data/files_spec_schema.yaml +5 -0
  115. hpcflow/sdk/data/parameters_spec_schema.yaml +7 -0
  116. hpcflow/sdk/data/task_schema_spec_schema.yaml +3 -0
  117. hpcflow/sdk/data/workflow_spec_schema.yaml +22 -0
  118. hpcflow/sdk/demo/__init__.py +3 -0
  119. hpcflow/sdk/demo/cli.py +242 -0
  120. hpcflow/sdk/helper/__init__.py +3 -0
  121. hpcflow/sdk/helper/cli.py +137 -0
  122. hpcflow/sdk/helper/helper.py +300 -0
  123. hpcflow/sdk/helper/watcher.py +192 -0
  124. hpcflow/sdk/log.py +288 -0
  125. hpcflow/sdk/persistence/__init__.py +18 -0
  126. hpcflow/sdk/persistence/base.py +2817 -0
  127. hpcflow/sdk/persistence/defaults.py +6 -0
  128. hpcflow/sdk/persistence/discovery.py +39 -0
  129. hpcflow/sdk/persistence/json.py +954 -0
  130. hpcflow/sdk/persistence/pending.py +948 -0
  131. hpcflow/sdk/persistence/store_resource.py +203 -0
  132. hpcflow/sdk/persistence/types.py +309 -0
  133. hpcflow/sdk/persistence/utils.py +73 -0
  134. hpcflow/sdk/persistence/zarr.py +2388 -0
  135. hpcflow/sdk/runtime.py +320 -0
  136. hpcflow/sdk/submission/__init__.py +3 -0
  137. hpcflow/sdk/submission/enums.py +70 -0
  138. hpcflow/sdk/submission/jobscript.py +2379 -0
  139. hpcflow/sdk/submission/schedulers/__init__.py +281 -0
  140. hpcflow/sdk/submission/schedulers/direct.py +233 -0
  141. hpcflow/sdk/submission/schedulers/sge.py +376 -0
  142. hpcflow/sdk/submission/schedulers/slurm.py +598 -0
  143. hpcflow/sdk/submission/schedulers/utils.py +25 -0
  144. hpcflow/sdk/submission/shells/__init__.py +52 -0
  145. hpcflow/sdk/submission/shells/base.py +229 -0
  146. hpcflow/sdk/submission/shells/bash.py +504 -0
  147. hpcflow/sdk/submission/shells/os_version.py +115 -0
  148. hpcflow/sdk/submission/shells/powershell.py +352 -0
  149. hpcflow/sdk/submission/submission.py +1402 -0
  150. hpcflow/sdk/submission/types.py +140 -0
  151. hpcflow/sdk/typing.py +194 -0
  152. hpcflow/sdk/utils/arrays.py +69 -0
  153. hpcflow/sdk/utils/deferred_file.py +55 -0
  154. hpcflow/sdk/utils/hashing.py +16 -0
  155. hpcflow/sdk/utils/patches.py +31 -0
  156. hpcflow/sdk/utils/strings.py +69 -0
  157. hpcflow/tests/api/test_api.py +32 -0
  158. hpcflow/tests/conftest.py +123 -0
  159. hpcflow/tests/data/__init__.py +0 -0
  160. hpcflow/tests/data/benchmark_N_elements.yaml +6 -0
  161. hpcflow/tests/data/benchmark_script_runner.yaml +26 -0
  162. hpcflow/tests/data/multi_path_sequences.yaml +29 -0
  163. hpcflow/tests/data/workflow_1.json +10 -0
  164. hpcflow/tests/data/workflow_1.yaml +5 -0
  165. hpcflow/tests/data/workflow_1_slurm.yaml +8 -0
  166. hpcflow/tests/data/workflow_1_wsl.yaml +8 -0
  167. hpcflow/tests/data/workflow_test_run_abort.yaml +42 -0
  168. hpcflow/tests/jinja_templates/test_jinja_templates.py +161 -0
  169. hpcflow/tests/programs/test_programs.py +180 -0
  170. hpcflow/tests/schedulers/direct_linux/test_direct_linux_submission.py +12 -0
  171. hpcflow/tests/schedulers/sge/test_sge_submission.py +36 -0
  172. hpcflow/tests/schedulers/slurm/test_slurm_submission.py +14 -0
  173. hpcflow/tests/scripts/test_input_file_generators.py +282 -0
  174. hpcflow/tests/scripts/test_main_scripts.py +1361 -0
  175. hpcflow/tests/scripts/test_non_snippet_script.py +46 -0
  176. hpcflow/tests/scripts/test_ouput_file_parsers.py +353 -0
  177. hpcflow/tests/shells/wsl/test_wsl_submission.py +14 -0
  178. hpcflow/tests/unit/test_action.py +1066 -0
  179. hpcflow/tests/unit/test_action_rule.py +24 -0
  180. hpcflow/tests/unit/test_app.py +132 -0
  181. hpcflow/tests/unit/test_cache.py +46 -0
  182. hpcflow/tests/unit/test_cli.py +172 -0
  183. hpcflow/tests/unit/test_command.py +377 -0
  184. hpcflow/tests/unit/test_config.py +195 -0
  185. hpcflow/tests/unit/test_config_file.py +162 -0
  186. hpcflow/tests/unit/test_element.py +666 -0
  187. hpcflow/tests/unit/test_element_iteration.py +88 -0
  188. hpcflow/tests/unit/test_element_set.py +158 -0
  189. hpcflow/tests/unit/test_group.py +115 -0
  190. hpcflow/tests/unit/test_input_source.py +1479 -0
  191. hpcflow/tests/unit/test_input_value.py +398 -0
  192. hpcflow/tests/unit/test_jobscript_unit.py +757 -0
  193. hpcflow/tests/unit/test_json_like.py +1247 -0
  194. hpcflow/tests/unit/test_loop.py +2674 -0
  195. hpcflow/tests/unit/test_meta_task.py +325 -0
  196. hpcflow/tests/unit/test_multi_path_sequences.py +259 -0
  197. hpcflow/tests/unit/test_object_list.py +116 -0
  198. hpcflow/tests/unit/test_parameter.py +243 -0
  199. hpcflow/tests/unit/test_persistence.py +664 -0
  200. hpcflow/tests/unit/test_resources.py +243 -0
  201. hpcflow/tests/unit/test_run.py +286 -0
  202. hpcflow/tests/unit/test_run_directories.py +29 -0
  203. hpcflow/tests/unit/test_runtime.py +9 -0
  204. hpcflow/tests/unit/test_schema_input.py +372 -0
  205. hpcflow/tests/unit/test_shell.py +129 -0
  206. hpcflow/tests/unit/test_slurm.py +39 -0
  207. hpcflow/tests/unit/test_submission.py +502 -0
  208. hpcflow/tests/unit/test_task.py +2560 -0
  209. hpcflow/tests/unit/test_task_schema.py +182 -0
  210. hpcflow/tests/unit/test_utils.py +616 -0
  211. hpcflow/tests/unit/test_value_sequence.py +549 -0
  212. hpcflow/tests/unit/test_values.py +91 -0
  213. hpcflow/tests/unit/test_workflow.py +827 -0
  214. hpcflow/tests/unit/test_workflow_template.py +186 -0
  215. hpcflow/tests/unit/utils/test_arrays.py +40 -0
  216. hpcflow/tests/unit/utils/test_deferred_file_writer.py +34 -0
  217. hpcflow/tests/unit/utils/test_hashing.py +65 -0
  218. hpcflow/tests/unit/utils/test_patches.py +5 -0
  219. hpcflow/tests/unit/utils/test_redirect_std.py +50 -0
  220. hpcflow/tests/unit/utils/test_strings.py +97 -0
  221. hpcflow/tests/workflows/__init__.py +0 -0
  222. hpcflow/tests/workflows/test_directory_structure.py +31 -0
  223. hpcflow/tests/workflows/test_jobscript.py +355 -0
  224. hpcflow/tests/workflows/test_run_status.py +198 -0
  225. hpcflow/tests/workflows/test_skip_downstream.py +696 -0
  226. hpcflow/tests/workflows/test_submission.py +140 -0
  227. hpcflow/tests/workflows/test_workflows.py +564 -0
  228. hpcflow/tests/workflows/test_zip.py +18 -0
  229. hpcflow/viz_demo.ipynb +6794 -0
  230. hpcflow-0.2.0a271.dist-info/LICENSE +375 -0
  231. hpcflow-0.2.0a271.dist-info/METADATA +65 -0
  232. hpcflow-0.2.0a271.dist-info/RECORD +237 -0
  233. {hpcflow-0.1.9.dist-info → hpcflow-0.2.0a271.dist-info}/WHEEL +4 -5
  234. hpcflow-0.2.0a271.dist-info/entry_points.txt +6 -0
  235. hpcflow/api.py +0 -458
  236. hpcflow/archive/archive.py +0 -308
  237. hpcflow/archive/cloud/cloud.py +0 -47
  238. hpcflow/archive/cloud/errors.py +0 -9
  239. hpcflow/archive/cloud/providers/dropbox.py +0 -432
  240. hpcflow/archive/errors.py +0 -5
  241. hpcflow/base_db.py +0 -4
  242. hpcflow/config.py +0 -232
  243. hpcflow/copytree.py +0 -66
  244. hpcflow/data/examples/_config.yml +0 -14
  245. hpcflow/data/examples/damask/demo/1.run.yml +0 -4
  246. hpcflow/data/examples/damask/demo/2.process.yml +0 -29
  247. hpcflow/data/examples/damask/demo/geom.geom +0 -2052
  248. hpcflow/data/examples/damask/demo/load.load +0 -1
  249. hpcflow/data/examples/damask/demo/material.config +0 -185
  250. hpcflow/data/examples/damask/inputs/geom.geom +0 -2052
  251. hpcflow/data/examples/damask/inputs/load.load +0 -1
  252. hpcflow/data/examples/damask/inputs/material.config +0 -185
  253. hpcflow/data/examples/damask/profiles/_variable_lookup.yml +0 -21
  254. hpcflow/data/examples/damask/profiles/damask.yml +0 -4
  255. hpcflow/data/examples/damask/profiles/damask_process.yml +0 -8
  256. hpcflow/data/examples/damask/profiles/damask_run.yml +0 -5
  257. hpcflow/data/examples/damask/profiles/default.yml +0 -6
  258. hpcflow/data/examples/thinking.yml +0 -177
  259. hpcflow/errors.py +0 -2
  260. hpcflow/init_db.py +0 -37
  261. hpcflow/models.py +0 -2549
  262. hpcflow/nesting.py +0 -9
  263. hpcflow/profiles.py +0 -455
  264. hpcflow/project.py +0 -81
  265. hpcflow/scheduler.py +0 -323
  266. hpcflow/utils.py +0 -103
  267. hpcflow/validation.py +0 -167
  268. hpcflow/variables.py +0 -544
  269. hpcflow-0.1.9.dist-info/METADATA +0 -168
  270. hpcflow-0.1.9.dist-info/RECORD +0 -45
  271. hpcflow-0.1.9.dist-info/entry_points.txt +0 -8
  272. hpcflow-0.1.9.dist-info/top_level.txt +0 -1
  273. /hpcflow/{archive → data/jinja_templates}/__init__.py +0 -0
  274. /hpcflow/{archive/cloud → data/programs}/__init__.py +0 -0
  275. /hpcflow/{archive/cloud/providers → data/workflows}/__init__.py +0 -0
@@ -0,0 +1,3792 @@
1
+ """
2
+ Tasks are components of workflows.
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from collections import defaultdict
7
+ import copy
8
+ from dataclasses import dataclass, field
9
+ from itertools import chain
10
+ from pathlib import Path
11
+ from typing import NamedTuple, cast, overload, TYPE_CHECKING
12
+ from typing_extensions import override
13
+
14
+ from hpcflow.sdk.typing import hydrate
15
+ from hpcflow.sdk.core.object_list import AppDataList
16
+ from hpcflow.sdk.log import TimeIt
17
+ from hpcflow.sdk.core.app_aware import AppAware
18
+ from hpcflow.sdk.core.json_like import ChildObjectSpec, JSONLike
19
+ from hpcflow.sdk.core.element import ElementGroup
20
+ from hpcflow.sdk.core.enums import InputSourceType, TaskSourceType
21
+ from hpcflow.sdk.core.errors import (
22
+ ContainerKeyError,
23
+ ExtraInputs,
24
+ InapplicableInputSourceElementIters,
25
+ MalformedNestingOrderPath,
26
+ MayNeedObjectError,
27
+ MissingElementGroup,
28
+ MissingInputs,
29
+ NoAvailableElementSetsError,
30
+ NoCoincidentInputSources,
31
+ TaskTemplateInvalidNesting,
32
+ TaskTemplateMultipleInputValues,
33
+ TaskTemplateMultipleSchemaObjectives,
34
+ TaskTemplateUnexpectedInput,
35
+ TaskTemplateUnexpectedSequenceInput,
36
+ UnavailableInputSource,
37
+ UnknownEnvironmentPresetError,
38
+ UnrequiredInputSources,
39
+ UnsetParameterDataError,
40
+ )
41
+ from hpcflow.sdk.core.parameters import ParameterValue
42
+ from hpcflow.sdk.core.utils import (
43
+ get_duplicate_items,
44
+ get_in_container,
45
+ get_item_repeat_index,
46
+ get_relative_path,
47
+ group_by_dict_key_values,
48
+ set_in_container,
49
+ split_param_label,
50
+ )
51
+
52
+ if TYPE_CHECKING:
53
+ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
54
+ from typing import Any, ClassVar, Literal, TypeVar
55
+ from typing_extensions import Self, TypeAlias, TypeIs
56
+ from ..typing import DataIndex, ParamSource
57
+ from .actions import Action
58
+ from .command_files import InputFile
59
+ from .element import (
60
+ Element,
61
+ ElementIteration,
62
+ ElementFilter,
63
+ ElementParameter,
64
+ _ElementPrefixedParameter as EPP,
65
+ )
66
+ from .parameters import (
67
+ InputValue,
68
+ InputSource,
69
+ ValueSequence,
70
+ MultiPathSequence,
71
+ SchemaInput,
72
+ SchemaOutput,
73
+ ParameterPath,
74
+ )
75
+ from .rule import Rule
76
+ from .task_schema import TaskObjective, TaskSchema, MetaTaskSchema
77
+ from .types import (
78
+ MultiplicityDescriptor,
79
+ RelevantData,
80
+ RelevantPath,
81
+ Resources,
82
+ RepeatsDescriptor,
83
+ )
84
+ from .workflow import Workflow, WorkflowTemplate
85
+
86
+ StrSeq = TypeVar("StrSeq", bound=Sequence[str])
87
+
88
+
89
+ INPUT_SOURCE_TYPES = ("local", "default", "task", "import")
90
+
91
+
92
+ @dataclass
93
+ class InputStatus:
94
+ """Information about a given schema input and its parametrisation within an element
95
+ set.
96
+
97
+ Parameters
98
+ ----------
99
+ has_default
100
+ True if a default value is available.
101
+ is_required
102
+ True if the input is required by one or more actions. An input may not be required
103
+ if it is only used in the generation of inputs files, and those input files are
104
+ passed to the element set directly.
105
+ is_provided
106
+ True if the input is locally provided in the element set.
107
+
108
+ """
109
+
110
+ #: True if a default value is available.
111
+ has_default: bool
112
+ #: True if the input is required by one or more actions. An input may not be required
113
+ #: if it is only used in the generation of inputs files, and those input files are
114
+ #: passed to the element set directly.
115
+ is_required: bool
116
+ #: True if the input is locally provided in the element set.
117
+ is_provided: bool
118
+
119
+ @property
120
+ def is_extra(self) -> bool:
121
+ """True if the input is provided but not required."""
122
+ return self.is_provided and not self.is_required
123
+
124
+
125
+ class ElementSet(JSONLike):
126
+ """Class to represent a parameterisation of a new set of elements.
127
+
128
+ Parameters
129
+ ----------
130
+ inputs: list[~hpcflow.app.InputValue]
131
+ Inputs to the set of elements.
132
+ input_files: list[~hpcflow.app.InputFile]
133
+ Input files to the set of elements.
134
+ sequences: list[~hpcflow.app.ValueSequence]
135
+ Input value sequences to parameterise over.
136
+ multi_path_sequences: list[~hpcflow.app.MultiPathSequence]
137
+ Multi-path sequences to parameterise over.
138
+ resources: ~hpcflow.app.ResourceList
139
+ Resources to use for the set of elements.
140
+ repeats: list[dict]
141
+ Description of how to repeat the set of elements.
142
+ groups: list[~hpcflow.app.ElementGroup]
143
+ Groupings in the set of elements.
144
+ input_sources: dict[str, ~hpcflow.app.InputSource]
145
+ Input source descriptors.
146
+ nesting_order: dict[str, int]
147
+ How to handle nesting of iterations.
148
+ env_preset: str
149
+ Which environment preset to use. Don't use at same time as ``environments``.
150
+ environments: dict
151
+ Environment descriptors to use. Don't use at same time as ``env_preset``.
152
+ sourceable_elem_iters: list[int]
153
+ If specified, a list of global element iteration indices from which inputs for
154
+ the new elements associated with this element set may be sourced. If not
155
+ specified, all workflow element iterations are considered sourceable.
156
+ allow_non_coincident_task_sources: bool
157
+ If True, if more than one parameter is sourced from the same task, then allow
158
+ these sources to come from distinct element sub-sets. If False (default),
159
+ only the intersection of element sub-sets for all parameters are included.
160
+ is_creation: bool
161
+ If True, merge ``environments`` into ``resources`` using the "any" scope, and
162
+ merge sequences belonging to multi-path sequences into the value-sequences list.
163
+ If False, ``environments`` are ignored. This is required on first initialisation,
164
+ but not on subsequent re-initialisation from a persistent workflow.
165
+ """
166
+
167
+ _child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
168
+ ChildObjectSpec(
169
+ name="inputs",
170
+ class_name="InputValue",
171
+ is_multiple=True,
172
+ dict_key_attr="parameter",
173
+ dict_val_attr="value",
174
+ parent_ref="_element_set",
175
+ ),
176
+ ChildObjectSpec(
177
+ name="input_files",
178
+ class_name="InputFile",
179
+ is_multiple=True,
180
+ dict_key_attr="file",
181
+ dict_val_attr="path",
182
+ parent_ref="_element_set",
183
+ ),
184
+ ChildObjectSpec(
185
+ name="resources",
186
+ class_name="ResourceList",
187
+ parent_ref="_element_set",
188
+ ),
189
+ ChildObjectSpec(
190
+ name="sequences",
191
+ class_name="ValueSequence",
192
+ is_multiple=True,
193
+ parent_ref="_element_set",
194
+ ),
195
+ ChildObjectSpec(
196
+ name="multi_path_sequences",
197
+ class_name="MultiPathSequence",
198
+ is_multiple=True,
199
+ parent_ref="_element_set",
200
+ ),
201
+ ChildObjectSpec(
202
+ name="input_sources",
203
+ class_name="InputSource",
204
+ is_multiple=True,
205
+ is_dict_values=True,
206
+ is_dict_values_ensure_list=True,
207
+ ),
208
+ ChildObjectSpec(
209
+ name="groups",
210
+ class_name="ElementGroup",
211
+ is_multiple=True,
212
+ ),
213
+ )
214
+
215
+ def __init__(
216
+ self,
217
+ inputs: list[InputValue] | dict[str, Any] | None = None,
218
+ input_files: list[InputFile] | None = None,
219
+ sequences: list[ValueSequence] | None = None,
220
+ multi_path_sequences: list[MultiPathSequence] | None = None,
221
+ resources: Resources = None,
222
+ repeats: list[RepeatsDescriptor] | int | None = None,
223
+ groups: list[ElementGroup] | None = None,
224
+ input_sources: dict[str, list[InputSource]] | None = None,
225
+ nesting_order: dict[str, float] | None = None,
226
+ env_preset: str | None = None,
227
+ environments: Mapping[str, Mapping[str, Any]] | None = None,
228
+ sourceable_elem_iters: list[int] | None = None,
229
+ allow_non_coincident_task_sources: bool = False,
230
+ is_creation: bool = True,
231
+ ):
232
+ #: Inputs to the set of elements.
233
+ self.inputs = self.__decode_inputs(inputs or [])
234
+ #: Input files to the set of elements.
235
+ self.input_files = input_files or []
236
+ #: Description of how to repeat the set of elements.
237
+ self.repeats = self.__decode_repeats(repeats or [])
238
+ #: Groupings in the set of elements.
239
+ self.groups = groups or []
240
+ #: Resources to use for the set of elements.
241
+ self.resources = self._app.ResourceList.normalise(resources)
242
+ #: Input value sequences to parameterise over.
243
+ self.sequences = sequences or []
244
+ #: Input value multi-path sequences to parameterise over.
245
+ self.multi_path_sequences = multi_path_sequences or []
246
+ #: Input source descriptors.
247
+ self.input_sources = input_sources or {}
248
+ #: How to handle nesting of iterations.
249
+ self.nesting_order = nesting_order or {}
250
+ #: Which environment preset to use.
251
+ self.env_preset = env_preset
252
+ #: Environment descriptors to use.
253
+ self.environments = environments
254
+ #: List of global element iteration indices from which inputs for
255
+ #: the new elements associated with this element set may be sourced.
256
+ #: If ``None``, all iterations are valid.
257
+ self.sourceable_elem_iters = sourceable_elem_iters
258
+ #: Whether to allow sources to come from distinct element sub-sets.
259
+ self.allow_non_coincident_task_sources = allow_non_coincident_task_sources
260
+ #: Whether this initialisation is the first for this data (i.e. not a
261
+ #: reconstruction from persistent workflow data), in which case, we merge
262
+ #: ``environments`` into ``resources`` using the "any" scope, and merge any multi-
263
+ #: path sequences into the sequences list.
264
+ self.is_creation = is_creation
265
+ self.original_input_sources: dict[str, list[InputSource]] | None = None
266
+ self.original_nesting_order: dict[str, float] | None = None
267
+
268
+ self._validate()
269
+ self._set_parent_refs()
270
+
271
+ # assigned by parent Task
272
+ self._task_template: Task | None = None
273
+ # assigned on _task_template assignment
274
+ self._defined_input_types: set[str] | None = None
275
+ # assigned by WorkflowTask._add_element_set
276
+ self._element_local_idx_range: list[int] | None = None
277
+
278
+ if self.is_creation:
279
+
280
+ # merge `environments` into element set resources (this mutates `resources`, and
281
+ # should only happen on creation of the element set, not re-initialisation from a
282
+ # persistent workflow):
283
+ if self.environments:
284
+ self.resources.merge_one(
285
+ self._app.ResourceSpec(scope="any", environments=self.environments)
286
+ )
287
+ # note: `env_preset` is merged into resources by the Task init.
288
+
289
+ # merge sequences belonging to multi-path sequences into the value-sequences list:
290
+ if self.multi_path_sequences:
291
+ for mp_seq in self.multi_path_sequences:
292
+ mp_seq._move_to_sequence_list(self.sequences)
293
+
294
+ self.is_creation = False
295
+
296
+ def __deepcopy__(self, memo: dict[int, Any] | None) -> Self:
297
+ dct = self.to_dict()
298
+ orig_inp = dct.pop("original_input_sources", None)
299
+ orig_nest = dct.pop("original_nesting_order", None)
300
+ elem_local_idx_range = dct.pop("_element_local_idx_range", None)
301
+ obj = self.__class__(**copy.deepcopy(dct, memo))
302
+ obj._task_template = self._task_template
303
+ obj._defined_input_types = self._defined_input_types
304
+ obj.original_input_sources = copy.deepcopy(orig_inp)
305
+ obj.original_nesting_order = copy.copy(orig_nest)
306
+ obj._element_local_idx_range = copy.copy(elem_local_idx_range)
307
+ return obj
308
+
309
+ def __eq__(self, other: Any) -> bool:
310
+ if not isinstance(other, self.__class__):
311
+ return False
312
+ return self.to_dict() == other.to_dict()
313
+
314
+ @classmethod
315
+ def _json_like_constructor(cls, json_like) -> Self:
316
+ """Invoked by `JSONLike.from_json_like` instead of `__init__`."""
317
+ orig_inp = json_like.pop("original_input_sources", None)
318
+ orig_nest = json_like.pop("original_nesting_order", None)
319
+ elem_local_idx_range = json_like.pop("_element_local_idx_range", None)
320
+ obj = cls(**json_like)
321
+ obj.original_input_sources = orig_inp
322
+ obj.original_nesting_order = orig_nest
323
+ obj._element_local_idx_range = elem_local_idx_range
324
+ return obj
325
+
326
+ def prepare_persistent_copy(self) -> Self:
327
+ """Return a copy of self, which will then be made persistent, and save copies of
328
+ attributes that may be changed during integration with the workflow."""
329
+ obj = copy.deepcopy(self)
330
+ obj.original_nesting_order = self.nesting_order
331
+ obj.original_input_sources = self.input_sources
332
+ return obj
333
+
334
+ @override
335
+ def _postprocess_to_dict(self, d: dict[str, Any]) -> dict[str, Any]:
336
+ dct = super()._postprocess_to_dict(d)
337
+ del dct["_defined_input_types"]
338
+ del dct["_task_template"]
339
+ return dct
340
+
341
+ @property
342
+ def task_template(self) -> Task:
343
+ """
344
+ The abstract task this was derived from.
345
+ """
346
+ assert self._task_template is not None
347
+ return self._task_template
348
+
349
+ @task_template.setter
350
+ def task_template(self, value: Task) -> None:
351
+ self._task_template = value
352
+ self.__validate_against_template()
353
+
354
+ @property
355
+ def input_types(self) -> list[str]:
356
+ """
357
+ The input types of the inputs to this element set.
358
+ """
359
+ return [in_.labelled_type for in_ in self.inputs]
360
+
361
+ @property
362
+ def element_local_idx_range(self) -> tuple[int, ...]:
363
+ """Indices of elements belonging to this element set."""
364
+ return tuple(self._element_local_idx_range or ())
365
+
366
+ @classmethod
367
+ def __decode_inputs(
368
+ cls, inputs: list[InputValue] | dict[str, Any]
369
+ ) -> list[InputValue]:
370
+ """support inputs passed as a dict"""
371
+ if isinstance(inputs, dict):
372
+ _inputs: list[InputValue] = []
373
+ for k, v in inputs.items():
374
+ param, label = split_param_label(k)
375
+ assert param
376
+ path = None
377
+ if "." in param:
378
+ param, path = param.split(".")
379
+ assert param is not None
380
+ _inputs.append(
381
+ cls._app.InputValue(parameter=param, label=label, path=path, value=v)
382
+ )
383
+ return _inputs
384
+ else:
385
+ return inputs
386
+
387
+ @classmethod
388
+ def __decode_repeats(
389
+ cls, repeats: list[RepeatsDescriptor] | int
390
+ ) -> list[RepeatsDescriptor]:
391
+ # support repeats as an int:
392
+ if isinstance(repeats, int):
393
+ return [
394
+ {
395
+ "name": "",
396
+ "number": repeats,
397
+ "nesting_order": 0.0,
398
+ }
399
+ ]
400
+ else:
401
+ return repeats
402
+
403
+ _ALLOWED_NESTING_PATHS: ClassVar[frozenset[str]] = frozenset(
404
+ {"inputs", "resources", "repeats"}
405
+ )
406
+
407
+ def _validate(self) -> None:
408
+ # check `nesting_order` paths:
409
+ for k in self.nesting_order:
410
+ if k.split(".")[0] not in self._ALLOWED_NESTING_PATHS:
411
+ raise MalformedNestingOrderPath(k, self._ALLOWED_NESTING_PATHS)
412
+
413
+ inp_paths = [in_.normalised_inputs_path for in_ in self.inputs]
414
+ if dup_paths := get_duplicate_items(inp_paths):
415
+ raise TaskTemplateMultipleInputValues(
416
+ f"The following inputs parameters are associated with multiple input value "
417
+ f"definitions: {dup_paths!r}."
418
+ )
419
+
420
+ inp_seq_paths = [
421
+ cast("str", seq.normalised_inputs_path)
422
+ for seq in self.sequences
423
+ if seq.input_type
424
+ ]
425
+ if dup_paths := get_duplicate_items(inp_seq_paths):
426
+ raise TaskTemplateMultipleInputValues(
427
+ f"The following input parameters are associated with multiple sequence "
428
+ f"value definitions: {dup_paths!r}."
429
+ )
430
+
431
+ if inp_and_seq := set(inp_paths).intersection(inp_seq_paths):
432
+ raise TaskTemplateMultipleInputValues(
433
+ f"The following input parameters are specified in both the `inputs` and "
434
+ f"`sequences` lists: {list(inp_and_seq)!r}, but must be specified in at "
435
+ f"most one of these."
436
+ )
437
+
438
+ for src_key, sources in self.input_sources.items():
439
+ if not sources:
440
+ raise ValueError(
441
+ f"If specified in `input_sources`, at least one input source must be "
442
+ f"provided for parameter {src_key!r}."
443
+ )
444
+
445
+ # disallow both `env_preset` and `environments` specifications:
446
+ if self.env_preset and self.environments:
447
+ raise ValueError("Specify at most one of `env_preset` and `environments`.")
448
+
449
+ def __validate_against_template(self) -> None:
450
+ expected_types = self.task_template.all_schema_input_types
451
+ if unexpected_types := set(self.input_types) - expected_types:
452
+ raise TaskTemplateUnexpectedInput(unexpected_types)
453
+
454
+ defined_inp_types = set(self.input_types)
455
+ for seq in self.sequences:
456
+ if inp_type := seq.labelled_type:
457
+ if inp_type not in expected_types:
458
+ raise TaskTemplateUnexpectedSequenceInput(
459
+ inp_type, expected_types, seq
460
+ )
461
+ defined_inp_types.add(inp_type)
462
+ if seq.path not in self.nesting_order and seq.nesting_order is not None:
463
+ self.nesting_order[seq.path] = seq.nesting_order
464
+
465
+ for rep_spec in self.repeats:
466
+ if (reps_path_i := f'repeats.{rep_spec["name"]}') not in self.nesting_order:
467
+ self.nesting_order[reps_path_i] = rep_spec["nesting_order"]
468
+
469
+ for k, v in self.nesting_order.items():
470
+ if v < 0:
471
+ raise TaskTemplateInvalidNesting(k, v)
472
+
473
+ self._defined_input_types = defined_inp_types
474
+
475
+ @classmethod
476
+ def ensure_element_sets(
477
+ cls,
478
+ inputs: list[InputValue] | dict[str, Any] | None = None,
479
+ input_files: list[InputFile] | None = None,
480
+ sequences: list[ValueSequence] | None = None,
481
+ multi_path_sequences: list[MultiPathSequence] | None = None,
482
+ resources: Resources = None,
483
+ repeats: list[RepeatsDescriptor] | int | None = None,
484
+ groups: list[ElementGroup] | None = None,
485
+ input_sources: dict[str, list[InputSource]] | None = None,
486
+ nesting_order: dict[str, float] | None = None,
487
+ env_preset: str | None = None,
488
+ environments: Mapping[str, Mapping[str, Any]] | None = None,
489
+ allow_non_coincident_task_sources: bool = False,
490
+ element_sets: list[Self] | None = None,
491
+ sourceable_elem_iters: list[int] | None = None,
492
+ ) -> list[Self]:
493
+ """
494
+ Make an instance after validating some argument combinations.
495
+ """
496
+ args = (
497
+ inputs,
498
+ input_files,
499
+ sequences,
500
+ multi_path_sequences,
501
+ resources,
502
+ repeats,
503
+ groups,
504
+ input_sources,
505
+ nesting_order,
506
+ env_preset,
507
+ environments,
508
+ )
509
+
510
+ if any(arg is not None for arg in args):
511
+ if element_sets is not None:
512
+ raise ValueError(
513
+ "If providing an `element_set`, no other arguments are allowed."
514
+ )
515
+ element_sets = [
516
+ cls(
517
+ *args,
518
+ sourceable_elem_iters=sourceable_elem_iters,
519
+ allow_non_coincident_task_sources=allow_non_coincident_task_sources,
520
+ )
521
+ ]
522
+ else:
523
+ if element_sets is None:
524
+ element_sets = [
525
+ cls(
526
+ *args,
527
+ sourceable_elem_iters=sourceable_elem_iters,
528
+ allow_non_coincident_task_sources=allow_non_coincident_task_sources,
529
+ )
530
+ ]
531
+
532
+ return element_sets
533
+
534
+ @property
535
+ def defined_input_types(self) -> set[str]:
536
+ """
537
+ The input types to this element set.
538
+ """
539
+ assert self._defined_input_types is not None
540
+ return self._defined_input_types
541
+
542
+ @property
543
+ def undefined_input_types(self) -> set[str]:
544
+ """
545
+ The input types to the abstract task that aren't related to this element set.
546
+ """
547
+ return self.task_template.all_schema_input_types - self.defined_input_types
548
+
549
+ def get_sequence_from_path(self, sequence_path: str) -> ValueSequence | None:
550
+ """
551
+ Get the value sequence for the given path, if it exists.
552
+ """
553
+ return next((seq for seq in self.sequences if seq.path == sequence_path), None)
554
+
555
+ def get_defined_parameter_types(self) -> list[str]:
556
+ """
557
+ Get the parameter types of this element set.
558
+ """
559
+ out: list[str] = []
560
+ for inp in self.inputs:
561
+ if not inp.is_sub_value:
562
+ out.append(inp.normalised_inputs_path)
563
+ for seq in self.sequences:
564
+ if seq.parameter and not seq.is_sub_value: # ignore resource sequences
565
+ assert seq.normalised_inputs_path is not None
566
+ out.append(seq.normalised_inputs_path)
567
+ return out
568
+
569
+ def get_defined_sub_parameter_types(self) -> list[str]:
570
+ """
571
+ Get the sub-parameter types of this element set.
572
+ """
573
+ out: list[str] = []
574
+ for inp in self.inputs:
575
+ if inp.is_sub_value:
576
+ out.append(inp.normalised_inputs_path)
577
+ for seq in self.sequences:
578
+ if seq.parameter and seq.is_sub_value: # ignore resource sequences
579
+ assert seq.normalised_inputs_path is not None
580
+ out.append(seq.normalised_inputs_path)
581
+ return out
582
+
583
+ def get_locally_defined_inputs(self) -> list[str]:
584
+ """
585
+ Get the input types that this element set defines.
586
+ """
587
+ return self.get_defined_parameter_types() + self.get_defined_sub_parameter_types()
588
+
589
+ @property
590
+ def index(self) -> int | None:
591
+ """
592
+ The index of this element set in its' template task's collection of sets.
593
+ """
594
+ return next(
595
+ (
596
+ idx
597
+ for idx, element_set in enumerate(self.task_template.element_sets)
598
+ if element_set is self
599
+ ),
600
+ None,
601
+ )
602
+
603
+ @property
604
+ def task(self) -> WorkflowTask:
605
+ """
606
+ The concrete task corresponding to this element set.
607
+ """
608
+ t = self.task_template.workflow_template
609
+ assert t
610
+ w = t.workflow
611
+ assert w
612
+ i = self.task_template.index
613
+ assert i is not None
614
+ return w.tasks[i]
615
+
616
+ @property
617
+ def elements(self) -> list[Element]:
618
+ """
619
+ The elements in this element set.
620
+ """
621
+ return self.task.elements[slice(*self.element_local_idx_range)]
622
+
623
+ @property
624
+ def element_iterations(self) -> list[ElementIteration]:
625
+ """
626
+ The iterations in this element set.
627
+ """
628
+ return list(chain.from_iterable(elem.iterations for elem in self.elements))
629
+
630
+ @property
631
+ def elem_iter_IDs(self) -> list[int]:
632
+ """
633
+ The IDs of the iterations in this element set.
634
+ """
635
+ return [it.id_ for it in self.element_iterations]
636
+
637
+ @overload
638
+ def get_task_dependencies(self, as_objects: Literal[False] = False) -> set[int]: ...
639
+
640
+ @overload
641
+ def get_task_dependencies(self, as_objects: Literal[True]) -> list[WorkflowTask]: ...
642
+
643
+ def get_task_dependencies(
644
+ self, as_objects: bool = False
645
+ ) -> list[WorkflowTask] | set[int]:
646
+ """Get upstream tasks that this element set depends on."""
647
+ deps: set[int] = set()
648
+ for element in self.elements:
649
+ deps.update(element.get_task_dependencies())
650
+ if as_objects:
651
+ return [self.task.workflow.tasks.get(insert_ID=id_) for id_ in sorted(deps)]
652
+ return deps
653
+
654
+ def is_input_type_provided(self, labelled_path: str) -> bool:
655
+ """Check if an input is provided locally as an InputValue or a ValueSequence."""
656
+ return any(
657
+ labelled_path == inp.normalised_inputs_path for inp in self.inputs
658
+ ) or any(
659
+ seq.parameter
660
+ # i.e. not a resource:
661
+ and labelled_path == seq.normalised_inputs_path
662
+ for seq in self.sequences
663
+ )
664
+
665
+
666
+ @hydrate
667
+ class OutputLabel(JSONLike):
668
+ """
669
+ Schema input labels that should be applied to a subset of task outputs.
670
+
671
+ Parameters
672
+ ----------
673
+ parameter:
674
+ Name of a parameter.
675
+ label:
676
+ Label to apply to the parameter.
677
+ where: ~hpcflow.app.ElementFilter
678
+ Optional filtering rule
679
+ """
680
+
681
+ _child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
682
+ ChildObjectSpec(
683
+ name="where",
684
+ class_name="ElementFilter",
685
+ ),
686
+ )
687
+
688
+ def __init__(
689
+ self,
690
+ parameter: str,
691
+ label: str,
692
+ where: Rule | None = None,
693
+ ) -> None:
694
+ #: Name of a parameter.
695
+ self.parameter = parameter
696
+ #: Label to apply to the parameter.
697
+ self.label = label
698
+ #: Filtering rule.
699
+ self.where = where
700
+
701
+
702
+ @hydrate
703
+ class Task(JSONLike):
704
+ """
705
+ Parametrisation of an isolated task for which a subset of input values are given
706
+ "locally". The remaining input values are expected to be satisfied by other
707
+ tasks/imports in the workflow.
708
+
709
+ Parameters
710
+ ----------
711
+ schema: ~hpcflow.app.TaskSchema | list[~hpcflow.app.TaskSchema]
712
+ A (list of) `TaskSchema` object(s) and/or a (list of) strings that are task
713
+ schema names that uniquely identify a task schema. If strings are provided,
714
+ the `TaskSchema` object will be fetched from the known task schemas loaded by
715
+ the app configuration.
716
+ repeats: list[dict]
717
+ groups: list[~hpcflow.app.ElementGroup]
718
+ resources: dict
719
+ inputs: list[~hpcflow.app.InputValue]
720
+ A list of `InputValue` objects.
721
+ input_files: list[~hpcflow.app.InputFile]
722
+ sequences: list[~hpcflow.app.ValueSequence]
723
+ Input value sequences to parameterise over.
724
+ multi_path_sequences: list[~hpcflow.app.MultiPathSequence]
725
+ Multi-path sequences to parameterise over.
726
+ input_sources: dict[str, ~hpcflow.app.InputSource]
727
+ nesting_order: list
728
+ env_preset: str
729
+ environments: dict[str, dict]
730
+ allow_non_coincident_task_sources: bool
731
+ If True, if more than one parameter is sourced from the same task, then allow
732
+ these sources to come from distinct element sub-sets. If False (default),
733
+ only the intersection of element sub-sets for all parameters are included.
734
+ element_sets: list[ElementSet]
735
+ output_labels: list[OutputLabel]
736
+ sourceable_elem_iters: list[int]
737
+ merge_envs: bool
738
+ If True, merge environment presets (set via the element set `env_preset` key)
739
+ into `resources` using the "any" scope. If False, these presets are ignored.
740
+ This is required on first initialisation, but not on subsequent
741
+ re-initialisation from a persistent workflow.
742
+ """
743
+
744
+ _child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
745
+ ChildObjectSpec(
746
+ name="schema",
747
+ class_name="TaskSchema",
748
+ is_multiple=True,
749
+ shared_data_name="task_schemas",
750
+ shared_data_primary_key="name",
751
+ parent_ref="_task_template",
752
+ ),
753
+ ChildObjectSpec(
754
+ name="element_sets",
755
+ class_name="ElementSet",
756
+ is_multiple=True,
757
+ parent_ref="task_template",
758
+ ),
759
+ ChildObjectSpec(
760
+ name="output_labels",
761
+ class_name="OutputLabel",
762
+ is_multiple=True,
763
+ ),
764
+ )
765
+
766
+ @classmethod
767
+ def __is_TaskSchema(cls, value) -> TypeIs[TaskSchema]:
768
+ return isinstance(value, cls._app.TaskSchema)
769
+
770
+ def __init__(
771
+ self,
772
+ schema: TaskSchema | str | list[TaskSchema] | list[str],
773
+ repeats: list[RepeatsDescriptor] | int | None = None,
774
+ groups: list[ElementGroup] | None = None,
775
+ resources: Resources = None,
776
+ inputs: list[InputValue] | dict[str, Any] | None = None,
777
+ input_files: list[InputFile] | None = None,
778
+ sequences: list[ValueSequence] | None = None,
779
+ multi_path_sequences: list[MultiPathSequence] | None = None,
780
+ input_sources: dict[str, list[InputSource]] | None = None,
781
+ nesting_order: dict[str, float] | None = None,
782
+ env_preset: str | None = None,
783
+ environments: Mapping[str, Mapping[str, Any]] | None = None,
784
+ allow_non_coincident_task_sources: bool = False,
785
+ element_sets: list[ElementSet] | None = None,
786
+ output_labels: list[OutputLabel] | None = None,
787
+ sourceable_elem_iters: list[int] | None = None,
788
+ merge_envs: bool = True,
789
+ ):
790
+ # TODO: allow init via specifying objective and/or method and/or implementation
791
+ # (lists of) strs e.g.: Task(
792
+ # objective='simulate_VE_loading',
793
+ # method=['CP_FFT', 'taylor'],
794
+ # implementation=['damask', 'damask']
795
+ # )
796
+ # where method and impl must be single strings of lists of the same length
797
+ # and method/impl are optional/required only if necessary to disambiguate
798
+ #
799
+ # this would be like Task(schemas=[
800
+ # 'simulate_VE_loading_CP_FFT_damask',
801
+ # 'simulate_VE_loading_taylor_damask'
802
+ # ])
803
+
804
+ _schemas: list[TaskSchema] = []
805
+ for item in schema if isinstance(schema, list) else [schema]:
806
+ if isinstance(item, str):
807
+ try:
808
+ _schemas.append(
809
+ self._app.TaskSchema.get_by_key(item)
810
+ ) # TODO: document that we need to use the actual app instance here?
811
+ continue
812
+ except KeyError:
813
+ raise KeyError(f"TaskSchema {item!r} not found.")
814
+ elif self.__is_TaskSchema(item):
815
+ _schemas.append(item)
816
+ else:
817
+ raise TypeError(f"Not a TaskSchema object: {item!r}")
818
+
819
+ self._schemas = _schemas
820
+
821
+ self._element_sets = self._app.ElementSet.ensure_element_sets(
822
+ inputs=inputs,
823
+ input_files=input_files,
824
+ sequences=sequences,
825
+ multi_path_sequences=multi_path_sequences,
826
+ resources=resources,
827
+ repeats=repeats,
828
+ groups=groups,
829
+ input_sources=input_sources,
830
+ nesting_order=nesting_order,
831
+ env_preset=env_preset,
832
+ environments=environments,
833
+ element_sets=element_sets,
834
+ allow_non_coincident_task_sources=allow_non_coincident_task_sources,
835
+ sourceable_elem_iters=sourceable_elem_iters,
836
+ )
837
+ self._output_labels = output_labels or []
838
+ #: Whether to merge ``environments`` into ``resources`` using the "any" scope
839
+ #: on first initialisation.
840
+ self.merge_envs = merge_envs
841
+ self.__groups: AppDataList[ElementGroup] = AppDataList(
842
+ groups or [], access_attribute="name"
843
+ )
844
+
845
+ # appended to when new element sets are added and reset on dump to disk:
846
+ self._pending_element_sets: list[ElementSet] = []
847
+
848
+ self._validate()
849
+ self._name = self.__get_name()
850
+
851
+ #: The template workflow that this task is within.
852
+ self.workflow_template: WorkflowTemplate | None = (
853
+ None # assigned by parent WorkflowTemplate
854
+ )
855
+ self._insert_ID: int | None = None
856
+ self._dir_name: str | None = None
857
+
858
+ if self.merge_envs:
859
+ self.__merge_envs_into_resources()
860
+
861
+ # TODO: consider adding a new element_set; will need to merge new environments?
862
+
863
+ self._set_parent_refs({"schema": "schemas"})
864
+
865
+ def __merge_envs_into_resources(self) -> None:
866
+ # for each element set, merge `env_preset` into `resources` (this mutates
867
+ # `resources`, and should only happen on creation of the task, not
868
+ # re-initialisation from a persistent workflow):
869
+ self.merge_envs = False
870
+
871
+ # TODO: required so we don't raise below; can be removed once we consider multiple
872
+ # schemas:
873
+ for es in self.element_sets:
874
+ if es.env_preset or any(seq.path == "env_preset" for seq in es.sequences):
875
+ break
876
+ else:
877
+ # No presets
878
+ return
879
+
880
+ try:
881
+ env_presets = self.schema.environment_presets
882
+ except ValueError as e:
883
+ # TODO: consider multiple schemas
884
+ raise NotImplementedError(
885
+ "Cannot merge environment presets into a task with multiple schemas."
886
+ ) from e
887
+
888
+ for es in self.element_sets:
889
+ if es.env_preset:
890
+ # retrieve env specifiers from presets defined in the schema:
891
+ try:
892
+ env_specs = env_presets[es.env_preset] # type: ignore[index]
893
+ except (TypeError, KeyError):
894
+ raise UnknownEnvironmentPresetError(es.env_preset, self.schema.name)
895
+ es.resources.merge_one(
896
+ self._app.ResourceSpec(scope="any", environments=env_specs)
897
+ )
898
+
899
+ for seq in es.sequences:
900
+ if seq.path == "env_preset":
901
+ # change to a resources path:
902
+ seq.path = "resources.any.environments"
903
+ _values = []
904
+ for val in seq.values or ():
905
+ try:
906
+ _values.append(env_presets[val]) # type: ignore[index]
907
+ except (TypeError, KeyError) as e:
908
+ raise UnknownEnvironmentPresetError(
909
+ val, self.schema.name
910
+ ) from e
911
+ seq._values = _values
912
+
913
+ def _reset_pending_element_sets(self) -> None:
914
+ self._pending_element_sets = []
915
+
916
+ def _accept_pending_element_sets(self) -> None:
917
+ self._element_sets += self._pending_element_sets
918
+ self._reset_pending_element_sets()
919
+
920
+ def __eq__(self, other: Any) -> bool:
921
+ if not isinstance(other, self.__class__):
922
+ return False
923
+ return self.to_dict() == other.to_dict()
924
+
925
+ def _add_element_set(self, element_set: ElementSet):
926
+ """Invoked by WorkflowTask._add_element_set."""
927
+ self._pending_element_sets.append(element_set)
928
+ wt = self.workflow_template
929
+ assert wt
930
+ w = wt.workflow
931
+ assert w
932
+ w._store.add_element_set(
933
+ self.insert_ID, cast("Mapping", element_set.to_json_like()[0])
934
+ )
935
+
936
+ @classmethod
937
+ def _json_like_constructor(cls, json_like: dict) -> Self:
938
+ """Invoked by `JSONLike.from_json_like` instead of `__init__`."""
939
+ insert_ID = json_like.pop("insert_ID", None)
940
+ dir_name = json_like.pop("dir_name", None)
941
+ obj = cls(**json_like)
942
+ obj._insert_ID = insert_ID
943
+ obj._dir_name = dir_name
944
+ return obj
945
+
946
+ def __repr__(self) -> str:
947
+ return f"{self.__class__.__name__}(name={self.name!r})"
948
+
949
+ def __deepcopy__(self, memo: dict[int, Any] | None) -> Self:
950
+ kwargs = self.to_dict()
951
+ _insert_ID = kwargs.pop("insert_ID")
952
+ _dir_name = kwargs.pop("dir_name")
953
+ # _pending_element_sets = kwargs.pop("pending_element_sets")
954
+ obj = self.__class__(**copy.deepcopy(kwargs, memo))
955
+ obj._insert_ID = _insert_ID
956
+ obj._dir_name = _dir_name
957
+ obj._name = self._name
958
+ obj.workflow_template = self.workflow_template
959
+ obj._pending_element_sets = self._pending_element_sets
960
+ return obj
961
+
962
+ def to_persistent(
963
+ self, workflow: Workflow, insert_ID: int
964
+ ) -> tuple[Self, list[int | list[int]]]:
965
+ """Return a copy where any schema input defaults are saved to a persistent
966
+ workflow. Element set data is not made persistent."""
967
+
968
+ obj = copy.deepcopy(self)
969
+ source: ParamSource = {"type": "default_input", "task_insert_ID": insert_ID}
970
+ new_refs = list(
971
+ chain.from_iterable(
972
+ schema.make_persistent(workflow, source) for schema in obj.schemas
973
+ )
974
+ )
975
+
976
+ return obj, new_refs
977
+
978
+ @override
979
+ def _postprocess_to_dict(self, d: dict[str, Any]) -> dict[str, Any]:
980
+ out = super()._postprocess_to_dict(d)
981
+ out["_schema"] = out.pop("_schemas")
982
+ res = {
983
+ k.lstrip("_"): v
984
+ for k, v in out.items()
985
+ if k not in ("_name", "_pending_element_sets", "_Task__groups")
986
+ }
987
+ return res
988
+
989
+ def set_sequence_parameters(self, element_set: ElementSet) -> None:
990
+ """
991
+ Set up parameters parsed by value sequences.
992
+ """
993
+ # set ValueSequence Parameter objects:
994
+ for seq in element_set.sequences:
995
+ if seq.input_type:
996
+ for schema_i in self.schemas:
997
+ for inp_j in schema_i.inputs:
998
+ if inp_j.typ == seq.input_type:
999
+ seq._parameter = inp_j.parameter
1000
+
1001
+ def _validate(self) -> None:
1002
+ # TODO: check a nesting order specified for each sequence?
1003
+
1004
+ if len(names := set(schema.objective.name for schema in self.schemas)) > 1:
1005
+ raise TaskTemplateMultipleSchemaObjectives(names)
1006
+
1007
+ def __get_name(self) -> str:
1008
+ out = self.objective.name
1009
+ for idx, schema_i in enumerate(self.schemas, start=1):
1010
+ need_and = idx < len(self.schemas) and (
1011
+ self.schemas[idx].method or self.schemas[idx].implementation
1012
+ )
1013
+ out += (
1014
+ f"{f'_{schema_i.method}' if schema_i.method else ''}"
1015
+ f"{f'_{schema_i.implementation}' if schema_i.implementation else ''}"
1016
+ f"{f'_and' if need_and else ''}"
1017
+ )
1018
+ return out
1019
+
1020
+ @staticmethod
1021
+ def get_task_unique_names(tasks: list[Task]) -> Sequence[str]:
1022
+ """Get the unique name of each in a list of tasks.
1023
+
1024
+ Returns
1025
+ -------
1026
+ list of str
1027
+ """
1028
+
1029
+ task_name_rep_idx = get_item_repeat_index(
1030
+ tasks,
1031
+ item_callable=lambda x: x.name,
1032
+ distinguish_singular=True,
1033
+ )
1034
+
1035
+ return [
1036
+ (
1037
+ f"{task.name}_{task_name_rep_idx[idx]}"
1038
+ if task_name_rep_idx[idx] > 0
1039
+ else task.name
1040
+ )
1041
+ for idx, task in enumerate(tasks)
1042
+ ]
1043
+
1044
+ @TimeIt.decorator
1045
+ def _prepare_persistent_outputs(
1046
+ self, workflow: Workflow, local_element_idx_range: Sequence[int]
1047
+ ) -> Mapping[str, Sequence[int]]:
1048
+ # TODO: check that schema is present when adding task? (should this be here?)
1049
+
1050
+ # allocate schema-level output parameter; precise EAR index will not be known
1051
+ # until we initialise EARs:
1052
+ output_data_indices: dict[str, list[int]] = {}
1053
+ for schema in self.schemas:
1054
+ for output in schema.outputs:
1055
+ # TODO: consider multiple schemas in action index?
1056
+
1057
+ path = f"outputs.{output.typ}"
1058
+ output_data_indices[path] = [
1059
+ # iteration_idx, action_idx, and EAR_idx are not known until
1060
+ # `initialise_EARs`:
1061
+ workflow._add_unset_parameter_data(
1062
+ {
1063
+ "type": "EAR_output",
1064
+ # "task_insert_ID": self.insert_ID,
1065
+ # "element_idx": idx,
1066
+ # "run_idx": 0,
1067
+ }
1068
+ )
1069
+ for idx in range(*local_element_idx_range)
1070
+ ]
1071
+
1072
+ return output_data_indices
1073
+
1074
+ def prepare_element_resolution(
1075
+ self, element_set: ElementSet, input_data_indices: Mapping[str, Sequence]
1076
+ ) -> list[MultiplicityDescriptor]:
1077
+ """
1078
+ Set up the resolution of details of elements
1079
+ (especially multiplicities and how iterations are nested)
1080
+ within an element set.
1081
+ """
1082
+ multiplicities: list[MultiplicityDescriptor] = [
1083
+ {
1084
+ "multiplicity": len(inp_idx_i),
1085
+ "nesting_order": element_set.nesting_order.get(path_i, -1.0),
1086
+ "path": path_i,
1087
+ }
1088
+ for path_i, inp_idx_i in input_data_indices.items()
1089
+ ]
1090
+
1091
+ # if all inputs with non-unit multiplicity have the same multiplicity and a
1092
+ # default nesting order of -1 or 0 (which will have probably been set by a
1093
+ # `ValueSequence` default), set the non-unit multiplicity inputs to a nesting
1094
+ # order of zero:
1095
+ non_unit_multis: dict[int, int] = {}
1096
+ unit_multis: list[int] = []
1097
+ change = True
1098
+ for idx, descriptor in enumerate(multiplicities):
1099
+ if descriptor["multiplicity"] == 1:
1100
+ unit_multis.append(idx)
1101
+ elif descriptor["nesting_order"] in (-1.0, 0.0):
1102
+ non_unit_multis[idx] = descriptor["multiplicity"]
1103
+ else:
1104
+ change = False
1105
+ break
1106
+
1107
+ if change and len(set(non_unit_multis.values())) == 1:
1108
+ for i_idx in non_unit_multis:
1109
+ multiplicities[i_idx]["nesting_order"] = 0
1110
+
1111
+ return multiplicities
1112
+
1113
+ @property
1114
+ def index(self) -> int | None:
1115
+ """
1116
+ The index of this task within the workflow's tasks.
1117
+ """
1118
+ if self.workflow_template:
1119
+ return self.workflow_template.tasks.index(self)
1120
+ else:
1121
+ return None
1122
+
1123
+ @property
1124
+ def output_labels(self) -> Sequence[OutputLabel]:
1125
+ """
1126
+ The labels on the outputs of the task.
1127
+ """
1128
+ return self._output_labels
1129
+
1130
+ @property
1131
+ def _element_indices(self) -> list[int] | None:
1132
+ if (
1133
+ self.workflow_template
1134
+ and self.workflow_template.workflow
1135
+ and self.index is not None
1136
+ ):
1137
+ task = self.workflow_template.workflow.tasks[self.index]
1138
+ return [element._index for element in task.elements]
1139
+ return None
1140
+
1141
+ def __get_task_source_element_iters(
1142
+ self, in_or_out: str, src_task: Task, labelled_path: str, element_set: ElementSet
1143
+ ) -> list[int]:
1144
+ """Get a sorted list of element iteration IDs that provide either inputs or
1145
+ outputs from the provided source task."""
1146
+
1147
+ if in_or_out == "input":
1148
+ # input parameter might not be provided e.g. if it is only used
1149
+ # to generate an input file, and that input file is passed
1150
+ # directly, so consider only source task element sets that
1151
+ # provide the input locally:
1152
+ es_idx = src_task.get_param_provided_element_sets(labelled_path)
1153
+ for es_i in src_task.element_sets:
1154
+ # add any element set that has task sources for this parameter
1155
+ es_i_idx = es_i.index
1156
+ if (
1157
+ es_i_idx is not None
1158
+ and es_i_idx not in es_idx
1159
+ and any(
1160
+ inp_src_i.source_type is InputSourceType.TASK
1161
+ for inp_src_i in es_i.input_sources.get(labelled_path, ())
1162
+ )
1163
+ ):
1164
+ es_idx.append(es_i_idx)
1165
+ else:
1166
+ # outputs are always available, so consider all source task
1167
+ # element sets:
1168
+ es_idx = list(range(src_task.num_element_sets))
1169
+
1170
+ if not es_idx:
1171
+ raise NoAvailableElementSetsError()
1172
+
1173
+ src_elem_iters: list[int] = []
1174
+ for es_i_idx in es_idx:
1175
+ es_i = src_task.element_sets[es_i_idx]
1176
+ src_elem_iters.extend(es_i.elem_iter_IDs) # should be sorted already
1177
+
1178
+ if element_set.sourceable_elem_iters is not None:
1179
+ # can only use a subset of element iterations (this is the
1180
+ # case where this element set is generated from an upstream
1181
+ # element set, in which case we only want to consider newly
1182
+ # added upstream elements when adding elements from this
1183
+ # element set):
1184
+ src_elem_iters = sorted(
1185
+ set(element_set.sourceable_elem_iters).intersection(src_elem_iters)
1186
+ )
1187
+
1188
+ return src_elem_iters
1189
+
1190
+ @staticmethod
1191
+ def __get_common_path(labelled_path: str, inputs_path: str) -> str | None:
1192
+ lab_s = labelled_path.split(".")
1193
+ inp_s = inputs_path.split(".")
1194
+ try:
1195
+ get_relative_path(lab_s, inp_s)
1196
+ return labelled_path
1197
+ except ValueError:
1198
+ pass
1199
+ try:
1200
+ get_relative_path(inp_s, lab_s)
1201
+ return inputs_path
1202
+ except ValueError:
1203
+ # no intersection between paths
1204
+ return None
1205
+
1206
+ @staticmethod
1207
+ def __filtered_iters(wk_task: WorkflowTask, where: Rule) -> list[int]:
1208
+ param_path = cast("str", where.path)
1209
+ param_prefix, param_name, *param_tail = param_path.split(".")
1210
+ src_elem_iters: list[int] = []
1211
+
1212
+ for elem in wk_task.elements:
1213
+ params: EPP = getattr(elem, param_prefix)
1214
+ param: ElementParameter = getattr(params, param_name)
1215
+ param_dat = param.value
1216
+
1217
+ # for remaining paths components try both getattr and
1218
+ # getitem:
1219
+ for path_k in param_tail:
1220
+ try:
1221
+ param_dat = param_dat[path_k]
1222
+ except TypeError:
1223
+ param_dat = getattr(param_dat, path_k)
1224
+
1225
+ if where._valida_check(param_dat):
1226
+ src_elem_iters.append(elem.iterations[0].id_)
1227
+
1228
+ return src_elem_iters
1229
+
1230
+ def __get_task_out_available_src_insert_idx(
1231
+ self,
1232
+ param_typ: str,
1233
+ existing_sources: list[InputSource],
1234
+ source_tasks: Sequence[WorkflowTask] = (),
1235
+ ) -> int:
1236
+ """Decide where to place a new task-output input source in the list of available
1237
+ sources.
1238
+
1239
+ Available sources should be ordered by precedence. In general, task-type input
1240
+ sources from tasks closer to the new task should take precedence, and for
1241
+ task-type input sources from a given task, task-output sources should take
1242
+ precedence over task-input sources. However, task-output sources that are further
1243
+ away (more upstream) should also take precedence over task-input source that are
1244
+ closer (more downstream) in the case where the task-input sources do not point to
1245
+ any local inputs.
1246
+
1247
+ For example, consider finding the list of available input sources for task t3,
1248
+ given these tasks and input/output parameters:
1249
+
1250
+ t1: p1 -> p2
1251
+ t2: p2 -> p3
1252
+ t3: p2 -> p4
1253
+
1254
+ There are two possible sources for p2 in task t3: t1(output) and t2(input). If the
1255
+ input source of p2 within task t2 is t1(output) as well, it makes sense for the
1256
+ t1(output) source for p2 in task t3 to take precedence over the t2(input) source.
1257
+ For example, we might define a sequence on task t2, which would mean the input
1258
+ source t2(input) for p2 in t3 would then have multiple values, which could be
1259
+ unexpected, given that only one value was generated by t1. On the other hand, if
1260
+ the user include some local input source for p2 within task t2 (a single value or
1261
+ a sequence, perhaps combined with other sources), then it arguably makes more
1262
+ sense for the t2(input) source to take precedence, because it is explicitly
1263
+ defined.
1264
+
1265
+ """
1266
+ src_tasks = {task.insert_ID: task for task in source_tasks}
1267
+ num_existing = len(existing_sources)
1268
+ new_idx = num_existing
1269
+ for rev_idx, ex_src in enumerate(existing_sources[::-1]):
1270
+
1271
+ # stop once we reach another task-output source or a local source:
1272
+ if (
1273
+ ex_src.task_source_type == TaskSourceType.OUTPUT
1274
+ or ex_src.source_type == InputSourceType.LOCAL
1275
+ ):
1276
+ return num_existing - rev_idx
1277
+
1278
+ elif ex_src.task_source_type == TaskSourceType.INPUT:
1279
+ assert ex_src.task_ref is not None
1280
+ has_local = self._app.InputSource.local() in [
1281
+ src
1282
+ for es in src_tasks[ex_src.task_ref].template.element_sets
1283
+ for src in es.input_sources.get(param_typ, [])
1284
+ ]
1285
+ if has_local:
1286
+ return num_existing - rev_idx
1287
+ else:
1288
+ # new task-output source should be inserted before this task-input
1289
+ # source, even though the task-input source is closer to the new task:
1290
+ new_idx = num_existing - rev_idx - 1
1291
+
1292
+ return new_idx
1293
+
1294
+ @TimeIt.decorator
1295
+ def get_available_task_input_sources(
1296
+ self,
1297
+ element_set: ElementSet,
1298
+ source_tasks: Sequence[WorkflowTask] = (),
1299
+ ) -> Mapping[str, Sequence[InputSource]]:
1300
+ """For each input parameter of this task, generate a list of possible input sources
1301
+ that derive from inputs or outputs of this and other provided tasks.
1302
+
1303
+ Note this only produces a subset of available input sources for each input
1304
+ parameter; other available input sources may exist from workflow imports."""
1305
+
1306
+ # ensure parameters provided by later tasks are added to the available sources
1307
+ # list first, meaning they take precedence when choosing an input source:
1308
+ source_tasks = sorted(source_tasks, key=lambda x: x.index, reverse=True)
1309
+
1310
+ available: dict[str, list[InputSource]] = {}
1311
+ for inputs_path, inp_status in self.get_input_statuses(element_set).items():
1312
+ # local specification takes precedence:
1313
+ if inputs_path in element_set.get_locally_defined_inputs():
1314
+ available.setdefault(inputs_path, []).append(
1315
+ self._app.InputSource.local()
1316
+ )
1317
+
1318
+ # search for task sources:
1319
+ for src_wk_task_i in source_tasks:
1320
+ # ensure we process output types before input types, so they appear in the
1321
+ # available sources list first, meaning they take precedence when choosing
1322
+ # an input source:
1323
+ src_task_i = src_wk_task_i.template
1324
+ for in_or_out, labelled_path in sorted(
1325
+ src_task_i.provides_parameters(),
1326
+ key=lambda x: x[0],
1327
+ reverse=True,
1328
+ ):
1329
+ src_elem_iters: list[int] = []
1330
+ common = self.__get_common_path(labelled_path, inputs_path)
1331
+ if common is not None:
1332
+ avail_src_path = common
1333
+ else:
1334
+ # no intersection between paths
1335
+ inputs_path_label = None
1336
+ out_label = None
1337
+ unlabelled, inputs_path_label = split_param_label(inputs_path)
1338
+ if unlabelled is None:
1339
+ continue
1340
+ try:
1341
+ get_relative_path(
1342
+ unlabelled.split("."), labelled_path.split(".")
1343
+ )
1344
+ avail_src_path = inputs_path
1345
+ except ValueError:
1346
+ continue
1347
+ if not inputs_path_label:
1348
+ continue
1349
+ for out_lab_i in src_task_i.output_labels:
1350
+ if out_lab_i.label == inputs_path_label:
1351
+ out_label = out_lab_i
1352
+
1353
+ # consider output labels
1354
+ if out_label and in_or_out == "output":
1355
+ # find element iteration IDs that match the output label
1356
+ # filter:
1357
+ if out_label.where:
1358
+ src_elem_iters = self.__filtered_iters(
1359
+ src_wk_task_i, out_label.where
1360
+ )
1361
+ else:
1362
+ src_elem_iters = [
1363
+ elem_i.iterations[0].id_
1364
+ for elem_i in src_wk_task_i.elements
1365
+ ]
1366
+
1367
+ if not src_elem_iters:
1368
+ try:
1369
+ src_elem_iters = self.__get_task_source_element_iters(
1370
+ in_or_out=in_or_out,
1371
+ src_task=src_task_i,
1372
+ labelled_path=labelled_path,
1373
+ element_set=element_set,
1374
+ )
1375
+ except NoAvailableElementSetsError:
1376
+ continue
1377
+ if not src_elem_iters:
1378
+ continue
1379
+
1380
+ if in_or_out == "output":
1381
+ insert_idx = self.__get_task_out_available_src_insert_idx(
1382
+ existing_sources=available.get(avail_src_path, []),
1383
+ source_tasks=source_tasks,
1384
+ param_typ=avail_src_path,
1385
+ )
1386
+ else:
1387
+ insert_idx = len(available.get(avail_src_path, []))
1388
+
1389
+ available.setdefault(avail_src_path, []).insert(
1390
+ insert_idx,
1391
+ self._app.InputSource.task(
1392
+ task_ref=src_task_i.insert_ID,
1393
+ task_source_type=in_or_out,
1394
+ element_iters=src_elem_iters,
1395
+ ),
1396
+ )
1397
+
1398
+ if inp_status.has_default:
1399
+ available.setdefault(inputs_path, []).append(
1400
+ self._app.InputSource.default()
1401
+ )
1402
+ return available
1403
+
1404
+ @property
1405
+ def schemas(self) -> list[TaskSchema]:
1406
+ """
1407
+ All the task schemas.
1408
+ """
1409
+ return self._schemas
1410
+
1411
+ @property
1412
+ def schema(self) -> TaskSchema:
1413
+ """The single task schema, if only one, else raises."""
1414
+ if len(self._schemas) == 1:
1415
+ return self._schemas[0]
1416
+ else:
1417
+ raise ValueError(
1418
+ "Multiple task schemas are associated with this task. Access the list "
1419
+ "via the `schemas` property."
1420
+ )
1421
+
1422
+ @property
1423
+ def element_sets(self) -> list[ElementSet]:
1424
+ """
1425
+ The element sets.
1426
+ """
1427
+ return self._element_sets + self._pending_element_sets
1428
+
1429
+ @property
1430
+ def num_element_sets(self) -> int:
1431
+ """
1432
+ The number of element sets.
1433
+ """
1434
+ return len(self._element_sets) + len(self._pending_element_sets)
1435
+
1436
+ @property
1437
+ def insert_ID(self) -> int:
1438
+ """
1439
+ Insertion ID.
1440
+ """
1441
+ assert self._insert_ID is not None
1442
+ return self._insert_ID
1443
+
1444
+ @property
1445
+ def dir_name(self) -> str:
1446
+ """
1447
+ Artefact directory name.
1448
+ """
1449
+ assert self._dir_name is not None
1450
+ return self._dir_name
1451
+
1452
+ @property
1453
+ def name(self) -> str:
1454
+ """
1455
+ Task name.
1456
+ """
1457
+ return self._name
1458
+
1459
+ @property
1460
+ def objective(self) -> TaskObjective:
1461
+ """
1462
+ The goal of this task.
1463
+ """
1464
+ obj = self.schemas[0].objective
1465
+ return obj
1466
+
1467
+ @property
1468
+ def all_schema_inputs(self) -> tuple[SchemaInput, ...]:
1469
+ """
1470
+ The inputs to this task's schemas.
1471
+ """
1472
+ return tuple(inp_j for schema_i in self.schemas for inp_j in schema_i.inputs)
1473
+
1474
+ @property
1475
+ def all_schema_outputs(self) -> tuple[SchemaOutput, ...]:
1476
+ """
1477
+ The outputs from this task's schemas.
1478
+ """
1479
+ return tuple(inp_j for schema_i in self.schemas for inp_j in schema_i.outputs)
1480
+
1481
+ @property
1482
+ def all_schema_input_types(self) -> set[str]:
1483
+ """
1484
+ The set of all schema input types (over all specified schemas).
1485
+ """
1486
+ return {inp_j for schema_i in self.schemas for inp_j in schema_i.input_types}
1487
+
1488
+ @property
1489
+ def all_schema_input_normalised_paths(self) -> set[str]:
1490
+ """
1491
+ Normalised paths for all schema input types.
1492
+ """
1493
+ return {f"inputs.{typ}" for typ in self.all_schema_input_types}
1494
+
1495
+ @property
1496
+ def all_schema_output_types(self) -> set[str]:
1497
+ """
1498
+ The set of all schema output types (over all specified schemas).
1499
+ """
1500
+ return {out_j for schema_i in self.schemas for out_j in schema_i.output_types}
1501
+
1502
+ def get_schema_action(self, idx: int) -> Action: #
1503
+ """
1504
+ Get the schema action at the given index.
1505
+ """
1506
+ _idx = 0
1507
+ for schema in self.schemas:
1508
+ for action in schema.actions:
1509
+ if _idx == idx:
1510
+ return action
1511
+ _idx += 1
1512
+ raise ValueError(f"No action in task {self.name!r} with index {idx!r}.")
1513
+
1514
+ def all_schema_actions(self) -> Iterator[tuple[int, Action]]:
1515
+ """
1516
+ Get all the schema actions and their indices.
1517
+ """
1518
+ idx = 0
1519
+ for schema in self.schemas:
1520
+ for action in schema.actions:
1521
+ yield (idx, action)
1522
+ idx += 1
1523
+
1524
+ @property
1525
+ def num_all_schema_actions(self) -> int:
1526
+ """
1527
+ The total number of schema actions.
1528
+ """
1529
+ return sum(len(schema.actions) for schema in self.schemas)
1530
+
1531
+ @property
1532
+ def all_sourced_normalised_paths(self) -> set[str]:
1533
+ """
1534
+ All the sourced normalised paths, including of sub-values.
1535
+ """
1536
+ sourced_input_types: set[str] = set()
1537
+ for elem_set in self.element_sets:
1538
+ sourced_input_types.update(
1539
+ inp.normalised_path for inp in elem_set.inputs if inp.is_sub_value
1540
+ )
1541
+ sourced_input_types.update(
1542
+ seq.normalised_path for seq in elem_set.sequences if seq.is_sub_value
1543
+ )
1544
+ return sourced_input_types | self.all_schema_input_normalised_paths
1545
+
1546
+ def is_input_type_required(self, typ: str, element_set: ElementSet) -> bool:
1547
+ """Check if an given input type must be specified in the parametrisation of this
1548
+ element set.
1549
+
1550
+ A schema input need not be specified if it is only required to generate an input
1551
+ file, and that input file is passed directly."""
1552
+
1553
+ provided_files = {in_file.file for in_file in element_set.input_files}
1554
+ for schema in self.schemas:
1555
+ if not schema.actions:
1556
+ return True # for empty tasks that are used merely for defining inputs
1557
+ if any(
1558
+ act.is_input_type_required(typ, provided_files) for act in schema.actions
1559
+ ):
1560
+ return True
1561
+
1562
+ return False
1563
+
1564
+ def get_param_provided_element_sets(self, labelled_path: str) -> list[int]:
1565
+ """Get the element set indices of this task for which a specified parameter type
1566
+ is locally provided.
1567
+
1568
+ Note
1569
+ ----
1570
+ Caller may freely modify this result.
1571
+ """
1572
+ return [
1573
+ idx
1574
+ for idx, src_es in enumerate(self.element_sets)
1575
+ if src_es.is_input_type_provided(labelled_path)
1576
+ ]
1577
+
1578
+ def get_input_statuses(self, elem_set: ElementSet) -> Mapping[str, InputStatus]:
1579
+ """Get a dict whose keys are normalised input paths (without the "inputs" prefix),
1580
+ and whose values are InputStatus objects.
1581
+
1582
+ Parameters
1583
+ ----------
1584
+ elem_set
1585
+ The element set for which input statuses should be returned.
1586
+ """
1587
+
1588
+ status: dict[str, InputStatus] = {}
1589
+ for schema_input in self.all_schema_inputs:
1590
+ for lab_info in schema_input.labelled_info():
1591
+ labelled_type = lab_info["labelled_type"]
1592
+ status[labelled_type] = InputStatus(
1593
+ has_default="default_value" in lab_info,
1594
+ is_provided=elem_set.is_input_type_provided(labelled_type),
1595
+ is_required=self.is_input_type_required(labelled_type, elem_set),
1596
+ )
1597
+
1598
+ for inp_path in elem_set.get_defined_sub_parameter_types():
1599
+ root_param = inp_path.split(".")[0]
1600
+ # If the root parameter is required then the sub-parameter should also be
1601
+ # required, otherwise there would be no point in specifying it:
1602
+ status[inp_path] = InputStatus(
1603
+ has_default=False,
1604
+ is_provided=True,
1605
+ is_required=status[root_param].is_required,
1606
+ )
1607
+
1608
+ return status
1609
+
1610
+ @property
1611
+ def universal_input_types(self) -> set[str]:
1612
+ """Get input types that are associated with all schemas"""
1613
+ raise NotImplementedError()
1614
+
1615
+ @property
1616
+ def non_universal_input_types(self) -> set[str]:
1617
+ """Get input types for each schema that are non-universal."""
1618
+ raise NotImplementedError()
1619
+
1620
+ @property
1621
+ def defined_input_types(self) -> set[str]:
1622
+ """
1623
+ The input types defined by this task, being the input types defined by any of
1624
+ its element sets.
1625
+ """
1626
+ dit: set[str] = set()
1627
+ for es in self.element_sets:
1628
+ dit.update(es.defined_input_types)
1629
+ return dit
1630
+ # TODO: Is this right?
1631
+
1632
+ @property
1633
+ def undefined_input_types(self) -> set[str]:
1634
+ """
1635
+ The schema's input types that this task doesn't define.
1636
+ """
1637
+ return self.all_schema_input_types - self.defined_input_types
1638
+
1639
+ @property
1640
+ def undefined_inputs(self) -> list[SchemaInput]:
1641
+ """
1642
+ The task's inputs that are undefined.
1643
+ """
1644
+ return [
1645
+ inp_j
1646
+ for schema_i in self.schemas
1647
+ for inp_j in schema_i.inputs
1648
+ if inp_j.typ in self.undefined_input_types
1649
+ ]
1650
+
1651
+ def provides_parameters(self) -> tuple[tuple[str, str], ...]:
1652
+ """Get all provided parameter labelled types and whether they are inputs and
1653
+ outputs, considering all element sets.
1654
+
1655
+ """
1656
+ out: dict[tuple[str, str], None] = {}
1657
+ for schema in self.schemas:
1658
+ out.update(dict.fromkeys(schema.provides_parameters))
1659
+
1660
+ # add sub-parameter input values and sequences:
1661
+ for es_i in self.element_sets:
1662
+ for inp_j in es_i.inputs:
1663
+ if inp_j.is_sub_value:
1664
+ out["input", inp_j.normalised_inputs_path] = None
1665
+ for seq_j in es_i.sequences:
1666
+ if seq_j.is_sub_value and (path := seq_j.normalised_inputs_path):
1667
+ out["input", path] = None
1668
+
1669
+ return tuple(out)
1670
+
1671
+ def add_group(
1672
+ self, name: str, where: ElementFilter, group_by_distinct: ParameterPath
1673
+ ):
1674
+ """
1675
+ Add an element group to this task.
1676
+ """
1677
+ group = ElementGroup(name=name, where=where, group_by_distinct=group_by_distinct)
1678
+ self.__groups.add_object(group)
1679
+
1680
+ def _get_single_label_lookup(self, prefix: str = "") -> Mapping[str, str]:
1681
+ """Get a mapping between schema input types that have a single label (i.e.
1682
+ labelled but with `multiple=False`) and the non-labelled type string.
1683
+
1684
+ For example, if a task schema has a schema input like:
1685
+ `SchemaInput(parameter="p1", labels={"one": {}}, multiple=False)`, this method
1686
+ would return a dict that includes: `{"p1[one]": "p1"}`. If the `prefix` argument
1687
+ is provided, this will be added to map key and value (and a terminating period
1688
+ will be added to the end of the prefix if it does not already end in one). For
1689
+ example, with `prefix="inputs"`, this method might return:
1690
+ `{"inputs.p1[one]": "inputs.p1"}`.
1691
+
1692
+ """
1693
+ lookup: dict[str, str] = {}
1694
+ for schema in self.schemas:
1695
+ lookup.update(schema._get_single_label_lookup(prefix=prefix))
1696
+ return lookup
1697
+
1698
+
1699
+ class _ESIdx(NamedTuple):
1700
+ ordered: list[int]
1701
+ uniq: frozenset[int]
1702
+
1703
+
1704
+ class WorkflowTask(AppAware):
1705
+ """
1706
+ Represents a :py:class:`Task` that is bound to a :py:class:`Workflow`.
1707
+
1708
+ Parameters
1709
+ ----------
1710
+ workflow:
1711
+ The workflow that the task is bound to.
1712
+ template:
1713
+ The task template that this binds.
1714
+ index:
1715
+ Where in the workflow's list of tasks is this one.
1716
+ element_IDs:
1717
+ The IDs of the elements of this task.
1718
+ """
1719
+
1720
+ def __init__(
1721
+ self,
1722
+ workflow: Workflow,
1723
+ template: Task,
1724
+ index: int,
1725
+ element_IDs: list[int],
1726
+ ):
1727
+ self._workflow = workflow
1728
+ self._template = template
1729
+ self._index = index
1730
+ self._element_IDs = element_IDs
1731
+
1732
+ # appended to when new elements are added and reset on dump to disk:
1733
+ self._pending_element_IDs: list[int] = []
1734
+
1735
+ self._elements: Elements | None = None # assigned on `elements` first access
1736
+
1737
+ def __repr__(self) -> str:
1738
+ return f"{self.__class__.__name__}(name={self.unique_name!r})"
1739
+
1740
+ def _reset_pending_element_IDs(self):
1741
+ self._pending_element_IDs = []
1742
+
1743
+ def _accept_pending_element_IDs(self):
1744
+ self._element_IDs += self._pending_element_IDs
1745
+ self._reset_pending_element_IDs()
1746
+
1747
+ @classmethod
1748
+ def new_empty_task(cls, workflow: Workflow, template: Task, index: int) -> Self:
1749
+ """
1750
+ Make a new instance without any elements set up yet.
1751
+
1752
+ Parameters
1753
+ ----------
1754
+ workflow:
1755
+ The workflow that the task is bound to.
1756
+ template:
1757
+ The task template that this binds.
1758
+ index:
1759
+ Where in the workflow's list of tasks is this one.
1760
+ """
1761
+ return cls(
1762
+ workflow=workflow,
1763
+ template=template,
1764
+ index=index,
1765
+ element_IDs=[],
1766
+ )
1767
+
1768
+ @property
1769
+ def workflow(self) -> Workflow:
1770
+ """
1771
+ The workflow this task is bound to.
1772
+ """
1773
+ return self._workflow
1774
+
1775
+ @property
1776
+ def template(self) -> Task:
1777
+ """
1778
+ The template for this task.
1779
+ """
1780
+ return self._template
1781
+
1782
+ @property
1783
+ def index(self) -> int:
1784
+ """
1785
+ The index of this task within its workflow.
1786
+ """
1787
+ return self._index
1788
+
1789
+ @property
1790
+ def element_IDs(self) -> list[int]:
1791
+ """
1792
+ The IDs of elements associated with this task.
1793
+ """
1794
+ return self._element_IDs + self._pending_element_IDs
1795
+
1796
+ @property
1797
+ @TimeIt.decorator
1798
+ def num_elements(self) -> int:
1799
+ """
1800
+ The number of elements associated with this task.
1801
+ """
1802
+ return len(self._element_IDs) + len(self._pending_element_IDs)
1803
+
1804
+ @property
1805
+ def num_actions(self) -> int:
1806
+ """
1807
+ The number of actions in this task.
1808
+ """
1809
+ return self.template.num_all_schema_actions
1810
+
1811
+ @property
1812
+ def name(self) -> str:
1813
+ """
1814
+ The name of this task based on its template.
1815
+ """
1816
+ return self.template.name
1817
+
1818
+ @property
1819
+ def unique_name(self) -> str:
1820
+ """
1821
+ The unique name for this task specifically.
1822
+ """
1823
+ return self.workflow.get_task_unique_names()[self.index]
1824
+
1825
+ @property
1826
+ def insert_ID(self) -> int:
1827
+ """
1828
+ The insertion ID of the template task.
1829
+ """
1830
+ return self.template.insert_ID
1831
+
1832
+ @property
1833
+ def dir_name(self) -> str:
1834
+ """
1835
+ The name of the directory for the task's temporary files.
1836
+ """
1837
+ dn = self.template.dir_name
1838
+ assert dn is not None
1839
+ return dn
1840
+
1841
+ @property
1842
+ def num_element_sets(self) -> int:
1843
+ """
1844
+ The number of element sets associated with this task.
1845
+ """
1846
+ return self.template.num_element_sets
1847
+
1848
+ @property
1849
+ @TimeIt.decorator
1850
+ def elements(self) -> Elements:
1851
+ """
1852
+ The elements associated with this task.
1853
+ """
1854
+ if self._elements is None:
1855
+ self._elements = Elements(self)
1856
+ return self._elements
1857
+
1858
+ def get_dir_name(self, loop_idx: Mapping[str, int] | None = None) -> str:
1859
+ """
1860
+ Get the directory name for a particular iteration.
1861
+ """
1862
+ if not loop_idx:
1863
+ return self.dir_name
1864
+ return self.dir_name + "_" + "_".join((f"{k}-{v}" for k, v in loop_idx.items()))
1865
+
1866
+ def get_all_element_iterations(self) -> Mapping[int, ElementIteration]:
1867
+ """
1868
+ Get the iterations known by the task's elements.
1869
+ """
1870
+ return {itr.id_: itr for elem in self.elements for itr in elem.iterations}
1871
+
1872
+ @staticmethod
1873
+ @TimeIt.decorator
1874
+ def __get_src_elem_iters(
1875
+ src_task: WorkflowTask, inp_src: InputSource
1876
+ ) -> tuple[Iterable[ElementIteration], list[int]]:
1877
+ src_iters = src_task.get_all_element_iterations()
1878
+
1879
+ if inp_src.element_iters:
1880
+ # only include "sourceable" element iterations:
1881
+ src_iters_list = [src_iters[itr_id] for itr_id in inp_src.element_iters]
1882
+ set_indices = [el.element.element_set_idx for el in src_iters.values()]
1883
+ return src_iters_list, set_indices
1884
+ return src_iters.values(), []
1885
+
1886
+ @TimeIt.decorator
1887
+ def __get_task_group_index(
1888
+ self,
1889
+ labelled_path_i: str,
1890
+ inp_src: InputSource,
1891
+ padded_elem_iters: Mapping[str, Sequence[int]],
1892
+ inp_group_name: str | None,
1893
+ ) -> None | Sequence[int | list[int]]:
1894
+ src_task = inp_src.get_task(self.workflow)
1895
+ assert src_task
1896
+ src_elem_iters, src_elem_set_idx = self.__get_src_elem_iters(src_task, inp_src)
1897
+
1898
+ if not src_elem_iters:
1899
+ return None
1900
+
1901
+ task_source_type = inp_src.task_source_type
1902
+ assert task_source_type is not None
1903
+ if task_source_type == TaskSourceType.OUTPUT and "[" in labelled_path_i:
1904
+ src_key = f"{task_source_type.name.lower()}s.{labelled_path_i.split('[')[0]}"
1905
+ else:
1906
+ src_key = f"{task_source_type.name.lower()}s.{labelled_path_i}"
1907
+
1908
+ padded_iters = padded_elem_iters.get(labelled_path_i, [])
1909
+ grp_idx = [
1910
+ (itr.get_data_idx()[src_key] if itr_idx not in padded_iters else -1)
1911
+ for itr_idx, itr in enumerate(src_elem_iters)
1912
+ ]
1913
+
1914
+ if not inp_group_name:
1915
+ return grp_idx
1916
+
1917
+ group_dat_idx: list[int | list[int]] = []
1918
+ element_sets = src_task.template.element_sets
1919
+ for dat_idx, src_set_idx, src_iter in zip(
1920
+ grp_idx, src_elem_set_idx, src_elem_iters
1921
+ ):
1922
+ src_es = element_sets[src_set_idx]
1923
+ if any(inp_group_name == grp.name for grp in src_es.groups):
1924
+ group_dat_idx.append(dat_idx)
1925
+ continue
1926
+ # if for any recursive iteration dependency, this group is
1927
+ # defined, assign:
1928
+ src_iter_deps = self.workflow.get_element_iterations_from_IDs(
1929
+ src_iter.get_element_iteration_dependencies(),
1930
+ )
1931
+
1932
+ if any(
1933
+ inp_group_name == grp.name
1934
+ for el_iter in src_iter_deps
1935
+ for grp in el_iter.element.element_set.groups
1936
+ ):
1937
+ group_dat_idx.append(dat_idx)
1938
+ continue
1939
+
1940
+ # also check input dependencies
1941
+ for p_src in src_iter.element.get_input_dependencies().values():
1942
+ k_es = self.workflow.tasks.get(
1943
+ insert_ID=p_src["task_insert_ID"]
1944
+ ).template.element_sets[p_src["element_set_idx"]]
1945
+ if any(inp_group_name == grp.name for grp in k_es.groups):
1946
+ group_dat_idx.append(dat_idx)
1947
+ break
1948
+
1949
+ # TODO: this only goes to one level of dependency
1950
+
1951
+ if not group_dat_idx:
1952
+ raise MissingElementGroup(self.unique_name, inp_group_name, labelled_path_i)
1953
+
1954
+ return [cast("int", group_dat_idx)] # TODO: generalise to multiple groups
1955
+
1956
+ @TimeIt.decorator
1957
+ def __make_new_elements_persistent(
1958
+ self,
1959
+ element_set: ElementSet,
1960
+ element_set_idx: int,
1961
+ padded_elem_iters: Mapping[str, Sequence[int]],
1962
+ ) -> tuple[
1963
+ dict[str, list[int | list[int]]], dict[str, Sequence[int]], dict[str, list[int]]
1964
+ ]:
1965
+ """Save parameter data to the persistent workflow."""
1966
+
1967
+ # TODO: rewrite. This method is a little hard to follow and results in somewhat
1968
+ # unexpected behaviour: if a local source and task source are requested for a
1969
+ # given input, the local source element(s) will always come first, regardless of
1970
+ # the ordering in element_set.input_sources.
1971
+
1972
+ input_data_idx: dict[str, list[int | list[int]]] = {}
1973
+ sequence_idx: dict[str, Sequence[int]] = {}
1974
+ source_idx: dict[str, list[int]] = {}
1975
+
1976
+ # Assign first assuming all locally defined values are to be used:
1977
+ param_src: ParamSource = {
1978
+ "type": "local_input",
1979
+ "task_insert_ID": self.insert_ID,
1980
+ "element_set_idx": element_set_idx,
1981
+ }
1982
+ loc_inp_src = self._app.InputSource.local()
1983
+ for res_i in element_set.resources:
1984
+ key, dat_ref, _ = res_i.make_persistent(self.workflow, param_src)
1985
+ input_data_idx[key] = list(dat_ref)
1986
+
1987
+ for inp_i in element_set.inputs:
1988
+ key, dat_ref, _ = inp_i.make_persistent(self.workflow, param_src)
1989
+ input_data_idx[key] = list(dat_ref)
1990
+ key_ = key.removeprefix("inputs.")
1991
+ try:
1992
+ # TODO: wouldn't need to do this if we raise when an InputValue is
1993
+ # provided for a parameter whose inputs sources do not include the local
1994
+ # value.
1995
+ source_idx[key] = [element_set.input_sources[key_].index(loc_inp_src)]
1996
+ except ValueError:
1997
+ pass
1998
+
1999
+ for inp_file_i in element_set.input_files:
2000
+ key, input_dat_ref, _ = inp_file_i.make_persistent(self.workflow, param_src)
2001
+ input_data_idx[key] = list(input_dat_ref)
2002
+
2003
+ for seq_i in element_set.sequences:
2004
+ key, seq_dat_ref, _ = seq_i.make_persistent(self.workflow, param_src)
2005
+ input_data_idx[key] = list(seq_dat_ref)
2006
+ sequence_idx[key] = list(range(len(seq_dat_ref)))
2007
+ try:
2008
+ key_ = key.split("inputs.")[1]
2009
+ except IndexError:
2010
+ # e.g. "resources."
2011
+ key_ = ""
2012
+ try:
2013
+ # TODO: wouldn't need to do this if we raise when an ValueSequence is
2014
+ # provided for a parameter whose inputs sources do not include the local
2015
+ # value.
2016
+ if key_:
2017
+ source_idx[key] = [
2018
+ element_set.input_sources[key_].index(loc_inp_src)
2019
+ ] * len(seq_dat_ref)
2020
+ except ValueError:
2021
+ pass
2022
+
2023
+ for rep_spec in element_set.repeats:
2024
+ seq_key = f"repeats.{rep_spec['name']}"
2025
+ num_range = range(rep_spec["number"])
2026
+ input_data_idx[seq_key] = list(num_range)
2027
+ sequence_idx[seq_key] = num_range
2028
+
2029
+ # Now check for task- and default-sources and overwrite or append to local sources:
2030
+ inp_stats = self.template.get_input_statuses(element_set)
2031
+ for labelled_path_i, sources_i in element_set.input_sources.items():
2032
+ if len(path_i_split := labelled_path_i.split(".")) > 1:
2033
+ path_i_root = path_i_split[0]
2034
+ else:
2035
+ path_i_root = labelled_path_i
2036
+ if not inp_stats[path_i_root].is_required:
2037
+ continue
2038
+
2039
+ inp_group_name, def_val = None, None
2040
+ for schema_input in self.template.all_schema_inputs:
2041
+ for lab_info in schema_input.labelled_info():
2042
+ if lab_info["labelled_type"] == path_i_root:
2043
+ inp_group_name = lab_info["group"]
2044
+ if "default_value" in lab_info:
2045
+ def_val = lab_info["default_value"]
2046
+ break
2047
+
2048
+ key = f"inputs.{labelled_path_i}"
2049
+
2050
+ for inp_src_idx, inp_src in enumerate(sources_i):
2051
+ if inp_src.source_type is InputSourceType.TASK:
2052
+ grp_idx = self.__get_task_group_index(
2053
+ labelled_path_i, inp_src, padded_elem_iters, inp_group_name
2054
+ )
2055
+ if grp_idx is None:
2056
+ continue
2057
+
2058
+ if self._app.InputSource.local() in sources_i:
2059
+ # add task source to existing local source:
2060
+ input_data_idx[key].extend(grp_idx)
2061
+ source_idx[key].extend([inp_src_idx] * len(grp_idx))
2062
+
2063
+ else: # BUG: doesn't work for multiple task inputs sources
2064
+ # overwrite existing local source (if it exists):
2065
+ input_data_idx[key] = list(grp_idx)
2066
+ source_idx[key] = [inp_src_idx] * len(grp_idx)
2067
+ if key in sequence_idx:
2068
+ sequence_idx.pop(key)
2069
+ # TODO: Use the value retrieved below?
2070
+ _ = element_set.get_sequence_from_path(key)
2071
+
2072
+ elif inp_src.source_type is InputSourceType.DEFAULT:
2073
+ assert def_val is not None
2074
+ assert def_val._value_group_idx is not None
2075
+ grp_idx_ = def_val._value_group_idx
2076
+ if self._app.InputSource.local() in sources_i:
2077
+ input_data_idx[key].append(grp_idx_)
2078
+ source_idx[key].append(inp_src_idx)
2079
+ else:
2080
+ input_data_idx[key] = [grp_idx_]
2081
+ source_idx[key] = [inp_src_idx]
2082
+
2083
+ # sort smallest to largest path, so more-specific items overwrite less-specific
2084
+ # items in parameter retrieval in `WorkflowTask._get_merged_parameter_data`:
2085
+ input_data_idx = dict(sorted(input_data_idx.items()))
2086
+
2087
+ return (input_data_idx, sequence_idx, source_idx)
2088
+
2089
+ @TimeIt.decorator
2090
+ def ensure_input_sources(
2091
+ self, element_set: ElementSet
2092
+ ) -> Mapping[str, Sequence[int]]:
2093
+ """Check valid input sources are specified for a new task to be added to the
2094
+ workflow in a given position. If none are specified, set them according to the
2095
+ default behaviour.
2096
+
2097
+ This method mutates `element_set.input_sources`.
2098
+
2099
+ """
2100
+
2101
+ # this depends on this schema, other task schemas and inputs/sequences:
2102
+ available_sources = self.template.get_available_task_input_sources(
2103
+ element_set=element_set,
2104
+ source_tasks=self.workflow.tasks[: self.index],
2105
+ )
2106
+
2107
+ if unreq := set(element_set.input_sources).difference(available_sources):
2108
+ raise UnrequiredInputSources(unreq)
2109
+
2110
+ # TODO: get available input sources from workflow imports
2111
+
2112
+ all_stats = self.template.get_input_statuses(element_set)
2113
+
2114
+ # an input is not required if it is only used to generate an input file that is
2115
+ # passed directly:
2116
+ req_types = set(k for k, v in all_stats.items() if v.is_required)
2117
+
2118
+ # check any specified sources are valid, and replace them with those computed in
2119
+ # `available_sources` since these will have `element_iters` assigned:
2120
+ for path_i, avail_i in available_sources.items():
2121
+ # for each sub-path in available sources, if the "root-path" source is
2122
+ # required, then add the sub-path source to `req_types` as well:
2123
+ if len(path_i_split := path_i.split(".")) > 1:
2124
+ if path_i_split[0] in req_types:
2125
+ req_types.add(path_i)
2126
+
2127
+ for s_idx, specified_source in enumerate(
2128
+ element_set.input_sources.get(path_i, [])
2129
+ ):
2130
+ self.workflow._resolve_input_source_task_reference(
2131
+ specified_source, self.unique_name
2132
+ )
2133
+ avail_idx = specified_source.is_in(avail_i)
2134
+ if avail_idx is None:
2135
+ raise UnavailableInputSource(specified_source, path_i, avail_i)
2136
+ available_source: InputSource
2137
+ try:
2138
+ available_source = avail_i[avail_idx]
2139
+ except TypeError:
2140
+ raise UnavailableInputSource(
2141
+ specified_source, path_i, avail_i
2142
+ ) from None
2143
+
2144
+ elem_iters_IDs = available_source.element_iters
2145
+ if specified_source.element_iters:
2146
+ # user-specified iter IDs; these must be a subset of available
2147
+ # element_iters:
2148
+ if not set(specified_source.element_iters).issubset(
2149
+ elem_iters_IDs or ()
2150
+ ):
2151
+ raise InapplicableInputSourceElementIters(
2152
+ specified_source, elem_iters_IDs
2153
+ )
2154
+ elem_iters_IDs = specified_source.element_iters
2155
+
2156
+ if specified_source.where:
2157
+ # filter iter IDs by user-specified rules, maintaining order:
2158
+ elem_iters = self.workflow.get_element_iterations_from_IDs(
2159
+ elem_iters_IDs or ()
2160
+ )
2161
+ elem_iters_IDs = [
2162
+ ei.id_ for ei in specified_source.where.filter(elem_iters)
2163
+ ]
2164
+
2165
+ available_source.element_iters = elem_iters_IDs
2166
+ element_set.input_sources[path_i][s_idx] = available_source
2167
+
2168
+ # sorting ensures that root parameters come before sub-parameters, which is
2169
+ # necessary when considering if we want to include a sub-parameter, when setting
2170
+ # missing sources below:
2171
+ unsourced_inputs = sorted(req_types.difference(element_set.input_sources))
2172
+
2173
+ if extra_types := {k for k, v in all_stats.items() if v.is_extra}:
2174
+ raise ExtraInputs(extra_types)
2175
+
2176
+ # set source for any unsourced inputs:
2177
+ missing: list[str] = []
2178
+ # track which root params we have set according to default behaviour (not
2179
+ # specified by user):
2180
+ set_root_params: set[str] = set()
2181
+ for input_type in unsourced_inputs:
2182
+ input_split = input_type.split(".")
2183
+ has_root_param = input_split[0] if len(input_split) > 1 else None
2184
+ inp_i_sources = available_sources.get(input_type, [])
2185
+
2186
+ source = None
2187
+ try:
2188
+ # first element is defined by default to take precedence in
2189
+ # `get_available_task_input_sources`:
2190
+ source = inp_i_sources[0]
2191
+ except IndexError:
2192
+ missing.append(input_type)
2193
+
2194
+ if source is not None:
2195
+ if has_root_param and has_root_param in set_root_params:
2196
+ # this is a sub-parameter, and the associated root parameter was not
2197
+ # specified by the user either, so we previously set it according to
2198
+ # default behaviour
2199
+ root_src = element_set.input_sources[has_root_param][0]
2200
+ # do not set a default task-input type source for this sub-parameter
2201
+ # if the associated root parameter has a default-set task-output
2202
+ # source from the same task:
2203
+ if (
2204
+ source.source_type is InputSourceType.TASK
2205
+ and source.task_source_type is TaskSourceType.INPUT
2206
+ and root_src.source_type is InputSourceType.TASK
2207
+ and root_src.task_source_type is TaskSourceType.OUTPUT
2208
+ and source.task_ref == root_src.task_ref
2209
+ ):
2210
+ continue
2211
+
2212
+ element_set.input_sources[input_type] = [source]
2213
+ if not has_root_param:
2214
+ set_root_params.add(input_type)
2215
+
2216
+ # for task sources that span multiple element sets, pad out sub-parameter
2217
+ # `element_iters` to include the element iterations from other element sets in
2218
+ # which the "root" parameter is defined:
2219
+ sources_by_task: dict[int, dict[str, InputSource]] = defaultdict(dict)
2220
+ elem_iter_by_task: dict[int, dict[str, list[int]]] = defaultdict(dict)
2221
+ all_elem_iters: set[int] = set()
2222
+ for inp_type, sources in element_set.input_sources.items():
2223
+ source = sources[0]
2224
+ if source.source_type is InputSourceType.TASK:
2225
+ assert source.task_ref is not None
2226
+ assert source.element_iters is not None
2227
+ sources_by_task[source.task_ref][inp_type] = source
2228
+ all_elem_iters.update(source.element_iters)
2229
+ elem_iter_by_task[source.task_ref][inp_type] = source.element_iters
2230
+
2231
+ all_elem_iters_by_ID = {
2232
+ el_iter.id_: el_iter
2233
+ for el_iter in self.workflow.get_element_iterations_from_IDs(all_elem_iters)
2234
+ }
2235
+
2236
+ # element set indices:
2237
+ padded_elem_iters = defaultdict(list)
2238
+ es_idx_by_task: dict[int, dict[str, _ESIdx]] = defaultdict(dict)
2239
+ for task_ref, task_iters in elem_iter_by_task.items():
2240
+ for inp_type, inp_iters in task_iters.items():
2241
+ es_indices = [
2242
+ all_elem_iters_by_ID[id_].element.element_set_idx for id_ in inp_iters
2243
+ ]
2244
+ es_idx_by_task[task_ref][inp_type] = _ESIdx(
2245
+ es_indices, frozenset(es_indices)
2246
+ )
2247
+ for root_param in {k for k in task_iters if "." not in k}:
2248
+ rp_nesting = element_set.nesting_order.get(f"inputs.{root_param}", None)
2249
+ rp_elem_sets, rp_elem_sets_uniq = es_idx_by_task[task_ref][root_param]
2250
+
2251
+ root_param_prefix = f"{root_param}."
2252
+ for sub_param_j in {
2253
+ k for k in task_iters if k.startswith(root_param_prefix)
2254
+ }:
2255
+ sub_param_nesting = element_set.nesting_order.get(
2256
+ f"inputs.{sub_param_j}", None
2257
+ )
2258
+ if sub_param_nesting == rp_nesting:
2259
+ sp_elem_sets_uniq = es_idx_by_task[task_ref][sub_param_j].uniq
2260
+
2261
+ if sp_elem_sets_uniq != rp_elem_sets_uniq:
2262
+ # replace elem_iters in sub-param sequence with those from the
2263
+ # root parameter, but re-order the elem iters to match their
2264
+ # original order:
2265
+ iters = elem_iter_by_task[task_ref][root_param]
2266
+
2267
+ # "mask" iter IDs corresponding to the sub-parameter's element
2268
+ # sets, and keep track of the extra indices so they can be
2269
+ # ignored later:
2270
+ sp_iters_new: list[int | None] = []
2271
+ for idx, (it_id, es_idx) in enumerate(
2272
+ zip(iters, rp_elem_sets)
2273
+ ):
2274
+ if es_idx in sp_elem_sets_uniq:
2275
+ sp_iters_new.append(None)
2276
+ else:
2277
+ sp_iters_new.append(it_id)
2278
+ padded_elem_iters[sub_param_j].append(idx)
2279
+
2280
+ # update sub-parameter element iters:
2281
+ for src in element_set.input_sources[sub_param_j]:
2282
+ if src.source_type is InputSourceType.TASK:
2283
+ # fill in sub-param elem_iters in their specified order
2284
+ sub_iters_it = iter(
2285
+ elem_iter_by_task[task_ref][sub_param_j]
2286
+ )
2287
+ src.element_iters = [
2288
+ it_id if it_id is not None else next(sub_iters_it)
2289
+ for it_id in sp_iters_new
2290
+ ]
2291
+ # assumes only a single task-type source for this
2292
+ # parameter
2293
+ break
2294
+
2295
+ # TODO: collate all input sources separately, then can fall back to a different
2296
+ # input source (if it was not specified manually) and if the "top" input source
2297
+ # results in no available elements due to `allow_non_coincident_task_sources`.
2298
+
2299
+ if not element_set.allow_non_coincident_task_sources:
2300
+ self.__enforce_some_sanity(sources_by_task, element_set)
2301
+
2302
+ if missing:
2303
+ raise MissingInputs(self.template, missing)
2304
+ return padded_elem_iters
2305
+
2306
+ @TimeIt.decorator
2307
+ def __enforce_some_sanity(
2308
+ self, sources_by_task: dict[int, dict[str, InputSource]], element_set: ElementSet
2309
+ ) -> None:
2310
+ """
2311
+ if multiple parameters are sourced from the same upstream task, only use
2312
+ element iterations for which all parameters are available (the set
2313
+ intersection)
2314
+ """
2315
+ for task_ref, sources in sources_by_task.items():
2316
+ # if a parameter has multiple labels, disregard from this by removing all
2317
+ # parameters:
2318
+ seen_labelled: dict[str, int] = defaultdict(int)
2319
+ for src_i in sources:
2320
+ if "[" in src_i:
2321
+ unlabelled, _ = split_param_label(src_i)
2322
+ assert unlabelled is not None
2323
+ seen_labelled[unlabelled] += 1
2324
+
2325
+ for prefix, count in seen_labelled.items():
2326
+ if count > 1:
2327
+ # remove:
2328
+ sources = {
2329
+ k: v for k, v in sources.items() if not k.startswith(prefix)
2330
+ }
2331
+
2332
+ if len(sources) < 2:
2333
+ continue
2334
+
2335
+ first_src = next(iter(sources.values()))
2336
+ intersect_task_i = set(first_src.element_iters or ())
2337
+ for inp_src in sources.values():
2338
+ intersect_task_i.intersection_update(inp_src.element_iters or ())
2339
+ if not intersect_task_i:
2340
+ raise NoCoincidentInputSources(self.name, task_ref)
2341
+
2342
+ # now change elements for the affected input sources.
2343
+ # sort by original order of first_src.element_iters
2344
+ int_task_i_lst = [
2345
+ i for i in first_src.element_iters or () if i in intersect_task_i
2346
+ ]
2347
+ for inp_type in sources:
2348
+ element_set.input_sources[inp_type][0].element_iters = int_task_i_lst
2349
+
2350
+ @TimeIt.decorator
2351
+ def generate_new_elements(
2352
+ self,
2353
+ input_data_indices: Mapping[str, Sequence[int | list[int]]],
2354
+ output_data_indices: Mapping[str, Sequence[int]],
2355
+ element_data_indices: Sequence[Mapping[str, int]],
2356
+ sequence_indices: Mapping[str, Sequence[int]],
2357
+ source_indices: Mapping[str, Sequence[int]],
2358
+ ) -> tuple[
2359
+ Sequence[DataIndex], Mapping[str, Sequence[int]], Mapping[str, Sequence[int]]
2360
+ ]:
2361
+ """
2362
+ Create information about new elements in this task.
2363
+ """
2364
+ new_elements: list[DataIndex] = []
2365
+ element_sequence_indices: dict[str, list[int]] = {}
2366
+ element_src_indices: dict[str, list[int]] = {}
2367
+ for i_idx, data_idx in enumerate(element_data_indices):
2368
+ elem_i = {
2369
+ k: input_data_indices[k][v]
2370
+ for k, v in data_idx.items()
2371
+ if input_data_indices[k][v] != -1
2372
+ }
2373
+ elem_i.update((k, v2[i_idx]) for k, v2 in output_data_indices.items())
2374
+ new_elements.append(elem_i)
2375
+
2376
+ for k, v3 in data_idx.items():
2377
+ # track which sequence value indices (if any) are used for each new
2378
+ # element:
2379
+ if k in sequence_indices:
2380
+ element_sequence_indices.setdefault(k, []).append(
2381
+ sequence_indices[k][v3]
2382
+ )
2383
+
2384
+ # track original InputSource associated with each new element:
2385
+ if k in source_indices:
2386
+ if input_data_indices[k][v3] != -1:
2387
+ src_idx_k = source_indices[k][v3]
2388
+ else:
2389
+ src_idx_k = -1
2390
+ element_src_indices.setdefault(k, []).append(src_idx_k)
2391
+
2392
+ return new_elements, element_sequence_indices, element_src_indices
2393
+
2394
+ @property
2395
+ def upstream_tasks(self) -> Iterator[WorkflowTask]:
2396
+ """All workflow tasks that are upstream from this task."""
2397
+ tasks = self.workflow.tasks
2398
+ for idx in range(0, self.index):
2399
+ yield tasks[idx]
2400
+
2401
+ @property
2402
+ def downstream_tasks(self) -> Iterator[WorkflowTask]:
2403
+ """All workflow tasks that are downstream from this task."""
2404
+ tasks = self.workflow.tasks
2405
+ for idx in range(self.index + 1, len(tasks)):
2406
+ yield tasks[idx]
2407
+
2408
+ @staticmethod
2409
+ @TimeIt.decorator
2410
+ def resolve_element_data_indices(
2411
+ multiplicities: list[MultiplicityDescriptor],
2412
+ ) -> Sequence[Mapping[str, int]]:
2413
+ """Find the index of the parameter group index list corresponding to each
2414
+ input data for all elements.
2415
+
2416
+ Parameters
2417
+ ----------
2418
+ multiplicities : list of MultiplicityDescriptor
2419
+ Each list item represents a sequence of values with keys:
2420
+ multiplicity: int
2421
+ nesting_order: float
2422
+ path : str
2423
+
2424
+ Returns
2425
+ -------
2426
+ element_dat_idx : list of dict
2427
+ Each list item is a dict representing a single task element and whose keys are
2428
+ input data paths and whose values are indices that index the values of the
2429
+ dict returned by the `task.make_persistent` method.
2430
+
2431
+ Note
2432
+ ----
2433
+ Non-integer nesting orders result in doing the dot product of that sequence with
2434
+ all the current sequences instead of just with the other sequences at the same
2435
+ nesting order (or as a cross product for other nesting orders entire).
2436
+ """
2437
+
2438
+ # order by nesting order (lower nesting orders will be slowest-varying):
2439
+ multi_srt = sorted(multiplicities, key=lambda x: x["nesting_order"])
2440
+ multi_srt_grp = group_by_dict_key_values(multi_srt, "nesting_order")
2441
+
2442
+ element_dat_idx: list[dict[str, int]] = [{}]
2443
+ last_nest_ord: int | None = None
2444
+ for para_sequences in multi_srt_grp:
2445
+ # check all equivalent nesting_orders have equivalent multiplicities
2446
+ all_multis = {md["multiplicity"] for md in para_sequences}
2447
+ if len(all_multis) > 1:
2448
+ raise ValueError(
2449
+ f"All inputs with the same `nesting_order` must have the same "
2450
+ f"multiplicity, but for paths "
2451
+ f"{[md['path'] for md in para_sequences]} with "
2452
+ f"`nesting_order` {para_sequences[0]['nesting_order']} found "
2453
+ f"multiplicities {[md['multiplicity'] for md in para_sequences]}."
2454
+ )
2455
+
2456
+ cur_nest_ord = int(para_sequences[0]["nesting_order"])
2457
+ new_elements: list[dict[str, int]] = []
2458
+ for elem_idx, element in enumerate(element_dat_idx):
2459
+ if last_nest_ord is not None and cur_nest_ord == last_nest_ord:
2460
+ # merge in parallel with existing elements:
2461
+ new_elements.append(
2462
+ {
2463
+ **element,
2464
+ **{md["path"]: elem_idx for md in para_sequences},
2465
+ }
2466
+ )
2467
+ else:
2468
+ for val_idx in range(para_sequences[0]["multiplicity"]):
2469
+ # nest with existing elements:
2470
+ new_elements.append(
2471
+ {
2472
+ **element,
2473
+ **{md["path"]: val_idx for md in para_sequences},
2474
+ }
2475
+ )
2476
+ element_dat_idx = new_elements
2477
+ last_nest_ord = cur_nest_ord
2478
+
2479
+ return element_dat_idx
2480
+
2481
+ @TimeIt.decorator
2482
+ @TimeIt.decorator
2483
+ def initialise_EARs(self, iter_IDs: list[int] | None = None) -> Sequence[int]:
2484
+ """Try to initialise any uninitialised EARs of this task."""
2485
+ if iter_IDs:
2486
+ iters = self.workflow.get_element_iterations_from_IDs(iter_IDs)
2487
+ else:
2488
+ iters = []
2489
+ for element in self.elements:
2490
+ # We don't yet cache Element objects, so `element`, and also it's
2491
+ # `ElementIterations, are transient. So there is no reason to update these
2492
+ # objects in memory to account for the new EARs. Subsequent calls to
2493
+ # `WorkflowTask.elements` will retrieve correct element data from the
2494
+ # store. This might need changing once/if we start caching Element
2495
+ # objects.
2496
+ iters.extend(element.iterations)
2497
+
2498
+ initialised: list[int] = []
2499
+ for iter_i in iters:
2500
+ if not iter_i.EARs_initialised:
2501
+ try:
2502
+ self.__initialise_element_iter_EARs(iter_i)
2503
+ initialised.append(iter_i.id_)
2504
+ except UnsetParameterDataError:
2505
+ # raised by `Action.test_rules`; cannot yet initialise EARs
2506
+ self._app.logger.debug(
2507
+ "UnsetParameterDataError raised: cannot yet initialise runs."
2508
+ )
2509
+ pass
2510
+ else:
2511
+ iter_i._EARs_initialised = True
2512
+ self.workflow.set_EARs_initialised(iter_i.id_)
2513
+ return initialised
2514
+
2515
+ @TimeIt.decorator
2516
+ def __initialise_element_iter_EARs(self, element_iter: ElementIteration) -> None:
2517
+ # keys are (act_idx, EAR_idx):
2518
+ all_data_idx: dict[tuple[int, int], DataIndex] = {}
2519
+ action_runs: dict[tuple[int, int], dict[str, Any]] = {}
2520
+
2521
+ # keys are parameter indices, values are EAR_IDs to update those sources to
2522
+ param_src_updates: dict[int, ParamSource] = {}
2523
+
2524
+ count = 0
2525
+ for act_idx, action in self.template.all_schema_actions():
2526
+ log_common = (
2527
+ f"for action {act_idx} of element iteration {element_iter.index} of "
2528
+ f"element {element_iter.element.index} of task {self.unique_name!r}."
2529
+ )
2530
+ # TODO: when we support adding new runs, we will probably pass additional
2531
+ # run-specific data index to `test_rules` and `generate_data_index`
2532
+ # (e.g. if we wanted to increase the memory requirements of a action because
2533
+ # it previously failed)
2534
+ act_valid, cmds_idx = action.test_rules(element_iter=element_iter)
2535
+ if act_valid:
2536
+ self._app.logger.info(f"All action rules evaluated to true {log_common}")
2537
+ EAR_ID = self.workflow.num_EARs + count
2538
+ param_source: ParamSource = {
2539
+ "type": "EAR_output",
2540
+ "EAR_ID": EAR_ID,
2541
+ }
2542
+ psrc_update = (
2543
+ action.generate_data_index( # adds an item to `all_data_idx`
2544
+ act_idx=act_idx,
2545
+ EAR_ID=EAR_ID,
2546
+ schema_data_idx=element_iter.data_idx,
2547
+ all_data_idx=all_data_idx,
2548
+ workflow=self.workflow,
2549
+ param_source=param_source,
2550
+ )
2551
+ )
2552
+ # with EARs initialised, we can update the pre-allocated schema-level
2553
+ # parameters with the correct EAR reference:
2554
+ for i in psrc_update:
2555
+ param_src_updates[cast("int", i)] = {"EAR_ID": EAR_ID}
2556
+ run_0 = {
2557
+ "elem_iter_ID": element_iter.id_,
2558
+ "action_idx": act_idx,
2559
+ "commands_idx": cmds_idx,
2560
+ "metadata": {},
2561
+ }
2562
+ action_runs[act_idx, EAR_ID] = run_0
2563
+ count += 1
2564
+ else:
2565
+ self._app.logger.info(
2566
+ f"Some action rules evaluated to false {log_common}"
2567
+ )
2568
+
2569
+ # `generate_data_index` can modify data index for previous actions, so only assign
2570
+ # this at the end:
2571
+ for (act_idx, EAR_ID_i), run in action_runs.items():
2572
+ self.workflow._store.add_EAR(
2573
+ elem_iter_ID=element_iter.id_,
2574
+ action_idx=act_idx,
2575
+ commands_idx=run["commands_idx"],
2576
+ data_idx=all_data_idx[act_idx, EAR_ID_i],
2577
+ )
2578
+
2579
+ self.workflow._store.update_param_source(param_src_updates)
2580
+
2581
+ @TimeIt.decorator
2582
+ def _add_element_set(self, element_set: ElementSet) -> list[int]:
2583
+ """
2584
+ Returns
2585
+ -------
2586
+ element_indices : list of int
2587
+ Global indices of newly added elements.
2588
+
2589
+ """
2590
+
2591
+ self.template.set_sequence_parameters(element_set)
2592
+
2593
+ # may modify element_set.input_sources:
2594
+ padded_elem_iters = self.ensure_input_sources(element_set)
2595
+
2596
+ (input_data_idx, seq_idx, src_idx) = self.__make_new_elements_persistent(
2597
+ element_set=element_set,
2598
+ element_set_idx=self.num_element_sets,
2599
+ padded_elem_iters=padded_elem_iters,
2600
+ )
2601
+ element_set.task_template = self.template # may modify element_set.nesting_order
2602
+
2603
+ multiplicities = self.template.prepare_element_resolution(
2604
+ element_set, input_data_idx
2605
+ )
2606
+
2607
+ element_inp_data_idx = self.resolve_element_data_indices(multiplicities)
2608
+
2609
+ local_element_idx_range = [
2610
+ self.num_elements,
2611
+ self.num_elements + len(element_inp_data_idx),
2612
+ ]
2613
+
2614
+ element_set._element_local_idx_range = local_element_idx_range
2615
+ self.template._add_element_set(element_set)
2616
+
2617
+ output_data_idx = self.template._prepare_persistent_outputs(
2618
+ workflow=self.workflow,
2619
+ local_element_idx_range=local_element_idx_range,
2620
+ )
2621
+
2622
+ (element_data_idx, element_seq_idx, element_src_idx) = self.generate_new_elements(
2623
+ input_data_idx,
2624
+ output_data_idx,
2625
+ element_inp_data_idx,
2626
+ seq_idx,
2627
+ src_idx,
2628
+ )
2629
+
2630
+ iter_IDs: list[int] = []
2631
+ elem_IDs: list[int] = []
2632
+ for elem_idx, data_idx in enumerate(element_data_idx):
2633
+ schema_params = set(i for i in data_idx if len(i.split(".")) == 2)
2634
+ elem_ID_i = self.workflow._store.add_element(
2635
+ task_ID=self.insert_ID,
2636
+ es_idx=self.num_element_sets - 1,
2637
+ seq_idx={k: v[elem_idx] for k, v in element_seq_idx.items()},
2638
+ src_idx={k: v[elem_idx] for k, v in element_src_idx.items() if v != -1},
2639
+ )
2640
+ iter_ID_i = self.workflow._store.add_element_iteration(
2641
+ element_ID=elem_ID_i,
2642
+ data_idx=data_idx,
2643
+ schema_parameters=list(schema_params),
2644
+ )
2645
+ iter_IDs.append(iter_ID_i)
2646
+ elem_IDs.append(elem_ID_i)
2647
+
2648
+ self._pending_element_IDs += elem_IDs
2649
+ self.initialise_EARs()
2650
+
2651
+ return iter_IDs
2652
+
2653
+ @overload
2654
+ def add_elements(
2655
+ self,
2656
+ *,
2657
+ base_element: Element | None = None,
2658
+ inputs: list[InputValue] | dict[str, Any] | None = None,
2659
+ input_files: list[InputFile] | None = None,
2660
+ sequences: list[ValueSequence] | None = None,
2661
+ resources: Resources = None,
2662
+ repeats: list[RepeatsDescriptor] | int | None = None,
2663
+ input_sources: dict[str, list[InputSource]] | None = None,
2664
+ nesting_order: dict[str, float] | None = None,
2665
+ element_sets: list[ElementSet] | None = None,
2666
+ sourceable_elem_iters: list[int] | None = None,
2667
+ propagate_to: (
2668
+ list[ElementPropagation]
2669
+ | Mapping[str, ElementPropagation | Mapping[str, Any]]
2670
+ | None
2671
+ ) = None,
2672
+ return_indices: Literal[True],
2673
+ ) -> list[int]: ...
2674
+
2675
+ @overload
2676
+ def add_elements(
2677
+ self,
2678
+ *,
2679
+ base_element: Element | None = None,
2680
+ inputs: list[InputValue] | dict[str, Any] | None = None,
2681
+ input_files: list[InputFile] | None = None,
2682
+ sequences: list[ValueSequence] | None = None,
2683
+ resources: Resources = None,
2684
+ repeats: list[RepeatsDescriptor] | int | None = None,
2685
+ input_sources: dict[str, list[InputSource]] | None = None,
2686
+ nesting_order: dict[str, float] | None = None,
2687
+ element_sets: list[ElementSet] | None = None,
2688
+ sourceable_elem_iters: list[int] | None = None,
2689
+ propagate_to: (
2690
+ list[ElementPropagation]
2691
+ | Mapping[str, ElementPropagation | Mapping[str, Any]]
2692
+ | None
2693
+ ) = None,
2694
+ return_indices: Literal[False] = False,
2695
+ ) -> None: ...
2696
+
2697
+ def add_elements(
2698
+ self,
2699
+ *,
2700
+ base_element: Element | None = None,
2701
+ inputs: list[InputValue] | dict[str, Any] | None = None,
2702
+ input_files: list[InputFile] | None = None,
2703
+ sequences: list[ValueSequence] | None = None,
2704
+ resources: Resources = None,
2705
+ repeats: list[RepeatsDescriptor] | int | None = None,
2706
+ input_sources: dict[str, list[InputSource]] | None = None,
2707
+ nesting_order: dict[str, float] | None = None,
2708
+ element_sets: list[ElementSet] | None = None,
2709
+ sourceable_elem_iters: list[int] | None = None,
2710
+ propagate_to: (
2711
+ list[ElementPropagation]
2712
+ | Mapping[str, ElementPropagation | Mapping[str, Any]]
2713
+ | None
2714
+ ) = None,
2715
+ return_indices=False,
2716
+ ) -> list[int] | None:
2717
+ """
2718
+ Add elements to this task.
2719
+
2720
+ Parameters
2721
+ ----------
2722
+ sourceable_elem_iters : list of int, optional
2723
+ If specified, a list of global element iteration indices from which inputs
2724
+ may be sourced. If not specified, all workflow element iterations are
2725
+ considered sourceable.
2726
+ propagate_to : dict[str, ElementPropagation]
2727
+ Propagate the new elements downstream to the specified tasks.
2728
+ return_indices : bool
2729
+ If True, return the list of indices of the newly added elements. False by
2730
+ default.
2731
+
2732
+ """
2733
+ real_propagate_to = self._app.ElementPropagation._prepare_propagate_to_dict(
2734
+ propagate_to, self.workflow
2735
+ )
2736
+ with self.workflow.batch_update():
2737
+ indices = self._add_elements(
2738
+ base_element=base_element,
2739
+ inputs=inputs,
2740
+ input_files=input_files,
2741
+ sequences=sequences,
2742
+ resources=resources,
2743
+ repeats=repeats,
2744
+ input_sources=input_sources,
2745
+ nesting_order=nesting_order,
2746
+ element_sets=element_sets,
2747
+ sourceable_elem_iters=sourceable_elem_iters,
2748
+ propagate_to=real_propagate_to,
2749
+ )
2750
+ return indices if return_indices else None
2751
+
2752
+ @TimeIt.decorator
2753
+ def _add_elements(
2754
+ self,
2755
+ *,
2756
+ base_element: Element | None = None,
2757
+ inputs: list[InputValue] | dict[str, Any] | None = None,
2758
+ input_files: list[InputFile] | None = None,
2759
+ sequences: list[ValueSequence] | None = None,
2760
+ resources: Resources = None,
2761
+ repeats: list[RepeatsDescriptor] | int | None = None,
2762
+ input_sources: dict[str, list[InputSource]] | None = None,
2763
+ nesting_order: dict[str, float] | None = None,
2764
+ element_sets: list[ElementSet] | None = None,
2765
+ sourceable_elem_iters: list[int] | None = None,
2766
+ propagate_to: dict[str, ElementPropagation],
2767
+ ) -> list[int] | None:
2768
+ """Add more elements to this task.
2769
+
2770
+ Parameters
2771
+ ----------
2772
+ sourceable_elem_iters : list[int]
2773
+ If specified, a list of global element iteration indices from which inputs
2774
+ may be sourced. If not specified, all workflow element iterations are
2775
+ considered sourceable.
2776
+ propagate_to : dict[str, ElementPropagation]
2777
+ Propagate the new elements downstream to the specified tasks.
2778
+ """
2779
+
2780
+ if base_element is not None:
2781
+ if base_element.task is not self:
2782
+ raise ValueError("If specified, `base_element` must belong to this task.")
2783
+ b_inputs, b_resources = base_element.to_element_set_data()
2784
+ inputs = inputs or b_inputs
2785
+ resources = resources or b_resources
2786
+
2787
+ element_sets = self._app.ElementSet.ensure_element_sets(
2788
+ inputs=inputs,
2789
+ input_files=input_files,
2790
+ sequences=sequences,
2791
+ resources=resources,
2792
+ repeats=repeats,
2793
+ input_sources=input_sources,
2794
+ nesting_order=nesting_order,
2795
+ element_sets=element_sets,
2796
+ sourceable_elem_iters=sourceable_elem_iters,
2797
+ )
2798
+
2799
+ elem_idx: list[int] = []
2800
+ for elem_set_i in element_sets:
2801
+ # copy and add the new element set:
2802
+ elem_idx.extend(self._add_element_set(elem_set_i.prepare_persistent_copy()))
2803
+
2804
+ if not propagate_to:
2805
+ return elem_idx
2806
+
2807
+ for task in self.get_dependent_tasks(as_objects=True):
2808
+ if (elem_prop := propagate_to.get(task.unique_name)) is None:
2809
+ continue
2810
+
2811
+ if all(
2812
+ self.unique_name != task.unique_name
2813
+ for task in elem_prop.element_set.get_task_dependencies(as_objects=True)
2814
+ ):
2815
+ # TODO: why can't we just do
2816
+ # `if self in not elem_propagate.element_set.task_dependencies:`?
2817
+ continue
2818
+
2819
+ # TODO: generate a new ElementSet for this task;
2820
+ # Assume for now we use a single base element set.
2821
+ # Later, allow combining multiple element sets.
2822
+ src_elem_iters = elem_idx + [
2823
+ j for el_set in element_sets for j in el_set.sourceable_elem_iters or ()
2824
+ ]
2825
+
2826
+ # note we must pass `resources` as a list since it is already persistent:
2827
+ elem_set_i = self._app.ElementSet(
2828
+ inputs=elem_prop.element_set.inputs,
2829
+ input_files=elem_prop.element_set.input_files,
2830
+ sequences=elem_prop.element_set.sequences,
2831
+ resources=elem_prop.element_set.resources[:],
2832
+ repeats=elem_prop.element_set.repeats,
2833
+ nesting_order=elem_prop.nesting_order,
2834
+ input_sources=elem_prop.input_sources,
2835
+ sourceable_elem_iters=src_elem_iters,
2836
+ )
2837
+
2838
+ del propagate_to[task.unique_name]
2839
+ prop_elem_idx = task._add_elements(
2840
+ element_sets=[elem_set_i],
2841
+ propagate_to=propagate_to,
2842
+ )
2843
+ elem_idx.extend(prop_elem_idx or ())
2844
+
2845
+ return elem_idx
2846
+
2847
+ @overload
2848
+ def get_element_dependencies(
2849
+ self,
2850
+ as_objects: Literal[False] = False,
2851
+ ) -> set[int]: ...
2852
+
2853
+ @overload
2854
+ def get_element_dependencies(
2855
+ self,
2856
+ as_objects: Literal[True],
2857
+ ) -> list[Element]: ...
2858
+
2859
+ def get_element_dependencies(
2860
+ self,
2861
+ as_objects: bool = False,
2862
+ ) -> set[int] | list[Element]:
2863
+ """Get elements from upstream tasks that this task depends on."""
2864
+
2865
+ deps: set[int] = set()
2866
+ for element in self.elements:
2867
+ for iter_i in element.iterations:
2868
+ deps.update(
2869
+ dep_elem_i.id_
2870
+ for dep_elem_i in iter_i.get_element_dependencies(as_objects=True)
2871
+ if dep_elem_i.task.insert_ID != self.insert_ID
2872
+ )
2873
+
2874
+ if as_objects:
2875
+ return self.workflow.get_elements_from_IDs(sorted(deps))
2876
+ return deps
2877
+
2878
+ @overload
2879
+ def get_task_dependencies(self, as_objects: Literal[False] = False) -> set[int]: ...
2880
+
2881
+ @overload
2882
+ def get_task_dependencies(self, as_objects: Literal[True]) -> list[WorkflowTask]: ...
2883
+
2884
+ def get_task_dependencies(
2885
+ self,
2886
+ as_objects: bool = False,
2887
+ ) -> set[int] | list[WorkflowTask]:
2888
+ """Get tasks (insert ID or WorkflowTask objects) that this task depends on.
2889
+
2890
+ Dependencies may come from either elements from upstream tasks, or from locally
2891
+ defined inputs/sequences/defaults from upstream tasks."""
2892
+
2893
+ # TODO: this method might become insufficient if/when we start considering a
2894
+ # new "task_iteration" input source type, which may take precedence over any
2895
+ # other input source types.
2896
+
2897
+ deps: set[int] = set()
2898
+ for element_set in self.template.element_sets:
2899
+ for sources in element_set.input_sources.values():
2900
+ deps.update(
2901
+ src.task_ref
2902
+ for src in sources
2903
+ if (
2904
+ src.source_type is InputSourceType.TASK
2905
+ and src.task_ref is not None
2906
+ )
2907
+ )
2908
+
2909
+ if as_objects:
2910
+ return [self.workflow.tasks.get(insert_ID=id_) for id_ in sorted(deps)]
2911
+ return deps
2912
+
2913
+ @overload
2914
+ def get_dependent_elements(
2915
+ self,
2916
+ as_objects: Literal[False] = False,
2917
+ ) -> set[int]: ...
2918
+
2919
+ @overload
2920
+ def get_dependent_elements(self, as_objects: Literal[True]) -> list[Element]: ...
2921
+
2922
+ def get_dependent_elements(
2923
+ self,
2924
+ as_objects: bool = False,
2925
+ ) -> set[int] | list[Element]:
2926
+ """Get elements from downstream tasks that depend on this task."""
2927
+ deps: set[int] = set()
2928
+ for task in self.downstream_tasks:
2929
+ deps.update(
2930
+ element.id_
2931
+ for element in task.elements
2932
+ if any(
2933
+ self.insert_ID in iter_i.get_task_dependencies()
2934
+ for iter_i in element.iterations
2935
+ )
2936
+ )
2937
+
2938
+ if as_objects:
2939
+ return self.workflow.get_elements_from_IDs(sorted(deps))
2940
+ return deps
2941
+
2942
+ @overload
2943
+ def get_dependent_tasks(self, as_objects: Literal[False] = False) -> set[int]: ...
2944
+
2945
+ @overload
2946
+ def get_dependent_tasks(self, as_objects: Literal[True]) -> list[WorkflowTask]: ...
2947
+
2948
+ @TimeIt.decorator
2949
+ def get_dependent_tasks(
2950
+ self,
2951
+ as_objects: bool = False,
2952
+ ) -> set[int] | list[WorkflowTask]:
2953
+ """Get tasks (insert ID or WorkflowTask objects) that depends on this task."""
2954
+
2955
+ # TODO: this method might become insufficient if/when we start considering a
2956
+ # new "task_iteration" input source type, which may take precedence over any
2957
+ # other input source types.
2958
+
2959
+ deps: set[int] = set()
2960
+ for task in self.downstream_tasks:
2961
+ if task.insert_ID not in deps and any(
2962
+ src.source_type is InputSourceType.TASK and src.task_ref == self.insert_ID
2963
+ for element_set in task.template.element_sets
2964
+ for sources in element_set.input_sources.values()
2965
+ for src in sources
2966
+ ):
2967
+ deps.add(task.insert_ID)
2968
+ if as_objects:
2969
+ return [self.workflow.tasks.get(insert_ID=id_) for id_ in sorted(deps)]
2970
+ return deps
2971
+
2972
+ @property
2973
+ def inputs(self) -> TaskInputParameters:
2974
+ """
2975
+ Inputs to this task.
2976
+ """
2977
+ return self._app.TaskInputParameters(self)
2978
+
2979
+ @property
2980
+ def outputs(self) -> TaskOutputParameters:
2981
+ """
2982
+ Outputs from this task.
2983
+ """
2984
+ return self._app.TaskOutputParameters(self)
2985
+
2986
+ def get(
2987
+ self, path: str, *, raise_on_missing=False, default: Any | None = None
2988
+ ) -> Parameters:
2989
+ """
2990
+ Get a parameter known to this task by its path.
2991
+ """
2992
+ return self._app.Parameters(
2993
+ self,
2994
+ path=path,
2995
+ return_element_parameters=False,
2996
+ raise_on_missing=raise_on_missing,
2997
+ default=default,
2998
+ )
2999
+
3000
+ def _paths_to_PV_classes(self, *paths: str | None) -> dict[str, type[ParameterValue]]:
3001
+ """Return a dict mapping dot-delimited string input paths to `ParameterValue`
3002
+ classes."""
3003
+
3004
+ params: dict[str, type[ParameterValue]] = {}
3005
+ for path in paths:
3006
+ if not path:
3007
+ # Skip None/empty
3008
+ continue
3009
+ path_split = path.split(".")
3010
+ if len(path_split) == 1 or path_split[0] not in ("inputs", "outputs"):
3011
+ continue
3012
+
3013
+ # top-level parameter can be found via the task schema:
3014
+ key_0 = ".".join(path_split[:2])
3015
+
3016
+ if key_0 not in params:
3017
+ if path_split[0] == "inputs":
3018
+ path_1, _ = split_param_label(
3019
+ path_split[1]
3020
+ ) # remove label if present
3021
+ for schema in self.template.schemas:
3022
+ for inp in schema.inputs:
3023
+ if inp.parameter.typ == path_1 and inp.parameter._value_class:
3024
+ params[key_0] = inp.parameter._value_class
3025
+
3026
+ elif path_split[0] == "outputs":
3027
+ for schema in self.template.schemas:
3028
+ for out in schema.outputs:
3029
+ if (
3030
+ out.parameter.typ == path_split[1]
3031
+ and out.parameter._value_class
3032
+ ):
3033
+ params[key_0] = out.parameter._value_class
3034
+
3035
+ if path_split[2:]:
3036
+ pv_classes = {cls._typ: cls for cls in ParameterValue.__subclasses__()}
3037
+
3038
+ # now proceed by searching for sub-parameters in each ParameterValue
3039
+ # sub-class:
3040
+ for idx, part_i in enumerate(path_split[2:], start=2):
3041
+ parent = path_split[:idx] # e.g. ["inputs", "p1"]
3042
+ child = path_split[: idx + 1] # e.g. ["inputs", "p1", "sub_param"]
3043
+ key_i = ".".join(child)
3044
+ if key_i in params:
3045
+ continue
3046
+ if parent_param := params.get(".".join(parent)):
3047
+ for attr_name, sub_type in parent_param._sub_parameters.items():
3048
+ if part_i == attr_name:
3049
+ # find the class with this `typ` attribute:
3050
+ if cls := pv_classes.get(sub_type):
3051
+ params[key_i] = cls
3052
+
3053
+ return params
3054
+
3055
+ @staticmethod
3056
+ def _get_relevant_paths(
3057
+ data_index: Mapping[str, Any], path: list[str], children_of: str | None = None
3058
+ ) -> Mapping[str, RelevantPath]:
3059
+ relevant_paths: dict[str, RelevantPath] = {}
3060
+ # first extract out relevant paths in `data_index`:
3061
+ for path_i in data_index:
3062
+ path_i_split = path_i.split(".")
3063
+ try:
3064
+ rel_path = get_relative_path(path, path_i_split)
3065
+ relevant_paths[path_i] = {"type": "parent", "relative_path": rel_path}
3066
+ except ValueError:
3067
+ try:
3068
+ update_path = get_relative_path(path_i_split, path)
3069
+ relevant_paths[path_i] = {
3070
+ "type": "update",
3071
+ "update_path": update_path,
3072
+ }
3073
+ except ValueError:
3074
+ # no intersection between paths
3075
+ if children_of and path_i.startswith(children_of):
3076
+ relevant_paths[path_i] = {"type": "sibling"}
3077
+ continue
3078
+
3079
+ return relevant_paths
3080
+
3081
+ def __get_relevant_data_item(
3082
+ self,
3083
+ path: str | None,
3084
+ path_i: str,
3085
+ data_idx_ij: int,
3086
+ raise_on_unset: bool,
3087
+ len_dat_idx: int = 1,
3088
+ ) -> tuple[Any, bool, str | None]:
3089
+ if path_i.startswith("repeats."):
3090
+ # data is an integer repeats index, rather than a parameter ID:
3091
+ return data_idx_ij, True, None
3092
+
3093
+ meth_i: str | None = None
3094
+ data_j: Any
3095
+ param_j = self.workflow.get_parameter(data_idx_ij)
3096
+ is_set_i = param_j.is_set
3097
+ if param_j.file:
3098
+ if param_j.file["store_contents"]:
3099
+ file_j = Path(self.workflow.path) / param_j.file["path"]
3100
+ else:
3101
+ file_j = Path(param_j.file["path"])
3102
+ data_j = file_j.as_posix()
3103
+ else:
3104
+ meth_i = param_j.source.get("value_class_method")
3105
+ if param_j.is_pending:
3106
+ # if pending, we need to convert `ParameterValue` objects
3107
+ # to their dict representation, so they can be merged with
3108
+ # other data:
3109
+ try:
3110
+ data_j = cast("ParameterValue", param_j.data).to_dict()
3111
+ except AttributeError:
3112
+ data_j = param_j.data
3113
+ else:
3114
+ # if not pending, data will be the result of an encode-
3115
+ # decode cycle, and it will not be initialised as an
3116
+ # object if the parameter is associated with a
3117
+ # `ParameterValue` class.
3118
+ data_j = param_j.data
3119
+ if raise_on_unset and not is_set_i:
3120
+ raise UnsetParameterDataError(path, path_i)
3121
+ if not is_set_i and self.workflow._is_tracking_unset:
3122
+ src_run_id = param_j.source.get("EAR_ID")
3123
+ unset_trackers = self.workflow._tracked_unset
3124
+ assert src_run_id is not None
3125
+ assert unset_trackers is not None
3126
+ unset_trackers[path_i].run_ids.add(src_run_id)
3127
+ unset_trackers[path_i].group_size = len_dat_idx
3128
+ return data_j, is_set_i, meth_i
3129
+
3130
+ def __get_relevant_data(
3131
+ self,
3132
+ relevant_data_idx: Mapping[str, list[int] | int],
3133
+ raise_on_unset: bool,
3134
+ path: str | None,
3135
+ ) -> Mapping[str, RelevantData]:
3136
+ relevant_data: dict[str, RelevantData] = {}
3137
+ for path_i, data_idx_i in relevant_data_idx.items():
3138
+ if not isinstance(data_idx_i, list):
3139
+ data, is_set, meth = self.__get_relevant_data_item(
3140
+ path, path_i, data_idx_i, raise_on_unset
3141
+ )
3142
+ relevant_data[path_i] = {
3143
+ "data": data,
3144
+ "value_class_method": meth,
3145
+ "is_set": is_set,
3146
+ "is_multi": False,
3147
+ }
3148
+ continue
3149
+
3150
+ data_i: list[Any] = []
3151
+ methods_i: list[str | None] = []
3152
+ is_param_set_i: list[bool] = []
3153
+ for data_idx_ij in data_idx_i:
3154
+ data_j, is_set_i, meth_i = self.__get_relevant_data_item(
3155
+ path, path_i, data_idx_ij, raise_on_unset, len_dat_idx=len(data_idx_i)
3156
+ )
3157
+ data_i.append(data_j)
3158
+ methods_i.append(meth_i)
3159
+ is_param_set_i.append(is_set_i)
3160
+
3161
+ relevant_data[path_i] = {
3162
+ "data": data_i,
3163
+ "value_class_method": methods_i,
3164
+ "is_set": is_param_set_i,
3165
+ "is_multi": True,
3166
+ }
3167
+
3168
+ if not raise_on_unset:
3169
+ to_remove: set[str] = set()
3170
+ for key, dat_info in relevant_data.items():
3171
+ if not dat_info["is_set"] and (not path or path in key):
3172
+ # remove sub-paths, as they cannot be merged with this parent
3173
+ prefix = f"{key}."
3174
+ to_remove.update(k for k in relevant_data if k.startswith(prefix))
3175
+ for key in to_remove:
3176
+ relevant_data.pop(key, None)
3177
+
3178
+ return relevant_data
3179
+
3180
+ @classmethod
3181
+ def __merge_relevant_data(
3182
+ cls,
3183
+ relevant_data: Mapping[str, RelevantData],
3184
+ relevant_paths: Mapping[str, RelevantPath],
3185
+ PV_classes,
3186
+ path: str | None,
3187
+ raise_on_missing: bool,
3188
+ ):
3189
+ current_val: list | dict | Any | None = None
3190
+ assigned_from_parent = False
3191
+ val_cls_method: str | None | list[str | None] = None
3192
+ path_is_multi = False
3193
+ path_is_set: bool | list[bool] = False
3194
+ all_multi_len: int | None = None
3195
+ for path_i, data_info_i in relevant_data.items():
3196
+ data_i = data_info_i["data"]
3197
+ if path_i == path:
3198
+ val_cls_method = data_info_i["value_class_method"]
3199
+ path_is_multi = data_info_i["is_multi"]
3200
+ path_is_set = data_info_i["is_set"]
3201
+
3202
+ if data_info_i["is_multi"]:
3203
+ if all_multi_len:
3204
+ if len(data_i) != all_multi_len:
3205
+ raise RuntimeError(
3206
+ "Cannot merge group values of different lengths."
3207
+ )
3208
+ else:
3209
+ # keep track of group lengths, only merge equal-length groups;
3210
+ all_multi_len = len(data_i)
3211
+
3212
+ path_info = relevant_paths[path_i]
3213
+ if path_info["type"] == "parent":
3214
+ try:
3215
+ if data_info_i["is_multi"]:
3216
+ current_val = [
3217
+ get_in_container(
3218
+ item,
3219
+ path_info["relative_path"],
3220
+ cast_indices=True,
3221
+ )
3222
+ for item in data_i
3223
+ ]
3224
+ path_is_multi = True
3225
+ path_is_set = data_info_i["is_set"]
3226
+ val_cls_method = data_info_i["value_class_method"]
3227
+ else:
3228
+ current_val = get_in_container(
3229
+ data_i,
3230
+ path_info["relative_path"],
3231
+ cast_indices=True,
3232
+ )
3233
+ except ContainerKeyError as err:
3234
+ if path_i in PV_classes:
3235
+ raise MayNeedObjectError(path=".".join([path_i, *err.path[:-1]]))
3236
+ continue
3237
+ except (IndexError, ValueError) as err:
3238
+ if raise_on_missing:
3239
+ raise err
3240
+ continue
3241
+ else:
3242
+ assigned_from_parent = True
3243
+ elif path_info["type"] == "update":
3244
+ current_val = current_val or {}
3245
+ if all_multi_len:
3246
+ if len(path_i.split(".")) == 2:
3247
+ # groups can only be "created" at the parameter level
3248
+ set_in_container(
3249
+ cont=current_val,
3250
+ path=path_info["update_path"],
3251
+ value=data_i,
3252
+ ensure_path=True,
3253
+ cast_indices=True,
3254
+ )
3255
+ else:
3256
+ # update group
3257
+ update_path = path_info["update_path"]
3258
+ if len(update_path) > 1:
3259
+ for idx, j in enumerate(data_i):
3260
+ set_in_container(
3261
+ cont=current_val,
3262
+ path=[*update_path[:1], idx, *update_path[1:]],
3263
+ value=j,
3264
+ ensure_path=True,
3265
+ cast_indices=True,
3266
+ )
3267
+ else:
3268
+ for i, j in zip(current_val, data_i):
3269
+ set_in_container(
3270
+ cont=i,
3271
+ path=update_path,
3272
+ value=j,
3273
+ ensure_path=True,
3274
+ cast_indices=True,
3275
+ )
3276
+
3277
+ else:
3278
+ set_in_container(
3279
+ current_val,
3280
+ path_info["update_path"],
3281
+ data_i,
3282
+ ensure_path=True,
3283
+ cast_indices=True,
3284
+ )
3285
+ if path in PV_classes:
3286
+ if path not in relevant_data:
3287
+ # requested data must be a sub-path of relevant data, so we can assume
3288
+ # path is set (if the parent was not set the sub-paths would be
3289
+ # removed in `__get_relevant_data`):
3290
+ path_is_set = path_is_set or True
3291
+
3292
+ if not assigned_from_parent:
3293
+ # search for unset parents in `relevant_data`:
3294
+ assert path is not None
3295
+ for parent_i_span in range(
3296
+ len(path_split := path.split(".")) - 1, 1, -1
3297
+ ):
3298
+ parent_path_i = ".".join(path_split[:parent_i_span])
3299
+ if not (relevant_par := relevant_data.get(parent_path_i)):
3300
+ continue
3301
+ if not (par_is_set := relevant_par["is_set"]) or not all(
3302
+ cast("list", par_is_set)
3303
+ ):
3304
+ val_cls_method = relevant_par["value_class_method"]
3305
+ path_is_multi = relevant_par["is_multi"]
3306
+ path_is_set = relevant_par["is_set"]
3307
+ current_val = relevant_par["data"]
3308
+ break
3309
+
3310
+ # initialise objects
3311
+ PV_cls = PV_classes[path]
3312
+ if path_is_multi:
3313
+ current_val = [
3314
+ (
3315
+ cls.__map_parameter_value(PV_cls, meth_i, val_i)
3316
+ if set_i and isinstance(val_i, dict)
3317
+ else None
3318
+ )
3319
+ for set_i, meth_i, val_i in zip(
3320
+ cast("list[bool]", path_is_set),
3321
+ cast("list[str|None]", val_cls_method),
3322
+ cast("list[Any]", current_val),
3323
+ )
3324
+ ]
3325
+ elif path_is_set and isinstance(current_val, dict):
3326
+ assert not isinstance(val_cls_method, list)
3327
+ current_val = cls.__map_parameter_value(
3328
+ PV_cls, val_cls_method, current_val
3329
+ )
3330
+
3331
+ return current_val, all_multi_len
3332
+
3333
+ @staticmethod
3334
+ def __map_parameter_value(
3335
+ PV_cls: type[ParameterValue], meth: str | None, val: dict
3336
+ ) -> Any | ParameterValue:
3337
+ if meth:
3338
+ method: Callable = getattr(PV_cls, meth)
3339
+ return method(**val)
3340
+ else:
3341
+ return PV_cls(**val)
3342
+
3343
+ @TimeIt.decorator
3344
+ def _get_merged_parameter_data(
3345
+ self,
3346
+ data_index: Mapping[str, list[int] | int],
3347
+ path: str | None = None,
3348
+ *,
3349
+ raise_on_missing: bool = False,
3350
+ raise_on_unset: bool = False,
3351
+ default: Any | None = None,
3352
+ ):
3353
+ """Get element data from the persistent store."""
3354
+ path_split = [] if not path else path.split(".")
3355
+
3356
+ if not (relevant_paths := self._get_relevant_paths(data_index, path_split)):
3357
+ if raise_on_missing:
3358
+ # TODO: custom exception?
3359
+ raise ValueError(f"Path {path!r} does not exist in the element data.")
3360
+ return default
3361
+
3362
+ relevant_data_idx = {k: v for k, v in data_index.items() if k in relevant_paths}
3363
+
3364
+ cache = self.workflow._merged_parameters_cache
3365
+ use_cache = (
3366
+ self.workflow._use_merged_parameters_cache
3367
+ and raise_on_missing is False
3368
+ and raise_on_unset is False
3369
+ and default is None # cannot cache on default value, may not be hashable
3370
+ )
3371
+ add_to_cache = False
3372
+ if use_cache:
3373
+ # generate the key:
3374
+ dat_idx_cache: list[tuple[str, tuple[int, ...] | int]] = []
3375
+ for k, v in sorted(relevant_data_idx.items()):
3376
+ dat_idx_cache.append((k, tuple(v) if isinstance(v, list) else v))
3377
+ cache_key = (path, tuple(dat_idx_cache))
3378
+
3379
+ # check for cache hit:
3380
+ if cache_key in cache:
3381
+ self._app.logger.debug(
3382
+ f"_get_merged_parameter_data: cache hit with key: {cache_key}"
3383
+ )
3384
+ return cache[cache_key]
3385
+ else:
3386
+ add_to_cache = True
3387
+
3388
+ PV_classes = self._paths_to_PV_classes(*relevant_paths, path)
3389
+ relevant_data = self.__get_relevant_data(relevant_data_idx, raise_on_unset, path)
3390
+
3391
+ current_val = None
3392
+ is_assigned = False
3393
+ try:
3394
+ current_val, _ = self.__merge_relevant_data(
3395
+ relevant_data, relevant_paths, PV_classes, path, raise_on_missing
3396
+ )
3397
+ except MayNeedObjectError as err:
3398
+ path_to_init = err.path
3399
+ path_to_init_split = path_to_init.split(".")
3400
+ relevant_paths = self._get_relevant_paths(data_index, path_to_init_split)
3401
+ PV_classes = self._paths_to_PV_classes(*relevant_paths, path_to_init)
3402
+ relevant_data_idx = {
3403
+ k: v for k, v in data_index.items() if k in relevant_paths
3404
+ }
3405
+ relevant_data = self.__get_relevant_data(
3406
+ relevant_data_idx, raise_on_unset, path
3407
+ )
3408
+ # merge the parent data
3409
+ current_val, group_len = self.__merge_relevant_data(
3410
+ relevant_data, relevant_paths, PV_classes, path_to_init, raise_on_missing
3411
+ )
3412
+ # try to retrieve attributes via the initialised object:
3413
+ rel_path_split = get_relative_path(path_split, path_to_init_split)
3414
+ try:
3415
+ if group_len:
3416
+ current_val = [
3417
+ get_in_container(
3418
+ cont=item,
3419
+ path=rel_path_split,
3420
+ cast_indices=True,
3421
+ allow_getattr=True,
3422
+ )
3423
+ for item in current_val
3424
+ ]
3425
+ else:
3426
+ current_val = get_in_container(
3427
+ cont=current_val,
3428
+ path=rel_path_split,
3429
+ cast_indices=True,
3430
+ allow_getattr=True,
3431
+ )
3432
+ except (KeyError, IndexError, ValueError):
3433
+ pass
3434
+ else:
3435
+ is_assigned = True
3436
+
3437
+ except (KeyError, IndexError, ValueError):
3438
+ pass
3439
+ else:
3440
+ is_assigned = True
3441
+
3442
+ if not is_assigned:
3443
+ if raise_on_missing:
3444
+ # TODO: custom exception?
3445
+ raise ValueError(f"Path {path!r} does not exist in the element data.")
3446
+ current_val = default
3447
+
3448
+ if add_to_cache:
3449
+ self._app.logger.debug(
3450
+ f"_get_merged_parameter_data: adding to cache with key: {cache_key!r}"
3451
+ )
3452
+ # tuple[str | None, tuple[tuple[str, tuple[int, ...] | int], ...]]
3453
+ # tuple[str | None, tuple[tuple[str, tuple[int, ...] | int], ...]] | None
3454
+ cache[cache_key] = current_val
3455
+
3456
+ return current_val
3457
+
3458
+
3459
+ class Elements:
3460
+ """
3461
+ The elements of a task. Iterable.
3462
+
3463
+ Parameters
3464
+ ----------
3465
+ task:
3466
+ The task this will be the elements of.
3467
+ """
3468
+
3469
+ __slots__ = ("_task",)
3470
+
3471
+ def __init__(self, task: WorkflowTask):
3472
+ self._task = task
3473
+
3474
+ # TODO: cache Element objects
3475
+
3476
+ def __repr__(self) -> str:
3477
+ return (
3478
+ f"{self.__class__.__name__}(task={self.task.unique_name!r}, "
3479
+ f"num_elements={self.task.num_elements})"
3480
+ )
3481
+
3482
+ @property
3483
+ def task(self) -> WorkflowTask:
3484
+ """
3485
+ The task this is the elements of.
3486
+ """
3487
+ return self._task
3488
+
3489
+ @TimeIt.decorator
3490
+ def __get_selection(self, selection: int | slice | list[int]) -> list[int]:
3491
+ """Normalise an element selection into a list of element indices."""
3492
+ if isinstance(selection, int):
3493
+ return [selection]
3494
+
3495
+ elif isinstance(selection, slice):
3496
+ return list(range(*selection.indices(self.task.num_elements)))
3497
+
3498
+ elif isinstance(selection, list):
3499
+ return selection
3500
+ else:
3501
+ raise RuntimeError(
3502
+ f"{self.__class__.__name__} selection must be an `int`, `slice` object, "
3503
+ f"or list of `int`s, but received type {type(selection)}."
3504
+ )
3505
+
3506
+ def __len__(self) -> int:
3507
+ return self.task.num_elements
3508
+
3509
+ def __iter__(self) -> Iterator[Element]:
3510
+ yield from self.task.workflow.get_task_elements(self.task)
3511
+
3512
+ @overload
3513
+ def __getitem__(
3514
+ self,
3515
+ selection: int,
3516
+ ) -> Element: ...
3517
+
3518
+ @overload
3519
+ def __getitem__(
3520
+ self,
3521
+ selection: slice | list[int],
3522
+ ) -> list[Element]: ...
3523
+
3524
+ @TimeIt.decorator
3525
+ def __getitem__(
3526
+ self,
3527
+ selection: int | slice | list[int],
3528
+ ) -> Element | list[Element]:
3529
+ elements = self.task.workflow.get_task_elements(
3530
+ self.task, self.__get_selection(selection)
3531
+ )
3532
+
3533
+ if isinstance(selection, int):
3534
+ return elements[0]
3535
+ else:
3536
+ return elements
3537
+
3538
+
3539
+ @dataclass
3540
+ @hydrate
3541
+ class Parameters(AppAware):
3542
+ """
3543
+ The parameters of a (workflow-bound) task. Iterable.
3544
+
3545
+ Parameters
3546
+ ----------
3547
+ task: WorkflowTask
3548
+ The task these are the parameters of.
3549
+ path: str
3550
+ The path to the parameter or parameters.
3551
+ return_element_parameters: bool
3552
+ Whether to return element parameters.
3553
+ raise_on_missing: bool
3554
+ Whether to raise an exception on a missing parameter.
3555
+ raise_on_unset: bool
3556
+ Whether to raise an exception on an unset parameter.
3557
+ default:
3558
+ A default value to use when the parameter is absent.
3559
+ """
3560
+
3561
+ #: The task these are the parameters of.
3562
+ task: WorkflowTask
3563
+ #: The path to the parameter or parameters.
3564
+ path: str
3565
+ #: Whether to return element parameters.
3566
+ return_element_parameters: bool
3567
+ #: Whether to raise an exception on a missing parameter.
3568
+ raise_on_missing: bool = False
3569
+ #: Whether to raise an exception on an unset parameter.
3570
+ raise_on_unset: bool = False
3571
+ #: A default value to use when the parameter is absent.
3572
+ default: Any | None = None
3573
+
3574
+ @TimeIt.decorator
3575
+ def __get_selection(
3576
+ self, selection: int | slice | list[int] | tuple[int, ...]
3577
+ ) -> list[int]:
3578
+ """Normalise an element selection into a list of element indices."""
3579
+ if isinstance(selection, int):
3580
+ return [selection]
3581
+ elif isinstance(selection, slice):
3582
+ return list(range(*selection.indices(self.task.num_elements)))
3583
+ elif isinstance(selection, list):
3584
+ return selection
3585
+ elif isinstance(selection, tuple):
3586
+ return list(selection)
3587
+ else:
3588
+ raise RuntimeError(
3589
+ f"{self.__class__.__name__} selection must be an `int`, `slice` object, "
3590
+ f"or list of `int`s, but received type {type(selection)}."
3591
+ )
3592
+
3593
+ def __iter__(self) -> Iterator[Any | ElementParameter]:
3594
+ yield from self.__getitem__(slice(None))
3595
+
3596
+ @overload
3597
+ def __getitem__(self, selection: int) -> Any | ElementParameter: ...
3598
+
3599
+ @overload
3600
+ def __getitem__(
3601
+ self, selection: slice | list[int]
3602
+ ) -> list[Any | ElementParameter]: ...
3603
+
3604
+ def __getitem__(
3605
+ self,
3606
+ selection: int | slice | list[int],
3607
+ ) -> Any | ElementParameter | list[Any | ElementParameter]:
3608
+ idx_lst = self.__get_selection(selection)
3609
+ elements = self.task.workflow.get_task_elements(self.task, idx_lst)
3610
+ if self.return_element_parameters:
3611
+ params = (
3612
+ self._app.ElementParameter(
3613
+ task=self.task,
3614
+ path=self.path,
3615
+ parent=elem,
3616
+ element=elem,
3617
+ )
3618
+ for elem in elements
3619
+ )
3620
+ else:
3621
+ params = (
3622
+ elem.get(
3623
+ path=self.path,
3624
+ raise_on_missing=self.raise_on_missing,
3625
+ raise_on_unset=self.raise_on_unset,
3626
+ default=self.default,
3627
+ )
3628
+ for elem in elements
3629
+ )
3630
+
3631
+ if isinstance(selection, int):
3632
+ return next(iter(params))
3633
+ else:
3634
+ return list(params)
3635
+
3636
+
3637
+ @dataclass
3638
+ @hydrate
3639
+ class TaskInputParameters(AppAware):
3640
+ """
3641
+ For retrieving schema input parameters across all elements.
3642
+ Treat as an unmodifiable namespace.
3643
+
3644
+ Parameters
3645
+ ----------
3646
+ task:
3647
+ The task that this represents the input parameters of.
3648
+ """
3649
+
3650
+ #: The task that this represents the input parameters of.
3651
+ task: WorkflowTask
3652
+ __input_names: frozenset[str] | None = field(default=None, init=False, compare=False)
3653
+
3654
+ def __getattr__(self, name: str) -> Parameters:
3655
+ if name not in self.__get_input_names():
3656
+ raise ValueError(f"No input named {name!r}.")
3657
+ return self._app.Parameters(self.task, f"inputs.{name}", True)
3658
+
3659
+ def __repr__(self) -> str:
3660
+ return (
3661
+ f"{self.__class__.__name__}("
3662
+ f"{', '.join(f'{name!r}' for name in sorted(self.__get_input_names()))})"
3663
+ )
3664
+
3665
+ def __dir__(self) -> Iterator[str]:
3666
+ yield from super().__dir__()
3667
+ yield from sorted(self.__get_input_names())
3668
+
3669
+ def __get_input_names(self) -> frozenset[str]:
3670
+ if self.__input_names is None:
3671
+ self.__input_names = frozenset(self.task.template.all_schema_input_types)
3672
+ return self.__input_names
3673
+
3674
+
3675
+ @dataclass
3676
+ @hydrate
3677
+ class TaskOutputParameters(AppAware):
3678
+ """
3679
+ For retrieving schema output parameters across all elements.
3680
+ Treat as an unmodifiable namespace.
3681
+
3682
+ Parameters
3683
+ ----------
3684
+ task:
3685
+ The task that this represents the output parameters of.
3686
+ """
3687
+
3688
+ #: The task that this represents the output parameters of.
3689
+ task: WorkflowTask
3690
+ __output_names: frozenset[str] | None = field(default=None, init=False, compare=False)
3691
+
3692
+ def __getattr__(self, name: str) -> Parameters:
3693
+ if name not in self.__get_output_names():
3694
+ raise ValueError(f"No output named {name!r}.")
3695
+ return self._app.Parameters(self.task, f"outputs.{name}", True)
3696
+
3697
+ def __repr__(self) -> str:
3698
+ return (
3699
+ f"{self.__class__.__name__}("
3700
+ f"{', '.join(map(repr, sorted(self.__get_output_names())))})"
3701
+ )
3702
+
3703
+ def __dir__(self) -> Iterator[str]:
3704
+ yield from super().__dir__()
3705
+ yield from sorted(self.__get_output_names())
3706
+
3707
+ def __get_output_names(self) -> frozenset[str]:
3708
+ if self.__output_names is None:
3709
+ self.__output_names = frozenset(self.task.template.all_schema_output_types)
3710
+ return self.__output_names
3711
+
3712
+
3713
+ @dataclass
3714
+ @hydrate
3715
+ class ElementPropagation(AppAware):
3716
+ """
3717
+ Class to represent how a newly added element set should propagate to a given
3718
+ downstream task.
3719
+
3720
+ Parameters
3721
+ ----------
3722
+ task:
3723
+ The task this is propagating to.
3724
+ nesting_order:
3725
+ The nesting order information.
3726
+ input_sources:
3727
+ The input source information.
3728
+ """
3729
+
3730
+ #: The task this is propagating to.
3731
+ task: WorkflowTask
3732
+ #: The nesting order information.
3733
+ nesting_order: dict[str, float] | None = None
3734
+ #: The input source information.
3735
+ input_sources: dict[str, list[InputSource]] | None = None
3736
+
3737
+ @property
3738
+ def element_set(self) -> ElementSet:
3739
+ """
3740
+ The element set that this propagates from.
3741
+
3742
+ Note
3743
+ ----
3744
+ Temporary property. May be moved or reinterpreted.
3745
+ """
3746
+ # TEMP property; for now just use the first element set as the base:
3747
+ return self.task.template.element_sets[0]
3748
+
3749
+ def __deepcopy__(self, memo: dict[int, Any] | None) -> Self:
3750
+ return self.__class__(
3751
+ task=self.task,
3752
+ nesting_order=copy.copy(self.nesting_order),
3753
+ input_sources=copy.deepcopy(self.input_sources, memo),
3754
+ )
3755
+
3756
+ @classmethod
3757
+ def _prepare_propagate_to_dict(
3758
+ cls,
3759
+ propagate_to: (
3760
+ list[ElementPropagation]
3761
+ | Mapping[str, ElementPropagation | Mapping[str, Any]]
3762
+ | None
3763
+ ),
3764
+ workflow: Workflow,
3765
+ ) -> dict[str, ElementPropagation]:
3766
+ if not propagate_to:
3767
+ return {}
3768
+ propagate_to = copy.deepcopy(propagate_to)
3769
+ if isinstance(propagate_to, list):
3770
+ return {prop.task.unique_name: prop for prop in propagate_to}
3771
+
3772
+ return {
3773
+ k: (
3774
+ v
3775
+ if isinstance(v, ElementPropagation)
3776
+ else cls(task=workflow.tasks.get(unique_name=k), **v)
3777
+ )
3778
+ for k, v in propagate_to.items()
3779
+ }
3780
+
3781
+
3782
+ #: A task used as a template for other tasks.
3783
+ TaskTemplate: TypeAlias = Task
3784
+
3785
+
3786
+ class MetaTask(JSONLike):
3787
+ def __init__(self, schema: MetaTaskSchema, tasks: Sequence[Task]):
3788
+ self.schema = schema
3789
+ self.tasks = tasks
3790
+
3791
+ # TODO: validate schema's inputs and outputs are inputs and outputs of `tasks`
3792
+ # schemas