hpcflow-new2 0.2.0a189__py3-none-any.whl → 0.2.0a199__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. hpcflow/__pyinstaller/hook-hpcflow.py +9 -6
  2. hpcflow/_version.py +1 -1
  3. hpcflow/app.py +1 -0
  4. hpcflow/data/scripts/bad_script.py +2 -0
  5. hpcflow/data/scripts/do_nothing.py +2 -0
  6. hpcflow/data/scripts/env_specifier_test/input_file_generator_pass_env_spec.py +4 -0
  7. hpcflow/data/scripts/env_specifier_test/main_script_test_pass_env_spec.py +8 -0
  8. hpcflow/data/scripts/env_specifier_test/output_file_parser_pass_env_spec.py +4 -0
  9. hpcflow/data/scripts/env_specifier_test/v1/input_file_generator_basic.py +4 -0
  10. hpcflow/data/scripts/env_specifier_test/v1/main_script_test_direct_in_direct_out.py +7 -0
  11. hpcflow/data/scripts/env_specifier_test/v1/output_file_parser_basic.py +4 -0
  12. hpcflow/data/scripts/env_specifier_test/v2/main_script_test_direct_in_direct_out.py +7 -0
  13. hpcflow/data/scripts/input_file_generator_basic.py +3 -0
  14. hpcflow/data/scripts/input_file_generator_basic_FAIL.py +3 -0
  15. hpcflow/data/scripts/input_file_generator_test_stdout_stderr.py +8 -0
  16. hpcflow/data/scripts/main_script_test_direct_in.py +3 -0
  17. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2.py +6 -0
  18. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed.py +6 -0
  19. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed_group.py +7 -0
  20. hpcflow/data/scripts/main_script_test_direct_in_direct_out_3.py +6 -0
  21. hpcflow/data/scripts/main_script_test_direct_in_group_direct_out_3.py +6 -0
  22. hpcflow/data/scripts/main_script_test_direct_in_group_one_fail_direct_out_3.py +6 -0
  23. hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +1 -1
  24. hpcflow/data/scripts/main_script_test_hdf5_in_obj_2.py +12 -0
  25. hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +1 -1
  26. hpcflow/data/scripts/main_script_test_json_out_FAIL.py +3 -0
  27. hpcflow/data/scripts/main_script_test_shell_env_vars.py +12 -0
  28. hpcflow/data/scripts/main_script_test_std_out_std_err.py +6 -0
  29. hpcflow/data/scripts/output_file_parser_basic.py +3 -0
  30. hpcflow/data/scripts/output_file_parser_basic_FAIL.py +7 -0
  31. hpcflow/data/scripts/output_file_parser_test_stdout_stderr.py +8 -0
  32. hpcflow/data/scripts/script_exit_test.py +5 -0
  33. hpcflow/data/template_components/environments.yaml +1 -1
  34. hpcflow/sdk/__init__.py +26 -15
  35. hpcflow/sdk/app.py +2192 -768
  36. hpcflow/sdk/cli.py +506 -296
  37. hpcflow/sdk/cli_common.py +105 -7
  38. hpcflow/sdk/config/__init__.py +1 -1
  39. hpcflow/sdk/config/callbacks.py +115 -43
  40. hpcflow/sdk/config/cli.py +126 -103
  41. hpcflow/sdk/config/config.py +674 -318
  42. hpcflow/sdk/config/config_file.py +131 -95
  43. hpcflow/sdk/config/errors.py +125 -84
  44. hpcflow/sdk/config/types.py +148 -0
  45. hpcflow/sdk/core/__init__.py +25 -1
  46. hpcflow/sdk/core/actions.py +1771 -1059
  47. hpcflow/sdk/core/app_aware.py +24 -0
  48. hpcflow/sdk/core/cache.py +139 -79
  49. hpcflow/sdk/core/command_files.py +263 -287
  50. hpcflow/sdk/core/commands.py +145 -112
  51. hpcflow/sdk/core/element.py +828 -535
  52. hpcflow/sdk/core/enums.py +192 -0
  53. hpcflow/sdk/core/environment.py +74 -93
  54. hpcflow/sdk/core/errors.py +455 -52
  55. hpcflow/sdk/core/execute.py +207 -0
  56. hpcflow/sdk/core/json_like.py +540 -272
  57. hpcflow/sdk/core/loop.py +751 -347
  58. hpcflow/sdk/core/loop_cache.py +164 -47
  59. hpcflow/sdk/core/object_list.py +370 -207
  60. hpcflow/sdk/core/parameters.py +1100 -627
  61. hpcflow/sdk/core/rule.py +59 -41
  62. hpcflow/sdk/core/run_dir_files.py +21 -37
  63. hpcflow/sdk/core/skip_reason.py +7 -0
  64. hpcflow/sdk/core/task.py +1649 -1339
  65. hpcflow/sdk/core/task_schema.py +308 -196
  66. hpcflow/sdk/core/test_utils.py +191 -114
  67. hpcflow/sdk/core/types.py +440 -0
  68. hpcflow/sdk/core/utils.py +485 -309
  69. hpcflow/sdk/core/validation.py +82 -9
  70. hpcflow/sdk/core/workflow.py +2544 -1178
  71. hpcflow/sdk/core/zarr_io.py +98 -137
  72. hpcflow/sdk/data/workflow_spec_schema.yaml +2 -0
  73. hpcflow/sdk/demo/cli.py +53 -33
  74. hpcflow/sdk/helper/cli.py +18 -15
  75. hpcflow/sdk/helper/helper.py +75 -63
  76. hpcflow/sdk/helper/watcher.py +61 -28
  77. hpcflow/sdk/log.py +122 -71
  78. hpcflow/sdk/persistence/__init__.py +8 -31
  79. hpcflow/sdk/persistence/base.py +1360 -606
  80. hpcflow/sdk/persistence/defaults.py +6 -0
  81. hpcflow/sdk/persistence/discovery.py +38 -0
  82. hpcflow/sdk/persistence/json.py +568 -188
  83. hpcflow/sdk/persistence/pending.py +382 -179
  84. hpcflow/sdk/persistence/store_resource.py +39 -23
  85. hpcflow/sdk/persistence/types.py +318 -0
  86. hpcflow/sdk/persistence/utils.py +14 -11
  87. hpcflow/sdk/persistence/zarr.py +1337 -433
  88. hpcflow/sdk/runtime.py +44 -41
  89. hpcflow/sdk/submission/{jobscript_info.py → enums.py} +39 -12
  90. hpcflow/sdk/submission/jobscript.py +1651 -692
  91. hpcflow/sdk/submission/schedulers/__init__.py +167 -39
  92. hpcflow/sdk/submission/schedulers/direct.py +121 -81
  93. hpcflow/sdk/submission/schedulers/sge.py +170 -129
  94. hpcflow/sdk/submission/schedulers/slurm.py +291 -268
  95. hpcflow/sdk/submission/schedulers/utils.py +12 -2
  96. hpcflow/sdk/submission/shells/__init__.py +14 -15
  97. hpcflow/sdk/submission/shells/base.py +150 -29
  98. hpcflow/sdk/submission/shells/bash.py +283 -173
  99. hpcflow/sdk/submission/shells/os_version.py +31 -30
  100. hpcflow/sdk/submission/shells/powershell.py +228 -170
  101. hpcflow/sdk/submission/submission.py +1014 -335
  102. hpcflow/sdk/submission/types.py +140 -0
  103. hpcflow/sdk/typing.py +182 -12
  104. hpcflow/sdk/utils/arrays.py +71 -0
  105. hpcflow/sdk/utils/deferred_file.py +55 -0
  106. hpcflow/sdk/utils/hashing.py +16 -0
  107. hpcflow/sdk/utils/patches.py +12 -0
  108. hpcflow/sdk/utils/strings.py +33 -0
  109. hpcflow/tests/api/test_api.py +32 -0
  110. hpcflow/tests/conftest.py +27 -6
  111. hpcflow/tests/data/multi_path_sequences.yaml +29 -0
  112. hpcflow/tests/data/workflow_test_run_abort.yaml +34 -35
  113. hpcflow/tests/schedulers/sge/test_sge_submission.py +36 -0
  114. hpcflow/tests/schedulers/slurm/test_slurm_submission.py +5 -2
  115. hpcflow/tests/scripts/test_input_file_generators.py +282 -0
  116. hpcflow/tests/scripts/test_main_scripts.py +866 -85
  117. hpcflow/tests/scripts/test_non_snippet_script.py +46 -0
  118. hpcflow/tests/scripts/test_ouput_file_parsers.py +353 -0
  119. hpcflow/tests/shells/wsl/test_wsl_submission.py +12 -4
  120. hpcflow/tests/unit/test_action.py +262 -75
  121. hpcflow/tests/unit/test_action_rule.py +9 -4
  122. hpcflow/tests/unit/test_app.py +33 -6
  123. hpcflow/tests/unit/test_cache.py +46 -0
  124. hpcflow/tests/unit/test_cli.py +134 -1
  125. hpcflow/tests/unit/test_command.py +71 -54
  126. hpcflow/tests/unit/test_config.py +142 -16
  127. hpcflow/tests/unit/test_config_file.py +21 -18
  128. hpcflow/tests/unit/test_element.py +58 -62
  129. hpcflow/tests/unit/test_element_iteration.py +50 -1
  130. hpcflow/tests/unit/test_element_set.py +29 -19
  131. hpcflow/tests/unit/test_group.py +4 -2
  132. hpcflow/tests/unit/test_input_source.py +116 -93
  133. hpcflow/tests/unit/test_input_value.py +29 -24
  134. hpcflow/tests/unit/test_jobscript_unit.py +757 -0
  135. hpcflow/tests/unit/test_json_like.py +44 -35
  136. hpcflow/tests/unit/test_loop.py +1396 -84
  137. hpcflow/tests/unit/test_meta_task.py +325 -0
  138. hpcflow/tests/unit/test_multi_path_sequences.py +229 -0
  139. hpcflow/tests/unit/test_object_list.py +17 -12
  140. hpcflow/tests/unit/test_parameter.py +29 -7
  141. hpcflow/tests/unit/test_persistence.py +237 -42
  142. hpcflow/tests/unit/test_resources.py +20 -18
  143. hpcflow/tests/unit/test_run.py +117 -6
  144. hpcflow/tests/unit/test_run_directories.py +29 -0
  145. hpcflow/tests/unit/test_runtime.py +2 -1
  146. hpcflow/tests/unit/test_schema_input.py +23 -15
  147. hpcflow/tests/unit/test_shell.py +23 -2
  148. hpcflow/tests/unit/test_slurm.py +8 -7
  149. hpcflow/tests/unit/test_submission.py +38 -89
  150. hpcflow/tests/unit/test_task.py +352 -247
  151. hpcflow/tests/unit/test_task_schema.py +33 -20
  152. hpcflow/tests/unit/test_utils.py +9 -11
  153. hpcflow/tests/unit/test_value_sequence.py +15 -12
  154. hpcflow/tests/unit/test_workflow.py +114 -83
  155. hpcflow/tests/unit/test_workflow_template.py +0 -1
  156. hpcflow/tests/unit/utils/test_arrays.py +40 -0
  157. hpcflow/tests/unit/utils/test_deferred_file_writer.py +34 -0
  158. hpcflow/tests/unit/utils/test_hashing.py +65 -0
  159. hpcflow/tests/unit/utils/test_patches.py +5 -0
  160. hpcflow/tests/unit/utils/test_redirect_std.py +50 -0
  161. hpcflow/tests/workflows/__init__.py +0 -0
  162. hpcflow/tests/workflows/test_directory_structure.py +31 -0
  163. hpcflow/tests/workflows/test_jobscript.py +334 -1
  164. hpcflow/tests/workflows/test_run_status.py +198 -0
  165. hpcflow/tests/workflows/test_skip_downstream.py +696 -0
  166. hpcflow/tests/workflows/test_submission.py +140 -0
  167. hpcflow/tests/workflows/test_workflows.py +160 -15
  168. hpcflow/tests/workflows/test_zip.py +18 -0
  169. hpcflow/viz_demo.ipynb +6587 -3
  170. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/METADATA +8 -4
  171. hpcflow_new2-0.2.0a199.dist-info/RECORD +221 -0
  172. hpcflow/sdk/core/parallel.py +0 -21
  173. hpcflow_new2-0.2.0a189.dist-info/RECORD +0 -158
  174. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/LICENSE +0 -0
  175. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/WHEEL +0 -0
  176. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a199.dist-info}/entry_points.txt +0 -0
hpcflow/sdk/core/task.py CHANGED
@@ -5,19 +5,20 @@ Tasks are components of workflows.
5
5
  from __future__ import annotations
6
6
  from collections import defaultdict
7
7
  import copy
8
- from dataclasses import dataclass
8
+ from dataclasses import dataclass, field
9
+ from itertools import chain
9
10
  from pathlib import Path
10
- from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
11
+ from typing import NamedTuple, cast, overload, TYPE_CHECKING
12
+ from typing_extensions import override
11
13
 
12
- from valida.rules import Rule
13
-
14
-
15
- from hpcflow.sdk import app
16
- from hpcflow.sdk.core.task_schema import TaskSchema
14
+ from hpcflow.sdk.typing import hydrate
15
+ from hpcflow.sdk.core.object_list import AppDataList
17
16
  from hpcflow.sdk.log import TimeIt
18
- from .json_like import ChildObjectSpec, JSONLike
19
- from .element import ElementGroup
20
- from .errors import (
17
+ from hpcflow.sdk.core.app_aware import AppAware
18
+ from hpcflow.sdk.core.json_like import ChildObjectSpec, JSONLike
19
+ from hpcflow.sdk.core.element import ElementGroup
20
+ from hpcflow.sdk.core.enums import InputSourceType, TaskSourceType
21
+ from hpcflow.sdk.core.errors import (
21
22
  ContainerKeyError,
22
23
  ExtraInputs,
23
24
  InapplicableInputSourceElementIters,
@@ -37,8 +38,8 @@ from .errors import (
37
38
  UnrequiredInputSources,
38
39
  UnsetParameterDataError,
39
40
  )
40
- from .parameters import InputSourceType, ParameterValue, TaskSourceType
41
- from .utils import (
41
+ from hpcflow.sdk.core.parameters import ParameterValue
42
+ from hpcflow.sdk.core.utils import (
42
43
  get_duplicate_items,
43
44
  get_in_container,
44
45
  get_item_repeat_index,
@@ -48,8 +49,44 @@ from .utils import (
48
49
  split_param_label,
49
50
  )
50
51
 
52
+ if TYPE_CHECKING:
53
+ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
54
+ from typing import Any, ClassVar, Literal, TypeVar
55
+ from typing_extensions import Self, TypeAlias, TypeIs
56
+ from ..typing import DataIndex, ParamSource
57
+ from .actions import Action
58
+ from .command_files import InputFile
59
+ from .element import (
60
+ Element,
61
+ ElementIteration,
62
+ ElementFilter,
63
+ ElementParameter,
64
+ _ElementPrefixedParameter as EPP,
65
+ )
66
+ from .parameters import (
67
+ InputValue,
68
+ InputSource,
69
+ ValueSequence,
70
+ MultiPathSequence,
71
+ SchemaInput,
72
+ SchemaOutput,
73
+ ParameterPath,
74
+ )
75
+ from .rule import Rule
76
+ from .task_schema import TaskObjective, TaskSchema, MetaTaskSchema
77
+ from .types import (
78
+ MultiplicityDescriptor,
79
+ RelevantData,
80
+ RelevantPath,
81
+ Resources,
82
+ RepeatsDescriptor,
83
+ )
84
+ from .workflow import Workflow, WorkflowTemplate
85
+
86
+ StrSeq = TypeVar("StrSeq", bound=Sequence[str])
51
87
 
52
- INPUT_SOURCE_TYPES = ["local", "default", "task", "import"]
88
+
89
+ INPUT_SOURCE_TYPES = ("local", "default", "task", "import")
53
90
 
54
91
 
55
92
  @dataclass
@@ -80,7 +117,7 @@ class InputStatus:
80
117
  is_provided: bool
81
118
 
82
119
  @property
83
- def is_extra(self):
120
+ def is_extra(self) -> bool:
84
121
  """True if the input is provided but not required."""
85
122
  return self.is_provided and not self.is_required
86
123
 
@@ -96,6 +133,8 @@ class ElementSet(JSONLike):
96
133
  Input files to the set of elements.
97
134
  sequences: list[~hpcflow.app.ValueSequence]
98
135
  Input value sequences to parameterise over.
136
+ multi_path_sequences: list[~hpcflow.app.MultiPathSequence]
137
+ Multi-path sequences to parameterise over.
99
138
  resources: ~hpcflow.app.ResourceList
100
139
  Resources to use for the set of elements.
101
140
  repeats: list[dict]
@@ -118,13 +157,14 @@ class ElementSet(JSONLike):
118
157
  If True, if more than one parameter is sourced from the same task, then allow
119
158
  these sources to come from distinct element sub-sets. If False (default),
120
159
  only the intersection of element sub-sets for all parameters are included.
121
- merge_envs: bool
122
- If True, merge ``environments`` into ``resources`` using the "any" scope. If
123
- False, ``environments`` are ignored. This is required on first initialisation,
160
+ is_creation: bool
161
+ If True, merge ``environments`` into ``resources`` using the "any" scope, and
162
+ merge sequences belonging to multi-path sequences into the value-sequences list.
163
+ If False, ``environments`` are ignored. This is required on first initialisation,
124
164
  but not on subsequent re-initialisation from a persistent workflow.
125
165
  """
126
166
 
127
- _child_objects = (
167
+ _child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
128
168
  ChildObjectSpec(
129
169
  name="inputs",
130
170
  class_name="InputValue",
@@ -152,6 +192,12 @@ class ElementSet(JSONLike):
152
192
  is_multiple=True,
153
193
  parent_ref="_element_set",
154
194
  ),
195
+ ChildObjectSpec(
196
+ name="multi_path_sequences",
197
+ class_name="MultiPathSequence",
198
+ is_multiple=True,
199
+ parent_ref="_element_set",
200
+ ),
155
201
  ChildObjectSpec(
156
202
  name="input_sources",
157
203
  class_name="InputSource",
@@ -168,32 +214,35 @@ class ElementSet(JSONLike):
168
214
 
169
215
  def __init__(
170
216
  self,
171
- inputs: Optional[List[app.InputValue]] = None,
172
- input_files: Optional[List[app.InputFile]] = None,
173
- sequences: Optional[List[app.ValueSequence]] = None,
174
- resources: Optional[Dict[str, Dict]] = None,
175
- repeats: Optional[List[Dict]] = None,
176
- groups: Optional[List[app.ElementGroup]] = None,
177
- input_sources: Optional[Dict[str, app.InputSource]] = None,
178
- nesting_order: Optional[List] = None,
179
- env_preset: Optional[str] = None,
180
- environments: Optional[Dict[str, Dict[str, Any]]] = None,
181
- sourceable_elem_iters: Optional[List[int]] = None,
182
- allow_non_coincident_task_sources: Optional[bool] = False,
183
- merge_envs: Optional[bool] = True,
217
+ inputs: list[InputValue] | dict[str, Any] | None = None,
218
+ input_files: list[InputFile] | None = None,
219
+ sequences: list[ValueSequence] | None = None,
220
+ multi_path_sequences: list[MultiPathSequence] | None = None,
221
+ resources: Resources = None,
222
+ repeats: list[RepeatsDescriptor] | int | None = None,
223
+ groups: list[ElementGroup] | None = None,
224
+ input_sources: dict[str, list[InputSource]] | None = None,
225
+ nesting_order: dict[str, float] | None = None,
226
+ env_preset: str | None = None,
227
+ environments: Mapping[str, Mapping[str, Any]] | None = None,
228
+ sourceable_elem_iters: list[int] | None = None,
229
+ allow_non_coincident_task_sources: bool = False,
230
+ is_creation: bool = True,
184
231
  ):
185
232
  #: Inputs to the set of elements.
186
- self.inputs = inputs or []
233
+ self.inputs = self.__decode_inputs(inputs or [])
187
234
  #: Input files to the set of elements.
188
235
  self.input_files = input_files or []
189
236
  #: Description of how to repeat the set of elements.
190
- self.repeats = repeats or []
237
+ self.repeats = self.__decode_repeats(repeats or [])
191
238
  #: Groupings in the set of elements.
192
239
  self.groups = groups or []
193
240
  #: Resources to use for the set of elements.
194
- self.resources = self.app.ResourceList.normalise(resources)
241
+ self.resources = self._app.ResourceList.normalise(resources)
195
242
  #: Input value sequences to parameterise over.
196
243
  self.sequences = sequences or []
244
+ #: Input value multi-path sequences to parameterise over.
245
+ self.multi_path_sequences = multi_path_sequences or []
197
246
  #: Input source descriptors.
198
247
  self.input_sources = input_sources or {}
199
248
  #: How to handle nesting of iterations.
@@ -208,30 +257,43 @@ class ElementSet(JSONLike):
208
257
  self.sourceable_elem_iters = sourceable_elem_iters
209
258
  #: Whether to allow sources to come from distinct element sub-sets.
210
259
  self.allow_non_coincident_task_sources = allow_non_coincident_task_sources
211
- #: Whether to merge ``environments`` into ``resources`` using the "any" scope
212
- #: on first initialisation.
213
- self.merge_envs = merge_envs
260
+ #: Whether this initialisation is the first for this data (i.e. not a
261
+ #: reconstruction from persistent workflow data), in which case, we merge
262
+ #: ``environments`` into ``resources`` using the "any" scope, and merge any multi-
263
+ #: path sequences into the sequences list.
264
+ self.is_creation = is_creation
265
+ self.original_input_sources: dict[str, list[InputSource]] | None = None
266
+ self.original_nesting_order: dict[str, float] | None = None
214
267
 
215
268
  self._validate()
216
269
  self._set_parent_refs()
217
270
 
218
- self._task_template = None # assigned by parent Task
219
- self._defined_input_types = None # assigned on _task_template assignment
220
- self._element_local_idx_range = None # assigned by WorkflowTask._add_element_set
271
+ # assigned by parent Task
272
+ self._task_template: Task | None = None
273
+ # assigned on _task_template assignment
274
+ self._defined_input_types: set[str] | None = None
275
+ # assigned by WorkflowTask._add_element_set
276
+ self._element_local_idx_range: list[int] | None = None
277
+
278
+ if self.is_creation:
279
+
280
+ # merge `environments` into element set resources (this mutates `resources`, and
281
+ # should only happen on creation of the element set, not re-initialisation from a
282
+ # persistent workflow):
283
+ if self.environments:
284
+ self.resources.merge_one(
285
+ self._app.ResourceSpec(scope="any", environments=self.environments)
286
+ )
287
+ # note: `env_preset` is merged into resources by the Task init.
221
288
 
222
- # merge `environments` into element set resources (this mutates `resources`, and
223
- # should only happen on creation of the element set, not re-initialisation from a
224
- # persistent workflow):
225
- if self.environments and self.merge_envs:
226
- envs_res = self.app.ResourceList(
227
- [self.app.ResourceSpec(scope="any", environments=self.environments)]
228
- )
229
- self.resources.merge_other(envs_res)
230
- self.merge_envs = False
289
+ # merge sequences belonging to multi-path sequences into the value-sequences list:
290
+ if self.multi_path_sequences:
291
+ for mp_seq in self.multi_path_sequences:
292
+ mp_seq._move_to_sequence_list(self.sequences)
231
293
 
232
- # note: `env_preset` is merged into resources by the Task init.
294
+ self.is_creation = False
233
295
 
234
- def __deepcopy__(self, memo):
296
+ def __deepcopy__(self, memo: dict[int, Any] | None) -> Self:
235
297
  dct = self.to_dict()
236
298
  orig_inp = dct.pop("original_input_sources", None)
237
299
  orig_nest = dct.pop("original_nesting_order", None)
@@ -240,19 +302,17 @@ class ElementSet(JSONLike):
240
302
  obj._task_template = self._task_template
241
303
  obj._defined_input_types = self._defined_input_types
242
304
  obj.original_input_sources = copy.deepcopy(orig_inp)
243
- obj.original_nesting_order = copy.deepcopy(orig_nest)
244
- obj._element_local_idx_range = copy.deepcopy(elem_local_idx_range)
305
+ obj.original_nesting_order = copy.copy(orig_nest)
306
+ obj._element_local_idx_range = copy.copy(elem_local_idx_range)
245
307
  return obj
246
308
 
247
- def __eq__(self, other):
309
+ def __eq__(self, other: Any) -> bool:
248
310
  if not isinstance(other, self.__class__):
249
311
  return False
250
- if self.to_dict() == other.to_dict():
251
- return True
252
- return False
312
+ return self.to_dict() == other.to_dict()
253
313
 
254
314
  @classmethod
255
- def _json_like_constructor(cls, json_like):
315
+ def _json_like_constructor(cls, json_like) -> Self:
256
316
  """Invoked by `JSONLike.from_json_like` instead of `__init__`."""
257
317
  orig_inp = json_like.pop("original_input_sources", None)
258
318
  orig_nest = json_like.pop("original_nesting_order", None)
@@ -263,7 +323,7 @@ class ElementSet(JSONLike):
263
323
  obj._element_local_idx_range = elem_local_idx_range
264
324
  return obj
265
325
 
266
- def prepare_persistent_copy(self):
326
+ def prepare_persistent_copy(self) -> Self:
267
327
  """Return a copy of self, which will then be made persistent, and save copies of
268
328
  attributes that may be changed during integration with the workflow."""
269
329
  obj = copy.deepcopy(self)
@@ -271,86 +331,98 @@ class ElementSet(JSONLike):
271
331
  obj.original_input_sources = self.input_sources
272
332
  return obj
273
333
 
274
- def to_dict(self):
275
- dct = super().to_dict()
334
+ @override
335
+ def _postprocess_to_dict(self, d: dict[str, Any]) -> dict[str, Any]:
336
+ dct = super()._postprocess_to_dict(d)
276
337
  del dct["_defined_input_types"]
277
338
  del dct["_task_template"]
278
339
  return dct
279
340
 
280
341
  @property
281
- def task_template(self):
342
+ def task_template(self) -> Task:
282
343
  """
283
344
  The abstract task this was derived from.
284
345
  """
346
+ assert self._task_template is not None
285
347
  return self._task_template
286
348
 
287
349
  @task_template.setter
288
- def task_template(self, value):
350
+ def task_template(self, value: Task) -> None:
289
351
  self._task_template = value
290
- self._validate_against_template()
352
+ self.__validate_against_template()
291
353
 
292
354
  @property
293
- def input_types(self):
355
+ def input_types(self) -> list[str]:
294
356
  """
295
357
  The input types of the inputs to this element set.
296
358
  """
297
- return [i.labelled_type for i in self.inputs]
359
+ return [in_.labelled_type for in_ in self.inputs]
298
360
 
299
361
  @property
300
- def element_local_idx_range(self):
362
+ def element_local_idx_range(self) -> tuple[int, ...]:
301
363
  """Indices of elements belonging to this element set."""
302
- return tuple(self._element_local_idx_range)
364
+ return tuple(self._element_local_idx_range or ())
303
365
 
304
- def _validate(self):
305
- # support inputs passed as a dict:
306
- _inputs = []
307
- try:
308
- for k, v in self.inputs.items():
366
+ @classmethod
367
+ def __decode_inputs(
368
+ cls, inputs: list[InputValue] | dict[str, Any]
369
+ ) -> list[InputValue]:
370
+ """support inputs passed as a dict"""
371
+ if isinstance(inputs, dict):
372
+ _inputs: list[InputValue] = []
373
+ for k, v in inputs.items():
309
374
  param, label = split_param_label(k)
310
- _inputs.append(self.app.InputValue(parameter=param, label=label, value=v))
311
- except AttributeError:
312
- pass
375
+ assert param is not None
376
+ _inputs.append(cls._app.InputValue(parameter=param, label=label, value=v))
377
+ return _inputs
313
378
  else:
314
- self.inputs = _inputs
379
+ return inputs
315
380
 
381
+ @classmethod
382
+ def __decode_repeats(
383
+ cls, repeats: list[RepeatsDescriptor] | int
384
+ ) -> list[RepeatsDescriptor]:
316
385
  # support repeats as an int:
317
- if isinstance(self.repeats, int):
318
- self.repeats = [
386
+ if isinstance(repeats, int):
387
+ return [
319
388
  {
320
389
  "name": "",
321
- "number": self.repeats,
322
- "nesting_order": 0,
390
+ "number": repeats,
391
+ "nesting_order": 0.0,
323
392
  }
324
393
  ]
394
+ else:
395
+ return repeats
396
+
397
+ _ALLOWED_NESTING_PATHS: ClassVar[frozenset[str]] = frozenset(
398
+ {"inputs", "resources", "repeats"}
399
+ )
325
400
 
401
+ def _validate(self) -> None:
326
402
  # check `nesting_order` paths:
327
- allowed_nesting_paths = ("inputs", "resources", "repeats")
328
403
  for k in self.nesting_order:
329
- if k.split(".")[0] not in allowed_nesting_paths:
330
- raise MalformedNestingOrderPath(
331
- f"Element set: nesting order path {k!r} not understood. Each key in "
332
- f"`nesting_order` must be start with one of "
333
- f"{allowed_nesting_paths!r}."
334
- )
404
+ if k.split(".")[0] not in self._ALLOWED_NESTING_PATHS:
405
+ raise MalformedNestingOrderPath(k, self._ALLOWED_NESTING_PATHS)
335
406
 
336
- inp_paths = [i.normalised_inputs_path for i in self.inputs]
337
- dup_inp_paths = get_duplicate_items(inp_paths)
338
- if dup_inp_paths:
407
+ inp_paths = [in_.normalised_inputs_path for in_ in self.inputs]
408
+ if dup_paths := get_duplicate_items(inp_paths):
339
409
  raise TaskTemplateMultipleInputValues(
340
410
  f"The following inputs parameters are associated with multiple input value "
341
- f"definitions: {dup_inp_paths!r}."
411
+ f"definitions: {dup_paths!r}."
342
412
  )
343
413
 
344
- inp_seq_paths = [i.normalised_inputs_path for i in self.sequences if i.input_type]
345
- dup_inp_seq_paths = get_duplicate_items(inp_seq_paths)
346
- if dup_inp_seq_paths:
414
+ inp_seq_paths = [
415
+ cast("str", seq.normalised_inputs_path)
416
+ for seq in self.sequences
417
+ if seq.input_type
418
+ ]
419
+ if dup_paths := get_duplicate_items(inp_seq_paths):
347
420
  raise TaskTemplateMultipleInputValues(
348
421
  f"The following input parameters are associated with multiple sequence "
349
- f"value definitions: {dup_inp_seq_paths!r}."
422
+ f"value definitions: {dup_paths!r}."
350
423
  )
351
424
 
352
- inp_and_seq = set(inp_paths).intersection(inp_seq_paths)
353
- if inp_and_seq:
425
+ if inp_and_seq := set(inp_paths).intersection(inp_seq_paths):
354
426
  raise TaskTemplateMultipleInputValues(
355
427
  f"The following input parameters are specified in both the `inputs` and "
356
428
  f"`sequences` lists: {list(inp_and_seq)!r}, but must be specified in at "
@@ -368,64 +440,50 @@ class ElementSet(JSONLike):
368
440
  if self.env_preset and self.environments:
369
441
  raise ValueError("Specify at most one of `env_preset` and `environments`.")
370
442
 
371
- def _validate_against_template(self):
372
- unexpected_types = (
373
- set(self.input_types) - self.task_template.all_schema_input_types
374
- )
375
- if unexpected_types:
376
- raise TaskTemplateUnexpectedInput(
377
- f"The following input parameters are unexpected: {list(unexpected_types)!r}"
378
- )
443
+ def __validate_against_template(self) -> None:
444
+ expected_types = self.task_template.all_schema_input_types
445
+ if unexpected_types := set(self.input_types) - expected_types:
446
+ raise TaskTemplateUnexpectedInput(unexpected_types)
379
447
 
380
- seq_inp_types = []
381
- for seq_i in self.sequences:
382
- inp_type = seq_i.labelled_type
383
- if inp_type:
384
- bad_inp = {inp_type} - self.task_template.all_schema_input_types
385
- allowed_str = ", ".join(
386
- f'"{i}"' for i in self.task_template.all_schema_input_types
387
- )
388
- if bad_inp:
448
+ defined_inp_types = set(self.input_types)
449
+ for seq in self.sequences:
450
+ if inp_type := seq.labelled_type:
451
+ if inp_type not in expected_types:
389
452
  raise TaskTemplateUnexpectedSequenceInput(
390
- f"The input type {inp_type!r} specified in the following sequence"
391
- f" path is unexpected: {seq_i.path!r}. Available input types are: "
392
- f"{allowed_str}."
453
+ inp_type, expected_types, seq
393
454
  )
394
- seq_inp_types.append(inp_type)
395
- if seq_i.path not in self.nesting_order:
396
- self.nesting_order.update({seq_i.path: seq_i.nesting_order})
455
+ defined_inp_types.add(inp_type)
456
+ if seq.path not in self.nesting_order and seq.nesting_order is not None:
457
+ self.nesting_order[seq.path] = seq.nesting_order
397
458
 
398
459
  for rep_spec in self.repeats:
399
- reps_path_i = f'repeats.{rep_spec["name"]}'
400
- if reps_path_i not in self.nesting_order:
460
+ if (reps_path_i := f'repeats.{rep_spec["name"]}') not in self.nesting_order:
401
461
  self.nesting_order[reps_path_i] = rep_spec["nesting_order"]
402
462
 
403
463
  for k, v in self.nesting_order.items():
404
464
  if v < 0:
405
- raise TaskTemplateInvalidNesting(
406
- f"`nesting_order` must be >=0 for all keys, but for key {k!r}, value "
407
- f"of {v!r} was specified."
408
- )
465
+ raise TaskTemplateInvalidNesting(k, v)
409
466
 
410
- self._defined_input_types = set(self.input_types + seq_inp_types)
467
+ self._defined_input_types = defined_inp_types
411
468
 
412
469
  @classmethod
413
470
  def ensure_element_sets(
414
471
  cls,
415
- inputs=None,
416
- input_files=None,
417
- sequences=None,
418
- resources=None,
419
- repeats=None,
420
- groups=None,
421
- input_sources=None,
422
- nesting_order=None,
423
- env_preset=None,
424
- environments=None,
425
- allow_non_coincident_task_sources=None,
426
- element_sets=None,
427
- sourceable_elem_iters=None,
428
- ):
472
+ inputs: list[InputValue] | dict[str, Any] | None = None,
473
+ input_files: list[InputFile] | None = None,
474
+ sequences: list[ValueSequence] | None = None,
475
+ multi_path_sequences: list[MultiPathSequence] | None = None,
476
+ resources: Resources = None,
477
+ repeats: list[RepeatsDescriptor] | int | None = None,
478
+ groups: list[ElementGroup] | None = None,
479
+ input_sources: dict[str, list[InputSource]] | None = None,
480
+ nesting_order: dict[str, float] | None = None,
481
+ env_preset: str | None = None,
482
+ environments: Mapping[str, Mapping[str, Any]] | None = None,
483
+ allow_non_coincident_task_sources: bool = False,
484
+ element_sets: list[Self] | None = None,
485
+ sourceable_elem_iters: list[int] | None = None,
486
+ ) -> list[Self]:
429
487
  """
430
488
  Make an instance after validating some argument combinations.
431
489
  """
@@ -433,6 +491,7 @@ class ElementSet(JSONLike):
433
491
  inputs,
434
492
  input_files,
435
493
  sequences,
494
+ multi_path_sequences,
436
495
  resources,
437
496
  repeats,
438
497
  groups,
@@ -441,21 +500,19 @@ class ElementSet(JSONLike):
441
500
  env_preset,
442
501
  environments,
443
502
  )
444
- args_not_none = [i is not None for i in args]
445
503
 
446
- if any(args_not_none):
504
+ if any(arg is not None for arg in args):
447
505
  if element_sets is not None:
448
506
  raise ValueError(
449
507
  "If providing an `element_set`, no other arguments are allowed."
450
508
  )
451
- else:
452
- element_sets = [
453
- cls(
454
- *args,
455
- sourceable_elem_iters=sourceable_elem_iters,
456
- allow_non_coincident_task_sources=allow_non_coincident_task_sources,
457
- )
458
- ]
509
+ element_sets = [
510
+ cls(
511
+ *args,
512
+ sourceable_elem_iters=sourceable_elem_iters,
513
+ allow_non_coincident_task_sources=allow_non_coincident_task_sources,
514
+ )
515
+ ]
459
516
  else:
460
517
  if element_sets is None:
461
518
  element_sets = [
@@ -469,127 +526,140 @@ class ElementSet(JSONLike):
469
526
  return element_sets
470
527
 
471
528
  @property
472
- def defined_input_types(self):
529
+ def defined_input_types(self) -> set[str]:
473
530
  """
474
531
  The input types to this element set.
475
532
  """
533
+ assert self._defined_input_types is not None
476
534
  return self._defined_input_types
477
535
 
478
536
  @property
479
- def undefined_input_types(self):
537
+ def undefined_input_types(self) -> set[str]:
480
538
  """
481
539
  The input types to the abstract task that aren't related to this element set.
482
540
  """
483
541
  return self.task_template.all_schema_input_types - self.defined_input_types
484
542
 
485
- def get_sequence_from_path(self, sequence_path):
543
+ def get_sequence_from_path(self, sequence_path: str) -> ValueSequence | None:
486
544
  """
487
545
  Get the value sequence for the given path, if it exists.
488
546
  """
489
- for seq in self.sequences:
490
- if seq.path == sequence_path:
491
- return seq
547
+ return next((seq for seq in self.sequences if seq.path == sequence_path), None)
492
548
 
493
- def get_defined_parameter_types(self):
549
+ def get_defined_parameter_types(self) -> list[str]:
494
550
  """
495
551
  Get the parameter types of this element set.
496
552
  """
497
- out = []
553
+ out: list[str] = []
498
554
  for inp in self.inputs:
499
555
  if not inp.is_sub_value:
500
556
  out.append(inp.normalised_inputs_path)
501
557
  for seq in self.sequences:
502
558
  if seq.parameter and not seq.is_sub_value: # ignore resource sequences
559
+ assert seq.normalised_inputs_path is not None
503
560
  out.append(seq.normalised_inputs_path)
504
561
  return out
505
562
 
506
- def get_defined_sub_parameter_types(self):
563
+ def get_defined_sub_parameter_types(self) -> list[str]:
507
564
  """
508
565
  Get the sub-parameter types of this element set.
509
566
  """
510
- out = []
567
+ out: list[str] = []
511
568
  for inp in self.inputs:
512
569
  if inp.is_sub_value:
513
570
  out.append(inp.normalised_inputs_path)
514
571
  for seq in self.sequences:
515
572
  if seq.parameter and seq.is_sub_value: # ignore resource sequences
573
+ assert seq.normalised_inputs_path is not None
516
574
  out.append(seq.normalised_inputs_path)
517
575
  return out
518
576
 
519
- def get_locally_defined_inputs(self):
577
+ def get_locally_defined_inputs(self) -> list[str]:
520
578
  """
521
579
  Get the input types that this element set defines.
522
580
  """
523
581
  return self.get_defined_parameter_types() + self.get_defined_sub_parameter_types()
524
582
 
525
583
  @property
526
- def index(self):
584
+ def index(self) -> int | None:
527
585
  """
528
586
  The index of this element set in its' template task's collection of sets.
529
587
  """
530
- for idx, element_set in enumerate(self.task_template.element_sets):
531
- if element_set is self:
532
- return idx
588
+ return next(
589
+ (
590
+ idx
591
+ for idx, element_set in enumerate(self.task_template.element_sets)
592
+ if element_set is self
593
+ ),
594
+ None,
595
+ )
533
596
 
534
597
  @property
535
- def task(self):
598
+ def task(self) -> WorkflowTask:
536
599
  """
537
600
  The concrete task corresponding to this element set.
538
601
  """
539
- return self.task_template.workflow_template.workflow.tasks[
540
- self.task_template.index
541
- ]
602
+ t = self.task_template.workflow_template
603
+ assert t
604
+ w = t.workflow
605
+ assert w
606
+ i = self.task_template.index
607
+ assert i is not None
608
+ return w.tasks[i]
542
609
 
543
610
  @property
544
- def elements(self):
611
+ def elements(self) -> list[Element]:
545
612
  """
546
613
  The elements in this element set.
547
614
  """
548
615
  return self.task.elements[slice(*self.element_local_idx_range)]
549
616
 
550
617
  @property
551
- def element_iterations(self):
618
+ def element_iterations(self) -> list[ElementIteration]:
552
619
  """
553
620
  The iterations in this element set.
554
621
  """
555
- return [j for i in self.elements for j in i.iterations]
622
+ return list(chain.from_iterable(elem.iterations for elem in self.elements))
556
623
 
557
624
  @property
558
- def elem_iter_IDs(self):
625
+ def elem_iter_IDs(self) -> list[int]:
559
626
  """
560
627
  The IDs of the iterations in this element set.
561
628
  """
562
- return [i.id_ for i in self.element_iterations]
629
+ return [it.id_ for it in self.element_iterations]
630
+
631
+ @overload
632
+ def get_task_dependencies(self, as_objects: Literal[False] = False) -> set[int]:
633
+ ...
634
+
635
+ @overload
636
+ def get_task_dependencies(self, as_objects: Literal[True]) -> list[WorkflowTask]:
637
+ ...
563
638
 
564
- def get_task_dependencies(self, as_objects=False):
639
+ def get_task_dependencies(
640
+ self, as_objects: bool = False
641
+ ) -> list[WorkflowTask] | set[int]:
565
642
  """Get upstream tasks that this element set depends on."""
566
- deps = []
643
+ deps: set[int] = set()
567
644
  for element in self.elements:
568
- for dep_i in element.get_task_dependencies(as_objects=False):
569
- if dep_i not in deps:
570
- deps.append(dep_i)
571
- deps = sorted(deps)
645
+ deps.update(element.get_task_dependencies())
572
646
  if as_objects:
573
- deps = [self.task.workflow.tasks.get(insert_ID=i) for i in deps]
574
-
647
+ return [self.task.workflow.tasks.get(insert_ID=id_) for id_ in sorted(deps)]
575
648
  return deps
576
649
 
577
650
  def is_input_type_provided(self, labelled_path: str) -> bool:
578
651
  """Check if an input is provided locally as an InputValue or a ValueSequence."""
579
-
580
- for inp in self.inputs:
581
- if labelled_path == inp.normalised_inputs_path:
582
- return True
583
-
584
- for seq in self.sequences:
585
- if seq.parameter:
586
- # i.e. not a resource:
587
- if labelled_path == seq.normalised_inputs_path:
588
- return True
589
-
590
- return False
652
+ return any(
653
+ labelled_path == inp.normalised_inputs_path for inp in self.inputs
654
+ ) or any(
655
+ seq.parameter
656
+ # i.e. not a resource:
657
+ and labelled_path == seq.normalised_inputs_path
658
+ for seq in self.sequences
659
+ )
591
660
 
592
661
 
662
+ @hydrate
593
663
  class OutputLabel(JSONLike):
594
664
  """
595
665
  Schema input labels that should be applied to a subset of task outputs.
@@ -604,7 +674,7 @@ class OutputLabel(JSONLike):
604
674
  Optional filtering rule
605
675
  """
606
676
 
607
- _child_objects = (
677
+ _child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
608
678
  ChildObjectSpec(
609
679
  name="where",
610
680
  class_name="ElementFilter",
@@ -615,7 +685,7 @@ class OutputLabel(JSONLike):
615
685
  self,
616
686
  parameter: str,
617
687
  label: str,
618
- where: Optional[List[app.ElementFilter]] = None,
688
+ where: Rule | None = None,
619
689
  ) -> None:
620
690
  #: Name of a parameter.
621
691
  self.parameter = parameter
@@ -625,6 +695,7 @@ class OutputLabel(JSONLike):
625
695
  self.where = where
626
696
 
627
697
 
698
+ @hydrate
628
699
  class Task(JSONLike):
629
700
  """
630
701
  Parametrisation of an isolated task for which a subset of input values are given
@@ -645,6 +716,9 @@ class Task(JSONLike):
645
716
  A list of `InputValue` objects.
646
717
  input_files: list[~hpcflow.app.InputFile]
647
718
  sequences: list[~hpcflow.app.ValueSequence]
719
+ Input value sequences to parameterise over.
720
+ multi_path_sequences: list[~hpcflow.app.MultiPathSequence]
721
+ Multi-path sequences to parameterise over.
648
722
  input_sources: dict[str, ~hpcflow.app.InputSource]
649
723
  nesting_order: list
650
724
  env_preset: str
@@ -663,7 +737,7 @@ class Task(JSONLike):
663
737
  re-initialisation from a persistent workflow.
664
738
  """
665
739
 
666
- _child_objects = (
740
+ _child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
667
741
  ChildObjectSpec(
668
742
  name="schema",
669
743
  class_name="TaskSchema",
@@ -685,24 +759,29 @@ class Task(JSONLike):
685
759
  ),
686
760
  )
687
761
 
762
+ @classmethod
763
+ def __is_TaskSchema(cls, value) -> TypeIs[TaskSchema]:
764
+ return isinstance(value, cls._app.TaskSchema)
765
+
688
766
  def __init__(
689
767
  self,
690
- schema: Union[app.TaskSchema, str, List[app.TaskSchema], List[str]],
691
- repeats: Optional[List[Dict]] = None,
692
- groups: Optional[List[app.ElementGroup]] = None,
693
- resources: Optional[Dict[str, Dict]] = None,
694
- inputs: Optional[List[app.InputValue]] = None,
695
- input_files: Optional[List[app.InputFile]] = None,
696
- sequences: Optional[List[app.ValueSequence]] = None,
697
- input_sources: Optional[Dict[str, app.InputSource]] = None,
698
- nesting_order: Optional[List] = None,
699
- env_preset: Optional[str] = None,
700
- environments: Optional[Dict[str, Dict[str, Any]]] = None,
701
- allow_non_coincident_task_sources: Optional[bool] = False,
702
- element_sets: Optional[List[app.ElementSet]] = None,
703
- output_labels: Optional[List[app.OutputLabel]] = None,
704
- sourceable_elem_iters: Optional[List[int]] = None,
705
- merge_envs: Optional[bool] = True,
768
+ schema: TaskSchema | str | list[TaskSchema] | list[str],
769
+ repeats: list[RepeatsDescriptor] | int | None = None,
770
+ groups: list[ElementGroup] | None = None,
771
+ resources: Resources = None,
772
+ inputs: list[InputValue] | dict[str, Any] | None = None,
773
+ input_files: list[InputFile] | None = None,
774
+ sequences: list[ValueSequence] | None = None,
775
+ multi_path_sequences: list[MultiPathSequence] | None = None,
776
+ input_sources: dict[str, list[InputSource]] | None = None,
777
+ nesting_order: dict[str, float] | None = None,
778
+ env_preset: str | None = None,
779
+ environments: Mapping[str, Mapping[str, Any]] | None = None,
780
+ allow_non_coincident_task_sources: bool = False,
781
+ element_sets: list[ElementSet] | None = None,
782
+ output_labels: list[OutputLabel] | None = None,
783
+ sourceable_elem_iters: list[int] | None = None,
784
+ merge_envs: bool = True,
706
785
  ):
707
786
  # TODO: allow init via specifying objective and/or method and/or implementation
708
787
  # (lists of) strs e.g.: Task(
@@ -718,28 +797,28 @@ class Task(JSONLike):
718
797
  # 'simulate_VE_loading_taylor_damask'
719
798
  # ])
720
799
 
721
- if not isinstance(schema, list):
722
- schema = [schema]
723
-
724
- _schemas = []
725
- for i in schema:
726
- if isinstance(i, str):
800
+ _schemas: list[TaskSchema] = []
801
+ for item in schema if isinstance(schema, list) else [schema]:
802
+ if isinstance(item, str):
727
803
  try:
728
- i = self.app.TaskSchema.get_by_key(
729
- i
804
+ _schemas.append(
805
+ self._app.TaskSchema.get_by_key(item)
730
806
  ) # TODO: document that we need to use the actual app instance here?
807
+ continue
731
808
  except KeyError:
732
- raise KeyError(f"TaskSchema {i!r} not found.")
733
- elif not isinstance(i, TaskSchema):
734
- raise TypeError(f"Not a TaskSchema object: {i!r}")
735
- _schemas.append(i)
809
+ raise KeyError(f"TaskSchema {item!r} not found.")
810
+ elif self.__is_TaskSchema(item):
811
+ _schemas.append(item)
812
+ else:
813
+ raise TypeError(f"Not a TaskSchema object: {item!r}")
736
814
 
737
815
  self._schemas = _schemas
738
816
 
739
- self._element_sets = self.app.ElementSet.ensure_element_sets(
817
+ self._element_sets = self._app.ElementSet.ensure_element_sets(
740
818
  inputs=inputs,
741
819
  input_files=input_files,
742
820
  sequences=sequences,
821
+ multi_path_sequences=multi_path_sequences,
743
822
  resources=resources,
744
823
  repeats=repeats,
745
824
  groups=groups,
@@ -755,26 +834,31 @@ class Task(JSONLike):
755
834
  #: Whether to merge ``environments`` into ``resources`` using the "any" scope
756
835
  #: on first initialisation.
757
836
  self.merge_envs = merge_envs
837
+ self.__groups: AppDataList[ElementGroup] = AppDataList(
838
+ groups or [], access_attribute="name"
839
+ )
758
840
 
759
841
  # appended to when new element sets are added and reset on dump to disk:
760
- self._pending_element_sets = []
842
+ self._pending_element_sets: list[ElementSet] = []
761
843
 
762
844
  self._validate()
763
- self._name = self._get_name()
845
+ self._name = self.__get_name()
764
846
 
765
847
  #: The template workflow that this task is within.
766
- self.workflow_template = None # assigned by parent WorkflowTemplate
767
- self._insert_ID = None
768
- self._dir_name = None
848
+ self.workflow_template: WorkflowTemplate | None = (
849
+ None # assigned by parent WorkflowTemplate
850
+ )
851
+ self._insert_ID: int | None = None
852
+ self._dir_name: str | None = None
769
853
 
770
854
  if self.merge_envs:
771
- self._merge_envs_into_resources()
855
+ self.__merge_envs_into_resources()
772
856
 
773
857
  # TODO: consider adding a new element_set; will need to merge new environments?
774
858
 
775
859
  self._set_parent_refs({"schema": "schemas"})
776
860
 
777
- def _merge_envs_into_resources(self):
861
+ def __merge_envs_into_resources(self) -> None:
778
862
  # for each element set, merge `env_preset` into `resources` (this mutates
779
863
  # `resources`, and should only happen on creation of the task, not
780
864
  # re-initialisation from a persistent workflow):
@@ -782,81 +866,71 @@ class Task(JSONLike):
782
866
 
783
867
  # TODO: required so we don't raise below; can be removed once we consider multiple
784
868
  # schemas:
785
- has_presets = False
786
869
  for es in self.element_sets:
787
- if es.env_preset:
788
- has_presets = True
870
+ if es.env_preset or any(seq.path == "env_preset" for seq in es.sequences):
789
871
  break
790
- for seq in es.sequences:
791
- if seq.path == "env_preset":
792
- has_presets = True
793
- break
794
- if has_presets:
795
- break
796
-
797
- if not has_presets:
872
+ else:
873
+ # No presets
798
874
  return
875
+
799
876
  try:
800
877
  env_presets = self.schema.environment_presets
801
- except ValueError:
878
+ except ValueError as e:
802
879
  # TODO: consider multiple schemas
803
880
  raise NotImplementedError(
804
881
  "Cannot merge environment presets into a task with multiple schemas."
805
- )
882
+ ) from e
806
883
 
807
884
  for es in self.element_sets:
808
885
  if es.env_preset:
809
886
  # retrieve env specifiers from presets defined in the schema:
810
887
  try:
811
- env_specs = env_presets[es.env_preset]
888
+ env_specs = env_presets[es.env_preset] # type: ignore[index]
812
889
  except (TypeError, KeyError):
813
- raise UnknownEnvironmentPresetError(
814
- f"There is no environment preset named {es.env_preset!r} "
815
- f"defined in the task schema {self.schema.name}."
816
- )
817
- envs_res = self.app.ResourceList(
818
- [self.app.ResourceSpec(scope="any", environments=env_specs)]
890
+ raise UnknownEnvironmentPresetError(es.env_preset, self.schema.name)
891
+ es.resources.merge_one(
892
+ self._app.ResourceSpec(scope="any", environments=env_specs)
819
893
  )
820
- es.resources.merge_other(envs_res)
821
894
 
822
895
  for seq in es.sequences:
823
896
  if seq.path == "env_preset":
824
897
  # change to a resources path:
825
- seq.path = f"resources.any.environments"
898
+ seq.path = "resources.any.environments"
826
899
  _values = []
827
- for i in seq.values:
900
+ for val in seq.values or ():
828
901
  try:
829
- _values.append(env_presets[i])
830
- except (TypeError, KeyError):
902
+ _values.append(env_presets[val]) # type: ignore[index]
903
+ except (TypeError, KeyError) as e:
831
904
  raise UnknownEnvironmentPresetError(
832
- f"There is no environment preset named {i!r} defined "
833
- f"in the task schema {self.schema.name}."
834
- )
905
+ val, self.schema.name
906
+ ) from e
835
907
  seq._values = _values
836
908
 
837
- def _reset_pending_element_sets(self):
909
+ def _reset_pending_element_sets(self) -> None:
838
910
  self._pending_element_sets = []
839
911
 
840
- def _accept_pending_element_sets(self):
912
+ def _accept_pending_element_sets(self) -> None:
841
913
  self._element_sets += self._pending_element_sets
842
914
  self._reset_pending_element_sets()
843
915
 
844
- def __eq__(self, other):
916
+ def __eq__(self, other: Any) -> bool:
845
917
  if not isinstance(other, self.__class__):
846
918
  return False
847
- if self.to_dict() == other.to_dict():
848
- return True
849
- return False
919
+ return self.to_dict() == other.to_dict()
850
920
 
851
- def _add_element_set(self, element_set: app.ElementSet):
921
+ def _add_element_set(self, element_set: ElementSet):
852
922
  """Invoked by WorkflowTask._add_element_set."""
853
923
  self._pending_element_sets.append(element_set)
854
- self.workflow_template.workflow._store.add_element_set(
855
- self.insert_ID, element_set.to_json_like()[0]
924
+ wt = self.workflow_template
925
+ assert wt
926
+ w = wt.workflow
927
+ assert w
928
+ w._store.add_element_set(
929
+ self.insert_ID, cast("Mapping", element_set.to_json_like()[0])
856
930
  )
857
931
 
858
932
  @classmethod
859
- def _json_like_constructor(cls, json_like):
933
+ def _json_like_constructor(cls, json_like: dict) -> Self:
860
934
  """Invoked by `JSONLike.from_json_like` instead of `__init__`."""
861
935
  insert_ID = json_like.pop("insert_ID", None)
862
936
  dir_name = json_like.pop("dir_name", None)
@@ -865,10 +939,10 @@ class Task(JSONLike):
865
939
  obj._dir_name = dir_name
866
940
  return obj
867
941
 
868
- def __repr__(self):
942
+ def __repr__(self) -> str:
869
943
  return f"{self.__class__.__name__}(name={self.name!r})"
870
944
 
871
- def __deepcopy__(self, memo):
945
+ def __deepcopy__(self, memo: dict[int, Any] | None) -> Self:
872
946
  kwargs = self.to_dict()
873
947
  _insert_ID = kwargs.pop("insert_ID")
874
948
  _dir_name = kwargs.pop("dir_name")
@@ -881,28 +955,34 @@ class Task(JSONLike):
881
955
  obj._pending_element_sets = self._pending_element_sets
882
956
  return obj
883
957
 
884
- def to_persistent(self, workflow, insert_ID):
958
+ def to_persistent(
959
+ self, workflow: Workflow, insert_ID: int
960
+ ) -> tuple[Self, list[int | list[int]]]:
885
961
  """Return a copy where any schema input defaults are saved to a persistent
886
962
  workflow. Element set data is not made persistent."""
887
963
 
888
964
  obj = copy.deepcopy(self)
889
- new_refs = []
890
- source = {"type": "default_input", "task_insert_ID": insert_ID}
891
- for schema in obj.schemas:
892
- new_refs.extend(schema.make_persistent(workflow, source))
965
+ source: ParamSource = {"type": "default_input", "task_insert_ID": insert_ID}
966
+ new_refs = list(
967
+ chain.from_iterable(
968
+ schema.make_persistent(workflow, source) for schema in obj.schemas
969
+ )
970
+ )
893
971
 
894
972
  return obj, new_refs
895
973
 
896
- def to_dict(self):
897
- out = super().to_dict()
974
+ @override
975
+ def _postprocess_to_dict(self, d: dict[str, Any]) -> dict[str, Any]:
976
+ out = super()._postprocess_to_dict(d)
898
977
  out["_schema"] = out.pop("_schemas")
899
- return {
978
+ res = {
900
979
  k.lstrip("_"): v
901
980
  for k, v in out.items()
902
- if k not in ("_name", "_pending_element_sets")
981
+ if k not in ("_name", "_pending_element_sets", "_Task__groups")
903
982
  }
983
+ return res
904
984
 
905
- def set_sequence_parameters(self, element_set):
985
+ def set_sequence_parameters(self, element_set: ElementSet) -> None:
906
986
  """
907
987
  Set up parameters parsed by value sequences.
908
988
  """
@@ -914,18 +994,14 @@ class Task(JSONLike):
914
994
  if inp_j.typ == seq.input_type:
915
995
  seq._parameter = inp_j.parameter
916
996
 
917
- def _validate(self):
997
+ def _validate(self) -> None:
918
998
  # TODO: check a nesting order specified for each sequence?
919
999
 
920
- names = set(i.objective.name for i in self.schemas)
921
- if len(names) > 1:
922
- raise TaskTemplateMultipleSchemaObjectives(
923
- f"All task schemas used within a task must have the same "
924
- f"objective, but found multiple objectives: {list(names)!r}"
925
- )
1000
+ if len(names := set(schema.objective.name for schema in self.schemas)) > 1:
1001
+ raise TaskTemplateMultipleSchemaObjectives(names)
926
1002
 
927
- def _get_name(self):
928
- out = f"{self.objective.name}"
1003
+ def __get_name(self) -> str:
1004
+ out = self.objective.name
929
1005
  for idx, schema_i in enumerate(self.schemas, start=1):
930
1006
  need_and = idx < len(self.schemas) and (
931
1007
  self.schemas[idx].method or self.schemas[idx].implementation
@@ -938,13 +1014,12 @@ class Task(JSONLike):
938
1014
  return out
939
1015
 
940
1016
  @staticmethod
941
- def get_task_unique_names(tasks: List[app.Task]):
1017
+ def get_task_unique_names(tasks: list[Task]) -> Sequence[str]:
942
1018
  """Get the unique name of each in a list of tasks.
943
1019
 
944
1020
  Returns
945
1021
  -------
946
1022
  list of str
947
-
948
1023
  """
949
1024
 
950
1025
  task_name_rep_idx = get_item_repeat_index(
@@ -953,72 +1028,77 @@ class Task(JSONLike):
953
1028
  distinguish_singular=True,
954
1029
  )
955
1030
 
956
- names = []
957
- for idx, task in enumerate(tasks):
958
- add_rep = f"_{task_name_rep_idx[idx]}" if task_name_rep_idx[idx] > 0 else ""
959
- names.append(f"{task.name}{add_rep}")
960
-
961
- return names
1031
+ return [
1032
+ (
1033
+ f"{task.name}_{task_name_rep_idx[idx]}"
1034
+ if task_name_rep_idx[idx] > 0
1035
+ else task.name
1036
+ )
1037
+ for idx, task in enumerate(tasks)
1038
+ ]
962
1039
 
963
1040
  @TimeIt.decorator
964
- def _prepare_persistent_outputs(self, workflow, local_element_idx_range):
1041
+ def _prepare_persistent_outputs(
1042
+ self, workflow: Workflow, local_element_idx_range: Sequence[int]
1043
+ ) -> Mapping[str, Sequence[int]]:
965
1044
  # TODO: check that schema is present when adding task? (should this be here?)
966
1045
 
967
1046
  # allocate schema-level output parameter; precise EAR index will not be known
968
1047
  # until we initialise EARs:
969
- output_data_indices = {}
1048
+ output_data_indices: dict[str, list[int]] = {}
970
1049
  for schema in self.schemas:
971
1050
  for output in schema.outputs:
972
1051
  # TODO: consider multiple schemas in action index?
973
1052
 
974
1053
  path = f"outputs.{output.typ}"
975
- output_data_indices[path] = []
976
- for idx in range(*local_element_idx_range):
1054
+ output_data_indices[path] = [
977
1055
  # iteration_idx, action_idx, and EAR_idx are not known until
978
1056
  # `initialise_EARs`:
979
- param_src = {
980
- "type": "EAR_output",
981
- # "task_insert_ID": self.insert_ID,
982
- # "element_idx": idx,
983
- # "run_idx": 0,
984
- }
985
- data_ref = workflow._add_unset_parameter_data(param_src)
986
- output_data_indices[path].append(data_ref)
1057
+ workflow._add_unset_parameter_data(
1058
+ {
1059
+ "type": "EAR_output",
1060
+ # "task_insert_ID": self.insert_ID,
1061
+ # "element_idx": idx,
1062
+ # "run_idx": 0,
1063
+ }
1064
+ )
1065
+ for idx in range(*local_element_idx_range)
1066
+ ]
987
1067
 
988
1068
  return output_data_indices
989
1069
 
990
- def prepare_element_resolution(self, element_set, input_data_indices):
1070
+ def prepare_element_resolution(
1071
+ self, element_set: ElementSet, input_data_indices: Mapping[str, Sequence]
1072
+ ) -> list[MultiplicityDescriptor]:
991
1073
  """
992
1074
  Set up the resolution of details of elements
993
1075
  (especially multiplicities and how iterations are nested)
994
1076
  within an element set.
995
1077
  """
996
- multiplicities = []
997
- for path_i, inp_idx_i in input_data_indices.items():
998
- multiplicities.append(
999
- {
1000
- "multiplicity": len(inp_idx_i),
1001
- "nesting_order": element_set.nesting_order.get(path_i, -1),
1002
- "path": path_i,
1003
- }
1004
- )
1078
+ multiplicities: list[MultiplicityDescriptor] = [
1079
+ {
1080
+ "multiplicity": len(inp_idx_i),
1081
+ "nesting_order": element_set.nesting_order.get(path_i, -1.0),
1082
+ "path": path_i,
1083
+ }
1084
+ for path_i, inp_idx_i in input_data_indices.items()
1085
+ ]
1005
1086
 
1006
1087
  # if all inputs with non-unit multiplicity have the same multiplicity and a
1007
1088
  # default nesting order of -1 or 0 (which will have probably been set by a
1008
1089
  # `ValueSequence` default), set the non-unit multiplicity inputs to a nesting
1009
1090
  # order of zero:
1010
- non_unit_multis = {}
1011
- unit_multis = []
1091
+ non_unit_multis: dict[int, int] = {}
1092
+ unit_multis: list[int] = []
1012
1093
  change = True
1013
- for idx, i in enumerate(multiplicities):
1014
- if i["multiplicity"] == 1:
1094
+ for idx, descriptor in enumerate(multiplicities):
1095
+ if descriptor["multiplicity"] == 1:
1015
1096
  unit_multis.append(idx)
1097
+ elif descriptor["nesting_order"] in (-1.0, 0.0):
1098
+ non_unit_multis[idx] = descriptor["multiplicity"]
1016
1099
  else:
1017
- if i["nesting_order"] in (-1, 0):
1018
- non_unit_multis[idx] = i["multiplicity"]
1019
- else:
1020
- change = False
1021
- break
1100
+ change = False
1101
+ break
1022
1102
 
1023
1103
  if change and len(set(non_unit_multis.values())) == 1:
1024
1104
  for i_idx in non_unit_multis:
@@ -1027,7 +1107,7 @@ class Task(JSONLike):
1027
1107
  return multiplicities
1028
1108
 
1029
1109
  @property
1030
- def index(self):
1110
+ def index(self) -> int | None:
1031
1111
  """
1032
1112
  The index of this task within the workflow's tasks.
1033
1113
  """
@@ -1037,19 +1117,26 @@ class Task(JSONLike):
1037
1117
  return None
1038
1118
 
1039
1119
  @property
1040
- def output_labels(self):
1120
+ def output_labels(self) -> Sequence[OutputLabel]:
1041
1121
  """
1042
1122
  The labels on the outputs of the task.
1043
1123
  """
1044
1124
  return self._output_labels
1045
1125
 
1046
1126
  @property
1047
- def _element_indices(self):
1048
- return self.workflow_template.workflow.tasks[self.index].element_indices
1127
+ def _element_indices(self) -> list[int] | None:
1128
+ if (
1129
+ self.workflow_template
1130
+ and self.workflow_template.workflow
1131
+ and self.index is not None
1132
+ ):
1133
+ task = self.workflow_template.workflow.tasks[self.index]
1134
+ return [element._index for element in task.elements]
1135
+ return None
1049
1136
 
1050
- def _get_task_source_element_iters(
1051
- self, in_or_out: str, src_task, labelled_path, element_set
1052
- ) -> List[int]:
1137
+ def __get_task_source_element_iters(
1138
+ self, in_or_out: str, src_task: Task, labelled_path: str, element_set: ElementSet
1139
+ ) -> list[int]:
1053
1140
  """Get a sorted list of element iteration IDs that provide either inputs or
1054
1141
  outputs from the provided source task."""
1055
1142
 
@@ -1061,12 +1148,16 @@ class Task(JSONLike):
1061
1148
  es_idx = src_task.get_param_provided_element_sets(labelled_path)
1062
1149
  for es_i in src_task.element_sets:
1063
1150
  # add any element set that has task sources for this parameter
1064
- for inp_src_i in es_i.input_sources.get(labelled_path, []):
1065
- if inp_src_i.source_type is InputSourceType.TASK:
1066
- if es_i.index not in es_idx:
1067
- es_idx.append(es_i.index)
1068
- break
1069
-
1151
+ es_i_idx = es_i.index
1152
+ if (
1153
+ es_i_idx is not None
1154
+ and es_i_idx not in es_idx
1155
+ and any(
1156
+ inp_src_i.source_type is InputSourceType.TASK
1157
+ for inp_src_i in es_i.input_sources.get(labelled_path, ())
1158
+ )
1159
+ ):
1160
+ es_idx.append(es_i_idx)
1070
1161
  else:
1071
1162
  # outputs are always available, so consider all source task
1072
1163
  # element sets:
@@ -1074,11 +1165,11 @@ class Task(JSONLike):
1074
1165
 
1075
1166
  if not es_idx:
1076
1167
  raise NoAvailableElementSetsError()
1077
- else:
1078
- src_elem_iters = []
1079
- for es_idx_i in es_idx:
1080
- es_i = src_task.element_sets[es_idx_i]
1081
- src_elem_iters += es_i.elem_iter_IDs # should be sorted already
1168
+
1169
+ src_elem_iters: list[int] = []
1170
+ for es_i_idx in es_idx:
1171
+ es_i = src_task.element_sets[es_i_idx]
1172
+ src_elem_iters.extend(es_i.elem_iter_IDs) # should be sorted already
1082
1173
 
1083
1174
  if element_set.sourceable_elem_iters is not None:
1084
1175
  # can only use a subset of element iterations (this is the
@@ -1087,36 +1178,73 @@ class Task(JSONLike):
1087
1178
  # added upstream elements when adding elements from this
1088
1179
  # element set):
1089
1180
  src_elem_iters = sorted(
1090
- list(set(element_set.sourceable_elem_iters) & set(src_elem_iters))
1181
+ set(element_set.sourceable_elem_iters).intersection(src_elem_iters)
1091
1182
  )
1092
1183
 
1093
1184
  return src_elem_iters
1094
1185
 
1186
+ @staticmethod
1187
+ def __get_common_path(labelled_path: str, inputs_path: str) -> str | None:
1188
+ lab_s = labelled_path.split(".")
1189
+ inp_s = inputs_path.split(".")
1190
+ try:
1191
+ get_relative_path(lab_s, inp_s)
1192
+ return labelled_path
1193
+ except ValueError:
1194
+ pass
1195
+ try:
1196
+ get_relative_path(inp_s, lab_s)
1197
+ return inputs_path
1198
+ except ValueError:
1199
+ # no intersection between paths
1200
+ return None
1201
+
1202
+ @staticmethod
1203
+ def __filtered_iters(wk_task: WorkflowTask, where: Rule) -> list[int]:
1204
+ param_path = cast("str", where.path)
1205
+ param_prefix, param_name, *param_tail = param_path.split(".")
1206
+ src_elem_iters: list[int] = []
1207
+
1208
+ for elem in wk_task.elements:
1209
+ params: EPP = getattr(elem, param_prefix)
1210
+ param: ElementParameter = getattr(params, param_name)
1211
+ param_dat = param.value
1212
+
1213
+ # for remaining paths components try both getattr and
1214
+ # getitem:
1215
+ for path_k in param_tail:
1216
+ try:
1217
+ param_dat = param_dat[path_k]
1218
+ except TypeError:
1219
+ param_dat = getattr(param_dat, path_k)
1220
+
1221
+ if where._valida_check(param_dat):
1222
+ src_elem_iters.append(elem.iterations[0].id_)
1223
+
1224
+ return src_elem_iters
1225
+
1095
1226
  def get_available_task_input_sources(
1096
1227
  self,
1097
- element_set: app.ElementSet,
1098
- source_tasks: Optional[List[app.WorkflowTask]] = None,
1099
- ) -> List[app.InputSource]:
1228
+ element_set: ElementSet,
1229
+ source_tasks: Sequence[WorkflowTask] = (),
1230
+ ) -> Mapping[str, Sequence[InputSource]]:
1100
1231
  """For each input parameter of this task, generate a list of possible input sources
1101
1232
  that derive from inputs or outputs of this and other provided tasks.
1102
1233
 
1103
1234
  Note this only produces a subset of available input sources for each input
1104
1235
  parameter; other available input sources may exist from workflow imports."""
1105
1236
 
1106
- if source_tasks:
1107
- # ensure parameters provided by later tasks are added to the available sources
1108
- # list first, meaning they take precedence when choosing an input source:
1109
- source_tasks = sorted(source_tasks, key=lambda x: x.index, reverse=True)
1110
- else:
1111
- source_tasks = []
1237
+ # ensure parameters provided by later tasks are added to the available sources
1238
+ # list first, meaning they take precedence when choosing an input source:
1239
+ source_tasks = sorted(source_tasks, key=lambda x: x.index, reverse=True)
1112
1240
 
1113
- available = {}
1241
+ available: dict[str, list[InputSource]] = {}
1114
1242
  for inputs_path, inp_status in self.get_input_statuses(element_set).items():
1115
1243
  # local specification takes precedence:
1116
1244
  if inputs_path in element_set.get_locally_defined_inputs():
1117
- if inputs_path not in available:
1118
- available[inputs_path] = []
1119
- available[inputs_path].append(self.app.InputSource.local())
1245
+ available.setdefault(inputs_path, []).append(
1246
+ self._app.InputSource.local()
1247
+ )
1120
1248
 
1121
1249
  # search for task sources:
1122
1250
  for src_wk_task_i in source_tasks:
@@ -1129,77 +1257,47 @@ class Task(JSONLike):
1129
1257
  key=lambda x: x[0],
1130
1258
  reverse=True,
1131
1259
  ):
1132
- src_elem_iters = []
1133
- lab_s = labelled_path.split(".")
1134
- inp_s = inputs_path.split(".")
1135
- try:
1136
- get_relative_path(lab_s, inp_s)
1137
- except ValueError:
1260
+ src_elem_iters: list[int] = []
1261
+ common = self.__get_common_path(labelled_path, inputs_path)
1262
+ if common is not None:
1263
+ avail_src_path = common
1264
+ else:
1265
+ # no intersection between paths
1266
+ inputs_path_label = None
1267
+ out_label = None
1268
+ unlabelled, inputs_path_label = split_param_label(inputs_path)
1269
+ if unlabelled is None:
1270
+ continue
1138
1271
  try:
1139
- get_relative_path(inp_s, lab_s)
1272
+ get_relative_path(
1273
+ unlabelled.split("."), labelled_path.split(".")
1274
+ )
1275
+ avail_src_path = inputs_path
1140
1276
  except ValueError:
1141
- # no intersection between paths
1142
- inputs_path_label = None
1143
- out_label = None
1144
- unlabelled, inputs_path_label = split_param_label(inputs_path)
1145
- try:
1146
- get_relative_path(unlabelled.split("."), lab_s)
1147
- avail_src_path = inputs_path
1148
- except ValueError:
1149
- continue
1150
-
1151
- if inputs_path_label:
1152
- for out_lab_i in src_task_i.output_labels:
1153
- if out_lab_i.label == inputs_path_label:
1154
- out_label = out_lab_i
1277
+ continue
1278
+ if not inputs_path_label:
1279
+ continue
1280
+ for out_lab_i in src_task_i.output_labels:
1281
+ if out_lab_i.label == inputs_path_label:
1282
+ out_label = out_lab_i
1283
+
1284
+ # consider output labels
1285
+ if out_label and in_or_out == "output":
1286
+ # find element iteration IDs that match the output label
1287
+ # filter:
1288
+ if out_label.where:
1289
+ src_elem_iters = self.__filtered_iters(
1290
+ src_wk_task_i, out_label.where
1291
+ )
1155
1292
  else:
1156
- continue
1157
-
1158
- # consider output labels
1159
- if out_label and in_or_out == "output":
1160
- # find element iteration IDs that match the output label
1161
- # filter:
1162
- if out_label.where:
1163
- param_path_split = out_label.where.path.split(".")
1164
-
1165
- for elem_i in src_wk_task_i.elements:
1166
- params = getattr(elem_i, param_path_split[0])
1167
- param_dat = getattr(
1168
- params, param_path_split[1]
1169
- ).value
1170
-
1171
- # for remaining paths components try both getattr and
1172
- # getitem:
1173
- for path_k in param_path_split[2:]:
1174
- try:
1175
- param_dat = param_dat[path_k]
1176
- except TypeError:
1177
- param_dat = getattr(param_dat, path_k)
1178
-
1179
- rule = Rule(
1180
- path=[0],
1181
- condition=out_label.where.condition,
1182
- cast=out_label.where.cast,
1183
- )
1184
- if rule.test([param_dat]).is_valid:
1185
- src_elem_iters.append(
1186
- elem_i.iterations[0].id_
1187
- )
1188
- else:
1189
- src_elem_iters = [
1190
- elem_i.iterations[0].id_
1191
- for elem_i in src_wk_task_i.elements
1192
- ]
1193
-
1194
- else:
1195
- avail_src_path = inputs_path
1196
-
1197
- else:
1198
- avail_src_path = labelled_path
1293
+ src_elem_iters = [
1294
+ elem_i.iterations[0].id_
1295
+ for elem_i in src_wk_task_i.elements
1296
+ ]
1199
1297
 
1200
1298
  if not src_elem_iters:
1201
1299
  try:
1202
- src_elem_iters = self._get_task_source_element_iters(
1300
+ src_elem_iters = self.__get_task_source_element_iters(
1203
1301
  in_or_out=in_or_out,
1204
1302
  src_task=src_task_i,
1205
1303
  labelled_path=labelled_path,
@@ -1207,36 +1305,32 @@ class Task(JSONLike):
1207
1305
  )
1208
1306
  except NoAvailableElementSetsError:
1209
1307
  continue
1308
+ if not src_elem_iters:
1309
+ continue
1210
1310
 
1211
- if not src_elem_iters:
1212
- continue
1213
-
1214
- task_source = self.app.InputSource.task(
1215
- task_ref=src_task_i.insert_ID,
1216
- task_source_type=in_or_out,
1217
- element_iters=src_elem_iters,
1311
+ available.setdefault(avail_src_path, []).append(
1312
+ self._app.InputSource.task(
1313
+ task_ref=src_task_i.insert_ID,
1314
+ task_source_type=in_or_out,
1315
+ element_iters=src_elem_iters,
1316
+ )
1218
1317
  )
1219
1318
 
1220
- if avail_src_path not in available:
1221
- available[avail_src_path] = []
1222
-
1223
- available[avail_src_path].append(task_source)
1224
-
1225
1319
  if inp_status.has_default:
1226
- if inputs_path not in available:
1227
- available[inputs_path] = []
1228
- available[inputs_path].append(self.app.InputSource.default())
1320
+ available.setdefault(inputs_path, []).append(
1321
+ self._app.InputSource.default()
1322
+ )
1229
1323
  return available
1230
1324
 
1231
1325
  @property
1232
- def schemas(self) -> List[app.TaskSchema]:
1326
+ def schemas(self) -> list[TaskSchema]:
1233
1327
  """
1234
1328
  All the task schemas.
1235
1329
  """
1236
1330
  return self._schemas
1237
1331
 
1238
1332
  @property
1239
- def schema(self) -> app.TaskSchema:
1333
+ def schema(self) -> TaskSchema:
1240
1334
  """The single task schema, if only one, else raises."""
1241
1335
  if len(self._schemas) == 1:
1242
1336
  return self._schemas[0]
@@ -1247,79 +1341,86 @@ class Task(JSONLike):
1247
1341
  )
1248
1342
 
1249
1343
  @property
1250
- def element_sets(self):
1344
+ def element_sets(self) -> list[ElementSet]:
1251
1345
  """
1252
1346
  The element sets.
1253
1347
  """
1254
1348
  return self._element_sets + self._pending_element_sets
1255
1349
 
1256
1350
  @property
1257
- def num_element_sets(self):
1351
+ def num_element_sets(self) -> int:
1258
1352
  """
1259
1353
  The number of element sets.
1260
1354
  """
1261
- return len(self.element_sets)
1355
+ return len(self._element_sets) + len(self._pending_element_sets)
1262
1356
 
1263
1357
  @property
1264
- def insert_ID(self):
1358
+ def insert_ID(self) -> int:
1265
1359
  """
1266
1360
  Insertion ID.
1267
1361
  """
1362
+ assert self._insert_ID is not None
1268
1363
  return self._insert_ID
1269
1364
 
1270
1365
  @property
1271
- def dir_name(self):
1366
+ def dir_name(self) -> str:
1272
1367
  """
1273
1368
  Artefact directory name.
1274
1369
  """
1370
+ assert self._dir_name is not None
1275
1371
  return self._dir_name
1276
1372
 
1277
1373
  @property
1278
- def name(self):
1374
+ def name(self) -> str:
1279
1375
  """
1280
1376
  Task name.
1281
1377
  """
1282
1378
  return self._name
1283
1379
 
1284
1380
  @property
1285
- def objective(self):
1381
+ def objective(self) -> TaskObjective:
1286
1382
  """
1287
1383
  The goal of this task.
1288
1384
  """
1289
- return self.schemas[0].objective
1385
+ obj = self.schemas[0].objective
1386
+ return obj
1290
1387
 
1291
1388
  @property
1292
- def all_schema_inputs(self) -> Tuple[app.SchemaInput]:
1389
+ def all_schema_inputs(self) -> tuple[SchemaInput, ...]:
1293
1390
  """
1294
1391
  The inputs to this task's schemas.
1295
1392
  """
1296
1393
  return tuple(inp_j for schema_i in self.schemas for inp_j in schema_i.inputs)
1297
1394
 
1298
1395
  @property
1299
- def all_schema_outputs(self) -> Tuple[app.SchemaOutput]:
1396
+ def all_schema_outputs(self) -> tuple[SchemaOutput, ...]:
1300
1397
  """
1301
1398
  The outputs from this task's schemas.
1302
1399
  """
1303
1400
  return tuple(inp_j for schema_i in self.schemas for inp_j in schema_i.outputs)
1304
1401
 
1305
1402
  @property
1306
- def all_schema_input_types(self):
1307
- """The set of all schema input types (over all specified schemas)."""
1403
+ def all_schema_input_types(self) -> set[str]:
1404
+ """
1405
+ The set of all schema input types (over all specified schemas).
1406
+ """
1308
1407
  return {inp_j for schema_i in self.schemas for inp_j in schema_i.input_types}
1309
1408
 
1310
1409
  @property
1311
- def all_schema_input_normalised_paths(self):
1410
+ def all_schema_input_normalised_paths(self) -> set[str]:
1312
1411
  """
1313
1412
  Normalised paths for all schema input types.
1314
1413
  """
1315
- return {f"inputs.{i}" for i in self.all_schema_input_types}
1414
+ return {f"inputs.{typ}" for typ in self.all_schema_input_types}
1316
1415
 
1317
1416
  @property
1318
- def all_schema_output_types(self):
1319
- """The set of all schema output types (over all specified schemas)."""
1417
+ def all_schema_output_types(self) -> set[str]:
1418
+ """
1419
+ The set of all schema output types (over all specified schemas).
1420
+ """
1320
1421
  return {out_j for schema_i in self.schemas for out_j in schema_i.output_types}
1321
1422
 
1322
- def get_schema_action(self, idx):
1423
+ def get_schema_action(self, idx: int) -> Action: #
1323
1424
  """
1324
1425
  Get the schema action at the given index.
1325
1426
  """
@@ -1331,7 +1432,7 @@ class Task(JSONLike):
1331
1432
  _idx += 1
1332
1433
  raise ValueError(f"No action in task {self.name!r} with index {idx!r}.")
1333
1434
 
1334
- def all_schema_actions(self) -> Iterator[Tuple[int, app.Action]]:
1435
+ def all_schema_actions(self) -> Iterator[tuple[int, Action]]:
1335
1436
  """
1336
1437
  Get all the schema actions and their indices.
1337
1438
  """
@@ -1346,54 +1447,56 @@ class Task(JSONLike):
1346
1447
  """
1347
1448
  The total number of schema actions.
1348
1449
  """
1349
- num = 0
1350
- for schema in self.schemas:
1351
- for _ in schema.actions:
1352
- num += 1
1353
- return num
1450
+ return sum(len(schema.actions) for schema in self.schemas)
1354
1451
 
1355
1452
  @property
1356
- def all_sourced_normalised_paths(self):
1453
+ def all_sourced_normalised_paths(self) -> set[str]:
1357
1454
  """
1358
1455
  All the sourced normalised paths, including of sub-values.
1359
1456
  """
1360
- sourced_input_types = []
1457
+ sourced_input_types: set[str] = set()
1361
1458
  for elem_set in self.element_sets:
1362
- for inp in elem_set.inputs:
1363
- if inp.is_sub_value:
1364
- sourced_input_types.append(inp.normalised_path)
1365
- for seq in elem_set.sequences:
1366
- if seq.is_sub_value:
1367
- sourced_input_types.append(seq.normalised_path)
1368
- return set(sourced_input_types) | self.all_schema_input_normalised_paths
1369
-
1370
- def is_input_type_required(self, typ: str, element_set: app.ElementSet) -> bool:
1459
+ sourced_input_types.update(
1460
+ inp.normalised_path for inp in elem_set.inputs if inp.is_sub_value
1461
+ )
1462
+ sourced_input_types.update(
1463
+ seq.normalised_path for seq in elem_set.sequences if seq.is_sub_value
1464
+ )
1465
+ return sourced_input_types | self.all_schema_input_normalised_paths
1466
+
1467
+ def is_input_type_required(self, typ: str, element_set: ElementSet) -> bool:
1371
1468
  """Check if an given input type must be specified in the parametrisation of this
1372
1469
  element set.
1373
1470
 
1374
1471
  A schema input need not be specified if it is only required to generate an input
1375
1472
  file, and that input file is passed directly."""
1376
1473
 
1377
- provided_files = [i.file for i in element_set.input_files]
1474
+ provided_files = {in_file.file for in_file in element_set.input_files}
1378
1475
  for schema in self.schemas:
1379
1476
  if not schema.actions:
1380
1477
  return True # for empty tasks that are used merely for defining inputs
1381
- for act in schema.actions:
1382
- if act.is_input_type_required(typ, provided_files):
1383
- return True
1478
+ if any(
1479
+ act.is_input_type_required(typ, provided_files) for act in schema.actions
1480
+ ):
1481
+ return True
1384
1482
 
1385
1483
  return False
1386
1484
 
1387
- def get_param_provided_element_sets(self, labelled_path: str) -> List[int]:
1485
+ def get_param_provided_element_sets(self, labelled_path: str) -> list[int]:
1388
1486
  """Get the element set indices of this task for which a specified parameter type
1389
- is locally provided."""
1390
- es_idx = []
1391
- for idx, src_es in enumerate(self.element_sets):
1392
- if src_es.is_input_type_provided(labelled_path):
1393
- es_idx.append(idx)
1394
- return es_idx
1395
-
1396
- def get_input_statuses(self, elem_set: app.ElementSet) -> Dict[str, InputStatus]:
1487
+ is locally provided.
1488
+
1489
+ Note
1490
+ ----
1491
+ Caller may freely modify this result.
1492
+ """
1493
+ return [
1494
+ idx
1495
+ for idx, src_es in enumerate(self.element_sets)
1496
+ if src_es.is_input_type_provided(labelled_path)
1497
+ ]
1498
+
1499
+ def get_input_statuses(self, elem_set: ElementSet) -> Mapping[str, InputStatus]:
1397
1500
  """Get a dict whose keys are normalised input paths (without the "inputs" prefix),
1398
1501
  and whose values are InputStatus objects.
1399
1502
 
@@ -1401,10 +1504,9 @@ class Task(JSONLike):
1401
1504
  ----------
1402
1505
  elem_set
1403
1506
  The element set for which input statuses should be returned.
1404
-
1405
1507
  """
1406
1508
 
1407
- status = {}
1509
+ status: dict[str, InputStatus] = {}
1408
1510
  for schema_input in self.all_schema_inputs:
1409
1511
  for lab_info in schema_input.labelled_info():
1410
1512
  labelled_type = lab_info["labelled_type"]
@@ -1427,31 +1529,36 @@ class Task(JSONLike):
1427
1529
  return status
1428
1530
 
1429
1531
  @property
1430
- def universal_input_types(self):
1532
+ def universal_input_types(self) -> set[str]:
1431
1533
  """Get input types that are associated with all schemas"""
1432
1534
  raise NotImplementedError()
1433
1535
 
1434
1536
  @property
1435
- def non_universal_input_types(self):
1537
+ def non_universal_input_types(self) -> set[str]:
1436
1538
  """Get input types for each schema that are non-universal."""
1437
1539
  raise NotImplementedError()
1438
1540
 
1439
1541
  @property
1440
- def defined_input_types(self):
1542
+ def defined_input_types(self) -> set[str]:
1441
1543
  """
1442
- The input types defined by this task.
1544
+ The input types defined by this task, being the input types defined by any of
1545
+ its element sets.
1443
1546
  """
1444
- return self._defined_input_types
1547
+ dit: set[str] = set()
1548
+ for es in self.element_sets:
1549
+ dit.update(es.defined_input_types)
1550
+ return dit
1551
+ # TODO: Is this right?
1445
1552
 
1446
1553
  @property
1447
- def undefined_input_types(self):
1554
+ def undefined_input_types(self) -> set[str]:
1448
1555
  """
1449
1556
  The schema's input types that this task doesn't define.
1450
1557
  """
1451
1558
  return self.all_schema_input_types - self.defined_input_types
1452
1559
 
1453
1560
  @property
1454
- def undefined_inputs(self):
1561
+ def undefined_inputs(self) -> list[SchemaInput]:
1455
1562
  """
1456
1563
  The task's inputs that are undefined.
1457
1564
  """
@@ -1462,41 +1569,36 @@ class Task(JSONLike):
1462
1569
  if inp_j.typ in self.undefined_input_types
1463
1570
  ]
1464
1571
 
1465
- def provides_parameters(self) -> Tuple[Tuple[str, str]]:
1572
+ def provides_parameters(self) -> tuple[tuple[str, str], ...]:
1466
1573
  """Get all provided parameter labelled types and whether they are inputs and
1467
1574
  outputs, considering all element sets.
1468
1575
 
1469
1576
  """
1470
- out = []
1577
+ out: dict[tuple[str, str], None] = {}
1471
1578
  for schema in self.schemas:
1472
- for in_or_out, labelled_type in schema.provides_parameters:
1473
- out.append((in_or_out, labelled_type))
1579
+ out.update(dict.fromkeys(schema.provides_parameters))
1474
1580
 
1475
1581
  # add sub-parameter input values and sequences:
1476
1582
  for es_i in self.element_sets:
1477
1583
  for inp_j in es_i.inputs:
1478
1584
  if inp_j.is_sub_value:
1479
- val_j = ("input", inp_j.normalised_inputs_path)
1480
- if val_j not in out:
1481
- out.append(val_j)
1585
+ out["input", inp_j.normalised_inputs_path] = None
1482
1586
  for seq_j in es_i.sequences:
1483
- if seq_j.is_sub_value:
1484
- val_j = ("input", seq_j.normalised_inputs_path)
1485
- if val_j not in out:
1486
- out.append(val_j)
1587
+ if seq_j.is_sub_value and (path := seq_j.normalised_inputs_path):
1588
+ out["input", path] = None
1487
1589
 
1488
1590
  return tuple(out)
1489
1591
 
1490
1592
  def add_group(
1491
- self, name: str, where: app.ElementFilter, group_by_distinct: app.ParameterPath
1593
+ self, name: str, where: ElementFilter, group_by_distinct: ParameterPath
1492
1594
  ):
1493
1595
  """
1494
1596
  Add an element group to this task.
1495
1597
  """
1496
1598
  group = ElementGroup(name=name, where=where, group_by_distinct=group_by_distinct)
1497
- self.groups.add_object(group)
1599
+ self.__groups.add_object(group)
1498
1600
 
1499
- def _get_single_label_lookup(self, prefix="") -> Dict[str, str]:
1601
+ def _get_single_label_lookup(self, prefix: str = "") -> Mapping[str, str]:
1500
1602
  """Get a mapping between schema input types that have a single label (i.e.
1501
1603
  labelled but with `multiple=False`) and the non-labelled type string.
1502
1604
 
@@ -1509,13 +1611,18 @@ class Task(JSONLike):
1509
1611
  `{"inputs.p1[one]": "inputs.p1"}`.
1510
1612
 
1511
1613
  """
1512
- lookup = {}
1513
- for i in self.schemas:
1514
- lookup.update(i._get_single_label_lookup(prefix=prefix))
1614
+ lookup: dict[str, str] = {}
1615
+ for schema in self.schemas:
1616
+ lookup.update(schema._get_single_label_lookup(prefix=prefix))
1515
1617
  return lookup
1516
1618
 
1517
1619
 
1518
- class WorkflowTask:
1620
+ class _ESIdx(NamedTuple):
1621
+ ordered: list[int]
1622
+ uniq: frozenset[int]
1623
+
1624
+
1625
+ class WorkflowTask(AppAware):
1519
1626
  """
1520
1627
  Represents a :py:class:`Task` that is bound to a :py:class:`Workflow`.
1521
1628
 
@@ -1531,14 +1638,12 @@ class WorkflowTask:
1531
1638
  The IDs of the elements of this task.
1532
1639
  """
1533
1640
 
1534
- _app_attr = "app"
1535
-
1536
1641
  def __init__(
1537
1642
  self,
1538
- workflow: app.Workflow,
1539
- template: app.Task,
1643
+ workflow: Workflow,
1644
+ template: Task,
1540
1645
  index: int,
1541
- element_IDs: List[int],
1646
+ element_IDs: list[int],
1542
1647
  ):
1543
1648
  self._workflow = workflow
1544
1649
  self._template = template
@@ -1546,9 +1651,9 @@ class WorkflowTask:
1546
1651
  self._element_IDs = element_IDs
1547
1652
 
1548
1653
  # appended to when new elements are added and reset on dump to disk:
1549
- self._pending_element_IDs = []
1654
+ self._pending_element_IDs: list[int] = []
1550
1655
 
1551
- self._elements = None # assigned on `elements` first access
1656
+ self._elements: Elements | None = None # assigned on `elements` first access
1552
1657
 
1553
1658
  def __repr__(self) -> str:
1554
1659
  return f"{self.__class__.__name__}(name={self.unique_name!r})"
@@ -1561,7 +1666,7 @@ class WorkflowTask:
1561
1666
  self._reset_pending_element_IDs()
1562
1667
 
1563
1668
  @classmethod
1564
- def new_empty_task(cls, workflow: app.Workflow, template: app.Task, index: int):
1669
+ def new_empty_task(cls, workflow: Workflow, template: Task, index: int) -> Self:
1565
1670
  """
1566
1671
  Make a new instance without any elements set up yet.
1567
1672
 
@@ -1574,86 +1679,88 @@ class WorkflowTask:
1574
1679
  index:
1575
1680
  Where in the workflow's list of tasks is this one.
1576
1681
  """
1577
- obj = cls(
1682
+ return cls(
1578
1683
  workflow=workflow,
1579
1684
  template=template,
1580
1685
  index=index,
1581
1686
  element_IDs=[],
1582
1687
  )
1583
- return obj
1584
1688
 
1585
1689
  @property
1586
- def workflow(self):
1690
+ def workflow(self) -> Workflow:
1587
1691
  """
1588
1692
  The workflow this task is bound to.
1589
1693
  """
1590
1694
  return self._workflow
1591
1695
 
1592
1696
  @property
1593
- def template(self):
1697
+ def template(self) -> Task:
1594
1698
  """
1595
1699
  The template for this task.
1596
1700
  """
1597
1701
  return self._template
1598
1702
 
1599
1703
  @property
1600
- def index(self):
1704
+ def index(self) -> int:
1601
1705
  """
1602
1706
  The index of this task within its workflow.
1603
1707
  """
1604
1708
  return self._index
1605
1709
 
1606
1710
  @property
1607
- def element_IDs(self):
1711
+ def element_IDs(self) -> list[int]:
1608
1712
  """
1609
1713
  The IDs of elements associated with this task.
1610
1714
  """
1611
1715
  return self._element_IDs + self._pending_element_IDs
1612
1716
 
1613
1717
  @property
1614
- def num_elements(self):
1718
+ @TimeIt.decorator
1719
+ def num_elements(self) -> int:
1615
1720
  """
1616
1721
  The number of elements associated with this task.
1617
1722
  """
1618
- return len(self.element_IDs)
1723
+ return len(self._element_IDs) + len(self._pending_element_IDs)
1619
1724
 
1620
1725
  @property
1621
- def num_actions(self):
1726
+ def num_actions(self) -> int:
1622
1727
  """
1623
1728
  The number of actions in this task.
1624
1729
  """
1625
1730
  return self.template.num_all_schema_actions
1626
1731
 
1627
1732
  @property
1628
- def name(self):
1733
+ def name(self) -> str:
1629
1734
  """
1630
1735
  The name of this task based on its template.
1631
1736
  """
1632
1737
  return self.template.name
1633
1738
 
1634
1739
  @property
1635
- def unique_name(self):
1740
+ def unique_name(self) -> str:
1636
1741
  """
1637
1742
  The unique name for this task specifically.
1638
1743
  """
1639
1744
  return self.workflow.get_task_unique_names()[self.index]
1640
1745
 
1641
1746
  @property
1642
- def insert_ID(self):
1747
+ def insert_ID(self) -> int:
1643
1748
  """
1644
1749
  The insertion ID of the template task.
1645
1750
  """
1646
1751
  return self.template.insert_ID
1647
1752
 
1648
1753
  @property
1649
- def dir_name(self):
1754
+ def dir_name(self) -> str:
1650
1755
  """
1651
1756
  The name of the directory for the task's temporary files.
1652
1757
  """
1653
- return self.template.dir_name
1758
+ dn = self.template.dir_name
1759
+ assert dn is not None
1760
+ return dn
1654
1761
 
1655
1762
  @property
1656
- def num_element_sets(self):
1763
+ def num_element_sets(self) -> int:
1657
1764
  """
1658
1765
  The number of element sets associated with this task.
1659
1766
  """
@@ -1661,15 +1768,15 @@ class WorkflowTask:
1661
1768
 
1662
1769
  @property
1663
1770
  @TimeIt.decorator
1664
- def elements(self):
1771
+ def elements(self) -> Elements:
1665
1772
  """
1666
1773
  The elements associated with this task.
1667
1774
  """
1668
1775
  if self._elements is None:
1669
- self._elements = self.app.Elements(self)
1776
+ self._elements = Elements(self)
1670
1777
  return self._elements
1671
1778
 
1672
- def get_dir_name(self, loop_idx: Dict[str, int] = None) -> str:
1779
+ def get_dir_name(self, loop_idx: Mapping[str, int] | None = None) -> str:
1673
1780
  """
1674
1781
  Get the directory name for a particular iteration.
1675
1782
  """
@@ -1677,15 +1784,102 @@ class WorkflowTask:
1677
1784
  return self.dir_name
1678
1785
  return self.dir_name + "_" + "_".join((f"{k}-{v}" for k, v in loop_idx.items()))
1679
1786
 
1680
- def get_all_element_iterations(self) -> Dict[int, app.ElementIteration]:
1787
+ def get_all_element_iterations(self) -> Mapping[int, ElementIteration]:
1681
1788
  """
1682
1789
  Get the iterations known by the task's elements.
1683
1790
  """
1684
- return {j.id_: j for i in self.elements for j in i.iterations}
1791
+ return {itr.id_: itr for elem in self.elements for itr in elem.iterations}
1685
1792
 
1686
- def _make_new_elements_persistent(
1687
- self, element_set, element_set_idx, padded_elem_iters
1688
- ):
1793
+ @staticmethod
1794
+ def __get_src_elem_iters(
1795
+ src_task: WorkflowTask, inp_src: InputSource
1796
+ ) -> tuple[Iterable[ElementIteration], list[int]]:
1797
+ src_iters = src_task.get_all_element_iterations()
1798
+
1799
+ if inp_src.element_iters:
1800
+ # only include "sourceable" element iterations:
1801
+ src_iters_list = [src_iters[itr_id] for itr_id in inp_src.element_iters]
1802
+ set_indices = [el.element.element_set_idx for el in src_iters.values()]
1803
+ return src_iters_list, set_indices
1804
+ return src_iters.values(), []
1805
+
1806
+ def __get_task_group_index(
1807
+ self,
1808
+ labelled_path_i: str,
1809
+ inp_src: InputSource,
1810
+ padded_elem_iters: Mapping[str, Sequence[int]],
1811
+ inp_group_name: str | None,
1812
+ ) -> None | Sequence[int | list[int]]:
1813
+ src_task = inp_src.get_task(self.workflow)
1814
+ assert src_task
1815
+ src_elem_iters, src_elem_set_idx = self.__get_src_elem_iters(src_task, inp_src)
1816
+
1817
+ if not src_elem_iters:
1818
+ return None
1819
+
1820
+ task_source_type = inp_src.task_source_type
1821
+ assert task_source_type is not None
1822
+ if task_source_type == TaskSourceType.OUTPUT and "[" in labelled_path_i:
1823
+ src_key = f"{task_source_type.name.lower()}s.{labelled_path_i.split('[')[0]}"
1824
+ else:
1825
+ src_key = f"{task_source_type.name.lower()}s.{labelled_path_i}"
1826
+
1827
+ padded_iters = padded_elem_iters.get(labelled_path_i, [])
1828
+ grp_idx = [
1829
+ (itr.get_data_idx()[src_key] if itr_idx not in padded_iters else -1)
1830
+ for itr_idx, itr in enumerate(src_elem_iters)
1831
+ ]
1832
+
1833
+ if not inp_group_name:
1834
+ return grp_idx
1835
+
1836
+ group_dat_idx: list[int | list[int]] = []
1837
+ element_sets = src_task.template.element_sets
1838
+ for dat_idx, src_set_idx, src_iter in zip(
1839
+ grp_idx, src_elem_set_idx, src_elem_iters
1840
+ ):
1841
+ src_es = element_sets[src_set_idx]
1842
+ if any(inp_group_name == grp.name for grp in src_es.groups):
1843
+ group_dat_idx.append(dat_idx)
1844
+ continue
1845
+ # if for any recursive iteration dependency, this group is
1846
+ # defined, assign:
1847
+ src_iter_deps = self.workflow.get_element_iterations_from_IDs(
1848
+ src_iter.get_element_iteration_dependencies(),
1849
+ )
1850
+
1851
+ if any(
1852
+ inp_group_name == grp.name
1853
+ for el_iter in src_iter_deps
1854
+ for grp in el_iter.element.element_set.groups
1855
+ ):
1856
+ group_dat_idx.append(dat_idx)
1857
+ continue
1858
+
1859
+ # also check input dependencies
1860
+ for p_src in src_iter.element.get_input_dependencies().values():
1861
+ k_es = self.workflow.tasks.get(
1862
+ insert_ID=p_src["task_insert_ID"]
1863
+ ).template.element_sets[p_src["element_set_idx"]]
1864
+ if any(inp_group_name == grp.name for grp in k_es.groups):
1865
+ group_dat_idx.append(dat_idx)
1866
+ break
1867
+
1868
+ # TODO: this only goes to one level of dependency
1869
+
1870
+ if not group_dat_idx:
1871
+ raise MissingElementGroup(self.unique_name, inp_group_name, labelled_path_i)
1872
+
1873
+ return [cast("int", group_dat_idx)] # TODO: generalise to multiple groups
1874
+
1875
+ def __make_new_elements_persistent(
1876
+ self,
1877
+ element_set: ElementSet,
1878
+ element_set_idx: int,
1879
+ padded_elem_iters: Mapping[str, Sequence[int]],
1880
+ ) -> tuple[
1881
+ dict[str, list[int | list[int]]], dict[str, Sequence[int]], dict[str, list[int]]
1882
+ ]:
1689
1883
  """Save parameter data to the persistent workflow."""
1690
1884
 
1691
1885
  # TODO: rewrite. This method is a little hard to follow and results in somewhat
@@ -1693,25 +1887,25 @@ class WorkflowTask:
1693
1887
  # given input, the local source element(s) will always come first, regardless of
1694
1888
  # the ordering in element_set.input_sources.
1695
1889
 
1696
- input_data_idx = {}
1697
- sequence_idx = {}
1698
- source_idx = {}
1890
+ input_data_idx: dict[str, list[int | list[int]]] = {}
1891
+ sequence_idx: dict[str, Sequence[int]] = {}
1892
+ source_idx: dict[str, list[int]] = {}
1699
1893
 
1700
1894
  # Assign first assuming all locally defined values are to be used:
1701
- param_src = {
1895
+ param_src: ParamSource = {
1702
1896
  "type": "local_input",
1703
1897
  "task_insert_ID": self.insert_ID,
1704
1898
  "element_set_idx": element_set_idx,
1705
1899
  }
1706
- loc_inp_src = self.app.InputSource.local()
1900
+ loc_inp_src = self._app.InputSource.local()
1707
1901
  for res_i in element_set.resources:
1708
1902
  key, dat_ref, _ = res_i.make_persistent(self.workflow, param_src)
1709
- input_data_idx[key] = dat_ref
1903
+ input_data_idx[key] = list(dat_ref)
1710
1904
 
1711
1905
  for inp_i in element_set.inputs:
1712
1906
  key, dat_ref, _ = inp_i.make_persistent(self.workflow, param_src)
1713
- input_data_idx[key] = dat_ref
1714
- key_ = key.split("inputs.")[1]
1907
+ input_data_idx[key] = list(dat_ref)
1908
+ key_ = key.removeprefix("inputs.")
1715
1909
  try:
1716
1910
  # TODO: wouldn't need to do this if we raise when an InputValue is
1717
1911
  # provided for a parameter whose inputs sources do not include the local
@@ -1721,39 +1915,39 @@ class WorkflowTask:
1721
1915
  pass
1722
1916
 
1723
1917
  for inp_file_i in element_set.input_files:
1724
- key, dat_ref, _ = inp_file_i.make_persistent(self.workflow, param_src)
1725
- input_data_idx[key] = dat_ref
1918
+ key, input_dat_ref, _ = inp_file_i.make_persistent(self.workflow, param_src)
1919
+ input_data_idx[key] = list(input_dat_ref)
1726
1920
 
1727
1921
  for seq_i in element_set.sequences:
1728
- key, dat_ref, _ = seq_i.make_persistent(self.workflow, param_src)
1729
- input_data_idx[key] = dat_ref
1730
- sequence_idx[key] = list(range(len(dat_ref)))
1922
+ key, seq_dat_ref, _ = seq_i.make_persistent(self.workflow, param_src)
1923
+ input_data_idx[key] = list(seq_dat_ref)
1924
+ sequence_idx[key] = list(range(len(seq_dat_ref)))
1731
1925
  try:
1732
1926
  key_ = key.split("inputs.")[1]
1733
1927
  except IndexError:
1734
- pass
1928
+ # e.g. "resources."
1929
+ key_ = ""
1735
1930
  try:
1736
1931
  # TODO: wouldn't need to do this if we raise when an ValueSequence is
1737
1932
  # provided for a parameter whose inputs sources do not include the local
1738
1933
  # value.
1739
- source_idx[key] = [
1740
- element_set.input_sources[key_].index(loc_inp_src)
1741
- ] * len(dat_ref)
1934
+ if key_:
1935
+ source_idx[key] = [
1936
+ element_set.input_sources[key_].index(loc_inp_src)
1937
+ ] * len(seq_dat_ref)
1742
1938
  except ValueError:
1743
1939
  pass
1744
1940
 
1745
1941
  for rep_spec in element_set.repeats:
1746
1942
  seq_key = f"repeats.{rep_spec['name']}"
1747
- num_range = list(range(rep_spec["number"]))
1748
- input_data_idx[seq_key] = num_range
1943
+ num_range = range(rep_spec["number"])
1944
+ input_data_idx[seq_key] = list(num_range)
1749
1945
  sequence_idx[seq_key] = num_range
1750
1946
 
1751
1947
  # Now check for task- and default-sources and overwrite or append to local sources:
1752
1948
  inp_stats = self.template.get_input_statuses(element_set)
1753
1949
  for labelled_path_i, sources_i in element_set.input_sources.items():
1754
- path_i_split = labelled_path_i.split(".")
1755
- is_path_i_sub = len(path_i_split) > 1
1756
- if is_path_i_sub:
1950
+ if len(path_i_split := labelled_path_i.split(".")) > 1:
1757
1951
  path_i_root = path_i_split[0]
1758
1952
  else:
1759
1953
  path_i_root = labelled_path_i
@@ -1773,114 +1967,36 @@ class WorkflowTask:
1773
1967
 
1774
1968
  for inp_src_idx, inp_src in enumerate(sources_i):
1775
1969
  if inp_src.source_type is InputSourceType.TASK:
1776
- src_task = inp_src.get_task(self.workflow)
1777
- src_elem_iters = src_task.get_all_element_iterations()
1778
-
1779
- if inp_src.element_iters:
1780
- # only include "sourceable" element iterations:
1781
- src_elem_iters = [
1782
- src_elem_iters[i] for i in inp_src.element_iters
1783
- ]
1784
- src_elem_set_idx = [
1785
- i.element.element_set_idx for i in src_elem_iters
1786
- ]
1787
-
1788
- if not src_elem_iters:
1970
+ grp_idx = self.__get_task_group_index(
1971
+ labelled_path_i, inp_src, padded_elem_iters, inp_group_name
1972
+ )
1973
+ if grp_idx is None:
1789
1974
  continue
1790
1975
 
1791
- task_source_type = inp_src.task_source_type.name.lower()
1792
- if task_source_type == "output" and "[" in labelled_path_i:
1793
- src_key = f"{task_source_type}s.{labelled_path_i.split('[')[0]}"
1794
- else:
1795
- src_key = f"{task_source_type}s.{labelled_path_i}"
1796
-
1797
- padded_iters = padded_elem_iters.get(labelled_path_i, [])
1798
- grp_idx = [
1799
- (
1800
- iter_i.get_data_idx()[src_key]
1801
- if iter_i_idx not in padded_iters
1802
- else -1
1803
- )
1804
- for iter_i_idx, iter_i in enumerate(src_elem_iters)
1805
- ]
1806
-
1807
- if inp_group_name:
1808
- group_dat_idx = []
1809
- for dat_idx_i, src_set_idx_i, src_iter in zip(
1810
- grp_idx, src_elem_set_idx, src_elem_iters
1811
- ):
1812
- src_es = src_task.template.element_sets[src_set_idx_i]
1813
- if inp_group_name in [i.name for i in src_es.groups or []]:
1814
- group_dat_idx.append(dat_idx_i)
1815
- else:
1816
- # if for any recursive iteration dependency, this group is
1817
- # defined, assign:
1818
- src_iter_i = src_iter
1819
- src_iter_deps = (
1820
- self.workflow.get_element_iterations_from_IDs(
1821
- src_iter_i.get_element_iteration_dependencies(),
1822
- )
1823
- )
1824
-
1825
- src_iter_deps_groups = [
1826
- j
1827
- for i in src_iter_deps
1828
- for j in i.element.element_set.groups
1829
- ]
1830
-
1831
- if inp_group_name in [
1832
- i.name for i in src_iter_deps_groups or []
1833
- ]:
1834
- group_dat_idx.append(dat_idx_i)
1835
-
1836
- # also check input dependencies
1837
- for (
1838
- k,
1839
- v,
1840
- ) in src_iter.element.get_input_dependencies().items():
1841
- k_es_idx = v["element_set_idx"]
1842
- k_task_iID = v["task_insert_ID"]
1843
- k_es = self.workflow.tasks.get(
1844
- insert_ID=k_task_iID
1845
- ).template.element_sets[k_es_idx]
1846
- if inp_group_name in [
1847
- i.name for i in k_es.groups or []
1848
- ]:
1849
- group_dat_idx.append(dat_idx_i)
1850
-
1851
- # TODO: this only goes to one level of dependency
1852
-
1853
- if not group_dat_idx:
1854
- raise MissingElementGroup(
1855
- f"Adding elements to task {self.unique_name!r}: no "
1856
- f"element group named {inp_group_name!r} found for input "
1857
- f"{labelled_path_i!r}."
1858
- )
1859
-
1860
- grp_idx = [group_dat_idx] # TODO: generalise to multiple groups
1861
-
1862
- if self.app.InputSource.local() in sources_i:
1976
+ if self._app.InputSource.local() in sources_i:
1863
1977
  # add task source to existing local source:
1864
- input_data_idx[key] += grp_idx
1865
- source_idx[key] += [inp_src_idx] * len(grp_idx)
1978
+ input_data_idx[key].extend(grp_idx)
1979
+ source_idx[key].extend([inp_src_idx] * len(grp_idx))
1866
1980
 
1867
1981
  else: # BUG: doesn't work for multiple task inputs sources
1868
1982
  # overwrite existing local source (if it exists):
1869
- input_data_idx[key] = grp_idx
1983
+ input_data_idx[key] = list(grp_idx)
1870
1984
  source_idx[key] = [inp_src_idx] * len(grp_idx)
1871
1985
  if key in sequence_idx:
1872
1986
  sequence_idx.pop(key)
1873
- seq = element_set.get_sequence_from_path(key)
1987
+ # TODO: Use the value retrieved below?
1988
+ _ = element_set.get_sequence_from_path(key)
1874
1989
 
1875
1990
  elif inp_src.source_type is InputSourceType.DEFAULT:
1876
- grp_idx = [def_val._value_group_idx]
1877
- if self.app.InputSource.local() in sources_i:
1878
- input_data_idx[key] += grp_idx
1879
- source_idx[key] += [inp_src_idx] * len(grp_idx)
1880
-
1991
+ assert def_val is not None
1992
+ assert def_val._value_group_idx is not None
1993
+ grp_idx_ = def_val._value_group_idx
1994
+ if self._app.InputSource.local() in sources_i:
1995
+ input_data_idx[key].append(grp_idx_)
1996
+ source_idx[key].append(inp_src_idx)
1881
1997
  else:
1882
- input_data_idx[key] = grp_idx
1883
- source_idx[key] = [inp_src_idx] * len(grp_idx)
1998
+ input_data_idx[key] = [grp_idx_]
1999
+ source_idx[key] = [inp_src_idx]
1884
2000
 
1885
2001
  # sort smallest to largest path, so more-specific items overwrite less-specific
1886
2002
  # items in parameter retrieval in `WorkflowTask._get_merged_parameter_data`:
@@ -1888,7 +2004,9 @@ class WorkflowTask:
1888
2004
 
1889
2005
  return (input_data_idx, sequence_idx, source_idx)
1890
2006
 
1891
- def ensure_input_sources(self, element_set) -> Dict[str, List[int]]:
2007
+ def ensure_input_sources(
2008
+ self, element_set: ElementSet
2009
+ ) -> Mapping[str, Sequence[int]]:
1892
2010
  """Check valid input sources are specified for a new task to be added to the
1893
2011
  workflow in a given position. If none are specified, set them according to the
1894
2012
  default behaviour.
@@ -1903,18 +2021,8 @@ class WorkflowTask:
1903
2021
  source_tasks=self.workflow.tasks[: self.index],
1904
2022
  )
1905
2023
 
1906
- unreq_sources = set(element_set.input_sources.keys()) - set(
1907
- available_sources.keys()
1908
- )
1909
- if unreq_sources:
1910
- unreq_src_str = ", ".join(f"{i!r}" for i in unreq_sources)
1911
- raise UnrequiredInputSources(
1912
- message=(
1913
- f"The following input sources are not required but have been "
1914
- f"specified: {unreq_src_str}."
1915
- ),
1916
- unrequired_sources=unreq_sources,
1917
- )
2024
+ if unreq := set(element_set.input_sources).difference(available_sources):
2025
+ raise UnrequiredInputSources(unreq)
1918
2026
 
1919
2027
  # TODO: get available input sources from workflow imports
1920
2028
 
@@ -1929,11 +2037,8 @@ class WorkflowTask:
1929
2037
  for path_i, avail_i in available_sources.items():
1930
2038
  # for each sub-path in available sources, if the "root-path" source is
1931
2039
  # required, then add the sub-path source to `req_types` as well:
1932
- path_i_split = path_i.split(".")
1933
- is_path_i_sub = len(path_i_split) > 1
1934
- if is_path_i_sub:
1935
- path_i_root = path_i_split[0]
1936
- if path_i_root in req_types:
2040
+ if len(path_i_split := path_i.split(".")) > 1:
2041
+ if path_i_split[0] in req_types:
1937
2042
  req_types.add(path_i)
1938
2043
 
1939
2044
  for s_idx, specified_source in enumerate(
@@ -1943,35 +2048,36 @@ class WorkflowTask:
1943
2048
  specified_source, self.unique_name
1944
2049
  )
1945
2050
  avail_idx = specified_source.is_in(avail_i)
2051
+ if avail_idx is None:
2052
+ raise UnavailableInputSource(specified_source, path_i, avail_i)
2053
+ available_source: InputSource
1946
2054
  try:
1947
2055
  available_source = avail_i[avail_idx]
1948
2056
  except TypeError:
1949
2057
  raise UnavailableInputSource(
1950
- f"The input source {specified_source.to_string()!r} is not "
1951
- f"available for input path {path_i!r}. Available "
1952
- f"input sources are: {[i.to_string() for i in avail_i]}."
2058
+ specified_source, path_i, avail_i
1953
2059
  ) from None
1954
2060
 
1955
2061
  elem_iters_IDs = available_source.element_iters
1956
2062
  if specified_source.element_iters:
1957
2063
  # user-specified iter IDs; these must be a subset of available
1958
2064
  # element_iters:
1959
- if not set(specified_source.element_iters).issubset(elem_iters_IDs):
2065
+ if not set(specified_source.element_iters).issubset(
2066
+ elem_iters_IDs or ()
2067
+ ):
1960
2068
  raise InapplicableInputSourceElementIters(
1961
- f"The specified `element_iters` for input source "
1962
- f"{specified_source.to_string()!r} are not all applicable. "
1963
- f"Applicable element iteration IDs for this input source "
1964
- f"are: {elem_iters_IDs!r}."
2069
+ specified_source, elem_iters_IDs
1965
2070
  )
1966
2071
  elem_iters_IDs = specified_source.element_iters
1967
2072
 
1968
2073
  if specified_source.where:
1969
2074
  # filter iter IDs by user-specified rules, maintaining order:
1970
2075
  elem_iters = self.workflow.get_element_iterations_from_IDs(
1971
- elem_iters_IDs
2076
+ elem_iters_IDs or ()
1972
2077
  )
1973
- filtered = specified_source.where.filter(elem_iters)
1974
- elem_iters_IDs = [i.id_ for i in filtered]
2078
+ elem_iters_IDs = [
2079
+ ei.id_ for ei in specified_source.where.filter(elem_iters)
2080
+ ]
1975
2081
 
1976
2082
  available_source.element_iters = elem_iters_IDs
1977
2083
  element_set.input_sources[path_i][s_idx] = available_source
@@ -1979,24 +2085,16 @@ class WorkflowTask:
1979
2085
  # sorting ensures that root parameters come before sub-parameters, which is
1980
2086
  # necessary when considering if we want to include a sub-parameter, when setting
1981
2087
  # missing sources below:
1982
- unsourced_inputs = sorted(req_types - set(element_set.input_sources.keys()))
1983
-
1984
- extra_types = set(k for k, v in all_stats.items() if v.is_extra)
1985
- if extra_types:
1986
- extra_str = ", ".join(f"{i!r}" for i in extra_types)
1987
- raise ExtraInputs(
1988
- message=(
1989
- f"The following inputs are not required, but have been passed: "
1990
- f"{extra_str}."
1991
- ),
1992
- extra_inputs=extra_types,
1993
- )
2088
+ unsourced_inputs = sorted(req_types.difference(element_set.input_sources))
2089
+
2090
+ if extra_types := {k for k, v in all_stats.items() if v.is_extra}:
2091
+ raise ExtraInputs(extra_types)
1994
2092
 
1995
2093
  # set source for any unsourced inputs:
1996
- missing = []
2094
+ missing: list[str] = []
1997
2095
  # track which root params we have set according to default behaviour (not
1998
2096
  # specified by user):
1999
- set_root_params = []
2097
+ set_root_params: set[str] = set()
2000
2098
  for input_type in unsourced_inputs:
2001
2099
  input_split = input_type.split(".")
2002
2100
  has_root_param = input_split[0] if len(input_split) > 1 else None
@@ -2028,89 +2126,85 @@ class WorkflowTask:
2028
2126
  ):
2029
2127
  continue
2030
2128
 
2031
- element_set.input_sources.update({input_type: [source]})
2129
+ element_set.input_sources[input_type] = [source]
2032
2130
  if not has_root_param:
2033
- set_root_params.append(input_type)
2131
+ set_root_params.add(input_type)
2034
2132
 
2035
2133
  # for task sources that span multiple element sets, pad out sub-parameter
2036
2134
  # `element_iters` to include the element iterations from other element sets in
2037
2135
  # which the "root" parameter is defined:
2038
- sources_by_task = defaultdict(dict)
2039
- elem_iter_by_task = defaultdict(dict)
2040
- all_elem_iters = set()
2136
+ sources_by_task: dict[int, dict[str, InputSource]] = defaultdict(dict)
2137
+ elem_iter_by_task: dict[int, dict[str, list[int]]] = defaultdict(dict)
2138
+ all_elem_iters: set[int] = set()
2041
2139
  for inp_type, sources in element_set.input_sources.items():
2042
2140
  source = sources[0]
2043
2141
  if source.source_type is InputSourceType.TASK:
2142
+ assert source.task_ref is not None
2143
+ assert source.element_iters is not None
2044
2144
  sources_by_task[source.task_ref][inp_type] = source
2045
2145
  all_elem_iters.update(source.element_iters)
2046
2146
  elem_iter_by_task[source.task_ref][inp_type] = source.element_iters
2047
2147
 
2048
- all_elem_iter_objs = self.workflow.get_element_iterations_from_IDs(all_elem_iters)
2049
- all_elem_iters_by_ID = {i.id_: i for i in all_elem_iter_objs}
2148
+ all_elem_iters_by_ID = {
2149
+ el_iter.id_: el_iter
2150
+ for el_iter in self.workflow.get_element_iterations_from_IDs(all_elem_iters)
2151
+ }
2050
2152
 
2051
2153
  # element set indices:
2052
2154
  padded_elem_iters = defaultdict(list)
2053
- es_idx_by_task = defaultdict(dict)
2155
+ es_idx_by_task: dict[int, dict[str, _ESIdx]] = defaultdict(dict)
2054
2156
  for task_ref, task_iters in elem_iter_by_task.items():
2055
2157
  for inp_type, inp_iters in task_iters.items():
2056
2158
  es_indices = [
2057
- all_elem_iters_by_ID[i].element.element_set_idx for i in inp_iters
2159
+ all_elem_iters_by_ID[id_].element.element_set_idx for id_ in inp_iters
2058
2160
  ]
2059
- es_idx_by_task[task_ref][inp_type] = (es_indices, set(es_indices))
2060
- root_params = {k for k in task_iters if "." not in k}
2061
- root_param_nesting = {
2062
- k: element_set.nesting_order.get(f"inputs.{k}", None) for k in root_params
2063
- }
2064
- for root_param_i in root_params:
2065
- sub_params = {
2066
- k
2067
- for k in task_iters
2068
- if k.split(".")[0] == root_param_i and k != root_param_i
2069
- }
2070
- rp_elem_sets = es_idx_by_task[task_ref][root_param_i][0]
2071
- rp_elem_sets_uniq = es_idx_by_task[task_ref][root_param_i][1]
2072
-
2073
- for sub_param_j in sub_params:
2161
+ es_idx_by_task[task_ref][inp_type] = _ESIdx(
2162
+ es_indices, frozenset(es_indices)
2163
+ )
2164
+ for root_param in {k for k in task_iters if "." not in k}:
2165
+ rp_nesting = element_set.nesting_order.get(f"inputs.{root_param}", None)
2166
+ rp_elem_sets, rp_elem_sets_uniq = es_idx_by_task[task_ref][root_param]
2167
+
2168
+ root_param_prefix = f"{root_param}."
2169
+ for sub_param_j in {
2170
+ k for k in task_iters if k.startswith(root_param_prefix)
2171
+ }:
2074
2172
  sub_param_nesting = element_set.nesting_order.get(
2075
2173
  f"inputs.{sub_param_j}", None
2076
2174
  )
2077
- if sub_param_nesting == root_param_nesting[root_param_i]:
2078
-
2079
- sp_elem_sets_uniq = es_idx_by_task[task_ref][sub_param_j][1]
2175
+ if sub_param_nesting == rp_nesting:
2176
+ sp_elem_sets_uniq = es_idx_by_task[task_ref][sub_param_j].uniq
2080
2177
 
2081
2178
  if sp_elem_sets_uniq != rp_elem_sets_uniq:
2082
-
2083
2179
  # replace elem_iters in sub-param sequence with those from the
2084
2180
  # root parameter, but re-order the elem iters to match their
2085
2181
  # original order:
2086
- iters_copy = elem_iter_by_task[task_ref][root_param_i][:]
2182
+ iters = elem_iter_by_task[task_ref][root_param]
2087
2183
 
2088
2184
  # "mask" iter IDs corresponding to the sub-parameter's element
2089
2185
  # sets, and keep track of the extra indices so they can be
2090
2186
  # ignored later:
2091
- sp_iters_new = []
2092
- for idx, (i, j) in enumerate(zip(iters_copy, rp_elem_sets)):
2093
- if j in sp_elem_sets_uniq:
2187
+ sp_iters_new: list[int | None] = []
2188
+ for idx, (it_id, es_idx) in enumerate(
2189
+ zip(iters, rp_elem_sets)
2190
+ ):
2191
+ if es_idx in sp_elem_sets_uniq:
2094
2192
  sp_iters_new.append(None)
2095
2193
  else:
2096
- sp_iters_new.append(i)
2194
+ sp_iters_new.append(it_id)
2097
2195
  padded_elem_iters[sub_param_j].append(idx)
2098
2196
 
2099
- # fill in sub-param elem_iters in their specified order
2100
- sub_iters_it = iter(elem_iter_by_task[task_ref][sub_param_j])
2101
- sp_iters_new = [
2102
- i if i is not None else next(sub_iters_it)
2103
- for i in sp_iters_new
2104
- ]
2105
-
2106
2197
  # update sub-parameter element iters:
2107
- for src_idx, src in enumerate(
2108
- element_set.input_sources[sub_param_j]
2109
- ):
2198
+ for src in element_set.input_sources[sub_param_j]:
2110
2199
  if src.source_type is InputSourceType.TASK:
2111
- element_set.input_sources[sub_param_j][
2112
- src_idx
2113
- ].element_iters = sp_iters_new
2200
+ # fill in sub-param elem_iters in their specified order
2201
+ sub_iters_it = iter(
2202
+ elem_iter_by_task[task_ref][sub_param_j]
2203
+ )
2204
+ src.element_iters = [
2205
+ it_id if it_id is not None else next(sub_iters_it)
2206
+ for it_id in sp_iters_new
2207
+ ]
2114
2208
  # assumes only a single task-type source for this
2115
2209
  # parameter
2116
2210
  break
@@ -2120,131 +2214,125 @@ class WorkflowTask:
2120
2214
  # results in no available elements due to `allow_non_coincident_task_sources`.
2121
2215
 
2122
2216
  if not element_set.allow_non_coincident_task_sources:
2123
- # if multiple parameters are sourced from the same upstream task, only use
2124
- # element iterations for which all parameters are available (the set
2125
- # intersection):
2126
- for task_ref, sources in sources_by_task.items():
2127
- # if a parameter has multiple labels, disregard from this by removing all
2128
- # parameters:
2129
- seen_labelled = {}
2130
- for src_i in sources.keys():
2131
- if "[" in src_i:
2132
- unlabelled, _ = split_param_label(src_i)
2133
- if unlabelled not in seen_labelled:
2134
- seen_labelled[unlabelled] = 1
2135
- else:
2136
- seen_labelled[unlabelled] += 1
2137
-
2138
- for unlabelled, count in seen_labelled.items():
2139
- if count > 1:
2140
- # remove:
2141
- sources = {
2142
- k: v
2143
- for k, v in sources.items()
2144
- if not k.startswith(unlabelled)
2145
- }
2217
+ self.__enforce_some_sanity(sources_by_task, element_set)
2146
2218
 
2147
- if len(sources) < 2:
2148
- continue
2149
-
2150
- first_src = next(iter(sources.values()))
2151
- intersect_task_i = set(first_src.element_iters)
2152
- for src_i in sources.values():
2153
- intersect_task_i.intersection_update(src_i.element_iters)
2154
- if not intersect_task_i:
2155
- raise NoCoincidentInputSources(
2156
- f"Task {self.name!r}: input sources from task {task_ref!r} have "
2157
- f"no coincident applicable element iterations. Consider setting "
2158
- f"the element set (or task) argument "
2159
- f"`allow_non_coincident_task_sources` to `True`, which will "
2160
- f"allow for input sources from the same task to use different "
2161
- f"(non-coinciding) subsets of element iterations from the "
2162
- f"source task."
2163
- )
2219
+ if missing:
2220
+ raise MissingInputs(missing)
2221
+ return padded_elem_iters
2164
2222
 
2165
- # now change elements for the affected input sources.
2166
- # sort by original order of first_src.element_iters
2167
- int_task_i_lst = [
2168
- i for i in first_src.element_iters if i in intersect_task_i
2169
- ]
2170
- for inp_type in sources.keys():
2171
- element_set.input_sources[inp_type][0].element_iters = int_task_i_lst
2223
+ def __enforce_some_sanity(
2224
+ self, sources_by_task: dict[int, dict[str, InputSource]], element_set: ElementSet
2225
+ ) -> None:
2226
+ """
2227
+ if multiple parameters are sourced from the same upstream task, only use
2228
+ element iterations for which all parameters are available (the set
2229
+ intersection)
2230
+ """
2231
+ for task_ref, sources in sources_by_task.items():
2232
+ # if a parameter has multiple labels, disregard from this by removing all
2233
+ # parameters:
2234
+ seen_labelled: dict[str, int] = defaultdict(int)
2235
+ for src_i in sources:
2236
+ if "[" in src_i:
2237
+ unlabelled, _ = split_param_label(src_i)
2238
+ assert unlabelled is not None
2239
+ seen_labelled[unlabelled] += 1
2240
+
2241
+ for prefix, count in seen_labelled.items():
2242
+ if count > 1:
2243
+ # remove:
2244
+ sources = {
2245
+ k: v for k, v in sources.items() if not k.startswith(prefix)
2246
+ }
2172
2247
 
2173
- if missing:
2174
- missing_str = ", ".join(f"{i!r}" for i in missing)
2175
- raise MissingInputs(
2176
- message=f"The following inputs have no sources: {missing_str}.",
2177
- missing_inputs=missing,
2178
- )
2248
+ if len(sources) < 2:
2249
+ continue
2179
2250
 
2180
- return padded_elem_iters
2251
+ first_src = next(iter(sources.values()))
2252
+ intersect_task_i = set(first_src.element_iters or ())
2253
+ for inp_src in sources.values():
2254
+ intersect_task_i.intersection_update(inp_src.element_iters or ())
2255
+ if not intersect_task_i:
2256
+ raise NoCoincidentInputSources(self.name, task_ref)
2257
+
2258
+ # now change elements for the affected input sources.
2259
+ # sort by original order of first_src.element_iters
2260
+ int_task_i_lst = [
2261
+ i for i in first_src.element_iters or () if i in intersect_task_i
2262
+ ]
2263
+ for inp_type in sources:
2264
+ element_set.input_sources[inp_type][0].element_iters = int_task_i_lst
2181
2265
 
2182
2266
  def generate_new_elements(
2183
2267
  self,
2184
- input_data_indices,
2185
- output_data_indices,
2186
- element_data_indices,
2187
- sequence_indices,
2188
- source_indices,
2189
- ):
2268
+ input_data_indices: Mapping[str, Sequence[int | list[int]]],
2269
+ output_data_indices: Mapping[str, Sequence[int]],
2270
+ element_data_indices: Sequence[Mapping[str, int]],
2271
+ sequence_indices: Mapping[str, Sequence[int]],
2272
+ source_indices: Mapping[str, Sequence[int]],
2273
+ ) -> tuple[
2274
+ Sequence[DataIndex], Mapping[str, Sequence[int]], Mapping[str, Sequence[int]]
2275
+ ]:
2190
2276
  """
2191
2277
  Create information about new elements in this task.
2192
2278
  """
2193
- new_elements = []
2194
- element_sequence_indices = {}
2195
- element_src_indices = {}
2196
- for i_idx, i in enumerate(element_data_indices):
2279
+ new_elements: list[DataIndex] = []
2280
+ element_sequence_indices: dict[str, list[int]] = {}
2281
+ element_src_indices: dict[str, list[int]] = {}
2282
+ for i_idx, data_idx in enumerate(element_data_indices):
2197
2283
  elem_i = {
2198
2284
  k: input_data_indices[k][v]
2199
- for k, v in i.items()
2285
+ for k, v in data_idx.items()
2200
2286
  if input_data_indices[k][v] != -1
2201
2287
  }
2202
- elem_i.update({k: v[i_idx] for k, v in output_data_indices.items()})
2288
+ elem_i.update((k, v2[i_idx]) for k, v2 in output_data_indices.items())
2203
2289
  new_elements.append(elem_i)
2204
2290
 
2205
- for k, v in i.items():
2291
+ for k, v3 in data_idx.items():
2206
2292
  # track which sequence value indices (if any) are used for each new
2207
2293
  # element:
2208
2294
  if k in sequence_indices:
2209
- if k not in element_sequence_indices:
2210
- element_sequence_indices[k] = []
2211
- element_sequence_indices[k].append(sequence_indices[k][v])
2295
+ element_sequence_indices.setdefault(k, []).append(
2296
+ sequence_indices[k][v3]
2297
+ )
2212
2298
 
2213
2299
  # track original InputSource associated with each new element:
2214
2300
  if k in source_indices:
2215
- if k not in element_src_indices:
2216
- element_src_indices[k] = []
2217
- if input_data_indices[k][v] != -1:
2218
- src_idx_k = source_indices[k][v]
2301
+ if input_data_indices[k][v3] != -1:
2302
+ src_idx_k = source_indices[k][v3]
2219
2303
  else:
2220
2304
  src_idx_k = -1
2221
- element_src_indices[k].append(src_idx_k)
2305
+ element_src_indices.setdefault(k, []).append(src_idx_k)
2222
2306
 
2223
2307
  return new_elements, element_sequence_indices, element_src_indices
2224
2308
 
2225
2309
  @property
2226
- def upstream_tasks(self):
2310
+ def upstream_tasks(self) -> Iterator[WorkflowTask]:
2227
2311
  """All workflow tasks that are upstream from this task."""
2228
- return [task for task in self.workflow.tasks[: self.index]]
2312
+ tasks = self.workflow.tasks
2313
+ for idx in range(0, self.index):
2314
+ yield tasks[idx]
2229
2315
 
2230
2316
  @property
2231
- def downstream_tasks(self):
2317
+ def downstream_tasks(self) -> Iterator[WorkflowTask]:
2232
2318
  """All workflow tasks that are downstream from this task."""
2233
- return [task for task in self.workflow.tasks[self.index + 1 :]]
2319
+ tasks = self.workflow.tasks
2320
+ for idx in range(self.index + 1, len(tasks)):
2321
+ yield tasks[idx]
2234
2322
 
2235
2323
  @staticmethod
2236
- def resolve_element_data_indices(multiplicities):
2237
- """Find the index of the Zarr parameter group index list corresponding to each
2324
+ def resolve_element_data_indices(
2325
+ multiplicities: list[MultiplicityDescriptor],
2326
+ ) -> Sequence[Mapping[str, int]]:
2327
+ """Find the index of the parameter group index list corresponding to each
2238
2328
  input data for all elements.
2239
2329
 
2240
- # TODO: update docstring; shouldn't reference Zarr.
2241
-
2242
2330
  Parameters
2243
2331
  ----------
2244
- multiplicities : list of dict
2332
+ multiplicities : list of MultiplicityDescriptor
2245
2333
  Each list item represents a sequence of values with keys:
2246
2334
  multiplicity: int
2247
- nesting_order: int
2335
+ nesting_order: float
2248
2336
  path : str
2249
2337
 
2250
2338
  Returns
@@ -2254,37 +2342,40 @@ class WorkflowTask:
2254
2342
  input data paths and whose values are indices that index the values of the
2255
2343
  dict returned by the `task.make_persistent` method.
2256
2344
 
2345
+ Note
2346
+ ----
2347
+ Non-integer nesting orders result in doing the dot product of that sequence with
2348
+ all the current sequences instead of just with the other sequences at the same
2349
+ nesting order (or as a cross product for other nesting orders entire).
2257
2350
  """
2258
2351
 
2259
2352
  # order by nesting order (lower nesting orders will be slowest-varying):
2260
2353
  multi_srt = sorted(multiplicities, key=lambda x: x["nesting_order"])
2261
- multi_srt_grp = group_by_dict_key_values(
2262
- multi_srt, "nesting_order"
2263
- ) # TODO: is tested?
2354
+ multi_srt_grp = group_by_dict_key_values(multi_srt, "nesting_order")
2264
2355
 
2265
- element_dat_idx = [{}]
2266
- last_nest_ord = None
2356
+ element_dat_idx: list[dict[str, int]] = [{}]
2357
+ last_nest_ord: int | None = None
2267
2358
  for para_sequences in multi_srt_grp:
2268
2359
  # check all equivalent nesting_orders have equivalent multiplicities
2269
- all_multis = {i["multiplicity"] for i in para_sequences}
2360
+ all_multis = {md["multiplicity"] for md in para_sequences}
2270
2361
  if len(all_multis) > 1:
2271
2362
  raise ValueError(
2272
2363
  f"All inputs with the same `nesting_order` must have the same "
2273
2364
  f"multiplicity, but for paths "
2274
- f"{[i['path'] for i in para_sequences]} with "
2365
+ f"{[md['path'] for md in para_sequences]} with "
2275
2366
  f"`nesting_order` {para_sequences[0]['nesting_order']} found "
2276
- f"multiplicities {[i['multiplicity'] for i in para_sequences]}."
2367
+ f"multiplicities {[md['multiplicity'] for md in para_sequences]}."
2277
2368
  )
2278
2369
 
2279
- cur_nest_ord = para_sequences[0]["nesting_order"]
2280
- new_elements = []
2370
+ cur_nest_ord = int(para_sequences[0]["nesting_order"])
2371
+ new_elements: list[dict[str, int]] = []
2281
2372
  for elem_idx, element in enumerate(element_dat_idx):
2282
- if last_nest_ord is not None and int(cur_nest_ord) == int(last_nest_ord):
2373
+ if last_nest_ord is not None and cur_nest_ord == last_nest_ord:
2283
2374
  # merge in parallel with existing elements:
2284
2375
  new_elements.append(
2285
2376
  {
2286
2377
  **element,
2287
- **{i["path"]: elem_idx for i in para_sequences},
2378
+ **{md["path"]: elem_idx for md in para_sequences},
2288
2379
  }
2289
2380
  )
2290
2381
  else:
@@ -2293,7 +2384,7 @@ class WorkflowTask:
2293
2384
  new_elements.append(
2294
2385
  {
2295
2386
  **element,
2296
- **{i["path"]: val_idx for i in para_sequences},
2387
+ **{md["path"]: val_idx for md in para_sequences},
2297
2388
  }
2298
2389
  )
2299
2390
  element_dat_idx = new_elements
@@ -2302,7 +2393,7 @@ class WorkflowTask:
2302
2393
  return element_dat_idx
2303
2394
 
2304
2395
  @TimeIt.decorator
2305
- def initialise_EARs(self, iter_IDs: Optional[List[int]] = None) -> List[int]:
2396
+ def initialise_EARs(self, iter_IDs: list[int] | None = None) -> Sequence[int]:
2306
2397
  """Try to initialise any uninitialised EARs of this task."""
2307
2398
  if iter_IDs:
2308
2399
  iters = self.workflow.get_element_iterations_from_IDs(iter_IDs)
@@ -2317,16 +2408,16 @@ class WorkflowTask:
2317
2408
  # objects.
2318
2409
  iters.extend(element.iterations)
2319
2410
 
2320
- initialised = []
2411
+ initialised: list[int] = []
2321
2412
  for iter_i in iters:
2322
2413
  if not iter_i.EARs_initialised:
2323
2414
  try:
2324
- self._initialise_element_iter_EARs(iter_i)
2415
+ self.__initialise_element_iter_EARs(iter_i)
2325
2416
  initialised.append(iter_i.id_)
2326
2417
  except UnsetParameterDataError:
2327
2418
  # raised by `Action.test_rules`; cannot yet initialise EARs
2328
- self.app.logger.debug(
2329
- f"UnsetParameterDataError raised: cannot yet initialise runs."
2419
+ self._app.logger.debug(
2420
+ "UnsetParameterDataError raised: cannot yet initialise runs."
2330
2421
  )
2331
2422
  pass
2332
2423
  else:
@@ -2335,13 +2426,13 @@ class WorkflowTask:
2335
2426
  return initialised
2336
2427
 
2337
2428
  @TimeIt.decorator
2338
- def _initialise_element_iter_EARs(self, element_iter: app.ElementIteration) -> None:
2429
+ def __initialise_element_iter_EARs(self, element_iter: ElementIteration) -> None:
2339
2430
  # keys are (act_idx, EAR_idx):
2340
- all_data_idx = {}
2341
- action_runs = {}
2431
+ all_data_idx: dict[tuple[int, int], DataIndex] = {}
2432
+ action_runs: dict[tuple[int, int], dict[str, Any]] = {}
2342
2433
 
2343
2434
  # keys are parameter indices, values are EAR_IDs to update those sources to
2344
- param_src_updates = {}
2435
+ param_src_updates: dict[int, ParamSource] = {}
2345
2436
 
2346
2437
  count = 0
2347
2438
  for act_idx, action in self.template.all_schema_actions():
@@ -2355,9 +2446,9 @@ class WorkflowTask:
2355
2446
  # it previously failed)
2356
2447
  act_valid, cmds_idx = action.test_rules(element_iter=element_iter)
2357
2448
  if act_valid:
2358
- self.app.logger.info(f"All action rules evaluated to true {log_common}")
2449
+ self._app.logger.info(f"All action rules evaluated to true {log_common}")
2359
2450
  EAR_ID = self.workflow.num_EARs + count
2360
- param_source = {
2451
+ param_source: ParamSource = {
2361
2452
  "type": "EAR_output",
2362
2453
  "EAR_ID": EAR_ID,
2363
2454
  }
@@ -2374,17 +2465,19 @@ class WorkflowTask:
2374
2465
  # with EARs initialised, we can update the pre-allocated schema-level
2375
2466
  # parameters with the correct EAR reference:
2376
2467
  for i in psrc_update:
2377
- param_src_updates[i] = {"EAR_ID": EAR_ID}
2468
+ param_src_updates[cast("int", i)] = {"EAR_ID": EAR_ID}
2378
2469
  run_0 = {
2379
2470
  "elem_iter_ID": element_iter.id_,
2380
2471
  "action_idx": act_idx,
2381
2472
  "commands_idx": cmds_idx,
2382
2473
  "metadata": {},
2383
2474
  }
2384
- action_runs[(act_idx, EAR_ID)] = run_0
2475
+ action_runs[act_idx, EAR_ID] = run_0
2385
2476
  count += 1
2386
2477
  else:
2387
- self.app.logger.info(f"Some action rules evaluated to false {log_common}")
2478
+ self._app.logger.info(
2479
+ f"Some action rules evaluated to false {log_common}"
2480
+ )
2388
2481
 
2389
2482
  # `generate_data_index` can modify data index for previous actions, so only assign
2390
2483
  # this at the end:
@@ -2393,14 +2486,13 @@ class WorkflowTask:
2393
2486
  elem_iter_ID=element_iter.id_,
2394
2487
  action_idx=act_idx,
2395
2488
  commands_idx=run["commands_idx"],
2396
- data_idx=all_data_idx[(act_idx, EAR_ID_i)],
2397
- metadata={},
2489
+ data_idx=all_data_idx[act_idx, EAR_ID_i],
2398
2490
  )
2399
2491
 
2400
2492
  self.workflow._store.update_param_source(param_src_updates)
2401
2493
 
2402
2494
  @TimeIt.decorator
2403
- def _add_element_set(self, element_set):
2495
+ def _add_element_set(self, element_set: ElementSet) -> list[int]:
2404
2496
  """
2405
2497
  Returns
2406
2498
  -------
@@ -2414,7 +2506,7 @@ class WorkflowTask:
2414
2506
  # may modify element_set.input_sources:
2415
2507
  padded_elem_iters = self.ensure_input_sources(element_set)
2416
2508
 
2417
- (input_data_idx, seq_idx, src_idx) = self._make_new_elements_persistent(
2509
+ (input_data_idx, seq_idx, src_idx) = self.__make_new_elements_persistent(
2418
2510
  element_set=element_set,
2419
2511
  element_set_idx=self.num_element_sets,
2420
2512
  padded_elem_iters=padded_elem_iters,
@@ -2448,10 +2540,10 @@ class WorkflowTask:
2448
2540
  src_idx,
2449
2541
  )
2450
2542
 
2451
- iter_IDs = []
2452
- elem_IDs = []
2543
+ iter_IDs: list[int] = []
2544
+ elem_IDs: list[int] = []
2453
2545
  for elem_idx, data_idx in enumerate(element_data_idx):
2454
- schema_params = set(i for i in data_idx.keys() if len(i.split(".")) == 2)
2546
+ schema_params = set(i for i in data_idx if len(i.split(".")) == 2)
2455
2547
  elem_ID_i = self.workflow._store.add_element(
2456
2548
  task_ID=self.insert_ID,
2457
2549
  es_idx=self.num_element_sets - 1,
@@ -2471,21 +2563,72 @@ class WorkflowTask:
2471
2563
 
2472
2564
  return iter_IDs
2473
2565
 
2566
+ @overload
2567
+ def add_elements(
2568
+ self,
2569
+ *,
2570
+ base_element: Element | None = None,
2571
+ inputs: list[InputValue] | dict[str, Any] | None = None,
2572
+ input_files: list[InputFile] | None = None,
2573
+ sequences: list[ValueSequence] | None = None,
2574
+ resources: Resources = None,
2575
+ repeats: list[RepeatsDescriptor] | int | None = None,
2576
+ input_sources: dict[str, list[InputSource]] | None = None,
2577
+ nesting_order: dict[str, float] | None = None,
2578
+ element_sets: list[ElementSet] | None = None,
2579
+ sourceable_elem_iters: list[int] | None = None,
2580
+ propagate_to: (
2581
+ list[ElementPropagation]
2582
+ | Mapping[str, ElementPropagation | Mapping[str, Any]]
2583
+ | None
2584
+ ) = None,
2585
+ return_indices: Literal[True],
2586
+ ) -> list[int]:
2587
+ ...
2588
+
2589
+ @overload
2474
2590
  def add_elements(
2475
2591
  self,
2476
- base_element=None,
2477
- inputs=None,
2478
- input_files=None,
2479
- sequences=None,
2480
- resources=None,
2481
- repeats=None,
2482
- input_sources=None,
2483
- nesting_order=None,
2484
- element_sets=None,
2485
- sourceable_elem_iters=None,
2486
- propagate_to=None,
2592
+ *,
2593
+ base_element: Element | None = None,
2594
+ inputs: list[InputValue] | dict[str, Any] | None = None,
2595
+ input_files: list[InputFile] | None = None,
2596
+ sequences: list[ValueSequence] | None = None,
2597
+ resources: Resources = None,
2598
+ repeats: list[RepeatsDescriptor] | int | None = None,
2599
+ input_sources: dict[str, list[InputSource]] | None = None,
2600
+ nesting_order: dict[str, float] | None = None,
2601
+ element_sets: list[ElementSet] | None = None,
2602
+ sourceable_elem_iters: list[int] | None = None,
2603
+ propagate_to: (
2604
+ list[ElementPropagation]
2605
+ | Mapping[str, ElementPropagation | Mapping[str, Any]]
2606
+ | None
2607
+ ) = None,
2608
+ return_indices: Literal[False] = False,
2609
+ ) -> None:
2610
+ ...
2611
+
2612
+ def add_elements(
2613
+ self,
2614
+ *,
2615
+ base_element: Element | None = None,
2616
+ inputs: list[InputValue] | dict[str, Any] | None = None,
2617
+ input_files: list[InputFile] | None = None,
2618
+ sequences: list[ValueSequence] | None = None,
2619
+ resources: Resources = None,
2620
+ repeats: list[RepeatsDescriptor] | int | None = None,
2621
+ input_sources: dict[str, list[InputSource]] | None = None,
2622
+ nesting_order: dict[str, float] | None = None,
2623
+ element_sets: list[ElementSet] | None = None,
2624
+ sourceable_elem_iters: list[int] | None = None,
2625
+ propagate_to: (
2626
+ list[ElementPropagation]
2627
+ | Mapping[str, ElementPropagation | Mapping[str, Any]]
2628
+ | None
2629
+ ) = None,
2487
2630
  return_indices=False,
2488
- ):
2631
+ ) -> list[int] | None:
2489
2632
  """
2490
2633
  Add elements to this task.
2491
2634
 
@@ -2502,11 +2645,11 @@ class WorkflowTask:
2502
2645
  default.
2503
2646
 
2504
2647
  """
2505
- propagate_to = self.app.ElementPropagation._prepare_propagate_to_dict(
2648
+ real_propagate_to = self._app.ElementPropagation._prepare_propagate_to_dict(
2506
2649
  propagate_to, self.workflow
2507
2650
  )
2508
2651
  with self.workflow.batch_update():
2509
- return self._add_elements(
2652
+ indices = self._add_elements(
2510
2653
  base_element=base_element,
2511
2654
  inputs=inputs,
2512
2655
  input_files=input_files,
@@ -2517,27 +2660,37 @@ class WorkflowTask:
2517
2660
  nesting_order=nesting_order,
2518
2661
  element_sets=element_sets,
2519
2662
  sourceable_elem_iters=sourceable_elem_iters,
2520
- propagate_to=propagate_to,
2521
- return_indices=return_indices,
2663
+ propagate_to=real_propagate_to,
2522
2664
  )
2665
+ return indices if return_indices else None
2523
2666
 
2524
2667
  @TimeIt.decorator
2525
2668
  def _add_elements(
2526
2669
  self,
2527
- base_element=None,
2528
- inputs=None,
2529
- input_files=None,
2530
- sequences=None,
2531
- resources=None,
2532
- repeats=None,
2533
- input_sources=None,
2534
- nesting_order=None,
2535
- element_sets=None,
2536
- sourceable_elem_iters=None,
2537
- propagate_to: Dict[str, app.ElementPropagation] = None,
2538
- return_indices: bool = False,
2539
- ):
2540
- """Add more elements to this task."""
2670
+ *,
2671
+ base_element: Element | None = None,
2672
+ inputs: list[InputValue] | dict[str, Any] | None = None,
2673
+ input_files: list[InputFile] | None = None,
2674
+ sequences: list[ValueSequence] | None = None,
2675
+ resources: Resources = None,
2676
+ repeats: list[RepeatsDescriptor] | int | None = None,
2677
+ input_sources: dict[str, list[InputSource]] | None = None,
2678
+ nesting_order: dict[str, float] | None = None,
2679
+ element_sets: list[ElementSet] | None = None,
2680
+ sourceable_elem_iters: list[int] | None = None,
2681
+ propagate_to: dict[str, ElementPropagation],
2682
+ ) -> list[int] | None:
2683
+ """Add more elements to this task.
2684
+
2685
+ Parameters
2686
+ ----------
2687
+ sourceable_elem_iters : list[int]
2688
+ If specified, a list of global element iteration indices from which inputs
2689
+ may be sourced. If not specified, all workflow element iterations are
2690
+ considered sourceable.
2691
+ propagate_to : dict[str, ElementPropagation]
2692
+ Propagate the new elements downstream to the specified tasks.
2693
+ """
2541
2694
 
2542
2695
  if base_element is not None:
2543
2696
  if base_element.task is not self:
@@ -2546,7 +2699,7 @@ class WorkflowTask:
2546
2699
  inputs = inputs or b_inputs
2547
2700
  resources = resources or b_resources
2548
2701
 
2549
- element_sets = self.app.ElementSet.ensure_element_sets(
2702
+ element_sets = self._app.ElementSet.ensure_element_sets(
2550
2703
  inputs=inputs,
2551
2704
  input_files=input_files,
2552
2705
  sequences=sequences,
@@ -2558,24 +2711,22 @@ class WorkflowTask:
2558
2711
  sourceable_elem_iters=sourceable_elem_iters,
2559
2712
  )
2560
2713
 
2561
- elem_idx = []
2714
+ elem_idx: list[int] = []
2562
2715
  for elem_set_i in element_sets:
2563
- # copy:
2564
- elem_set_i = elem_set_i.prepare_persistent_copy()
2716
+ # copy and add the new element set:
2717
+ elem_idx.extend(self._add_element_set(elem_set_i.prepare_persistent_copy()))
2565
2718
 
2566
- # add the new element set:
2567
- elem_idx += self._add_element_set(elem_set_i)
2719
+ if not propagate_to:
2720
+ return elem_idx
2568
2721
 
2569
2722
  for task in self.get_dependent_tasks(as_objects=True):
2570
- elem_prop = propagate_to.get(task.unique_name)
2571
- if elem_prop is None:
2723
+ if (elem_prop := propagate_to.get(task.unique_name)) is None:
2572
2724
  continue
2573
2725
 
2574
- task_dep_names = [
2575
- i.unique_name
2576
- for i in elem_prop.element_set.get_task_dependencies(as_objects=True)
2577
- ]
2578
- if self.unique_name not in task_dep_names:
2726
+ if all(
2727
+ self.unique_name != task.unique_name
2728
+ for task in elem_prop.element_set.get_task_dependencies(as_objects=True)
2729
+ ):
2579
2730
  # TODO: why can't we just do
2580
2731
  # `if self in not elem_propagate.element_set.task_dependencies:`?
2581
2732
  continue
@@ -2584,11 +2735,11 @@ class WorkflowTask:
2584
2735
  # Assume for now we use a single base element set.
2585
2736
  # Later, allow combining multiple element sets.
2586
2737
  src_elem_iters = elem_idx + [
2587
- j for i in element_sets for j in i.sourceable_elem_iters or []
2738
+ j for el_set in element_sets for j in el_set.sourceable_elem_iters or ()
2588
2739
  ]
2589
2740
 
2590
2741
  # note we must pass `resources` as a list since it is already persistent:
2591
- elem_set_i = self.app.ElementSet(
2742
+ elem_set_i = self._app.ElementSet(
2592
2743
  inputs=elem_prop.element_set.inputs,
2593
2744
  input_files=elem_prop.element_set.input_files,
2594
2745
  sequences=elem_prop.element_set.sequences,
@@ -2602,40 +2753,57 @@ class WorkflowTask:
2602
2753
  del propagate_to[task.unique_name]
2603
2754
  prop_elem_idx = task._add_elements(
2604
2755
  element_sets=[elem_set_i],
2605
- return_indices=True,
2606
2756
  propagate_to=propagate_to,
2607
2757
  )
2608
- elem_idx.extend(prop_elem_idx)
2758
+ elem_idx.extend(prop_elem_idx or ())
2609
2759
 
2610
- if return_indices:
2611
- return elem_idx
2760
+ return elem_idx
2761
+
2762
+ @overload
2763
+ def get_element_dependencies(
2764
+ self,
2765
+ as_objects: Literal[False] = False,
2766
+ ) -> set[int]:
2767
+ ...
2768
+
2769
+ @overload
2770
+ def get_element_dependencies(
2771
+ self,
2772
+ as_objects: Literal[True],
2773
+ ) -> list[Element]:
2774
+ ...
2612
2775
 
2613
2776
  def get_element_dependencies(
2614
2777
  self,
2615
2778
  as_objects: bool = False,
2616
- ) -> List[Union[int, app.Element]]:
2779
+ ) -> set[int] | list[Element]:
2617
2780
  """Get elements from upstream tasks that this task depends on."""
2618
2781
 
2619
- deps = []
2620
- for element in self.elements[:]:
2782
+ deps: set[int] = set()
2783
+ for element in self.elements:
2621
2784
  for iter_i in element.iterations:
2622
- for dep_elem_i in iter_i.get_element_dependencies(as_objects=True):
2623
- if (
2624
- dep_elem_i.task.insert_ID != self.insert_ID
2625
- and dep_elem_i not in deps
2626
- ):
2627
- deps.append(dep_elem_i.id_)
2785
+ deps.update(
2786
+ dep_elem_i.id_
2787
+ for dep_elem_i in iter_i.get_element_dependencies(as_objects=True)
2788
+ if dep_elem_i.task.insert_ID != self.insert_ID
2789
+ )
2628
2790
 
2629
- deps = sorted(deps)
2630
2791
  if as_objects:
2631
- deps = self.workflow.get_elements_from_IDs(deps)
2632
-
2792
+ return self.workflow.get_elements_from_IDs(sorted(deps))
2633
2793
  return deps
2634
2794
 
2795
+ @overload
2796
+ def get_task_dependencies(self, as_objects: Literal[False] = False) -> set[int]:
2797
+ ...
2798
+
2799
+ @overload
2800
+ def get_task_dependencies(self, as_objects: Literal[True]) -> list[WorkflowTask]:
2801
+ ...
2802
+
2635
2803
  def get_task_dependencies(
2636
2804
  self,
2637
2805
  as_objects: bool = False,
2638
- ) -> List[Union[int, app.WorkflowTask]]:
2806
+ ) -> set[int] | list[WorkflowTask]:
2639
2807
  """Get tasks (insert ID or WorkflowTask objects) that this task depends on.
2640
2808
 
2641
2809
  Dependencies may come from either elements from upstream tasks, or from locally
@@ -2645,87 +2813,106 @@ class WorkflowTask:
2645
2813
  # new "task_iteration" input source type, which may take precedence over any
2646
2814
  # other input source types.
2647
2815
 
2648
- deps = []
2816
+ deps: set[int] = set()
2649
2817
  for element_set in self.template.element_sets:
2650
2818
  for sources in element_set.input_sources.values():
2651
- for src in sources:
2819
+ deps.update(
2820
+ src.task_ref
2821
+ for src in sources
2652
2822
  if (
2653
2823
  src.source_type is InputSourceType.TASK
2654
- and src.task_ref not in deps
2655
- ):
2656
- deps.append(src.task_ref)
2824
+ and src.task_ref is not None
2825
+ )
2826
+ )
2657
2827
 
2658
- deps = sorted(deps)
2659
2828
  if as_objects:
2660
- deps = [self.workflow.tasks.get(insert_ID=i) for i in deps]
2661
-
2829
+ return [self.workflow.tasks.get(insert_ID=id_) for id_ in sorted(deps)]
2662
2830
  return deps
2663
2831
 
2832
+ @overload
2833
+ def get_dependent_elements(
2834
+ self,
2835
+ as_objects: Literal[False] = False,
2836
+ ) -> set[int]:
2837
+ ...
2838
+
2839
+ @overload
2840
+ def get_dependent_elements(self, as_objects: Literal[True]) -> list[Element]:
2841
+ ...
2842
+
2664
2843
  def get_dependent_elements(
2665
2844
  self,
2666
2845
  as_objects: bool = False,
2667
- ) -> List[Union[int, app.Element]]:
2846
+ ) -> set[int] | list[Element]:
2668
2847
  """Get elements from downstream tasks that depend on this task."""
2669
- deps = []
2848
+ deps: set[int] = set()
2670
2849
  for task in self.downstream_tasks:
2671
- for element in task.elements[:]:
2672
- for iter_i in element.iterations:
2673
- for dep_i in iter_i.get_task_dependencies(as_objects=False):
2674
- if dep_i == self.insert_ID and element.id_ not in deps:
2675
- deps.append(element.id_)
2850
+ deps.update(
2851
+ element.id_
2852
+ for element in task.elements
2853
+ if any(
2854
+ self.insert_ID in iter_i.get_task_dependencies()
2855
+ for iter_i in element.iterations
2856
+ )
2857
+ )
2676
2858
 
2677
- deps = sorted(deps)
2678
2859
  if as_objects:
2679
- deps = self.workflow.get_elements_from_IDs(deps)
2680
-
2860
+ return self.workflow.get_elements_from_IDs(sorted(deps))
2681
2861
  return deps
2682
2862
 
2863
+ @overload
2864
+ def get_dependent_tasks(self, as_objects: Literal[False] = False) -> set[int]:
2865
+ ...
2866
+
2867
+ @overload
2868
+ def get_dependent_tasks(self, as_objects: Literal[True]) -> list[WorkflowTask]:
2869
+ ...
2870
+
2683
2871
  @TimeIt.decorator
2684
2872
  def get_dependent_tasks(
2685
2873
  self,
2686
2874
  as_objects: bool = False,
2687
- ) -> List[Union[int, app.WorkflowTask]]:
2875
+ ) -> set[int] | list[WorkflowTask]:
2688
2876
  """Get tasks (insert ID or WorkflowTask objects) that depends on this task."""
2689
2877
 
2690
2878
  # TODO: this method might become insufficient if/when we start considering a
2691
2879
  # new "task_iteration" input source type, which may take precedence over any
2692
2880
  # other input source types.
2693
2881
 
2694
- deps = []
2882
+ deps: set[int] = set()
2695
2883
  for task in self.downstream_tasks:
2696
- for element_set in task.template.element_sets:
2697
- for sources in element_set.input_sources.values():
2698
- for src in sources:
2699
- if (
2700
- src.source_type is InputSourceType.TASK
2701
- and src.task_ref == self.insert_ID
2702
- and task.insert_ID not in deps
2703
- ):
2704
- deps.append(task.insert_ID)
2705
- deps = sorted(deps)
2884
+ if task.insert_ID not in deps and any(
2885
+ src.source_type is InputSourceType.TASK and src.task_ref == self.insert_ID
2886
+ for element_set in task.template.element_sets
2887
+ for sources in element_set.input_sources.values()
2888
+ for src in sources
2889
+ ):
2890
+ deps.add(task.insert_ID)
2706
2891
  if as_objects:
2707
- deps = [self.workflow.tasks.get(insert_ID=i) for i in deps]
2892
+ return [self.workflow.tasks.get(insert_ID=id_) for id_ in sorted(deps)]
2708
2893
  return deps
2709
2894
 
2710
2895
  @property
2711
- def inputs(self):
2896
+ def inputs(self) -> TaskInputParameters:
2712
2897
  """
2713
2898
  Inputs to this task.
2714
2899
  """
2715
- return self.app.TaskInputParameters(self)
2900
+ return self._app.TaskInputParameters(self)
2716
2901
 
2717
2902
  @property
2718
- def outputs(self):
2903
+ def outputs(self) -> TaskOutputParameters:
2719
2904
  """
2720
2905
  Outputs from this task.
2721
2906
  """
2722
- return self.app.TaskOutputParameters(self)
2907
+ return self._app.TaskOutputParameters(self)
2723
2908
 
2724
- def get(self, path, raise_on_missing=False, default=None):
2909
+ def get(
2910
+ self, path: str, *, raise_on_missing=False, default: Any | None = None
2911
+ ) -> Parameters:
2725
2912
  """
2726
2913
  Get a parameter known to this task by its path.
2727
2914
  """
2728
- return self.app.Parameters(
2915
+ return self._app.Parameters(
2729
2916
  self,
2730
2917
  path=path,
2731
2918
  return_element_parameters=False,
@@ -2733,12 +2920,15 @@ class WorkflowTask:
2733
2920
  default=default,
2734
2921
  )
2735
2922
 
2736
- def _paths_to_PV_classes(self, paths: Iterable[str]) -> Dict:
2923
+ def _paths_to_PV_classes(self, *paths: str | None) -> dict[str, type[ParameterValue]]:
2737
2924
  """Return a dict mapping dot-delimited string input paths to `ParameterValue`
2738
2925
  classes."""
2739
2926
 
2740
- params = {}
2927
+ params: dict[str, type[ParameterValue]] = {}
2741
2928
  for path in paths:
2929
+ if not path:
2930
+ # Skip None/empty
2931
+ continue
2742
2932
  path_split = path.split(".")
2743
2933
  if len(path_split) == 1 or path_split[0] not in ("inputs", "outputs"):
2744
2934
  continue
@@ -2751,22 +2941,22 @@ class WorkflowTask:
2751
2941
  path_1, _ = split_param_label(
2752
2942
  path_split[1]
2753
2943
  ) # remove label if present
2754
- for i in self.template.schemas:
2755
- for j in i.inputs:
2756
- if j.parameter.typ == path_1 and j.parameter._value_class:
2757
- params[key_0] = j.parameter._value_class
2944
+ for schema in self.template.schemas:
2945
+ for inp in schema.inputs:
2946
+ if inp.parameter.typ == path_1 and inp.parameter._value_class:
2947
+ params[key_0] = inp.parameter._value_class
2758
2948
 
2759
2949
  elif path_split[0] == "outputs":
2760
- for i in self.template.schemas:
2761
- for j in i.outputs:
2950
+ for schema in self.template.schemas:
2951
+ for out in schema.outputs:
2762
2952
  if (
2763
- j.parameter.typ == path_split[1]
2764
- and j.parameter._value_class
2953
+ out.parameter.typ == path_split[1]
2954
+ and out.parameter._value_class
2765
2955
  ):
2766
- params[key_0] = j.parameter._value_class
2956
+ params[key_0] = out.parameter._value_class
2767
2957
 
2768
2958
  if path_split[2:]:
2769
- pv_classes = ParameterValue.__subclasses__()
2959
+ pv_classes = {cls._typ: cls for cls in ParameterValue.__subclasses__()}
2770
2960
 
2771
2961
  # now proceed by searching for sub-parameters in each ParameterValue
2772
2962
  # sub-class:
@@ -2776,306 +2966,370 @@ class WorkflowTask:
2776
2966
  key_i = ".".join(child)
2777
2967
  if key_i in params:
2778
2968
  continue
2779
- parent_param = params.get(".".join(parent))
2780
- if parent_param:
2969
+ if parent_param := params.get(".".join(parent)):
2781
2970
  for attr_name, sub_type in parent_param._sub_parameters.items():
2782
2971
  if part_i == attr_name:
2783
2972
  # find the class with this `typ` attribute:
2784
- for cls_i in pv_classes:
2785
- if cls_i._typ == sub_type:
2786
- params[key_i] = cls_i
2787
- break
2973
+ if cls := pv_classes.get(sub_type):
2974
+ params[key_i] = cls
2788
2975
 
2789
2976
  return params
2790
2977
 
2791
- @TimeIt.decorator
2792
- def _get_merged_parameter_data(
2793
- self,
2794
- data_index,
2795
- path=None,
2796
- raise_on_missing=False,
2797
- raise_on_unset=False,
2798
- default: Any = None,
2799
- ):
2800
- """Get element data from the persistent store."""
2801
-
2802
- def _get_relevant_paths(
2803
- data_index: Dict, path: List[str], children_of: str = None
2804
- ):
2805
- relevant_paths = {}
2806
- # first extract out relevant paths in `data_index`:
2807
- for path_i in data_index:
2808
- path_i_split = path_i.split(".")
2978
+ @staticmethod
2979
+ def _get_relevant_paths(
2980
+ data_index: Mapping[str, Any], path: list[str], children_of: str | None = None
2981
+ ) -> Mapping[str, RelevantPath]:
2982
+ relevant_paths: dict[str, RelevantPath] = {}
2983
+ # first extract out relevant paths in `data_index`:
2984
+ for path_i in data_index:
2985
+ path_i_split = path_i.split(".")
2986
+ try:
2987
+ rel_path = get_relative_path(path, path_i_split)
2988
+ relevant_paths[path_i] = {"type": "parent", "relative_path": rel_path}
2989
+ except ValueError:
2809
2990
  try:
2810
- rel_path = get_relative_path(path, path_i_split)
2811
- relevant_paths[path_i] = {"type": "parent", "relative_path": rel_path}
2991
+ update_path = get_relative_path(path_i_split, path)
2992
+ relevant_paths[path_i] = {
2993
+ "type": "update",
2994
+ "update_path": update_path,
2995
+ }
2812
2996
  except ValueError:
2813
- try:
2814
- update_path = get_relative_path(path_i_split, path)
2815
- relevant_paths[path_i] = {
2816
- "type": "update",
2817
- "update_path": update_path,
2818
- }
2819
- except ValueError:
2820
- # no intersection between paths
2821
- if children_of and path_i.startswith(children_of):
2822
- relevant_paths[path_i] = {"type": "sibling"}
2823
- continue
2997
+ # no intersection between paths
2998
+ if children_of and path_i.startswith(children_of):
2999
+ relevant_paths[path_i] = {"type": "sibling"}
3000
+ continue
2824
3001
 
2825
- return relevant_paths
2826
-
2827
- def _get_relevant_data(relevant_data_idx: Dict, raise_on_unset: bool, path: str):
2828
- relevant_data = {}
2829
- for path_i, data_idx_i in relevant_data_idx.items():
2830
- is_multi = isinstance(data_idx_i, list)
2831
- if not is_multi:
2832
- data_idx_i = [data_idx_i]
2833
-
2834
- data_i = []
2835
- methods_i = []
2836
- is_param_set_i = []
2837
- for data_idx_ij in data_idx_i:
2838
- meth_i = None
2839
- is_set_i = True
2840
- if path_i[0] == "repeats":
2841
- # data is an integer repeats index, rather than a parameter ID:
2842
- data_j = data_idx_ij
2843
- else:
2844
- param_j = self.workflow.get_parameter(data_idx_ij)
2845
- is_set_i = param_j.is_set
2846
- if param_j.file:
2847
- if param_j.file["store_contents"]:
2848
- data_j = Path(self.workflow.path) / param_j.file["path"]
2849
- else:
2850
- data_j = Path(param_j.file["path"])
2851
- data_j = data_j.as_posix()
2852
- else:
2853
- meth_i = param_j.source.get("value_class_method")
2854
- if param_j.is_pending:
2855
- # if pending, we need to convert `ParameterValue` objects
2856
- # to their dict representation, so they can be merged with
2857
- # other data:
2858
- try:
2859
- data_j = param_j.data.to_dict()
2860
- except AttributeError:
2861
- data_j = param_j.data
2862
- else:
2863
- # if not pending, data will be the result of an encode-
2864
- # decode cycle, and it will not be initialised as an
2865
- # object if the parameter is associated with a
2866
- # `ParameterValue` class.
2867
- data_j = param_j.data
2868
- if raise_on_unset and not is_set_i:
2869
- raise UnsetParameterDataError(
2870
- f"Element data path {path!r} resolves to unset data for "
2871
- f"(at least) data-index path: {path_i!r}."
2872
- )
2873
- if is_multi:
2874
- data_i.append(data_j)
2875
- methods_i.append(meth_i)
2876
- is_param_set_i.append(is_set_i)
2877
- else:
2878
- data_i = data_j
2879
- methods_i = meth_i
2880
- is_param_set_i = is_set_i
3002
+ return relevant_paths
2881
3003
 
3004
+ def __get_relevant_data_item(
3005
+ self,
3006
+ path: str | None,
3007
+ path_i: str,
3008
+ data_idx_ij: int,
3009
+ raise_on_unset: bool,
3010
+ len_dat_idx: int = 1,
3011
+ ) -> tuple[Any, bool, str | None]:
3012
+ if path_i.startswith("repeats."):
3013
+ # data is an integer repeats index, rather than a parameter ID:
3014
+ return data_idx_ij, True, None
3015
+
3016
+ meth_i: str | None = None
3017
+ data_j: Any
3018
+ param_j = self.workflow.get_parameter(data_idx_ij)
3019
+ is_set_i = param_j.is_set
3020
+ if param_j.file:
3021
+ if param_j.file["store_contents"]:
3022
+ file_j = Path(self.workflow.path) / param_j.file["path"]
3023
+ else:
3024
+ file_j = Path(param_j.file["path"])
3025
+ data_j = file_j.as_posix()
3026
+ else:
3027
+ meth_i = param_j.source.get("value_class_method")
3028
+ if param_j.is_pending:
3029
+ # if pending, we need to convert `ParameterValue` objects
3030
+ # to their dict representation, so they can be merged with
3031
+ # other data:
3032
+ try:
3033
+ data_j = cast("ParameterValue", param_j.data).to_dict()
3034
+ except AttributeError:
3035
+ data_j = param_j.data
3036
+ else:
3037
+ # if not pending, data will be the result of an encode-
3038
+ # decode cycle, and it will not be initialised as an
3039
+ # object if the parameter is associated with a
3040
+ # `ParameterValue` class.
3041
+ data_j = param_j.data
3042
+ if raise_on_unset and not is_set_i:
3043
+ raise UnsetParameterDataError(path, path_i)
3044
+ if not is_set_i and self.workflow._is_tracking_unset:
3045
+ src_run_id = param_j.source.get("EAR_ID")
3046
+ unset_trackers = self.workflow._tracked_unset
3047
+ assert src_run_id is not None
3048
+ assert unset_trackers is not None
3049
+ unset_trackers[path_i].run_ids.add(src_run_id)
3050
+ unset_trackers[path_i].group_size = len_dat_idx
3051
+ return data_j, is_set_i, meth_i
3052
+
3053
+ def __get_relevant_data(
3054
+ self,
3055
+ relevant_data_idx: Mapping[str, list[int] | int],
3056
+ raise_on_unset: bool,
3057
+ path: str | None,
3058
+ ) -> Mapping[str, RelevantData]:
3059
+ relevant_data: dict[str, RelevantData] = {}
3060
+ for path_i, data_idx_i in relevant_data_idx.items():
3061
+ if not isinstance(data_idx_i, list):
3062
+ data, is_set, meth = self.__get_relevant_data_item(
3063
+ path, path_i, data_idx_i, raise_on_unset
3064
+ )
2882
3065
  relevant_data[path_i] = {
2883
- "data": data_i,
2884
- "value_class_method": methods_i,
2885
- "is_set": is_param_set_i,
2886
- "is_multi": is_multi,
2887
- }
2888
- if not raise_on_unset:
2889
- to_remove = []
2890
- for key, dat_info in relevant_data.items():
2891
- if not dat_info["is_set"] and ((path and path in key) or not path):
2892
- # remove sub-paths, as they cannot be merged with this parent
2893
- to_remove.extend(
2894
- k for k in relevant_data if k != key and k.startswith(key)
2895
- )
2896
- relevant_data = {
2897
- k: v for k, v in relevant_data.items() if k not in to_remove
3066
+ "data": data,
3067
+ "value_class_method": meth,
3068
+ "is_set": is_set,
3069
+ "is_multi": False,
2898
3070
  }
3071
+ continue
2899
3072
 
2900
- return relevant_data
3073
+ data_i: list[Any] = []
3074
+ methods_i: list[str | None] = []
3075
+ is_param_set_i: list[bool] = []
3076
+ for data_idx_ij in data_idx_i:
3077
+ data_j, is_set_i, meth_i = self.__get_relevant_data_item(
3078
+ path, path_i, data_idx_ij, raise_on_unset, len_dat_idx=len(data_idx_i)
3079
+ )
3080
+ data_i.append(data_j)
3081
+ methods_i.append(meth_i)
3082
+ is_param_set_i.append(is_set_i)
3083
+
3084
+ relevant_data[path_i] = {
3085
+ "data": data_i,
3086
+ "value_class_method": methods_i,
3087
+ "is_set": is_param_set_i,
3088
+ "is_multi": True,
3089
+ }
2901
3090
 
2902
- def _merge_relevant_data(
2903
- relevant_data, relevant_paths, PV_classes, path: str, raise_on_missing: bool
2904
- ):
2905
- current_val = None
2906
- assigned_from_parent = False
2907
- val_cls_method = None
2908
- path_is_multi = False
2909
- path_is_set = False
2910
- all_multi_len = None
2911
- for path_i, data_info_i in relevant_data.items():
2912
- data_i = data_info_i["data"]
2913
- if path_i == path:
2914
- val_cls_method = data_info_i["value_class_method"]
2915
- path_is_multi = data_info_i["is_multi"]
2916
- path_is_set = data_info_i["is_set"]
2917
-
2918
- if data_info_i["is_multi"]:
2919
- if all_multi_len:
2920
- if len(data_i) != all_multi_len:
2921
- raise RuntimeError(
2922
- f"Cannot merge group values of different lengths."
2923
- )
2924
- else:
2925
- # keep track of group lengths, only merge equal-length groups;
2926
- all_multi_len = len(data_i)
2927
-
2928
- path_info = relevant_paths[path_i]
2929
- path_type = path_info["type"]
2930
- if path_type == "parent":
2931
- try:
2932
- if data_info_i["is_multi"]:
2933
- current_val = [
2934
- get_in_container(
2935
- i,
2936
- path_info["relative_path"],
2937
- cast_indices=True,
2938
- )
2939
- for i in data_i
2940
- ]
2941
- path_is_multi = True
2942
- path_is_set = data_info_i["is_set"]
2943
- val_cls_method = data_info_i["value_class_method"]
2944
- else:
2945
- current_val = get_in_container(
2946
- data_i,
3091
+ if not raise_on_unset:
3092
+ to_remove: set[str] = set()
3093
+ for key, dat_info in relevant_data.items():
3094
+ if not dat_info["is_set"] and (not path or path in key):
3095
+ # remove sub-paths, as they cannot be merged with this parent
3096
+ prefix = f"{key}."
3097
+ to_remove.update(k for k in relevant_data if k.startswith(prefix))
3098
+ for key in to_remove:
3099
+ relevant_data.pop(key, None)
3100
+
3101
+ return relevant_data
3102
+
3103
+ @classmethod
3104
+ def __merge_relevant_data(
3105
+ cls,
3106
+ relevant_data: Mapping[str, RelevantData],
3107
+ relevant_paths: Mapping[str, RelevantPath],
3108
+ PV_classes,
3109
+ path: str | None,
3110
+ raise_on_missing: bool,
3111
+ ):
3112
+ current_val: list | dict | Any | None = None
3113
+ assigned_from_parent = False
3114
+ val_cls_method: str | None | list[str | None] = None
3115
+ path_is_multi = False
3116
+ path_is_set: bool | list[bool] = False
3117
+ all_multi_len: int | None = None
3118
+ for path_i, data_info_i in relevant_data.items():
3119
+ data_i = data_info_i["data"]
3120
+ if path_i == path:
3121
+ val_cls_method = data_info_i["value_class_method"]
3122
+ path_is_multi = data_info_i["is_multi"]
3123
+ path_is_set = data_info_i["is_set"]
3124
+
3125
+ if data_info_i["is_multi"]:
3126
+ if all_multi_len:
3127
+ if len(data_i) != all_multi_len:
3128
+ raise RuntimeError(
3129
+ "Cannot merge group values of different lengths."
3130
+ )
3131
+ else:
3132
+ # keep track of group lengths, only merge equal-length groups;
3133
+ all_multi_len = len(data_i)
3134
+
3135
+ path_info = relevant_paths[path_i]
3136
+ if path_info["type"] == "parent":
3137
+ try:
3138
+ if data_info_i["is_multi"]:
3139
+ current_val = [
3140
+ get_in_container(
3141
+ item,
2947
3142
  path_info["relative_path"],
2948
3143
  cast_indices=True,
2949
3144
  )
2950
- except ContainerKeyError as err:
2951
- if path_i in PV_classes:
2952
- err_path = ".".join([path_i] + err.path[:-1])
2953
- raise MayNeedObjectError(path=err_path)
2954
- continue
2955
- except (IndexError, ValueError) as err:
2956
- if raise_on_missing:
2957
- raise err
2958
- continue
2959
- else:
2960
- assigned_from_parent = True
2961
- elif path_type == "update":
2962
- current_val = current_val or {}
2963
- if all_multi_len:
2964
- if len(path_i.split(".")) == 2:
2965
- # groups can only be "created" at the parameter level
2966
- set_in_container(
2967
- cont=current_val,
2968
- path=path_info["update_path"],
2969
- value=data_i,
2970
- ensure_path=True,
2971
- cast_indices=True,
2972
- )
2973
- else:
2974
- # update group
2975
- update_path = path_info["update_path"]
2976
- if len(update_path) > 1:
2977
- for idx, j in enumerate(data_i):
2978
- path_ij = update_path[0:1] + [idx] + update_path[1:]
2979
- set_in_container(
2980
- cont=current_val,
2981
- path=path_ij,
2982
- value=j,
2983
- ensure_path=True,
2984
- cast_indices=True,
2985
- )
2986
- else:
2987
- for idx, (i, j) in enumerate(zip(current_val, data_i)):
2988
- set_in_container(
2989
- cont=i,
2990
- path=update_path,
2991
- value=j,
2992
- ensure_path=True,
2993
- cast_indices=True,
2994
- )
2995
-
3145
+ for item in data_i
3146
+ ]
3147
+ path_is_multi = True
3148
+ path_is_set = data_info_i["is_set"]
3149
+ val_cls_method = data_info_i["value_class_method"]
2996
3150
  else:
2997
- set_in_container(
2998
- current_val,
2999
- path_info["update_path"],
3151
+ current_val = get_in_container(
3000
3152
  data_i,
3153
+ path_info["relative_path"],
3154
+ cast_indices=True,
3155
+ )
3156
+ except ContainerKeyError as err:
3157
+ if path_i in PV_classes:
3158
+ raise MayNeedObjectError(path=".".join([path_i, *err.path[:-1]]))
3159
+ continue
3160
+ except (IndexError, ValueError) as err:
3161
+ if raise_on_missing:
3162
+ raise err
3163
+ continue
3164
+ else:
3165
+ assigned_from_parent = True
3166
+ elif path_info["type"] == "update":
3167
+ current_val = current_val or {}
3168
+ if all_multi_len:
3169
+ if len(path_i.split(".")) == 2:
3170
+ # groups can only be "created" at the parameter level
3171
+ set_in_container(
3172
+ cont=current_val,
3173
+ path=path_info["update_path"],
3174
+ value=data_i,
3001
3175
  ensure_path=True,
3002
3176
  cast_indices=True,
3003
3177
  )
3004
- if path in PV_classes:
3005
- if path not in relevant_data:
3006
- # requested data must be a sub-path of relevant data, so we can assume
3007
- # path is set (if the parent was not set the sub-paths would be
3008
- # removed in `_get_relevant_data`):
3009
- path_is_set = path_is_set or True
3010
-
3011
- if not assigned_from_parent:
3012
- # search for unset parents in `relevant_data`:
3013
- path_split = path.split(".")
3014
- for parent_i_span in range(len(path_split) - 1, 1, -1):
3015
- parent_path_i = ".".join(path_split[0:parent_i_span])
3016
- relevant_par = relevant_data.get(parent_path_i)
3017
- par_is_set = relevant_par["is_set"]
3018
- if not par_is_set or any(not i for i in par_is_set):
3019
- val_cls_method = relevant_par["value_class_method"]
3020
- path_is_multi = relevant_par["is_multi"]
3021
- path_is_set = relevant_par["is_set"]
3022
- current_val = relevant_par["data"]
3023
- break
3024
-
3025
- # initialise objects
3026
- PV_cls = PV_classes[path]
3027
- if path_is_multi:
3028
- _current_val = []
3029
- for set_i, meth_i, val_i in zip(
3030
- path_is_set, val_cls_method, current_val
3031
- ):
3032
- if set_i and isinstance(val_i, dict):
3033
- method_i = getattr(PV_cls, meth_i) if meth_i else PV_cls
3034
- _cur_val_i = method_i(**val_i)
3178
+ else:
3179
+ # update group
3180
+ update_path = path_info["update_path"]
3181
+ if len(update_path) > 1:
3182
+ for idx, j in enumerate(data_i):
3183
+ set_in_container(
3184
+ cont=current_val,
3185
+ path=[*update_path[:1], idx, *update_path[1:]],
3186
+ value=j,
3187
+ ensure_path=True,
3188
+ cast_indices=True,
3189
+ )
3035
3190
  else:
3036
- _cur_val_i = None
3037
- _current_val.append(_cur_val_i)
3038
- current_val = _current_val
3039
- elif path_is_set and isinstance(current_val, dict):
3040
- method = getattr(PV_cls, val_cls_method) if val_cls_method else PV_cls
3041
- current_val = method(**current_val)
3191
+ for i, j in zip(current_val, data_i):
3192
+ set_in_container(
3193
+ cont=i,
3194
+ path=update_path,
3195
+ value=j,
3196
+ ensure_path=True,
3197
+ cast_indices=True,
3198
+ )
3042
3199
 
3043
- return current_val, all_multi_len
3200
+ else:
3201
+ set_in_container(
3202
+ current_val,
3203
+ path_info["update_path"],
3204
+ data_i,
3205
+ ensure_path=True,
3206
+ cast_indices=True,
3207
+ )
3208
+ if path in PV_classes:
3209
+ if path not in relevant_data:
3210
+ # requested data must be a sub-path of relevant data, so we can assume
3211
+ # path is set (if the parent was not set the sub-paths would be
3212
+ # removed in `__get_relevant_data`):
3213
+ path_is_set = path_is_set or True
3214
+
3215
+ if not assigned_from_parent:
3216
+ # search for unset parents in `relevant_data`:
3217
+ assert path is not None
3218
+ for parent_i_span in range(
3219
+ len(path_split := path.split(".")) - 1, 1, -1
3220
+ ):
3221
+ parent_path_i = ".".join(path_split[:parent_i_span])
3222
+ if not (relevant_par := relevant_data.get(parent_path_i)):
3223
+ continue
3224
+ if not (par_is_set := relevant_par["is_set"]) or not all(
3225
+ cast("list", par_is_set)
3226
+ ):
3227
+ val_cls_method = relevant_par["value_class_method"]
3228
+ path_is_multi = relevant_par["is_multi"]
3229
+ path_is_set = relevant_par["is_set"]
3230
+ current_val = relevant_par["data"]
3231
+ break
3044
3232
 
3045
- # TODO: custom exception?
3046
- missing_err = ValueError(f"Path {path!r} does not exist in the element data.")
3233
+ # initialise objects
3234
+ PV_cls = PV_classes[path]
3235
+ if path_is_multi:
3236
+ current_val = [
3237
+ (
3238
+ cls.__map_parameter_value(PV_cls, meth_i, val_i)
3239
+ if set_i and isinstance(val_i, dict)
3240
+ else None
3241
+ )
3242
+ for set_i, meth_i, val_i in zip(
3243
+ cast("list[bool]", path_is_set),
3244
+ cast("list[str|None]", val_cls_method),
3245
+ cast("list[Any]", current_val),
3246
+ )
3247
+ ]
3248
+ elif path_is_set and isinstance(current_val, dict):
3249
+ assert not isinstance(val_cls_method, list)
3250
+ current_val = cls.__map_parameter_value(
3251
+ PV_cls, val_cls_method, current_val
3252
+ )
3047
3253
 
3254
+ return current_val, all_multi_len
3255
+
3256
+ @staticmethod
3257
+ def __map_parameter_value(
3258
+ PV_cls: type[ParameterValue], meth: str | None, val: dict
3259
+ ) -> Any | ParameterValue:
3260
+ if meth:
3261
+ method: Callable = getattr(PV_cls, meth)
3262
+ return method(**val)
3263
+ else:
3264
+ return PV_cls(**val)
3265
+
3266
+ @TimeIt.decorator
3267
+ def _get_merged_parameter_data(
3268
+ self,
3269
+ data_index: Mapping[str, list[int] | int],
3270
+ path: str | None = None,
3271
+ *,
3272
+ raise_on_missing: bool = False,
3273
+ raise_on_unset: bool = False,
3274
+ default: Any | None = None,
3275
+ ):
3276
+ """Get element data from the persistent store."""
3048
3277
  path_split = [] if not path else path.split(".")
3049
3278
 
3050
- relevant_paths = _get_relevant_paths(data_index, path_split)
3051
- if not relevant_paths:
3279
+ if not (relevant_paths := self._get_relevant_paths(data_index, path_split)):
3052
3280
  if raise_on_missing:
3053
- raise missing_err
3281
+ # TODO: custom exception?
3282
+ raise ValueError(f"Path {path!r} does not exist in the element data.")
3054
3283
  return default
3055
3284
 
3056
3285
  relevant_data_idx = {k: v for k, v in data_index.items() if k in relevant_paths}
3057
- PV_cls_paths = list(relevant_paths.keys()) + ([path] if path else [])
3058
- PV_classes = self._paths_to_PV_classes(PV_cls_paths)
3059
- relevant_data = _get_relevant_data(relevant_data_idx, raise_on_unset, path)
3286
+
3287
+ cache = self.workflow._merged_parameters_cache
3288
+ use_cache = (
3289
+ self.workflow._use_merged_parameters_cache
3290
+ and raise_on_missing is False
3291
+ and raise_on_unset is False
3292
+ and default is None # cannot cache on default value, may not be hashable
3293
+ )
3294
+ add_to_cache = False
3295
+ if use_cache:
3296
+ # generate the key:
3297
+ dat_idx_cache: list[tuple[str, tuple[int, ...] | int]] = []
3298
+ for k, v in sorted(relevant_data_idx.items()):
3299
+ dat_idx_cache.append((k, tuple(v) if isinstance(v, list) else v))
3300
+ cache_key = (path, tuple(dat_idx_cache))
3301
+
3302
+ # check for cache hit:
3303
+ if cache_key in cache:
3304
+ self._app.logger.debug(
3305
+ f"_get_merged_parameter_data: cache hit with key: {cache_key}"
3306
+ )
3307
+ return cache[cache_key]
3308
+ else:
3309
+ add_to_cache = True
3310
+
3311
+ PV_classes = self._paths_to_PV_classes(*relevant_paths, path)
3312
+ relevant_data = self.__get_relevant_data(relevant_data_idx, raise_on_unset, path)
3060
3313
 
3061
3314
  current_val = None
3062
3315
  is_assigned = False
3063
3316
  try:
3064
- current_val, _ = _merge_relevant_data(
3317
+ current_val, _ = self.__merge_relevant_data(
3065
3318
  relevant_data, relevant_paths, PV_classes, path, raise_on_missing
3066
3319
  )
3067
3320
  except MayNeedObjectError as err:
3068
3321
  path_to_init = err.path
3069
3322
  path_to_init_split = path_to_init.split(".")
3070
- relevant_paths = _get_relevant_paths(data_index, path_to_init_split)
3071
- PV_cls_paths = list(relevant_paths.keys()) + [path_to_init]
3072
- PV_classes = self._paths_to_PV_classes(PV_cls_paths)
3323
+ relevant_paths = self._get_relevant_paths(data_index, path_to_init_split)
3324
+ PV_classes = self._paths_to_PV_classes(*relevant_paths, path_to_init)
3073
3325
  relevant_data_idx = {
3074
3326
  k: v for k, v in data_index.items() if k in relevant_paths
3075
3327
  }
3076
- relevant_data = _get_relevant_data(relevant_data_idx, raise_on_unset, path)
3328
+ relevant_data = self.__get_relevant_data(
3329
+ relevant_data_idx, raise_on_unset, path
3330
+ )
3077
3331
  # merge the parent data
3078
- current_val, group_len = _merge_relevant_data(
3332
+ current_val, group_len = self.__merge_relevant_data(
3079
3333
  relevant_data, relevant_paths, PV_classes, path_to_init, raise_on_missing
3080
3334
  )
3081
3335
  # try to retrieve attributes via the initialised object:
@@ -3084,12 +3338,12 @@ class WorkflowTask:
3084
3338
  if group_len:
3085
3339
  current_val = [
3086
3340
  get_in_container(
3087
- cont=i,
3341
+ cont=item,
3088
3342
  path=rel_path_split,
3089
3343
  cast_indices=True,
3090
3344
  allow_getattr=True,
3091
3345
  )
3092
- for i in current_val
3346
+ for item in current_val
3093
3347
  ]
3094
3348
  else:
3095
3349
  current_val = get_in_container(
@@ -3110,9 +3364,18 @@ class WorkflowTask:
3110
3364
 
3111
3365
  if not is_assigned:
3112
3366
  if raise_on_missing:
3113
- raise missing_err
3367
+ # TODO: custom exception?
3368
+ raise ValueError(f"Path {path!r} does not exist in the element data.")
3114
3369
  current_val = default
3115
3370
 
3371
+ if add_to_cache:
3372
+ self._app.logger.debug(
3373
+ f"_get_merged_parameter_data: adding to cache with key: {cache_key!r}"
3374
+ )
3375
+ # tuple[str | None, tuple[tuple[str, tuple[int, ...] | int], ...]]
3376
+ # tuple[str | None, tuple[tuple[str, tuple[int, ...] | int], ...]] | None
3377
+ cache[cache_key] = current_val
3378
+
3116
3379
  return current_val
3117
3380
 
3118
3381
 
@@ -3128,7 +3391,7 @@ class Elements:
3128
3391
 
3129
3392
  __slots__ = ("_task",)
3130
3393
 
3131
- def __init__(self, task: app.WorkflowTask):
3394
+ def __init__(self, task: WorkflowTask):
3132
3395
  self._task = task
3133
3396
 
3134
3397
  # TODO: cache Element objects
@@ -3140,44 +3403,57 @@ class Elements:
3140
3403
  )
3141
3404
 
3142
3405
  @property
3143
- def task(self):
3406
+ def task(self) -> WorkflowTask:
3144
3407
  """
3145
3408
  The task this is the elements of.
3146
3409
  """
3147
3410
  return self._task
3148
3411
 
3149
3412
  @TimeIt.decorator
3150
- def _get_selection(self, selection: Union[int, slice, List[int]]) -> List[int]:
3413
+ def __get_selection(self, selection: int | slice | list[int]) -> list[int]:
3151
3414
  """Normalise an element selection into a list of element indices."""
3152
3415
  if isinstance(selection, int):
3153
- lst = [selection]
3416
+ return [selection]
3154
3417
 
3155
3418
  elif isinstance(selection, slice):
3156
- lst = list(range(*selection.indices(self.task.num_elements)))
3419
+ return list(range(*selection.indices(self.task.num_elements)))
3157
3420
 
3158
3421
  elif isinstance(selection, list):
3159
- lst = selection
3422
+ return selection
3160
3423
  else:
3161
3424
  raise RuntimeError(
3162
3425
  f"{self.__class__.__name__} selection must be an `int`, `slice` object, "
3163
3426
  f"or list of `int`s, but received type {type(selection)}."
3164
3427
  )
3165
- return lst
3166
3428
 
3167
- def __len__(self):
3429
+ def __len__(self) -> int:
3168
3430
  return self.task.num_elements
3169
3431
 
3170
- def __iter__(self):
3171
- for i in self.task.workflow.get_task_elements(self.task):
3172
- yield i
3432
+ def __iter__(self) -> Iterator[Element]:
3433
+ yield from self.task.workflow.get_task_elements(self.task)
3434
+
3435
+ @overload
3436
+ def __getitem__(
3437
+ self,
3438
+ selection: int,
3439
+ ) -> Element:
3440
+ ...
3441
+
3442
+ @overload
3443
+ def __getitem__(
3444
+ self,
3445
+ selection: slice | list[int],
3446
+ ) -> list[Element]:
3447
+ ...
3173
3448
 
3174
3449
  @TimeIt.decorator
3175
3450
  def __getitem__(
3176
3451
  self,
3177
- selection: Union[int, slice, List[int]],
3178
- ) -> Union[app.Element, List[app.Element]]:
3179
- idx_lst = self._get_selection(selection)
3180
- elements = self.task.workflow.get_task_elements(self.task, idx_lst)
3452
+ selection: int | slice | list[int],
3453
+ ) -> Element | list[Element]:
3454
+ elements = self.task.workflow.get_task_elements(
3455
+ self.task, self.__get_selection(selection)
3456
+ )
3181
3457
 
3182
3458
  if isinstance(selection, int):
3183
3459
  return elements[0]
@@ -3186,7 +3462,8 @@ class Elements:
3186
3462
 
3187
3463
 
3188
3464
  @dataclass
3189
- class Parameters:
3465
+ @hydrate
3466
+ class Parameters(AppAware):
3190
3467
  """
3191
3468
  The parameters of a (workflow-bound) task. Iterable.
3192
3469
 
@@ -3206,78 +3483,85 @@ class Parameters:
3206
3483
  A default value to use when the parameter is absent.
3207
3484
  """
3208
3485
 
3209
- _app_attr = "_app"
3210
-
3211
3486
  #: The task these are the parameters of.
3212
- task: app.WorkflowTask
3487
+ task: WorkflowTask
3213
3488
  #: The path to the parameter or parameters.
3214
3489
  path: str
3215
3490
  #: Whether to return element parameters.
3216
3491
  return_element_parameters: bool
3217
3492
  #: Whether to raise an exception on a missing parameter.
3218
- raise_on_missing: Optional[bool] = False
3493
+ raise_on_missing: bool = False
3219
3494
  #: Whether to raise an exception on an unset parameter.
3220
- raise_on_unset: Optional[bool] = False
3495
+ raise_on_unset: bool = False
3221
3496
  #: A default value to use when the parameter is absent.
3222
- default: Optional[Any] = None
3497
+ default: Any | None = None
3223
3498
 
3224
3499
  @TimeIt.decorator
3225
- def _get_selection(self, selection: Union[int, slice, List[int]]) -> List[int]:
3500
+ def __get_selection(
3501
+ self, selection: int | slice | list[int] | tuple[int, ...]
3502
+ ) -> list[int]:
3226
3503
  """Normalise an element selection into a list of element indices."""
3227
3504
  if isinstance(selection, int):
3228
- lst = [selection]
3229
-
3505
+ return [selection]
3230
3506
  elif isinstance(selection, slice):
3231
- lst = list(range(*selection.indices(self.task.num_elements)))
3232
-
3507
+ return list(range(*selection.indices(self.task.num_elements)))
3233
3508
  elif isinstance(selection, list):
3234
- lst = selection
3509
+ return selection
3510
+ elif isinstance(selection, tuple):
3511
+ return list(selection)
3235
3512
  else:
3236
3513
  raise RuntimeError(
3237
3514
  f"{self.__class__.__name__} selection must be an `int`, `slice` object, "
3238
3515
  f"or list of `int`s, but received type {type(selection)}."
3239
3516
  )
3240
- return lst
3241
3517
 
3242
- def __iter__(self):
3243
- for i in self.__getitem__(slice(None)):
3244
- yield i
3518
+ def __iter__(self) -> Iterator[Any | ElementParameter]:
3519
+ yield from self.__getitem__(slice(None))
3520
+
3521
+ @overload
3522
+ def __getitem__(self, selection: int) -> Any | ElementParameter:
3523
+ ...
3524
+
3525
+ @overload
3526
+ def __getitem__(self, selection: slice | list[int]) -> list[Any | ElementParameter]:
3527
+ ...
3245
3528
 
3246
3529
  def __getitem__(
3247
3530
  self,
3248
- selection: Union[int, slice, List[int]],
3249
- ) -> Union[Any, List[Any]]:
3250
- idx_lst = self._get_selection(selection)
3531
+ selection: int | slice | list[int],
3532
+ ) -> Any | ElementParameter | list[Any | ElementParameter]:
3533
+ idx_lst = self.__get_selection(selection)
3251
3534
  elements = self.task.workflow.get_task_elements(self.task, idx_lst)
3252
3535
  if self.return_element_parameters:
3253
- params = [
3536
+ params = (
3254
3537
  self._app.ElementParameter(
3255
3538
  task=self.task,
3256
3539
  path=self.path,
3257
- parent=i,
3258
- element=i,
3540
+ parent=elem,
3541
+ element=elem,
3259
3542
  )
3260
- for i in elements
3261
- ]
3543
+ for elem in elements
3544
+ )
3262
3545
  else:
3263
- params = [
3264
- i.get(
3546
+ params = (
3547
+ elem.get(
3265
3548
  path=self.path,
3266
3549
  raise_on_missing=self.raise_on_missing,
3267
3550
  raise_on_unset=self.raise_on_unset,
3268
3551
  default=self.default,
3269
3552
  )
3270
- for i in elements
3271
- ]
3553
+ for elem in elements
3554
+ )
3272
3555
 
3273
3556
  if isinstance(selection, int):
3274
- return params[0]
3557
+ return next(iter(params))
3275
3558
  else:
3276
- return params
3559
+ return list(params)
3277
3560
 
3278
3561
 
3279
3562
  @dataclass
3280
- class TaskInputParameters:
3563
+ @hydrate
3564
+ class TaskInputParameters(AppAware):
3281
3565
  """
3282
3566
  For retrieving schema input parameters across all elements.
3283
3567
  Treat as an unmodifiable namespace.
@@ -3288,31 +3572,34 @@ class TaskInputParameters:
3288
3572
  The task that this represents the input parameters of.
3289
3573
  """
3290
3574
 
3291
- _app_attr = "_app"
3292
-
3293
3575
  #: The task that this represents the input parameters of.
3294
- task: app.WorkflowTask
3576
+ task: WorkflowTask
3577
+ __input_names: frozenset[str] | None = field(default=None, init=False, compare=False)
3295
3578
 
3296
- def __getattr__(self, name):
3297
- if name not in self._get_input_names():
3579
+ def __getattr__(self, name: str) -> Parameters:
3580
+ if name not in self.__get_input_names():
3298
3581
  raise ValueError(f"No input named {name!r}.")
3299
3582
  return self._app.Parameters(self.task, f"inputs.{name}", True)
3300
3583
 
3301
- def __repr__(self):
3584
+ def __repr__(self) -> str:
3302
3585
  return (
3303
3586
  f"{self.__class__.__name__}("
3304
- f"{', '.join(f'{i!r}' for i in self._get_input_names())})"
3587
+ f"{', '.join(f'{name!r}' for name in sorted(self.__get_input_names()))})"
3305
3588
  )
3306
3589
 
3307
- def __dir__(self):
3308
- return super().__dir__() + self._get_input_names()
3590
+ def __dir__(self) -> Iterator[str]:
3591
+ yield from super().__dir__()
3592
+ yield from sorted(self.__get_input_names())
3309
3593
 
3310
- def _get_input_names(self):
3311
- return sorted(self.task.template.all_schema_input_types)
3594
+ def __get_input_names(self) -> frozenset[str]:
3595
+ if self.__input_names is None:
3596
+ self.__input_names = frozenset(self.task.template.all_schema_input_types)
3597
+ return self.__input_names
3312
3598
 
3313
3599
 
3314
3600
  @dataclass
3315
- class TaskOutputParameters:
3601
+ @hydrate
3602
+ class TaskOutputParameters(AppAware):
3316
3603
  """
3317
3604
  For retrieving schema output parameters across all elements.
3318
3605
  Treat as an unmodifiable namespace.
@@ -3323,31 +3610,34 @@ class TaskOutputParameters:
3323
3610
  The task that this represents the output parameters of.
3324
3611
  """
3325
3612
 
3326
- _app_attr = "_app"
3327
-
3328
3613
  #: The task that this represents the output parameters of.
3329
- task: app.WorkflowTask
3614
+ task: WorkflowTask
3615
+ __output_names: frozenset[str] | None = field(default=None, init=False, compare=False)
3330
3616
 
3331
- def __getattr__(self, name):
3332
- if name not in self._get_output_names():
3617
+ def __getattr__(self, name: str) -> Parameters:
3618
+ if name not in self.__get_output_names():
3333
3619
  raise ValueError(f"No output named {name!r}.")
3334
3620
  return self._app.Parameters(self.task, f"outputs.{name}", True)
3335
3621
 
3336
- def __repr__(self):
3622
+ def __repr__(self) -> str:
3337
3623
  return (
3338
3624
  f"{self.__class__.__name__}("
3339
- f"{', '.join(f'{i!r}' for i in self._get_output_names())})"
3625
+ f"{', '.join(map(repr, sorted(self.__get_output_names())))})"
3340
3626
  )
3341
3627
 
3342
- def __dir__(self):
3343
- return super().__dir__() + self._get_output_names()
3628
+ def __dir__(self) -> Iterator[str]:
3629
+ yield from super().__dir__()
3630
+ yield from sorted(self.__get_output_names())
3344
3631
 
3345
- def _get_output_names(self):
3346
- return sorted(self.task.template.all_schema_output_types)
3632
+ def __get_output_names(self) -> frozenset[str]:
3633
+ if self.__output_names is None:
3634
+ self.__output_names = frozenset(self.task.template.all_schema_output_types)
3635
+ return self.__output_names
3347
3636
 
3348
3637
 
3349
3638
  @dataclass
3350
- class ElementPropagation:
3639
+ @hydrate
3640
+ class ElementPropagation(AppAware):
3351
3641
  """
3352
3642
  Class to represent how a newly added element set should propagate to a given
3353
3643
  downstream task.
@@ -3362,17 +3652,15 @@ class ElementPropagation:
3362
3652
  The input source information.
3363
3653
  """
3364
3654
 
3365
- _app_attr = "app"
3366
-
3367
3655
  #: The task this is propagating to.
3368
- task: app.Task
3656
+ task: WorkflowTask
3369
3657
  #: The nesting order information.
3370
- nesting_order: Optional[Dict] = None
3658
+ nesting_order: dict[str, float] | None = None
3371
3659
  #: The input source information.
3372
- input_sources: Optional[Dict] = None
3660
+ input_sources: dict[str, list[InputSource]] | None = None
3373
3661
 
3374
3662
  @property
3375
- def element_set(self):
3663
+ def element_set(self) -> ElementSet:
3376
3664
  """
3377
3665
  The element set that this propagates from.
3378
3666
 
@@ -3383,25 +3671,47 @@ class ElementPropagation:
3383
3671
  # TEMP property; for now just use the first element set as the base:
3384
3672
  return self.task.template.element_sets[0]
3385
3673
 
3386
- def __deepcopy__(self, memo):
3674
+ def __deepcopy__(self, memo: dict[int, Any] | None) -> Self:
3387
3675
  return self.__class__(
3388
3676
  task=self.task,
3389
- nesting_order=copy.deepcopy(self.nesting_order, memo),
3677
+ nesting_order=copy.copy(self.nesting_order),
3390
3678
  input_sources=copy.deepcopy(self.input_sources, memo),
3391
3679
  )
3392
3680
 
3393
3681
  @classmethod
3394
- def _prepare_propagate_to_dict(cls, propagate_to, workflow):
3395
- propagate_to = copy.deepcopy(propagate_to)
3682
+ def _prepare_propagate_to_dict(
3683
+ cls,
3684
+ propagate_to: (
3685
+ list[ElementPropagation]
3686
+ | Mapping[str, ElementPropagation | Mapping[str, Any]]
3687
+ | None
3688
+ ),
3689
+ workflow: Workflow,
3690
+ ) -> dict[str, ElementPropagation]:
3396
3691
  if not propagate_to:
3397
- propagate_to = {}
3398
- elif isinstance(propagate_to, list):
3399
- propagate_to = {i.task.unique_name: i for i in propagate_to}
3400
-
3401
- for k, v in propagate_to.items():
3402
- if not isinstance(v, ElementPropagation):
3403
- propagate_to[k] = cls.app.ElementPropagation(
3404
- task=workflow.tasks.get(unique_name=k),
3405
- **v,
3406
- )
3407
- return propagate_to
3692
+ return {}
3693
+ propagate_to = copy.deepcopy(propagate_to)
3694
+ if isinstance(propagate_to, list):
3695
+ return {prop.task.unique_name: prop for prop in propagate_to}
3696
+
3697
+ return {
3698
+ k: (
3699
+ v
3700
+ if isinstance(v, ElementPropagation)
3701
+ else cls(task=workflow.tasks.get(unique_name=k), **v)
3702
+ )
3703
+ for k, v in propagate_to.items()
3704
+ }
3705
+
3706
+
3707
+ #: A task used as a template for other tasks.
3708
+ TaskTemplate: TypeAlias = Task
3709
+
3710
+
3711
+ class MetaTask(JSONLike):
3712
+ def __init__(self, schema: MetaTaskSchema, tasks: Sequence[Task]):
3713
+ self.schema = schema
3714
+ self.tasks = tasks
3715
+
3716
+ # TODO: validate schema's inputs and outputs are inputs and outputs of `tasks`
3717
+ # schemas