hpcflow-new2 0.2.0a189__py3-none-any.whl → 0.2.0a190__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. hpcflow/__pyinstaller/hook-hpcflow.py +8 -6
  2. hpcflow/_version.py +1 -1
  3. hpcflow/app.py +1 -0
  4. hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +1 -1
  5. hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +1 -1
  6. hpcflow/sdk/__init__.py +21 -15
  7. hpcflow/sdk/app.py +2133 -770
  8. hpcflow/sdk/cli.py +281 -250
  9. hpcflow/sdk/cli_common.py +6 -2
  10. hpcflow/sdk/config/__init__.py +1 -1
  11. hpcflow/sdk/config/callbacks.py +77 -42
  12. hpcflow/sdk/config/cli.py +126 -103
  13. hpcflow/sdk/config/config.py +578 -311
  14. hpcflow/sdk/config/config_file.py +131 -95
  15. hpcflow/sdk/config/errors.py +112 -85
  16. hpcflow/sdk/config/types.py +145 -0
  17. hpcflow/sdk/core/actions.py +1054 -994
  18. hpcflow/sdk/core/app_aware.py +24 -0
  19. hpcflow/sdk/core/cache.py +81 -63
  20. hpcflow/sdk/core/command_files.py +275 -185
  21. hpcflow/sdk/core/commands.py +111 -107
  22. hpcflow/sdk/core/element.py +724 -503
  23. hpcflow/sdk/core/enums.py +192 -0
  24. hpcflow/sdk/core/environment.py +74 -93
  25. hpcflow/sdk/core/errors.py +398 -51
  26. hpcflow/sdk/core/json_like.py +540 -272
  27. hpcflow/sdk/core/loop.py +380 -334
  28. hpcflow/sdk/core/loop_cache.py +160 -43
  29. hpcflow/sdk/core/object_list.py +370 -207
  30. hpcflow/sdk/core/parameters.py +728 -600
  31. hpcflow/sdk/core/rule.py +59 -41
  32. hpcflow/sdk/core/run_dir_files.py +33 -22
  33. hpcflow/sdk/core/task.py +1546 -1325
  34. hpcflow/sdk/core/task_schema.py +240 -196
  35. hpcflow/sdk/core/test_utils.py +126 -88
  36. hpcflow/sdk/core/types.py +387 -0
  37. hpcflow/sdk/core/utils.py +410 -305
  38. hpcflow/sdk/core/validation.py +82 -9
  39. hpcflow/sdk/core/workflow.py +1192 -1028
  40. hpcflow/sdk/core/zarr_io.py +98 -137
  41. hpcflow/sdk/demo/cli.py +46 -33
  42. hpcflow/sdk/helper/cli.py +18 -16
  43. hpcflow/sdk/helper/helper.py +75 -63
  44. hpcflow/sdk/helper/watcher.py +61 -28
  45. hpcflow/sdk/log.py +83 -59
  46. hpcflow/sdk/persistence/__init__.py +8 -31
  47. hpcflow/sdk/persistence/base.py +988 -586
  48. hpcflow/sdk/persistence/defaults.py +6 -0
  49. hpcflow/sdk/persistence/discovery.py +38 -0
  50. hpcflow/sdk/persistence/json.py +408 -153
  51. hpcflow/sdk/persistence/pending.py +158 -123
  52. hpcflow/sdk/persistence/store_resource.py +37 -22
  53. hpcflow/sdk/persistence/types.py +307 -0
  54. hpcflow/sdk/persistence/utils.py +14 -11
  55. hpcflow/sdk/persistence/zarr.py +477 -420
  56. hpcflow/sdk/runtime.py +44 -41
  57. hpcflow/sdk/submission/{jobscript_info.py → enums.py} +39 -12
  58. hpcflow/sdk/submission/jobscript.py +444 -404
  59. hpcflow/sdk/submission/schedulers/__init__.py +133 -40
  60. hpcflow/sdk/submission/schedulers/direct.py +97 -71
  61. hpcflow/sdk/submission/schedulers/sge.py +132 -126
  62. hpcflow/sdk/submission/schedulers/slurm.py +263 -268
  63. hpcflow/sdk/submission/schedulers/utils.py +7 -2
  64. hpcflow/sdk/submission/shells/__init__.py +14 -15
  65. hpcflow/sdk/submission/shells/base.py +102 -29
  66. hpcflow/sdk/submission/shells/bash.py +72 -55
  67. hpcflow/sdk/submission/shells/os_version.py +31 -30
  68. hpcflow/sdk/submission/shells/powershell.py +37 -29
  69. hpcflow/sdk/submission/submission.py +203 -257
  70. hpcflow/sdk/submission/types.py +143 -0
  71. hpcflow/sdk/typing.py +163 -12
  72. hpcflow/tests/conftest.py +8 -6
  73. hpcflow/tests/schedulers/slurm/test_slurm_submission.py +5 -2
  74. hpcflow/tests/scripts/test_main_scripts.py +60 -30
  75. hpcflow/tests/shells/wsl/test_wsl_submission.py +6 -4
  76. hpcflow/tests/unit/test_action.py +86 -75
  77. hpcflow/tests/unit/test_action_rule.py +9 -4
  78. hpcflow/tests/unit/test_app.py +13 -6
  79. hpcflow/tests/unit/test_cli.py +1 -1
  80. hpcflow/tests/unit/test_command.py +71 -54
  81. hpcflow/tests/unit/test_config.py +20 -15
  82. hpcflow/tests/unit/test_config_file.py +21 -18
  83. hpcflow/tests/unit/test_element.py +58 -62
  84. hpcflow/tests/unit/test_element_iteration.py +3 -1
  85. hpcflow/tests/unit/test_element_set.py +29 -19
  86. hpcflow/tests/unit/test_group.py +4 -2
  87. hpcflow/tests/unit/test_input_source.py +116 -93
  88. hpcflow/tests/unit/test_input_value.py +29 -24
  89. hpcflow/tests/unit/test_json_like.py +44 -35
  90. hpcflow/tests/unit/test_loop.py +65 -58
  91. hpcflow/tests/unit/test_object_list.py +17 -12
  92. hpcflow/tests/unit/test_parameter.py +16 -7
  93. hpcflow/tests/unit/test_persistence.py +48 -35
  94. hpcflow/tests/unit/test_resources.py +20 -18
  95. hpcflow/tests/unit/test_run.py +8 -3
  96. hpcflow/tests/unit/test_runtime.py +2 -1
  97. hpcflow/tests/unit/test_schema_input.py +23 -15
  98. hpcflow/tests/unit/test_shell.py +3 -2
  99. hpcflow/tests/unit/test_slurm.py +8 -7
  100. hpcflow/tests/unit/test_submission.py +39 -19
  101. hpcflow/tests/unit/test_task.py +352 -247
  102. hpcflow/tests/unit/test_task_schema.py +33 -20
  103. hpcflow/tests/unit/test_utils.py +9 -11
  104. hpcflow/tests/unit/test_value_sequence.py +15 -12
  105. hpcflow/tests/unit/test_workflow.py +114 -83
  106. hpcflow/tests/unit/test_workflow_template.py +0 -1
  107. hpcflow/tests/workflows/test_jobscript.py +2 -1
  108. hpcflow/tests/workflows/test_workflows.py +18 -13
  109. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/METADATA +2 -1
  110. hpcflow_new2-0.2.0a190.dist-info/RECORD +165 -0
  111. hpcflow/sdk/core/parallel.py +0 -21
  112. hpcflow_new2-0.2.0a189.dist-info/RECORD +0 -158
  113. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/LICENSE +0 -0
  114. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/WHEEL +0 -0
  115. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/entry_points.txt +0 -0
hpcflow/sdk/core/task.py CHANGED
@@ -5,19 +5,20 @@ Tasks are components of workflows.
5
5
  from __future__ import annotations
6
6
  from collections import defaultdict
7
7
  import copy
8
- from dataclasses import dataclass
8
+ from dataclasses import dataclass, field
9
+ from itertools import chain
9
10
  from pathlib import Path
10
- from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
11
+ from typing import NamedTuple, cast, overload, TYPE_CHECKING
12
+ from typing_extensions import override
11
13
 
12
- from valida.rules import Rule
13
-
14
-
15
- from hpcflow.sdk import app
16
- from hpcflow.sdk.core.task_schema import TaskSchema
14
+ from hpcflow.sdk.typing import hydrate
15
+ from hpcflow.sdk.core.object_list import AppDataList
17
16
  from hpcflow.sdk.log import TimeIt
18
- from .json_like import ChildObjectSpec, JSONLike
19
- from .element import ElementGroup
20
- from .errors import (
17
+ from hpcflow.sdk.core.app_aware import AppAware
18
+ from hpcflow.sdk.core.json_like import ChildObjectSpec, JSONLike
19
+ from hpcflow.sdk.core.element import ElementGroup
20
+ from hpcflow.sdk.core.enums import InputSourceType, TaskSourceType
21
+ from hpcflow.sdk.core.errors import (
21
22
  ContainerKeyError,
22
23
  ExtraInputs,
23
24
  InapplicableInputSourceElementIters,
@@ -37,8 +38,8 @@ from .errors import (
37
38
  UnrequiredInputSources,
38
39
  UnsetParameterDataError,
39
40
  )
40
- from .parameters import InputSourceType, ParameterValue, TaskSourceType
41
- from .utils import (
41
+ from hpcflow.sdk.core.parameters import ParameterValue
42
+ from hpcflow.sdk.core.utils import (
42
43
  get_duplicate_items,
43
44
  get_in_container,
44
45
  get_item_repeat_index,
@@ -48,8 +49,43 @@ from .utils import (
48
49
  split_param_label,
49
50
  )
50
51
 
52
+ if TYPE_CHECKING:
53
+ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
54
+ from typing import Any, ClassVar, Literal, TypeVar
55
+ from typing_extensions import Self, TypeAlias, TypeIs
56
+ from ..typing import DataIndex, ParamSource
57
+ from .actions import Action
58
+ from .command_files import InputFile
59
+ from .element import (
60
+ Element,
61
+ ElementIteration,
62
+ ElementFilter,
63
+ ElementParameter,
64
+ _ElementPrefixedParameter as EPP,
65
+ )
66
+ from .parameters import (
67
+ InputValue,
68
+ InputSource,
69
+ ValueSequence,
70
+ SchemaInput,
71
+ SchemaOutput,
72
+ ParameterPath,
73
+ )
74
+ from .rule import Rule
75
+ from .task_schema import TaskObjective, TaskSchema
76
+ from .types import (
77
+ MultiplicityDescriptor,
78
+ RelevantData,
79
+ RelevantPath,
80
+ Resources,
81
+ RepeatsDescriptor,
82
+ )
83
+ from .workflow import Workflow, WorkflowTemplate
84
+
85
+ StrSeq = TypeVar("StrSeq", bound=Sequence[str])
86
+
51
87
 
52
- INPUT_SOURCE_TYPES = ["local", "default", "task", "import"]
88
+ INPUT_SOURCE_TYPES = ("local", "default", "task", "import")
53
89
 
54
90
 
55
91
  @dataclass
@@ -80,7 +116,7 @@ class InputStatus:
80
116
  is_provided: bool
81
117
 
82
118
  @property
83
- def is_extra(self):
119
+ def is_extra(self) -> bool:
84
120
  """True if the input is provided but not required."""
85
121
  return self.is_provided and not self.is_required
86
122
 
@@ -124,7 +160,7 @@ class ElementSet(JSONLike):
124
160
  but not on subsequent re-initialisation from a persistent workflow.
125
161
  """
126
162
 
127
- _child_objects = (
163
+ _child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
128
164
  ChildObjectSpec(
129
165
  name="inputs",
130
166
  class_name="InputValue",
@@ -168,30 +204,30 @@ class ElementSet(JSONLike):
168
204
 
169
205
  def __init__(
170
206
  self,
171
- inputs: Optional[List[app.InputValue]] = None,
172
- input_files: Optional[List[app.InputFile]] = None,
173
- sequences: Optional[List[app.ValueSequence]] = None,
174
- resources: Optional[Dict[str, Dict]] = None,
175
- repeats: Optional[List[Dict]] = None,
176
- groups: Optional[List[app.ElementGroup]] = None,
177
- input_sources: Optional[Dict[str, app.InputSource]] = None,
178
- nesting_order: Optional[List] = None,
179
- env_preset: Optional[str] = None,
180
- environments: Optional[Dict[str, Dict[str, Any]]] = None,
181
- sourceable_elem_iters: Optional[List[int]] = None,
182
- allow_non_coincident_task_sources: Optional[bool] = False,
183
- merge_envs: Optional[bool] = True,
207
+ inputs: list[InputValue] | dict[str, Any] | None = None,
208
+ input_files: list[InputFile] | None = None,
209
+ sequences: list[ValueSequence] | None = None,
210
+ resources: Resources = None,
211
+ repeats: list[RepeatsDescriptor] | int | None = None,
212
+ groups: list[ElementGroup] | None = None,
213
+ input_sources: dict[str, list[InputSource]] | None = None,
214
+ nesting_order: dict[str, float] | None = None,
215
+ env_preset: str | None = None,
216
+ environments: Mapping[str, Mapping[str, Any]] | None = None,
217
+ sourceable_elem_iters: list[int] | None = None,
218
+ allow_non_coincident_task_sources: bool = False,
219
+ merge_envs: bool = True,
184
220
  ):
185
221
  #: Inputs to the set of elements.
186
- self.inputs = inputs or []
222
+ self.inputs = self.__decode_inputs(inputs or [])
187
223
  #: Input files to the set of elements.
188
224
  self.input_files = input_files or []
189
225
  #: Description of how to repeat the set of elements.
190
- self.repeats = repeats or []
226
+ self.repeats = self.__decode_repeats(repeats or [])
191
227
  #: Groupings in the set of elements.
192
228
  self.groups = groups or []
193
229
  #: Resources to use for the set of elements.
194
- self.resources = self.app.ResourceList.normalise(resources)
230
+ self.resources = self._app.ResourceList.normalise(resources)
195
231
  #: Input value sequences to parameterise over.
196
232
  self.sequences = sequences or []
197
233
  #: Input source descriptors.
@@ -211,27 +247,31 @@ class ElementSet(JSONLike):
211
247
  #: Whether to merge ``environments`` into ``resources`` using the "any" scope
212
248
  #: on first initialisation.
213
249
  self.merge_envs = merge_envs
250
+ self.original_input_sources: dict[str, list[InputSource]] | None = None
251
+ self.original_nesting_order: dict[str, float] | None = None
214
252
 
215
253
  self._validate()
216
254
  self._set_parent_refs()
217
255
 
218
- self._task_template = None # assigned by parent Task
219
- self._defined_input_types = None # assigned on _task_template assignment
220
- self._element_local_idx_range = None # assigned by WorkflowTask._add_element_set
256
+ # assigned by parent Task
257
+ self._task_template: Task | None = None
258
+ # assigned on _task_template assignment
259
+ self._defined_input_types: set[str] | None = None
260
+ # assigned by WorkflowTask._add_element_set
261
+ self._element_local_idx_range: list[int] | None = None
221
262
 
222
263
  # merge `environments` into element set resources (this mutates `resources`, and
223
264
  # should only happen on creation of the element set, not re-initialisation from a
224
265
  # persistent workflow):
225
266
  if self.environments and self.merge_envs:
226
- envs_res = self.app.ResourceList(
227
- [self.app.ResourceSpec(scope="any", environments=self.environments)]
267
+ self.resources.merge_one(
268
+ self._app.ResourceSpec(scope="any", environments=self.environments)
228
269
  )
229
- self.resources.merge_other(envs_res)
230
270
  self.merge_envs = False
231
271
 
232
272
  # note: `env_preset` is merged into resources by the Task init.
233
273
 
234
- def __deepcopy__(self, memo):
274
+ def __deepcopy__(self, memo: dict[int, Any] | None) -> Self:
235
275
  dct = self.to_dict()
236
276
  orig_inp = dct.pop("original_input_sources", None)
237
277
  orig_nest = dct.pop("original_nesting_order", None)
@@ -240,19 +280,17 @@ class ElementSet(JSONLike):
240
280
  obj._task_template = self._task_template
241
281
  obj._defined_input_types = self._defined_input_types
242
282
  obj.original_input_sources = copy.deepcopy(orig_inp)
243
- obj.original_nesting_order = copy.deepcopy(orig_nest)
244
- obj._element_local_idx_range = copy.deepcopy(elem_local_idx_range)
283
+ obj.original_nesting_order = copy.copy(orig_nest)
284
+ obj._element_local_idx_range = copy.copy(elem_local_idx_range)
245
285
  return obj
246
286
 
247
- def __eq__(self, other):
287
+ def __eq__(self, other: Any) -> bool:
248
288
  if not isinstance(other, self.__class__):
249
289
  return False
250
- if self.to_dict() == other.to_dict():
251
- return True
252
- return False
290
+ return self.to_dict() == other.to_dict()
253
291
 
254
292
  @classmethod
255
- def _json_like_constructor(cls, json_like):
293
+ def _json_like_constructor(cls, json_like) -> Self:
256
294
  """Invoked by `JSONLike.from_json_like` instead of `__init__`."""
257
295
  orig_inp = json_like.pop("original_input_sources", None)
258
296
  orig_nest = json_like.pop("original_nesting_order", None)
@@ -263,7 +301,7 @@ class ElementSet(JSONLike):
263
301
  obj._element_local_idx_range = elem_local_idx_range
264
302
  return obj
265
303
 
266
- def prepare_persistent_copy(self):
304
+ def prepare_persistent_copy(self) -> Self:
267
305
  """Return a copy of self, which will then be made persistent, and save copies of
268
306
  attributes that may be changed during integration with the workflow."""
269
307
  obj = copy.deepcopy(self)
@@ -271,86 +309,98 @@ class ElementSet(JSONLike):
271
309
  obj.original_input_sources = self.input_sources
272
310
  return obj
273
311
 
274
- def to_dict(self):
275
- dct = super().to_dict()
312
+ @override
313
+ def _postprocess_to_dict(self, d: dict[str, Any]) -> dict[str, Any]:
314
+ dct = super()._postprocess_to_dict(d)
276
315
  del dct["_defined_input_types"]
277
316
  del dct["_task_template"]
278
317
  return dct
279
318
 
280
319
  @property
281
- def task_template(self):
320
+ def task_template(self) -> Task:
282
321
  """
283
322
  The abstract task this was derived from.
284
323
  """
324
+ assert self._task_template is not None
285
325
  return self._task_template
286
326
 
287
327
  @task_template.setter
288
- def task_template(self, value):
328
+ def task_template(self, value: Task) -> None:
289
329
  self._task_template = value
290
- self._validate_against_template()
330
+ self.__validate_against_template()
291
331
 
292
332
  @property
293
- def input_types(self):
333
+ def input_types(self) -> list[str]:
294
334
  """
295
335
  The input types of the inputs to this element set.
296
336
  """
297
- return [i.labelled_type for i in self.inputs]
337
+ return [in_.labelled_type for in_ in self.inputs]
298
338
 
299
339
  @property
300
- def element_local_idx_range(self):
340
+ def element_local_idx_range(self) -> tuple[int, ...]:
301
341
  """Indices of elements belonging to this element set."""
302
- return tuple(self._element_local_idx_range)
342
+ return tuple(self._element_local_idx_range or ())
303
343
 
304
- def _validate(self):
305
- # support inputs passed as a dict:
306
- _inputs = []
307
- try:
308
- for k, v in self.inputs.items():
344
+ @classmethod
345
+ def __decode_inputs(
346
+ cls, inputs: list[InputValue] | dict[str, Any]
347
+ ) -> list[InputValue]:
348
+ """support inputs passed as a dict"""
349
+ if isinstance(inputs, dict):
350
+ _inputs: list[InputValue] = []
351
+ for k, v in inputs.items():
309
352
  param, label = split_param_label(k)
310
- _inputs.append(self.app.InputValue(parameter=param, label=label, value=v))
311
- except AttributeError:
312
- pass
353
+ assert param is not None
354
+ _inputs.append(cls._app.InputValue(parameter=param, label=label, value=v))
355
+ return _inputs
313
356
  else:
314
- self.inputs = _inputs
357
+ return inputs
315
358
 
359
+ @classmethod
360
+ def __decode_repeats(
361
+ cls, repeats: list[RepeatsDescriptor] | int
362
+ ) -> list[RepeatsDescriptor]:
316
363
  # support repeats as an int:
317
- if isinstance(self.repeats, int):
318
- self.repeats = [
364
+ if isinstance(repeats, int):
365
+ return [
319
366
  {
320
367
  "name": "",
321
- "number": self.repeats,
322
- "nesting_order": 0,
368
+ "number": repeats,
369
+ "nesting_order": 0.0,
323
370
  }
324
371
  ]
372
+ else:
373
+ return repeats
374
+
375
+ _ALLOWED_NESTING_PATHS: ClassVar[frozenset[str]] = frozenset(
376
+ {"inputs", "resources", "repeats"}
377
+ )
325
378
 
379
+ def _validate(self) -> None:
326
380
  # check `nesting_order` paths:
327
- allowed_nesting_paths = ("inputs", "resources", "repeats")
328
381
  for k in self.nesting_order:
329
- if k.split(".")[0] not in allowed_nesting_paths:
330
- raise MalformedNestingOrderPath(
331
- f"Element set: nesting order path {k!r} not understood. Each key in "
332
- f"`nesting_order` must be start with one of "
333
- f"{allowed_nesting_paths!r}."
334
- )
382
+ if k.split(".")[0] not in self._ALLOWED_NESTING_PATHS:
383
+ raise MalformedNestingOrderPath(k, self._ALLOWED_NESTING_PATHS)
335
384
 
336
- inp_paths = [i.normalised_inputs_path for i in self.inputs]
337
- dup_inp_paths = get_duplicate_items(inp_paths)
338
- if dup_inp_paths:
385
+ inp_paths = [in_.normalised_inputs_path for in_ in self.inputs]
386
+ if dup_paths := get_duplicate_items(inp_paths):
339
387
  raise TaskTemplateMultipleInputValues(
340
388
  f"The following inputs parameters are associated with multiple input value "
341
- f"definitions: {dup_inp_paths!r}."
389
+ f"definitions: {dup_paths!r}."
342
390
  )
343
391
 
344
- inp_seq_paths = [i.normalised_inputs_path for i in self.sequences if i.input_type]
345
- dup_inp_seq_paths = get_duplicate_items(inp_seq_paths)
346
- if dup_inp_seq_paths:
392
+ inp_seq_paths = [
393
+ cast("str", seq.normalised_inputs_path)
394
+ for seq in self.sequences
395
+ if seq.input_type
396
+ ]
397
+ if dup_paths := get_duplicate_items(inp_seq_paths):
347
398
  raise TaskTemplateMultipleInputValues(
348
399
  f"The following input parameters are associated with multiple sequence "
349
- f"value definitions: {dup_inp_seq_paths!r}."
400
+ f"value definitions: {dup_paths!r}."
350
401
  )
351
402
 
352
- inp_and_seq = set(inp_paths).intersection(inp_seq_paths)
353
- if inp_and_seq:
403
+ if inp_and_seq := set(inp_paths).intersection(inp_seq_paths):
354
404
  raise TaskTemplateMultipleInputValues(
355
405
  f"The following input parameters are specified in both the `inputs` and "
356
406
  f"`sequences` lists: {list(inp_and_seq)!r}, but must be specified in at "
@@ -368,64 +418,49 @@ class ElementSet(JSONLike):
368
418
  if self.env_preset and self.environments:
369
419
  raise ValueError("Specify at most one of `env_preset` and `environments`.")
370
420
 
371
- def _validate_against_template(self):
372
- unexpected_types = (
373
- set(self.input_types) - self.task_template.all_schema_input_types
374
- )
375
- if unexpected_types:
376
- raise TaskTemplateUnexpectedInput(
377
- f"The following input parameters are unexpected: {list(unexpected_types)!r}"
378
- )
421
+ def __validate_against_template(self) -> None:
422
+ expected_types = self.task_template.all_schema_input_types
423
+ if unexpected_types := set(self.input_types) - expected_types:
424
+ raise TaskTemplateUnexpectedInput(unexpected_types)
379
425
 
380
- seq_inp_types = []
381
- for seq_i in self.sequences:
382
- inp_type = seq_i.labelled_type
383
- if inp_type:
384
- bad_inp = {inp_type} - self.task_template.all_schema_input_types
385
- allowed_str = ", ".join(
386
- f'"{i}"' for i in self.task_template.all_schema_input_types
387
- )
388
- if bad_inp:
426
+ defined_inp_types = set(self.input_types)
427
+ for seq in self.sequences:
428
+ if inp_type := seq.labelled_type:
429
+ if inp_type not in expected_types:
389
430
  raise TaskTemplateUnexpectedSequenceInput(
390
- f"The input type {inp_type!r} specified in the following sequence"
391
- f" path is unexpected: {seq_i.path!r}. Available input types are: "
392
- f"{allowed_str}."
431
+ inp_type, expected_types, seq
393
432
  )
394
- seq_inp_types.append(inp_type)
395
- if seq_i.path not in self.nesting_order:
396
- self.nesting_order.update({seq_i.path: seq_i.nesting_order})
433
+ defined_inp_types.add(inp_type)
434
+ if seq.path not in self.nesting_order and seq.nesting_order is not None:
435
+ self.nesting_order[seq.path] = seq.nesting_order
397
436
 
398
437
  for rep_spec in self.repeats:
399
- reps_path_i = f'repeats.{rep_spec["name"]}'
400
- if reps_path_i not in self.nesting_order:
438
+ if (reps_path_i := f'repeats.{rep_spec["name"]}') not in self.nesting_order:
401
439
  self.nesting_order[reps_path_i] = rep_spec["nesting_order"]
402
440
 
403
441
  for k, v in self.nesting_order.items():
404
442
  if v < 0:
405
- raise TaskTemplateInvalidNesting(
406
- f"`nesting_order` must be >=0 for all keys, but for key {k!r}, value "
407
- f"of {v!r} was specified."
408
- )
443
+ raise TaskTemplateInvalidNesting(k, v)
409
444
 
410
- self._defined_input_types = set(self.input_types + seq_inp_types)
445
+ self._defined_input_types = defined_inp_types
411
446
 
412
447
  @classmethod
413
448
  def ensure_element_sets(
414
449
  cls,
415
- inputs=None,
416
- input_files=None,
417
- sequences=None,
418
- resources=None,
419
- repeats=None,
420
- groups=None,
421
- input_sources=None,
422
- nesting_order=None,
423
- env_preset=None,
424
- environments=None,
425
- allow_non_coincident_task_sources=None,
426
- element_sets=None,
427
- sourceable_elem_iters=None,
428
- ):
450
+ inputs: list[InputValue] | dict[str, Any] | None = None,
451
+ input_files: list[InputFile] | None = None,
452
+ sequences: list[ValueSequence] | None = None,
453
+ resources: Resources = None,
454
+ repeats: list[RepeatsDescriptor] | int | None = None,
455
+ groups: list[ElementGroup] | None = None,
456
+ input_sources: dict[str, list[InputSource]] | None = None,
457
+ nesting_order: dict[str, float] | None = None,
458
+ env_preset: str | None = None,
459
+ environments: Mapping[str, Mapping[str, Any]] | None = None,
460
+ allow_non_coincident_task_sources: bool = False,
461
+ element_sets: list[Self] | None = None,
462
+ sourceable_elem_iters: list[int] | None = None,
463
+ ) -> list[Self]:
429
464
  """
430
465
  Make an instance after validating some argument combinations.
431
466
  """
@@ -441,21 +476,19 @@ class ElementSet(JSONLike):
441
476
  env_preset,
442
477
  environments,
443
478
  )
444
- args_not_none = [i is not None for i in args]
445
479
 
446
- if any(args_not_none):
480
+ if any(arg is not None for arg in args):
447
481
  if element_sets is not None:
448
482
  raise ValueError(
449
483
  "If providing an `element_set`, no other arguments are allowed."
450
484
  )
451
- else:
452
- element_sets = [
453
- cls(
454
- *args,
455
- sourceable_elem_iters=sourceable_elem_iters,
456
- allow_non_coincident_task_sources=allow_non_coincident_task_sources,
457
- )
458
- ]
485
+ element_sets = [
486
+ cls(
487
+ *args,
488
+ sourceable_elem_iters=sourceable_elem_iters,
489
+ allow_non_coincident_task_sources=allow_non_coincident_task_sources,
490
+ )
491
+ ]
459
492
  else:
460
493
  if element_sets is None:
461
494
  element_sets = [
@@ -469,127 +502,140 @@ class ElementSet(JSONLike):
469
502
  return element_sets
470
503
 
471
504
  @property
472
- def defined_input_types(self):
505
+ def defined_input_types(self) -> set[str]:
473
506
  """
474
507
  The input types to this element set.
475
508
  """
509
+ assert self._defined_input_types is not None
476
510
  return self._defined_input_types
477
511
 
478
512
  @property
479
- def undefined_input_types(self):
513
+ def undefined_input_types(self) -> set[str]:
480
514
  """
481
515
  The input types to the abstract task that aren't related to this element set.
482
516
  """
483
517
  return self.task_template.all_schema_input_types - self.defined_input_types
484
518
 
485
- def get_sequence_from_path(self, sequence_path):
519
+ def get_sequence_from_path(self, sequence_path: str) -> ValueSequence | None:
486
520
  """
487
521
  Get the value sequence for the given path, if it exists.
488
522
  """
489
- for seq in self.sequences:
490
- if seq.path == sequence_path:
491
- return seq
523
+ return next((seq for seq in self.sequences if seq.path == sequence_path), None)
492
524
 
493
- def get_defined_parameter_types(self):
525
+ def get_defined_parameter_types(self) -> list[str]:
494
526
  """
495
527
  Get the parameter types of this element set.
496
528
  """
497
- out = []
529
+ out: list[str] = []
498
530
  for inp in self.inputs:
499
531
  if not inp.is_sub_value:
500
532
  out.append(inp.normalised_inputs_path)
501
533
  for seq in self.sequences:
502
534
  if seq.parameter and not seq.is_sub_value: # ignore resource sequences
535
+ assert seq.normalised_inputs_path is not None
503
536
  out.append(seq.normalised_inputs_path)
504
537
  return out
505
538
 
506
- def get_defined_sub_parameter_types(self):
539
+ def get_defined_sub_parameter_types(self) -> list[str]:
507
540
  """
508
541
  Get the sub-parameter types of this element set.
509
542
  """
510
- out = []
543
+ out: list[str] = []
511
544
  for inp in self.inputs:
512
545
  if inp.is_sub_value:
513
546
  out.append(inp.normalised_inputs_path)
514
547
  for seq in self.sequences:
515
548
  if seq.parameter and seq.is_sub_value: # ignore resource sequences
549
+ assert seq.normalised_inputs_path is not None
516
550
  out.append(seq.normalised_inputs_path)
517
551
  return out
518
552
 
519
- def get_locally_defined_inputs(self):
553
+ def get_locally_defined_inputs(self) -> list[str]:
520
554
  """
521
555
  Get the input types that this element set defines.
522
556
  """
523
557
  return self.get_defined_parameter_types() + self.get_defined_sub_parameter_types()
524
558
 
525
559
  @property
526
- def index(self):
560
+ def index(self) -> int | None:
527
561
  """
528
562
  The index of this element set in its' template task's collection of sets.
529
563
  """
530
- for idx, element_set in enumerate(self.task_template.element_sets):
531
- if element_set is self:
532
- return idx
564
+ return next(
565
+ (
566
+ idx
567
+ for idx, element_set in enumerate(self.task_template.element_sets)
568
+ if element_set is self
569
+ ),
570
+ None,
571
+ )
533
572
 
534
573
  @property
535
- def task(self):
574
+ def task(self) -> WorkflowTask:
536
575
  """
537
576
  The concrete task corresponding to this element set.
538
577
  """
539
- return self.task_template.workflow_template.workflow.tasks[
540
- self.task_template.index
541
- ]
578
+ t = self.task_template.workflow_template
579
+ assert t
580
+ w = t.workflow
581
+ assert w
582
+ i = self.task_template.index
583
+ assert i is not None
584
+ return w.tasks[i]
542
585
 
543
586
  @property
544
- def elements(self):
587
+ def elements(self) -> list[Element]:
545
588
  """
546
589
  The elements in this element set.
547
590
  """
548
591
  return self.task.elements[slice(*self.element_local_idx_range)]
549
592
 
550
593
  @property
551
- def element_iterations(self):
594
+ def element_iterations(self) -> list[ElementIteration]:
552
595
  """
553
596
  The iterations in this element set.
554
597
  """
555
- return [j for i in self.elements for j in i.iterations]
598
+ return list(chain.from_iterable(elem.iterations for elem in self.elements))
556
599
 
557
600
  @property
558
- def elem_iter_IDs(self):
601
+ def elem_iter_IDs(self) -> list[int]:
559
602
  """
560
603
  The IDs of the iterations in this element set.
561
604
  """
562
- return [i.id_ for i in self.element_iterations]
605
+ return [it.id_ for it in self.element_iterations]
606
+
607
+ @overload
608
+ def get_task_dependencies(self, as_objects: Literal[False] = False) -> set[int]:
609
+ ...
563
610
 
564
- def get_task_dependencies(self, as_objects=False):
611
+ @overload
612
+ def get_task_dependencies(self, as_objects: Literal[True]) -> list[WorkflowTask]:
613
+ ...
614
+
615
+ def get_task_dependencies(
616
+ self, as_objects: bool = False
617
+ ) -> list[WorkflowTask] | set[int]:
565
618
  """Get upstream tasks that this element set depends on."""
566
- deps = []
619
+ deps: set[int] = set()
567
620
  for element in self.elements:
568
- for dep_i in element.get_task_dependencies(as_objects=False):
569
- if dep_i not in deps:
570
- deps.append(dep_i)
571
- deps = sorted(deps)
621
+ deps.update(element.get_task_dependencies())
572
622
  if as_objects:
573
- deps = [self.task.workflow.tasks.get(insert_ID=i) for i in deps]
574
-
623
+ return [self.task.workflow.tasks.get(insert_ID=id_) for id_ in sorted(deps)]
575
624
  return deps
576
625
 
577
626
  def is_input_type_provided(self, labelled_path: str) -> bool:
578
627
  """Check if an input is provided locally as an InputValue or a ValueSequence."""
579
-
580
- for inp in self.inputs:
581
- if labelled_path == inp.normalised_inputs_path:
582
- return True
583
-
584
- for seq in self.sequences:
585
- if seq.parameter:
586
- # i.e. not a resource:
587
- if labelled_path == seq.normalised_inputs_path:
588
- return True
589
-
590
- return False
628
+ return any(
629
+ labelled_path == inp.normalised_inputs_path for inp in self.inputs
630
+ ) or any(
631
+ seq.parameter
632
+ # i.e. not a resource:
633
+ and labelled_path == seq.normalised_inputs_path
634
+ for seq in self.sequences
635
+ )
591
636
 
592
637
 
638
+ @hydrate
593
639
  class OutputLabel(JSONLike):
594
640
  """
595
641
  Schema input labels that should be applied to a subset of task outputs.
@@ -604,7 +650,7 @@ class OutputLabel(JSONLike):
604
650
  Optional filtering rule
605
651
  """
606
652
 
607
- _child_objects = (
653
+ _child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
608
654
  ChildObjectSpec(
609
655
  name="where",
610
656
  class_name="ElementFilter",
@@ -615,7 +661,7 @@ class OutputLabel(JSONLike):
615
661
  self,
616
662
  parameter: str,
617
663
  label: str,
618
- where: Optional[List[app.ElementFilter]] = None,
664
+ where: Rule | None = None,
619
665
  ) -> None:
620
666
  #: Name of a parameter.
621
667
  self.parameter = parameter
@@ -625,6 +671,7 @@ class OutputLabel(JSONLike):
625
671
  self.where = where
626
672
 
627
673
 
674
+ @hydrate
628
675
  class Task(JSONLike):
629
676
  """
630
677
  Parametrisation of an isolated task for which a subset of input values are given
@@ -663,7 +710,7 @@ class Task(JSONLike):
663
710
  re-initialisation from a persistent workflow.
664
711
  """
665
712
 
666
- _child_objects = (
713
+ _child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
667
714
  ChildObjectSpec(
668
715
  name="schema",
669
716
  class_name="TaskSchema",
@@ -685,24 +732,28 @@ class Task(JSONLike):
685
732
  ),
686
733
  )
687
734
 
735
+ @classmethod
736
+ def __is_TaskSchema(cls, value) -> TypeIs[TaskSchema]:
737
+ return isinstance(value, cls._app.TaskSchema)
738
+
688
739
  def __init__(
689
740
  self,
690
- schema: Union[app.TaskSchema, str, List[app.TaskSchema], List[str]],
691
- repeats: Optional[List[Dict]] = None,
692
- groups: Optional[List[app.ElementGroup]] = None,
693
- resources: Optional[Dict[str, Dict]] = None,
694
- inputs: Optional[List[app.InputValue]] = None,
695
- input_files: Optional[List[app.InputFile]] = None,
696
- sequences: Optional[List[app.ValueSequence]] = None,
697
- input_sources: Optional[Dict[str, app.InputSource]] = None,
698
- nesting_order: Optional[List] = None,
699
- env_preset: Optional[str] = None,
700
- environments: Optional[Dict[str, Dict[str, Any]]] = None,
701
- allow_non_coincident_task_sources: Optional[bool] = False,
702
- element_sets: Optional[List[app.ElementSet]] = None,
703
- output_labels: Optional[List[app.OutputLabel]] = None,
704
- sourceable_elem_iters: Optional[List[int]] = None,
705
- merge_envs: Optional[bool] = True,
741
+ schema: TaskSchema | str | list[TaskSchema] | list[str],
742
+ repeats: list[RepeatsDescriptor] | int | None = None,
743
+ groups: list[ElementGroup] | None = None,
744
+ resources: Resources = None,
745
+ inputs: list[InputValue] | dict[str, Any] | None = None,
746
+ input_files: list[InputFile] | None = None,
747
+ sequences: list[ValueSequence] | None = None,
748
+ input_sources: dict[str, list[InputSource]] | None = None,
749
+ nesting_order: dict[str, float] | None = None,
750
+ env_preset: str | None = None,
751
+ environments: Mapping[str, Mapping[str, Any]] | None = None,
752
+ allow_non_coincident_task_sources: bool = False,
753
+ element_sets: list[ElementSet] | None = None,
754
+ output_labels: list[OutputLabel] | None = None,
755
+ sourceable_elem_iters: list[int] | None = None,
756
+ merge_envs: bool = True,
706
757
  ):
707
758
  # TODO: allow init via specifying objective and/or method and/or implementation
708
759
  # (lists of) strs e.g.: Task(
@@ -718,25 +769,24 @@ class Task(JSONLike):
718
769
  # 'simulate_VE_loading_taylor_damask'
719
770
  # ])
720
771
 
721
- if not isinstance(schema, list):
722
- schema = [schema]
723
-
724
- _schemas = []
725
- for i in schema:
726
- if isinstance(i, str):
772
+ _schemas: list[TaskSchema] = []
773
+ for item in schema if isinstance(schema, list) else [schema]:
774
+ if isinstance(item, str):
727
775
  try:
728
- i = self.app.TaskSchema.get_by_key(
729
- i
776
+ _schemas.append(
777
+ self._app.TaskSchema.get_by_key(item)
730
778
  ) # TODO: document that we need to use the actual app instance here?
779
+ continue
731
780
  except KeyError:
732
- raise KeyError(f"TaskSchema {i!r} not found.")
733
- elif not isinstance(i, TaskSchema):
734
- raise TypeError(f"Not a TaskSchema object: {i!r}")
735
- _schemas.append(i)
781
+ raise KeyError(f"TaskSchema {item!r} not found.")
782
+ elif self.__is_TaskSchema(item):
783
+ _schemas.append(item)
784
+ else:
785
+ raise TypeError(f"Not a TaskSchema object: {item!r}")
736
786
 
737
787
  self._schemas = _schemas
738
788
 
739
- self._element_sets = self.app.ElementSet.ensure_element_sets(
789
+ self._element_sets = self._app.ElementSet.ensure_element_sets(
740
790
  inputs=inputs,
741
791
  input_files=input_files,
742
792
  sequences=sequences,
@@ -755,26 +805,31 @@ class Task(JSONLike):
755
805
  #: Whether to merge ``environments`` into ``resources`` using the "any" scope
756
806
  #: on first initialisation.
757
807
  self.merge_envs = merge_envs
808
+ self.__groups: AppDataList[ElementGroup] = AppDataList(
809
+ groups or [], access_attribute="name"
810
+ )
758
811
 
759
812
  # appended to when new element sets are added and reset on dump to disk:
760
- self._pending_element_sets = []
813
+ self._pending_element_sets: list[ElementSet] = []
761
814
 
762
815
  self._validate()
763
- self._name = self._get_name()
816
+ self._name = self.__get_name()
764
817
 
765
818
  #: The template workflow that this task is within.
766
- self.workflow_template = None # assigned by parent WorkflowTemplate
767
- self._insert_ID = None
768
- self._dir_name = None
819
+ self.workflow_template: WorkflowTemplate | None = (
820
+ None # assigned by parent WorkflowTemplate
821
+ )
822
+ self._insert_ID: int | None = None
823
+ self._dir_name: str | None = None
769
824
 
770
825
  if self.merge_envs:
771
- self._merge_envs_into_resources()
826
+ self.__merge_envs_into_resources()
772
827
 
773
828
  # TODO: consider adding a new element_set; will need to merge new environments?
774
829
 
775
830
  self._set_parent_refs({"schema": "schemas"})
776
831
 
777
- def _merge_envs_into_resources(self):
832
+ def __merge_envs_into_resources(self) -> None:
778
833
  # for each element set, merge `env_preset` into `resources` (this mutates
779
834
  # `resources`, and should only happen on creation of the task, not
780
835
  # re-initialisation from a persistent workflow):
@@ -782,81 +837,71 @@ class Task(JSONLike):
782
837
 
783
838
  # TODO: required so we don't raise below; can be removed once we consider multiple
784
839
  # schemas:
785
- has_presets = False
786
840
  for es in self.element_sets:
787
- if es.env_preset:
788
- has_presets = True
841
+ if es.env_preset or any(seq.path == "env_preset" for seq in es.sequences):
789
842
  break
790
- for seq in es.sequences:
791
- if seq.path == "env_preset":
792
- has_presets = True
793
- break
794
- if has_presets:
795
- break
796
-
797
- if not has_presets:
843
+ else:
844
+ # No presets
798
845
  return
846
+
799
847
  try:
800
848
  env_presets = self.schema.environment_presets
801
- except ValueError:
849
+ except ValueError as e:
802
850
  # TODO: consider multiple schemas
803
851
  raise NotImplementedError(
804
852
  "Cannot merge environment presets into a task with multiple schemas."
805
- )
853
+ ) from e
806
854
 
807
855
  for es in self.element_sets:
808
856
  if es.env_preset:
809
857
  # retrieve env specifiers from presets defined in the schema:
810
858
  try:
811
- env_specs = env_presets[es.env_preset]
859
+ env_specs = env_presets[es.env_preset] # type: ignore[index]
812
860
  except (TypeError, KeyError):
813
- raise UnknownEnvironmentPresetError(
814
- f"There is no environment preset named {es.env_preset!r} "
815
- f"defined in the task schema {self.schema.name}."
816
- )
817
- envs_res = self.app.ResourceList(
818
- [self.app.ResourceSpec(scope="any", environments=env_specs)]
861
+ raise UnknownEnvironmentPresetError(es.env_preset, self.schema.name)
862
+ es.resources.merge_one(
863
+ self._app.ResourceSpec(scope="any", environments=env_specs)
819
864
  )
820
- es.resources.merge_other(envs_res)
821
865
 
822
866
  for seq in es.sequences:
823
867
  if seq.path == "env_preset":
824
868
  # change to a resources path:
825
- seq.path = f"resources.any.environments"
869
+ seq.path = "resources.any.environments"
826
870
  _values = []
827
- for i in seq.values:
871
+ for val in seq.values or ():
828
872
  try:
829
- _values.append(env_presets[i])
830
- except (TypeError, KeyError):
873
+ _values.append(env_presets[val]) # type: ignore[index]
874
+ except (TypeError, KeyError) as e:
831
875
  raise UnknownEnvironmentPresetError(
832
- f"There is no environment preset named {i!r} defined "
833
- f"in the task schema {self.schema.name}."
834
- )
876
+ val, self.schema.name
877
+ ) from e
835
878
  seq._values = _values
836
879
 
837
- def _reset_pending_element_sets(self):
880
+ def _reset_pending_element_sets(self) -> None:
838
881
  self._pending_element_sets = []
839
882
 
840
- def _accept_pending_element_sets(self):
883
+ def _accept_pending_element_sets(self) -> None:
841
884
  self._element_sets += self._pending_element_sets
842
885
  self._reset_pending_element_sets()
843
886
 
844
- def __eq__(self, other):
887
+ def __eq__(self, other: Any) -> bool:
845
888
  if not isinstance(other, self.__class__):
846
889
  return False
847
- if self.to_dict() == other.to_dict():
848
- return True
849
- return False
890
+ return self.to_dict() == other.to_dict()
850
891
 
851
- def _add_element_set(self, element_set: app.ElementSet):
892
+ def _add_element_set(self, element_set: ElementSet):
852
893
  """Invoked by WorkflowTask._add_element_set."""
853
894
  self._pending_element_sets.append(element_set)
854
- self.workflow_template.workflow._store.add_element_set(
855
- self.insert_ID, element_set.to_json_like()[0]
895
+ wt = self.workflow_template
896
+ assert wt
897
+ w = wt.workflow
898
+ assert w
899
+ w._store.add_element_set(
900
+ self.insert_ID, cast("Mapping", element_set.to_json_like()[0])
856
901
  )
857
902
 
858
903
  @classmethod
859
- def _json_like_constructor(cls, json_like):
904
+ def _json_like_constructor(cls, json_like: dict) -> Self:
860
905
  """Invoked by `JSONLike.from_json_like` instead of `__init__`."""
861
906
  insert_ID = json_like.pop("insert_ID", None)
862
907
  dir_name = json_like.pop("dir_name", None)
@@ -865,10 +910,10 @@ class Task(JSONLike):
865
910
  obj._dir_name = dir_name
866
911
  return obj
867
912
 
868
- def __repr__(self):
913
+ def __repr__(self) -> str:
869
914
  return f"{self.__class__.__name__}(name={self.name!r})"
870
915
 
871
- def __deepcopy__(self, memo):
916
+ def __deepcopy__(self, memo: dict[int, Any] | None) -> Self:
872
917
  kwargs = self.to_dict()
873
918
  _insert_ID = kwargs.pop("insert_ID")
874
919
  _dir_name = kwargs.pop("dir_name")
@@ -881,28 +926,34 @@ class Task(JSONLike):
881
926
  obj._pending_element_sets = self._pending_element_sets
882
927
  return obj
883
928
 
884
- def to_persistent(self, workflow, insert_ID):
929
+ def to_persistent(
930
+ self, workflow: Workflow, insert_ID: int
931
+ ) -> tuple[Self, list[int | list[int]]]:
885
932
  """Return a copy where any schema input defaults are saved to a persistent
886
933
  workflow. Element set data is not made persistent."""
887
934
 
888
935
  obj = copy.deepcopy(self)
889
- new_refs = []
890
- source = {"type": "default_input", "task_insert_ID": insert_ID}
891
- for schema in obj.schemas:
892
- new_refs.extend(schema.make_persistent(workflow, source))
936
+ source: ParamSource = {"type": "default_input", "task_insert_ID": insert_ID}
937
+ new_refs = list(
938
+ chain.from_iterable(
939
+ schema.make_persistent(workflow, source) for schema in obj.schemas
940
+ )
941
+ )
893
942
 
894
943
  return obj, new_refs
895
944
 
896
- def to_dict(self):
897
- out = super().to_dict()
945
+ @override
946
+ def _postprocess_to_dict(self, d: dict[str, Any]) -> dict[str, Any]:
947
+ out = super()._postprocess_to_dict(d)
898
948
  out["_schema"] = out.pop("_schemas")
899
- return {
949
+ res = {
900
950
  k.lstrip("_"): v
901
951
  for k, v in out.items()
902
- if k not in ("_name", "_pending_element_sets")
952
+ if k not in ("_name", "_pending_element_sets", "_Task__groups")
903
953
  }
954
+ return res
904
955
 
905
- def set_sequence_parameters(self, element_set):
956
+ def set_sequence_parameters(self, element_set: ElementSet) -> None:
906
957
  """
907
958
  Set up parameters parsed by value sequences.
908
959
  """
@@ -914,18 +965,14 @@ class Task(JSONLike):
914
965
  if inp_j.typ == seq.input_type:
915
966
  seq._parameter = inp_j.parameter
916
967
 
917
- def _validate(self):
968
+ def _validate(self) -> None:
918
969
  # TODO: check a nesting order specified for each sequence?
919
970
 
920
- names = set(i.objective.name for i in self.schemas)
921
- if len(names) > 1:
922
- raise TaskTemplateMultipleSchemaObjectives(
923
- f"All task schemas used within a task must have the same "
924
- f"objective, but found multiple objectives: {list(names)!r}"
925
- )
971
+ if len(names := set(schema.objective.name for schema in self.schemas)) > 1:
972
+ raise TaskTemplateMultipleSchemaObjectives(names)
926
973
 
927
- def _get_name(self):
928
- out = f"{self.objective.name}"
974
+ def __get_name(self) -> str:
975
+ out = self.objective.name
929
976
  for idx, schema_i in enumerate(self.schemas, start=1):
930
977
  need_and = idx < len(self.schemas) and (
931
978
  self.schemas[idx].method or self.schemas[idx].implementation
@@ -938,13 +985,12 @@ class Task(JSONLike):
938
985
  return out
939
986
 
940
987
  @staticmethod
941
- def get_task_unique_names(tasks: List[app.Task]):
988
+ def get_task_unique_names(tasks: list[Task]) -> Sequence[str]:
942
989
  """Get the unique name of each in a list of tasks.
943
990
 
944
991
  Returns
945
992
  -------
946
993
  list of str
947
-
948
994
  """
949
995
 
950
996
  task_name_rep_idx = get_item_repeat_index(
@@ -953,72 +999,75 @@ class Task(JSONLike):
953
999
  distinguish_singular=True,
954
1000
  )
955
1001
 
956
- names = []
957
- for idx, task in enumerate(tasks):
958
- add_rep = f"_{task_name_rep_idx[idx]}" if task_name_rep_idx[idx] > 0 else ""
959
- names.append(f"{task.name}{add_rep}")
960
-
961
- return names
1002
+ return [
1003
+ f"{task.name}_{task_name_rep_idx[idx]}"
1004
+ if task_name_rep_idx[idx] > 0
1005
+ else task.name
1006
+ for idx, task in enumerate(tasks)
1007
+ ]
962
1008
 
963
1009
  @TimeIt.decorator
964
- def _prepare_persistent_outputs(self, workflow, local_element_idx_range):
1010
+ def _prepare_persistent_outputs(
1011
+ self, workflow: Workflow, local_element_idx_range: Sequence[int]
1012
+ ) -> Mapping[str, Sequence[int]]:
965
1013
  # TODO: check that schema is present when adding task? (should this be here?)
966
1014
 
967
1015
  # allocate schema-level output parameter; precise EAR index will not be known
968
1016
  # until we initialise EARs:
969
- output_data_indices = {}
1017
+ output_data_indices: dict[str, list[int]] = {}
970
1018
  for schema in self.schemas:
971
1019
  for output in schema.outputs:
972
1020
  # TODO: consider multiple schemas in action index?
973
1021
 
974
1022
  path = f"outputs.{output.typ}"
975
- output_data_indices[path] = []
976
- for idx in range(*local_element_idx_range):
1023
+ output_data_indices[path] = [
977
1024
  # iteration_idx, action_idx, and EAR_idx are not known until
978
1025
  # `initialise_EARs`:
979
- param_src = {
980
- "type": "EAR_output",
981
- # "task_insert_ID": self.insert_ID,
982
- # "element_idx": idx,
983
- # "run_idx": 0,
984
- }
985
- data_ref = workflow._add_unset_parameter_data(param_src)
986
- output_data_indices[path].append(data_ref)
1026
+ workflow._add_unset_parameter_data(
1027
+ {
1028
+ "type": "EAR_output",
1029
+ # "task_insert_ID": self.insert_ID,
1030
+ # "element_idx": idx,
1031
+ # "run_idx": 0,
1032
+ }
1033
+ )
1034
+ for idx in range(*local_element_idx_range)
1035
+ ]
987
1036
 
988
1037
  return output_data_indices
989
1038
 
990
- def prepare_element_resolution(self, element_set, input_data_indices):
1039
+ def prepare_element_resolution(
1040
+ self, element_set: ElementSet, input_data_indices: Mapping[str, Sequence]
1041
+ ) -> list[MultiplicityDescriptor]:
991
1042
  """
992
1043
  Set up the resolution of details of elements
993
1044
  (especially multiplicities and how iterations are nested)
994
1045
  within an element set.
995
1046
  """
996
- multiplicities = []
997
- for path_i, inp_idx_i in input_data_indices.items():
998
- multiplicities.append(
999
- {
1000
- "multiplicity": len(inp_idx_i),
1001
- "nesting_order": element_set.nesting_order.get(path_i, -1),
1002
- "path": path_i,
1003
- }
1004
- )
1047
+ multiplicities: list[MultiplicityDescriptor] = [
1048
+ {
1049
+ "multiplicity": len(inp_idx_i),
1050
+ "nesting_order": element_set.nesting_order.get(path_i, -1.0),
1051
+ "path": path_i,
1052
+ }
1053
+ for path_i, inp_idx_i in input_data_indices.items()
1054
+ ]
1005
1055
 
1006
1056
  # if all inputs with non-unit multiplicity have the same multiplicity and a
1007
1057
  # default nesting order of -1 or 0 (which will have probably been set by a
1008
1058
  # `ValueSequence` default), set the non-unit multiplicity inputs to a nesting
1009
1059
  # order of zero:
1010
- non_unit_multis = {}
1011
- unit_multis = []
1060
+ non_unit_multis: dict[int, int] = {}
1061
+ unit_multis: list[int] = []
1012
1062
  change = True
1013
- for idx, i in enumerate(multiplicities):
1014
- if i["multiplicity"] == 1:
1063
+ for idx, descriptor in enumerate(multiplicities):
1064
+ if descriptor["multiplicity"] == 1:
1015
1065
  unit_multis.append(idx)
1066
+ elif descriptor["nesting_order"] in (-1.0, 0.0):
1067
+ non_unit_multis[idx] = descriptor["multiplicity"]
1016
1068
  else:
1017
- if i["nesting_order"] in (-1, 0):
1018
- non_unit_multis[idx] = i["multiplicity"]
1019
- else:
1020
- change = False
1021
- break
1069
+ change = False
1070
+ break
1022
1071
 
1023
1072
  if change and len(set(non_unit_multis.values())) == 1:
1024
1073
  for i_idx in non_unit_multis:
@@ -1027,7 +1076,7 @@ class Task(JSONLike):
1027
1076
  return multiplicities
1028
1077
 
1029
1078
  @property
1030
- def index(self):
1079
+ def index(self) -> int | None:
1031
1080
  """
1032
1081
  The index of this task within the workflow's tasks.
1033
1082
  """
@@ -1037,19 +1086,26 @@ class Task(JSONLike):
1037
1086
  return None
1038
1087
 
1039
1088
  @property
1040
- def output_labels(self):
1089
+ def output_labels(self) -> Sequence[OutputLabel]:
1041
1090
  """
1042
1091
  The labels on the outputs of the task.
1043
1092
  """
1044
1093
  return self._output_labels
1045
1094
 
1046
1095
  @property
1047
- def _element_indices(self):
1048
- return self.workflow_template.workflow.tasks[self.index].element_indices
1096
+ def _element_indices(self) -> list[int] | None:
1097
+ if (
1098
+ self.workflow_template
1099
+ and self.workflow_template.workflow
1100
+ and self.index is not None
1101
+ ):
1102
+ task = self.workflow_template.workflow.tasks[self.index]
1103
+ return [element._index for element in task.elements]
1104
+ return None
1049
1105
 
1050
- def _get_task_source_element_iters(
1051
- self, in_or_out: str, src_task, labelled_path, element_set
1052
- ) -> List[int]:
1106
+ def __get_task_source_element_iters(
1107
+ self, in_or_out: str, src_task: Task, labelled_path: str, element_set: ElementSet
1108
+ ) -> list[int]:
1053
1109
  """Get a sorted list of element iteration IDs that provide either inputs or
1054
1110
  outputs from the provided source task."""
1055
1111
 
@@ -1061,12 +1117,16 @@ class Task(JSONLike):
1061
1117
  es_idx = src_task.get_param_provided_element_sets(labelled_path)
1062
1118
  for es_i in src_task.element_sets:
1063
1119
  # add any element set that has task sources for this parameter
1064
- for inp_src_i in es_i.input_sources.get(labelled_path, []):
1065
- if inp_src_i.source_type is InputSourceType.TASK:
1066
- if es_i.index not in es_idx:
1067
- es_idx.append(es_i.index)
1068
- break
1069
-
1120
+ es_i_idx = es_i.index
1121
+ if (
1122
+ es_i_idx is not None
1123
+ and es_i_idx not in es_idx
1124
+ and any(
1125
+ inp_src_i.source_type is InputSourceType.TASK
1126
+ for inp_src_i in es_i.input_sources.get(labelled_path, ())
1127
+ )
1128
+ ):
1129
+ es_idx.append(es_i_idx)
1070
1130
  else:
1071
1131
  # outputs are always available, so consider all source task
1072
1132
  # element sets:
@@ -1074,11 +1134,11 @@ class Task(JSONLike):
1074
1134
 
1075
1135
  if not es_idx:
1076
1136
  raise NoAvailableElementSetsError()
1077
- else:
1078
- src_elem_iters = []
1079
- for es_idx_i in es_idx:
1080
- es_i = src_task.element_sets[es_idx_i]
1081
- src_elem_iters += es_i.elem_iter_IDs # should be sorted already
1137
+
1138
+ src_elem_iters: list[int] = []
1139
+ for es_i_idx in es_idx:
1140
+ es_i = src_task.element_sets[es_i_idx]
1141
+ src_elem_iters.extend(es_i.elem_iter_IDs) # should be sorted already
1082
1142
 
1083
1143
  if element_set.sourceable_elem_iters is not None:
1084
1144
  # can only use a subset of element iterations (this is the
@@ -1087,36 +1147,73 @@ class Task(JSONLike):
1087
1147
  # added upstream elements when adding elements from this
1088
1148
  # element set):
1089
1149
  src_elem_iters = sorted(
1090
- list(set(element_set.sourceable_elem_iters) & set(src_elem_iters))
1150
+ set(element_set.sourceable_elem_iters).intersection(src_elem_iters)
1091
1151
  )
1092
1152
 
1093
1153
  return src_elem_iters
1094
1154
 
1155
+ @staticmethod
1156
+ def __get_common_path(labelled_path: str, inputs_path: str) -> str | None:
1157
+ lab_s = labelled_path.split(".")
1158
+ inp_s = inputs_path.split(".")
1159
+ try:
1160
+ get_relative_path(lab_s, inp_s)
1161
+ return labelled_path
1162
+ except ValueError:
1163
+ pass
1164
+ try:
1165
+ get_relative_path(inp_s, lab_s)
1166
+ return inputs_path
1167
+ except ValueError:
1168
+ # no intersection between paths
1169
+ return None
1170
+
1171
+ @staticmethod
1172
+ def __filtered_iters(wk_task: WorkflowTask, where: Rule) -> list[int]:
1173
+ param_path = cast("str", where.path)
1174
+ param_prefix, param_name, *param_tail = param_path.split(".")
1175
+ src_elem_iters: list[int] = []
1176
+
1177
+ for elem in wk_task.elements:
1178
+ params: EPP = getattr(elem, param_prefix)
1179
+ param: ElementParameter = getattr(params, param_name)
1180
+ param_dat = param.value
1181
+
1182
+ # for remaining paths components try both getattr and
1183
+ # getitem:
1184
+ for path_k in param_tail:
1185
+ try:
1186
+ param_dat = param_dat[path_k]
1187
+ except TypeError:
1188
+ param_dat = getattr(param_dat, path_k)
1189
+
1190
+ if where._valida_check(param_dat):
1191
+ src_elem_iters.append(elem.iterations[0].id_)
1192
+
1193
+ return src_elem_iters
1194
+
1095
1195
  def get_available_task_input_sources(
1096
1196
  self,
1097
- element_set: app.ElementSet,
1098
- source_tasks: Optional[List[app.WorkflowTask]] = None,
1099
- ) -> List[app.InputSource]:
1197
+ element_set: ElementSet,
1198
+ source_tasks: Sequence[WorkflowTask] = (),
1199
+ ) -> Mapping[str, Sequence[InputSource]]:
1100
1200
  """For each input parameter of this task, generate a list of possible input sources
1101
1201
  that derive from inputs or outputs of this and other provided tasks.
1102
1202
 
1103
1203
  Note this only produces a subset of available input sources for each input
1104
1204
  parameter; other available input sources may exist from workflow imports."""
1105
1205
 
1106
- if source_tasks:
1107
- # ensure parameters provided by later tasks are added to the available sources
1108
- # list first, meaning they take precedence when choosing an input source:
1109
- source_tasks = sorted(source_tasks, key=lambda x: x.index, reverse=True)
1110
- else:
1111
- source_tasks = []
1206
+ # ensure parameters provided by later tasks are added to the available sources
1207
+ # list first, meaning they take precedence when choosing an input source:
1208
+ source_tasks = sorted(source_tasks, key=lambda x: x.index, reverse=True)
1112
1209
 
1113
- available = {}
1210
+ available: dict[str, list[InputSource]] = {}
1114
1211
  for inputs_path, inp_status in self.get_input_statuses(element_set).items():
1115
1212
  # local specification takes precedence:
1116
1213
  if inputs_path in element_set.get_locally_defined_inputs():
1117
- if inputs_path not in available:
1118
- available[inputs_path] = []
1119
- available[inputs_path].append(self.app.InputSource.local())
1214
+ available.setdefault(inputs_path, []).append(
1215
+ self._app.InputSource.local()
1216
+ )
1120
1217
 
1121
1218
  # search for task sources:
1122
1219
  for src_wk_task_i in source_tasks:
@@ -1129,77 +1226,47 @@ class Task(JSONLike):
1129
1226
  key=lambda x: x[0],
1130
1227
  reverse=True,
1131
1228
  ):
1132
- src_elem_iters = []
1133
- lab_s = labelled_path.split(".")
1134
- inp_s = inputs_path.split(".")
1135
- try:
1136
- get_relative_path(lab_s, inp_s)
1137
- except ValueError:
1229
+ src_elem_iters: list[int] = []
1230
+ common = self.__get_common_path(labelled_path, inputs_path)
1231
+ if common is not None:
1232
+ avail_src_path = common
1233
+ else:
1234
+ # no intersection between paths
1235
+ inputs_path_label = None
1236
+ out_label = None
1237
+ unlabelled, inputs_path_label = split_param_label(inputs_path)
1238
+ if unlabelled is None:
1239
+ continue
1138
1240
  try:
1139
- get_relative_path(inp_s, lab_s)
1241
+ get_relative_path(
1242
+ unlabelled.split("."), labelled_path.split(".")
1243
+ )
1244
+ avail_src_path = inputs_path
1140
1245
  except ValueError:
1141
- # no intersection between paths
1142
- inputs_path_label = None
1143
- out_label = None
1144
- unlabelled, inputs_path_label = split_param_label(inputs_path)
1145
- try:
1146
- get_relative_path(unlabelled.split("."), lab_s)
1147
- avail_src_path = inputs_path
1148
- except ValueError:
1149
- continue
1150
-
1151
- if inputs_path_label:
1152
- for out_lab_i in src_task_i.output_labels:
1153
- if out_lab_i.label == inputs_path_label:
1154
- out_label = out_lab_i
1246
+ continue
1247
+ if not inputs_path_label:
1248
+ continue
1249
+ for out_lab_i in src_task_i.output_labels:
1250
+ if out_lab_i.label == inputs_path_label:
1251
+ out_label = out_lab_i
1252
+
1253
+ # consider output labels
1254
+ if out_label and in_or_out == "output":
1255
+ # find element iteration IDs that match the output label
1256
+ # filter:
1257
+ if out_label.where:
1258
+ src_elem_iters = self.__filtered_iters(
1259
+ src_wk_task_i, out_label.where
1260
+ )
1155
1261
  else:
1156
- continue
1157
-
1158
- # consider output labels
1159
- if out_label and in_or_out == "output":
1160
- # find element iteration IDs that match the output label
1161
- # filter:
1162
- if out_label.where:
1163
- param_path_split = out_label.where.path.split(".")
1164
-
1165
- for elem_i in src_wk_task_i.elements:
1166
- params = getattr(elem_i, param_path_split[0])
1167
- param_dat = getattr(
1168
- params, param_path_split[1]
1169
- ).value
1170
-
1171
- # for remaining paths components try both getattr and
1172
- # getitem:
1173
- for path_k in param_path_split[2:]:
1174
- try:
1175
- param_dat = param_dat[path_k]
1176
- except TypeError:
1177
- param_dat = getattr(param_dat, path_k)
1178
-
1179
- rule = Rule(
1180
- path=[0],
1181
- condition=out_label.where.condition,
1182
- cast=out_label.where.cast,
1183
- )
1184
- if rule.test([param_dat]).is_valid:
1185
- src_elem_iters.append(
1186
- elem_i.iterations[0].id_
1187
- )
1188
- else:
1189
- src_elem_iters = [
1190
- elem_i.iterations[0].id_
1191
- for elem_i in src_wk_task_i.elements
1192
- ]
1193
-
1194
- else:
1195
- avail_src_path = inputs_path
1196
-
1197
- else:
1198
- avail_src_path = labelled_path
1262
+ src_elem_iters = [
1263
+ elem_i.iterations[0].id_
1264
+ for elem_i in src_wk_task_i.elements
1265
+ ]
1199
1266
 
1200
1267
  if not src_elem_iters:
1201
1268
  try:
1202
- src_elem_iters = self._get_task_source_element_iters(
1269
+ src_elem_iters = self.__get_task_source_element_iters(
1203
1270
  in_or_out=in_or_out,
1204
1271
  src_task=src_task_i,
1205
1272
  labelled_path=labelled_path,
@@ -1207,36 +1274,32 @@ class Task(JSONLike):
1207
1274
  )
1208
1275
  except NoAvailableElementSetsError:
1209
1276
  continue
1277
+ if not src_elem_iters:
1278
+ continue
1210
1279
 
1211
- if not src_elem_iters:
1212
- continue
1213
-
1214
- task_source = self.app.InputSource.task(
1215
- task_ref=src_task_i.insert_ID,
1216
- task_source_type=in_or_out,
1217
- element_iters=src_elem_iters,
1280
+ available.setdefault(avail_src_path, []).append(
1281
+ self._app.InputSource.task(
1282
+ task_ref=src_task_i.insert_ID,
1283
+ task_source_type=in_or_out,
1284
+ element_iters=src_elem_iters,
1285
+ )
1218
1286
  )
1219
1287
 
1220
- if avail_src_path not in available:
1221
- available[avail_src_path] = []
1222
-
1223
- available[avail_src_path].append(task_source)
1224
-
1225
1288
  if inp_status.has_default:
1226
- if inputs_path not in available:
1227
- available[inputs_path] = []
1228
- available[inputs_path].append(self.app.InputSource.default())
1289
+ available.setdefault(inputs_path, []).append(
1290
+ self._app.InputSource.default()
1291
+ )
1229
1292
  return available
1230
1293
 
1231
1294
  @property
1232
- def schemas(self) -> List[app.TaskSchema]:
1295
+ def schemas(self) -> list[TaskSchema]:
1233
1296
  """
1234
1297
  All the task schemas.
1235
1298
  """
1236
1299
  return self._schemas
1237
1300
 
1238
1301
  @property
1239
- def schema(self) -> app.TaskSchema:
1302
+ def schema(self) -> TaskSchema:
1240
1303
  """The single task schema, if only one, else raises."""
1241
1304
  if len(self._schemas) == 1:
1242
1305
  return self._schemas[0]
@@ -1247,79 +1310,86 @@ class Task(JSONLike):
1247
1310
  )
1248
1311
 
1249
1312
  @property
1250
- def element_sets(self):
1313
+ def element_sets(self) -> list[ElementSet]:
1251
1314
  """
1252
1315
  The element sets.
1253
1316
  """
1254
1317
  return self._element_sets + self._pending_element_sets
1255
1318
 
1256
1319
  @property
1257
- def num_element_sets(self):
1320
+ def num_element_sets(self) -> int:
1258
1321
  """
1259
1322
  The number of element sets.
1260
1323
  """
1261
- return len(self.element_sets)
1324
+ return len(self._element_sets) + len(self._pending_element_sets)
1262
1325
 
1263
1326
  @property
1264
- def insert_ID(self):
1327
+ def insert_ID(self) -> int:
1265
1328
  """
1266
1329
  Insertion ID.
1267
1330
  """
1331
+ assert self._insert_ID is not None
1268
1332
  return self._insert_ID
1269
1333
 
1270
1334
  @property
1271
- def dir_name(self):
1335
+ def dir_name(self) -> str:
1272
1336
  """
1273
1337
  Artefact directory name.
1274
1338
  """
1339
+ assert self._dir_name is not None
1275
1340
  return self._dir_name
1276
1341
 
1277
1342
  @property
1278
- def name(self):
1343
+ def name(self) -> str:
1279
1344
  """
1280
1345
  Task name.
1281
1346
  """
1282
1347
  return self._name
1283
1348
 
1284
1349
  @property
1285
- def objective(self):
1350
+ def objective(self) -> TaskObjective:
1286
1351
  """
1287
1352
  The goal of this task.
1288
1353
  """
1289
- return self.schemas[0].objective
1354
+ obj = self.schemas[0].objective
1355
+ return obj
1290
1356
 
1291
1357
  @property
1292
- def all_schema_inputs(self) -> Tuple[app.SchemaInput]:
1358
+ def all_schema_inputs(self) -> tuple[SchemaInput, ...]:
1293
1359
  """
1294
1360
  The inputs to this task's schemas.
1295
1361
  """
1296
1362
  return tuple(inp_j for schema_i in self.schemas for inp_j in schema_i.inputs)
1297
1363
 
1298
1364
  @property
1299
- def all_schema_outputs(self) -> Tuple[app.SchemaOutput]:
1365
+ def all_schema_outputs(self) -> tuple[SchemaOutput, ...]:
1300
1366
  """
1301
1367
  The outputs from this task's schemas.
1302
1368
  """
1303
1369
  return tuple(inp_j for schema_i in self.schemas for inp_j in schema_i.outputs)
1304
1370
 
1305
1371
  @property
1306
- def all_schema_input_types(self):
1307
- """The set of all schema input types (over all specified schemas)."""
1372
+ def all_schema_input_types(self) -> set[str]:
1373
+ """
1374
+ The set of all schema input types (over all specified schemas).
1375
+ """
1308
1376
  return {inp_j for schema_i in self.schemas for inp_j in schema_i.input_types}
1309
1377
 
1310
1378
  @property
1311
- def all_schema_input_normalised_paths(self):
1379
+ def all_schema_input_normalised_paths(self) -> set[str]:
1312
1380
  """
1313
1381
  Normalised paths for all schema input types.
1314
1382
  """
1315
- return {f"inputs.{i}" for i in self.all_schema_input_types}
1383
+ return {f"inputs.{typ}" for typ in self.all_schema_input_types}
1316
1384
 
1317
1385
  @property
1318
- def all_schema_output_types(self):
1319
- """The set of all schema output types (over all specified schemas)."""
1386
+ def all_schema_output_types(self) -> set[str]:
1387
+ """
1388
+ The set of all schema output types (over all specified schemas).
1389
+ """
1320
1390
  return {out_j for schema_i in self.schemas for out_j in schema_i.output_types}
1321
1391
 
1322
- def get_schema_action(self, idx):
1392
+ def get_schema_action(self, idx: int) -> Action: #
1323
1393
  """
1324
1394
  Get the schema action at the given index.
1325
1395
  """
@@ -1331,7 +1401,7 @@ class Task(JSONLike):
1331
1401
  _idx += 1
1332
1402
  raise ValueError(f"No action in task {self.name!r} with index {idx!r}.")
1333
1403
 
1334
- def all_schema_actions(self) -> Iterator[Tuple[int, app.Action]]:
1404
+ def all_schema_actions(self) -> Iterator[tuple[int, Action]]:
1335
1405
  """
1336
1406
  Get all the schema actions and their indices.
1337
1407
  """
@@ -1346,54 +1416,56 @@ class Task(JSONLike):
1346
1416
  """
1347
1417
  The total number of schema actions.
1348
1418
  """
1349
- num = 0
1350
- for schema in self.schemas:
1351
- for _ in schema.actions:
1352
- num += 1
1353
- return num
1419
+ return sum(len(schema.actions) for schema in self.schemas)
1354
1420
 
1355
1421
  @property
1356
- def all_sourced_normalised_paths(self):
1422
+ def all_sourced_normalised_paths(self) -> set[str]:
1357
1423
  """
1358
1424
  All the sourced normalised paths, including of sub-values.
1359
1425
  """
1360
- sourced_input_types = []
1426
+ sourced_input_types: set[str] = set()
1361
1427
  for elem_set in self.element_sets:
1362
- for inp in elem_set.inputs:
1363
- if inp.is_sub_value:
1364
- sourced_input_types.append(inp.normalised_path)
1365
- for seq in elem_set.sequences:
1366
- if seq.is_sub_value:
1367
- sourced_input_types.append(seq.normalised_path)
1368
- return set(sourced_input_types) | self.all_schema_input_normalised_paths
1369
-
1370
- def is_input_type_required(self, typ: str, element_set: app.ElementSet) -> bool:
1428
+ sourced_input_types.update(
1429
+ inp.normalised_path for inp in elem_set.inputs if inp.is_sub_value
1430
+ )
1431
+ sourced_input_types.update(
1432
+ seq.normalised_path for seq in elem_set.sequences if seq.is_sub_value
1433
+ )
1434
+ return sourced_input_types | self.all_schema_input_normalised_paths
1435
+
1436
+ def is_input_type_required(self, typ: str, element_set: ElementSet) -> bool:
1371
1437
  """Check if an given input type must be specified in the parametrisation of this
1372
1438
  element set.
1373
1439
 
1374
1440
  A schema input need not be specified if it is only required to generate an input
1375
1441
  file, and that input file is passed directly."""
1376
1442
 
1377
- provided_files = [i.file for i in element_set.input_files]
1443
+ provided_files = {in_file.file for in_file in element_set.input_files}
1378
1444
  for schema in self.schemas:
1379
1445
  if not schema.actions:
1380
1446
  return True # for empty tasks that are used merely for defining inputs
1381
- for act in schema.actions:
1382
- if act.is_input_type_required(typ, provided_files):
1383
- return True
1447
+ if any(
1448
+ act.is_input_type_required(typ, provided_files) for act in schema.actions
1449
+ ):
1450
+ return True
1384
1451
 
1385
1452
  return False
1386
1453
 
1387
- def get_param_provided_element_sets(self, labelled_path: str) -> List[int]:
1454
+ def get_param_provided_element_sets(self, labelled_path: str) -> list[int]:
1388
1455
  """Get the element set indices of this task for which a specified parameter type
1389
- is locally provided."""
1390
- es_idx = []
1391
- for idx, src_es in enumerate(self.element_sets):
1392
- if src_es.is_input_type_provided(labelled_path):
1393
- es_idx.append(idx)
1394
- return es_idx
1395
-
1396
- def get_input_statuses(self, elem_set: app.ElementSet) -> Dict[str, InputStatus]:
1456
+ is locally provided.
1457
+
1458
+ Note
1459
+ ----
1460
+ Caller may freely modify this result.
1461
+ """
1462
+ return [
1463
+ idx
1464
+ for idx, src_es in enumerate(self.element_sets)
1465
+ if src_es.is_input_type_provided(labelled_path)
1466
+ ]
1467
+
1468
+ def get_input_statuses(self, elem_set: ElementSet) -> Mapping[str, InputStatus]:
1397
1469
  """Get a dict whose keys are normalised input paths (without the "inputs" prefix),
1398
1470
  and whose values are InputStatus objects.
1399
1471
 
@@ -1401,10 +1473,9 @@ class Task(JSONLike):
1401
1473
  ----------
1402
1474
  elem_set
1403
1475
  The element set for which input statuses should be returned.
1404
-
1405
1476
  """
1406
1477
 
1407
- status = {}
1478
+ status: dict[str, InputStatus] = {}
1408
1479
  for schema_input in self.all_schema_inputs:
1409
1480
  for lab_info in schema_input.labelled_info():
1410
1481
  labelled_type = lab_info["labelled_type"]
@@ -1427,31 +1498,36 @@ class Task(JSONLike):
1427
1498
  return status
1428
1499
 
1429
1500
  @property
1430
- def universal_input_types(self):
1501
+ def universal_input_types(self) -> set[str]:
1431
1502
  """Get input types that are associated with all schemas"""
1432
1503
  raise NotImplementedError()
1433
1504
 
1434
1505
  @property
1435
- def non_universal_input_types(self):
1506
+ def non_universal_input_types(self) -> set[str]:
1436
1507
  """Get input types for each schema that are non-universal."""
1437
1508
  raise NotImplementedError()
1438
1509
 
1439
1510
  @property
1440
- def defined_input_types(self):
1511
+ def defined_input_types(self) -> set[str]:
1441
1512
  """
1442
- The input types defined by this task.
1513
+ The input types defined by this task, being the input types defined by any of
1514
+ its element sets.
1443
1515
  """
1444
- return self._defined_input_types
1516
+ dit: set[str] = set()
1517
+ for es in self.element_sets:
1518
+ dit.update(es.defined_input_types)
1519
+ return dit
1520
+ # TODO: Is this right?
1445
1521
 
1446
1522
  @property
1447
- def undefined_input_types(self):
1523
+ def undefined_input_types(self) -> set[str]:
1448
1524
  """
1449
1525
  The schema's input types that this task doesn't define.
1450
1526
  """
1451
1527
  return self.all_schema_input_types - self.defined_input_types
1452
1528
 
1453
1529
  @property
1454
- def undefined_inputs(self):
1530
+ def undefined_inputs(self) -> list[SchemaInput]:
1455
1531
  """
1456
1532
  The task's inputs that are undefined.
1457
1533
  """
@@ -1462,41 +1538,36 @@ class Task(JSONLike):
1462
1538
  if inp_j.typ in self.undefined_input_types
1463
1539
  ]
1464
1540
 
1465
- def provides_parameters(self) -> Tuple[Tuple[str, str]]:
1541
+ def provides_parameters(self) -> tuple[tuple[str, str], ...]:
1466
1542
  """Get all provided parameter labelled types and whether they are inputs and
1467
1543
  outputs, considering all element sets.
1468
1544
 
1469
1545
  """
1470
- out = []
1546
+ out: dict[tuple[str, str], None] = {}
1471
1547
  for schema in self.schemas:
1472
- for in_or_out, labelled_type in schema.provides_parameters:
1473
- out.append((in_or_out, labelled_type))
1548
+ out.update(dict.fromkeys(schema.provides_parameters))
1474
1549
 
1475
1550
  # add sub-parameter input values and sequences:
1476
1551
  for es_i in self.element_sets:
1477
1552
  for inp_j in es_i.inputs:
1478
1553
  if inp_j.is_sub_value:
1479
- val_j = ("input", inp_j.normalised_inputs_path)
1480
- if val_j not in out:
1481
- out.append(val_j)
1554
+ out["input", inp_j.normalised_inputs_path] = None
1482
1555
  for seq_j in es_i.sequences:
1483
- if seq_j.is_sub_value:
1484
- val_j = ("input", seq_j.normalised_inputs_path)
1485
- if val_j not in out:
1486
- out.append(val_j)
1556
+ if seq_j.is_sub_value and (path := seq_j.normalised_inputs_path):
1557
+ out["input", path] = None
1487
1558
 
1488
1559
  return tuple(out)
1489
1560
 
1490
1561
  def add_group(
1491
- self, name: str, where: app.ElementFilter, group_by_distinct: app.ParameterPath
1562
+ self, name: str, where: ElementFilter, group_by_distinct: ParameterPath
1492
1563
  ):
1493
1564
  """
1494
1565
  Add an element group to this task.
1495
1566
  """
1496
1567
  group = ElementGroup(name=name, where=where, group_by_distinct=group_by_distinct)
1497
- self.groups.add_object(group)
1568
+ self.__groups.add_object(group)
1498
1569
 
1499
- def _get_single_label_lookup(self, prefix="") -> Dict[str, str]:
1570
+ def _get_single_label_lookup(self, prefix: str = "") -> Mapping[str, str]:
1500
1571
  """Get a mapping between schema input types that have a single label (i.e.
1501
1572
  labelled but with `multiple=False`) and the non-labelled type string.
1502
1573
 
@@ -1509,13 +1580,18 @@ class Task(JSONLike):
1509
1580
  `{"inputs.p1[one]": "inputs.p1"}`.
1510
1581
 
1511
1582
  """
1512
- lookup = {}
1513
- for i in self.schemas:
1514
- lookup.update(i._get_single_label_lookup(prefix=prefix))
1583
+ lookup: dict[str, str] = {}
1584
+ for schema in self.schemas:
1585
+ lookup.update(schema._get_single_label_lookup(prefix=prefix))
1515
1586
  return lookup
1516
1587
 
1517
1588
 
1518
- class WorkflowTask:
1589
+ class _ESIdx(NamedTuple):
1590
+ ordered: list[int]
1591
+ uniq: frozenset[int]
1592
+
1593
+
1594
+ class WorkflowTask(AppAware):
1519
1595
  """
1520
1596
  Represents a :py:class:`Task` that is bound to a :py:class:`Workflow`.
1521
1597
 
@@ -1531,14 +1607,12 @@ class WorkflowTask:
1531
1607
  The IDs of the elements of this task.
1532
1608
  """
1533
1609
 
1534
- _app_attr = "app"
1535
-
1536
1610
  def __init__(
1537
1611
  self,
1538
- workflow: app.Workflow,
1539
- template: app.Task,
1612
+ workflow: Workflow,
1613
+ template: Task,
1540
1614
  index: int,
1541
- element_IDs: List[int],
1615
+ element_IDs: list[int],
1542
1616
  ):
1543
1617
  self._workflow = workflow
1544
1618
  self._template = template
@@ -1546,9 +1620,9 @@ class WorkflowTask:
1546
1620
  self._element_IDs = element_IDs
1547
1621
 
1548
1622
  # appended to when new elements are added and reset on dump to disk:
1549
- self._pending_element_IDs = []
1623
+ self._pending_element_IDs: list[int] = []
1550
1624
 
1551
- self._elements = None # assigned on `elements` first access
1625
+ self._elements: Elements | None = None # assigned on `elements` first access
1552
1626
 
1553
1627
  def __repr__(self) -> str:
1554
1628
  return f"{self.__class__.__name__}(name={self.unique_name!r})"
@@ -1561,7 +1635,7 @@ class WorkflowTask:
1561
1635
  self._reset_pending_element_IDs()
1562
1636
 
1563
1637
  @classmethod
1564
- def new_empty_task(cls, workflow: app.Workflow, template: app.Task, index: int):
1638
+ def new_empty_task(cls, workflow: Workflow, template: Task, index: int) -> Self:
1565
1639
  """
1566
1640
  Make a new instance without any elements set up yet.
1567
1641
 
@@ -1574,86 +1648,87 @@ class WorkflowTask:
1574
1648
  index:
1575
1649
  Where in the workflow's list of tasks is this one.
1576
1650
  """
1577
- obj = cls(
1651
+ return cls(
1578
1652
  workflow=workflow,
1579
1653
  template=template,
1580
1654
  index=index,
1581
1655
  element_IDs=[],
1582
1656
  )
1583
- return obj
1584
1657
 
1585
1658
  @property
1586
- def workflow(self):
1659
+ def workflow(self) -> Workflow:
1587
1660
  """
1588
1661
  The workflow this task is bound to.
1589
1662
  """
1590
1663
  return self._workflow
1591
1664
 
1592
1665
  @property
1593
- def template(self):
1666
+ def template(self) -> Task:
1594
1667
  """
1595
1668
  The template for this task.
1596
1669
  """
1597
1670
  return self._template
1598
1671
 
1599
1672
  @property
1600
- def index(self):
1673
+ def index(self) -> int:
1601
1674
  """
1602
1675
  The index of this task within its workflow.
1603
1676
  """
1604
1677
  return self._index
1605
1678
 
1606
1679
  @property
1607
- def element_IDs(self):
1680
+ def element_IDs(self) -> list[int]:
1608
1681
  """
1609
1682
  The IDs of elements associated with this task.
1610
1683
  """
1611
1684
  return self._element_IDs + self._pending_element_IDs
1612
1685
 
1613
1686
  @property
1614
- def num_elements(self):
1687
+ def num_elements(self) -> int:
1615
1688
  """
1616
1689
  The number of elements associated with this task.
1617
1690
  """
1618
- return len(self.element_IDs)
1691
+ return len(self._element_IDs) + len(self._pending_element_IDs)
1619
1692
 
1620
1693
  @property
1621
- def num_actions(self):
1694
+ def num_actions(self) -> int:
1622
1695
  """
1623
1696
  The number of actions in this task.
1624
1697
  """
1625
1698
  return self.template.num_all_schema_actions
1626
1699
 
1627
1700
  @property
1628
- def name(self):
1701
+ def name(self) -> str:
1629
1702
  """
1630
1703
  The name of this task based on its template.
1631
1704
  """
1632
1705
  return self.template.name
1633
1706
 
1634
1707
  @property
1635
- def unique_name(self):
1708
+ def unique_name(self) -> str:
1636
1709
  """
1637
1710
  The unique name for this task specifically.
1638
1711
  """
1639
1712
  return self.workflow.get_task_unique_names()[self.index]
1640
1713
 
1641
1714
  @property
1642
- def insert_ID(self):
1715
+ def insert_ID(self) -> int:
1643
1716
  """
1644
1717
  The insertion ID of the template task.
1645
1718
  """
1646
1719
  return self.template.insert_ID
1647
1720
 
1648
1721
  @property
1649
- def dir_name(self):
1722
+ def dir_name(self) -> str:
1650
1723
  """
1651
1724
  The name of the directory for the task's temporary files.
1652
1725
  """
1653
- return self.template.dir_name
1726
+ dn = self.template.dir_name
1727
+ assert dn is not None
1728
+ return dn
1654
1729
 
1655
1730
  @property
1656
- def num_element_sets(self):
1731
+ def num_element_sets(self) -> int:
1657
1732
  """
1658
1733
  The number of element sets associated with this task.
1659
1734
  """
@@ -1661,15 +1736,15 @@ class WorkflowTask:
1661
1736
 
1662
1737
  @property
1663
1738
  @TimeIt.decorator
1664
- def elements(self):
1739
+ def elements(self) -> Elements:
1665
1740
  """
1666
1741
  The elements associated with this task.
1667
1742
  """
1668
1743
  if self._elements is None:
1669
- self._elements = self.app.Elements(self)
1744
+ self._elements = Elements(self)
1670
1745
  return self._elements
1671
1746
 
1672
- def get_dir_name(self, loop_idx: Dict[str, int] = None) -> str:
1747
+ def get_dir_name(self, loop_idx: Mapping[str, int] | None = None) -> str:
1673
1748
  """
1674
1749
  Get the directory name for a particular iteration.
1675
1750
  """
@@ -1677,15 +1752,102 @@ class WorkflowTask:
1677
1752
  return self.dir_name
1678
1753
  return self.dir_name + "_" + "_".join((f"{k}-{v}" for k, v in loop_idx.items()))
1679
1754
 
1680
- def get_all_element_iterations(self) -> Dict[int, app.ElementIteration]:
1755
+ def get_all_element_iterations(self) -> Mapping[int, ElementIteration]:
1681
1756
  """
1682
1757
  Get the iterations known by the task's elements.
1683
1758
  """
1684
- return {j.id_: j for i in self.elements for j in i.iterations}
1759
+ return {itr.id_: itr for elem in self.elements for itr in elem.iterations}
1685
1760
 
1686
- def _make_new_elements_persistent(
1687
- self, element_set, element_set_idx, padded_elem_iters
1688
- ):
1761
+ @staticmethod
1762
+ def __get_src_elem_iters(
1763
+ src_task: WorkflowTask, inp_src: InputSource
1764
+ ) -> tuple[Iterable[ElementIteration], list[int]]:
1765
+ src_iters = src_task.get_all_element_iterations()
1766
+
1767
+ if inp_src.element_iters:
1768
+ # only include "sourceable" element iterations:
1769
+ src_iters_list = [src_iters[itr_id] for itr_id in inp_src.element_iters]
1770
+ set_indices = [el.element.element_set_idx for el in src_iters.values()]
1771
+ return src_iters_list, set_indices
1772
+ return src_iters.values(), []
1773
+
1774
+ def __get_task_group_index(
1775
+ self,
1776
+ labelled_path_i: str,
1777
+ inp_src: InputSource,
1778
+ padded_elem_iters: Mapping[str, Sequence[int]],
1779
+ inp_group_name: str | None,
1780
+ ) -> None | Sequence[int | list[int]]:
1781
+ src_task = inp_src.get_task(self.workflow)
1782
+ assert src_task
1783
+ src_elem_iters, src_elem_set_idx = self.__get_src_elem_iters(src_task, inp_src)
1784
+
1785
+ if not src_elem_iters:
1786
+ return None
1787
+
1788
+ task_source_type = inp_src.task_source_type
1789
+ assert task_source_type is not None
1790
+ if task_source_type == TaskSourceType.OUTPUT and "[" in labelled_path_i:
1791
+ src_key = f"{task_source_type.name.lower()}s.{labelled_path_i.split('[')[0]}"
1792
+ else:
1793
+ src_key = f"{task_source_type.name.lower()}s.{labelled_path_i}"
1794
+
1795
+ padded_iters = padded_elem_iters.get(labelled_path_i, [])
1796
+ grp_idx = [
1797
+ (itr.get_data_idx()[src_key] if itr_idx not in padded_iters else -1)
1798
+ for itr_idx, itr in enumerate(src_elem_iters)
1799
+ ]
1800
+
1801
+ if not inp_group_name:
1802
+ return grp_idx
1803
+
1804
+ group_dat_idx: list[int | list[int]] = []
1805
+ element_sets = src_task.template.element_sets
1806
+ for dat_idx, src_set_idx, src_iter in zip(
1807
+ grp_idx, src_elem_set_idx, src_elem_iters
1808
+ ):
1809
+ src_es = element_sets[src_set_idx]
1810
+ if any(inp_group_name == grp.name for grp in src_es.groups):
1811
+ group_dat_idx.append(dat_idx)
1812
+ continue
1813
+ # if for any recursive iteration dependency, this group is
1814
+ # defined, assign:
1815
+ src_iter_deps = self.workflow.get_element_iterations_from_IDs(
1816
+ src_iter.get_element_iteration_dependencies(),
1817
+ )
1818
+
1819
+ if any(
1820
+ inp_group_name == grp.name
1821
+ for el_iter in src_iter_deps
1822
+ for grp in el_iter.element.element_set.groups
1823
+ ):
1824
+ group_dat_idx.append(dat_idx)
1825
+ continue
1826
+
1827
+ # also check input dependencies
1828
+ for p_src in src_iter.element.get_input_dependencies().values():
1829
+ k_es = self.workflow.tasks.get(
1830
+ insert_ID=p_src["task_insert_ID"]
1831
+ ).template.element_sets[p_src["element_set_idx"]]
1832
+ if any(inp_group_name == grp.name for grp in k_es.groups):
1833
+ group_dat_idx.append(dat_idx)
1834
+ break
1835
+
1836
+ # TODO: this only goes to one level of dependency
1837
+
1838
+ if not group_dat_idx:
1839
+ raise MissingElementGroup(self.unique_name, inp_group_name, labelled_path_i)
1840
+
1841
+ return [cast("int", group_dat_idx)] # TODO: generalise to multiple groups
1842
+
1843
+ def __make_new_elements_persistent(
1844
+ self,
1845
+ element_set: ElementSet,
1846
+ element_set_idx: int,
1847
+ padded_elem_iters: Mapping[str, Sequence[int]],
1848
+ ) -> tuple[
1849
+ dict[str, list[int | list[int]]], dict[str, Sequence[int]], dict[str, list[int]]
1850
+ ]:
1689
1851
  """Save parameter data to the persistent workflow."""
1690
1852
 
1691
1853
  # TODO: rewrite. This method is a little hard to follow and results in somewhat
@@ -1693,25 +1855,25 @@ class WorkflowTask:
1693
1855
  # given input, the local source element(s) will always come first, regardless of
1694
1856
  # the ordering in element_set.input_sources.
1695
1857
 
1696
- input_data_idx = {}
1697
- sequence_idx = {}
1698
- source_idx = {}
1858
+ input_data_idx: dict[str, list[int | list[int]]] = {}
1859
+ sequence_idx: dict[str, Sequence[int]] = {}
1860
+ source_idx: dict[str, list[int]] = {}
1699
1861
 
1700
1862
  # Assign first assuming all locally defined values are to be used:
1701
- param_src = {
1863
+ param_src: ParamSource = {
1702
1864
  "type": "local_input",
1703
1865
  "task_insert_ID": self.insert_ID,
1704
1866
  "element_set_idx": element_set_idx,
1705
1867
  }
1706
- loc_inp_src = self.app.InputSource.local()
1868
+ loc_inp_src = self._app.InputSource.local()
1707
1869
  for res_i in element_set.resources:
1708
1870
  key, dat_ref, _ = res_i.make_persistent(self.workflow, param_src)
1709
- input_data_idx[key] = dat_ref
1871
+ input_data_idx[key] = list(dat_ref)
1710
1872
 
1711
1873
  for inp_i in element_set.inputs:
1712
1874
  key, dat_ref, _ = inp_i.make_persistent(self.workflow, param_src)
1713
- input_data_idx[key] = dat_ref
1714
- key_ = key.split("inputs.")[1]
1875
+ input_data_idx[key] = list(dat_ref)
1876
+ key_ = key.removeprefix("inputs.")
1715
1877
  try:
1716
1878
  # TODO: wouldn't need to do this if we raise when an InputValue is
1717
1879
  # provided for a parameter whose inputs sources do not include the local
@@ -1721,15 +1883,15 @@ class WorkflowTask:
1721
1883
  pass
1722
1884
 
1723
1885
  for inp_file_i in element_set.input_files:
1724
- key, dat_ref, _ = inp_file_i.make_persistent(self.workflow, param_src)
1725
- input_data_idx[key] = dat_ref
1886
+ key, input_dat_ref, _ = inp_file_i.make_persistent(self.workflow, param_src)
1887
+ input_data_idx[key] = list(input_dat_ref)
1726
1888
 
1727
1889
  for seq_i in element_set.sequences:
1728
- key, dat_ref, _ = seq_i.make_persistent(self.workflow, param_src)
1729
- input_data_idx[key] = dat_ref
1730
- sequence_idx[key] = list(range(len(dat_ref)))
1890
+ key, seq_dat_ref, _ = seq_i.make_persistent(self.workflow, param_src)
1891
+ input_data_idx[key] = list(seq_dat_ref)
1892
+ sequence_idx[key] = list(range(len(seq_dat_ref)))
1731
1893
  try:
1732
- key_ = key.split("inputs.")[1]
1894
+ key_ = key.removeprefix("inputs.")
1733
1895
  except IndexError:
1734
1896
  pass
1735
1897
  try:
@@ -1738,22 +1900,20 @@ class WorkflowTask:
1738
1900
  # value.
1739
1901
  source_idx[key] = [
1740
1902
  element_set.input_sources[key_].index(loc_inp_src)
1741
- ] * len(dat_ref)
1903
+ ] * len(seq_dat_ref)
1742
1904
  except ValueError:
1743
1905
  pass
1744
1906
 
1745
1907
  for rep_spec in element_set.repeats:
1746
1908
  seq_key = f"repeats.{rep_spec['name']}"
1747
- num_range = list(range(rep_spec["number"]))
1748
- input_data_idx[seq_key] = num_range
1909
+ num_range = range(rep_spec["number"])
1910
+ input_data_idx[seq_key] = list(num_range)
1749
1911
  sequence_idx[seq_key] = num_range
1750
1912
 
1751
1913
  # Now check for task- and default-sources and overwrite or append to local sources:
1752
1914
  inp_stats = self.template.get_input_statuses(element_set)
1753
1915
  for labelled_path_i, sources_i in element_set.input_sources.items():
1754
- path_i_split = labelled_path_i.split(".")
1755
- is_path_i_sub = len(path_i_split) > 1
1756
- if is_path_i_sub:
1916
+ if len(path_i_split := labelled_path_i.split(".")) > 1:
1757
1917
  path_i_root = path_i_split[0]
1758
1918
  else:
1759
1919
  path_i_root = labelled_path_i
@@ -1773,114 +1933,36 @@ class WorkflowTask:
1773
1933
 
1774
1934
  for inp_src_idx, inp_src in enumerate(sources_i):
1775
1935
  if inp_src.source_type is InputSourceType.TASK:
1776
- src_task = inp_src.get_task(self.workflow)
1777
- src_elem_iters = src_task.get_all_element_iterations()
1778
-
1779
- if inp_src.element_iters:
1780
- # only include "sourceable" element iterations:
1781
- src_elem_iters = [
1782
- src_elem_iters[i] for i in inp_src.element_iters
1783
- ]
1784
- src_elem_set_idx = [
1785
- i.element.element_set_idx for i in src_elem_iters
1786
- ]
1787
-
1788
- if not src_elem_iters:
1936
+ grp_idx = self.__get_task_group_index(
1937
+ labelled_path_i, inp_src, padded_elem_iters, inp_group_name
1938
+ )
1939
+ if grp_idx is None:
1789
1940
  continue
1790
1941
 
1791
- task_source_type = inp_src.task_source_type.name.lower()
1792
- if task_source_type == "output" and "[" in labelled_path_i:
1793
- src_key = f"{task_source_type}s.{labelled_path_i.split('[')[0]}"
1794
- else:
1795
- src_key = f"{task_source_type}s.{labelled_path_i}"
1796
-
1797
- padded_iters = padded_elem_iters.get(labelled_path_i, [])
1798
- grp_idx = [
1799
- (
1800
- iter_i.get_data_idx()[src_key]
1801
- if iter_i_idx not in padded_iters
1802
- else -1
1803
- )
1804
- for iter_i_idx, iter_i in enumerate(src_elem_iters)
1805
- ]
1806
-
1807
- if inp_group_name:
1808
- group_dat_idx = []
1809
- for dat_idx_i, src_set_idx_i, src_iter in zip(
1810
- grp_idx, src_elem_set_idx, src_elem_iters
1811
- ):
1812
- src_es = src_task.template.element_sets[src_set_idx_i]
1813
- if inp_group_name in [i.name for i in src_es.groups or []]:
1814
- group_dat_idx.append(dat_idx_i)
1815
- else:
1816
- # if for any recursive iteration dependency, this group is
1817
- # defined, assign:
1818
- src_iter_i = src_iter
1819
- src_iter_deps = (
1820
- self.workflow.get_element_iterations_from_IDs(
1821
- src_iter_i.get_element_iteration_dependencies(),
1822
- )
1823
- )
1824
-
1825
- src_iter_deps_groups = [
1826
- j
1827
- for i in src_iter_deps
1828
- for j in i.element.element_set.groups
1829
- ]
1830
-
1831
- if inp_group_name in [
1832
- i.name for i in src_iter_deps_groups or []
1833
- ]:
1834
- group_dat_idx.append(dat_idx_i)
1835
-
1836
- # also check input dependencies
1837
- for (
1838
- k,
1839
- v,
1840
- ) in src_iter.element.get_input_dependencies().items():
1841
- k_es_idx = v["element_set_idx"]
1842
- k_task_iID = v["task_insert_ID"]
1843
- k_es = self.workflow.tasks.get(
1844
- insert_ID=k_task_iID
1845
- ).template.element_sets[k_es_idx]
1846
- if inp_group_name in [
1847
- i.name for i in k_es.groups or []
1848
- ]:
1849
- group_dat_idx.append(dat_idx_i)
1850
-
1851
- # TODO: this only goes to one level of dependency
1852
-
1853
- if not group_dat_idx:
1854
- raise MissingElementGroup(
1855
- f"Adding elements to task {self.unique_name!r}: no "
1856
- f"element group named {inp_group_name!r} found for input "
1857
- f"{labelled_path_i!r}."
1858
- )
1859
-
1860
- grp_idx = [group_dat_idx] # TODO: generalise to multiple groups
1861
-
1862
- if self.app.InputSource.local() in sources_i:
1942
+ if self._app.InputSource.local() in sources_i:
1863
1943
  # add task source to existing local source:
1864
- input_data_idx[key] += grp_idx
1865
- source_idx[key] += [inp_src_idx] * len(grp_idx)
1944
+ input_data_idx[key].extend(grp_idx)
1945
+ source_idx[key].extend([inp_src_idx] * len(grp_idx))
1866
1946
 
1867
1947
  else: # BUG: doesn't work for multiple task inputs sources
1868
1948
  # overwrite existing local source (if it exists):
1869
- input_data_idx[key] = grp_idx
1949
+ input_data_idx[key] = list(grp_idx)
1870
1950
  source_idx[key] = [inp_src_idx] * len(grp_idx)
1871
1951
  if key in sequence_idx:
1872
1952
  sequence_idx.pop(key)
1873
- seq = element_set.get_sequence_from_path(key)
1953
+ # TODO: Use the value retrieved below?
1954
+ _ = element_set.get_sequence_from_path(key)
1874
1955
 
1875
1956
  elif inp_src.source_type is InputSourceType.DEFAULT:
1876
- grp_idx = [def_val._value_group_idx]
1877
- if self.app.InputSource.local() in sources_i:
1878
- input_data_idx[key] += grp_idx
1879
- source_idx[key] += [inp_src_idx] * len(grp_idx)
1880
-
1957
+ assert def_val is not None
1958
+ assert def_val._value_group_idx is not None
1959
+ grp_idx_ = def_val._value_group_idx
1960
+ if self._app.InputSource.local() in sources_i:
1961
+ input_data_idx[key].append(grp_idx_)
1962
+ source_idx[key].append(inp_src_idx)
1881
1963
  else:
1882
- input_data_idx[key] = grp_idx
1883
- source_idx[key] = [inp_src_idx] * len(grp_idx)
1964
+ input_data_idx[key] = [grp_idx_]
1965
+ source_idx[key] = [inp_src_idx]
1884
1966
 
1885
1967
  # sort smallest to largest path, so more-specific items overwrite less-specific
1886
1968
  # items in parameter retrieval in `WorkflowTask._get_merged_parameter_data`:
@@ -1888,7 +1970,9 @@ class WorkflowTask:
1888
1970
 
1889
1971
  return (input_data_idx, sequence_idx, source_idx)
1890
1972
 
1891
- def ensure_input_sources(self, element_set) -> Dict[str, List[int]]:
1973
+ def ensure_input_sources(
1974
+ self, element_set: ElementSet
1975
+ ) -> Mapping[str, Sequence[int]]:
1892
1976
  """Check valid input sources are specified for a new task to be added to the
1893
1977
  workflow in a given position. If none are specified, set them according to the
1894
1978
  default behaviour.
@@ -1903,18 +1987,8 @@ class WorkflowTask:
1903
1987
  source_tasks=self.workflow.tasks[: self.index],
1904
1988
  )
1905
1989
 
1906
- unreq_sources = set(element_set.input_sources.keys()) - set(
1907
- available_sources.keys()
1908
- )
1909
- if unreq_sources:
1910
- unreq_src_str = ", ".join(f"{i!r}" for i in unreq_sources)
1911
- raise UnrequiredInputSources(
1912
- message=(
1913
- f"The following input sources are not required but have been "
1914
- f"specified: {unreq_src_str}."
1915
- ),
1916
- unrequired_sources=unreq_sources,
1917
- )
1990
+ if unreq := set(element_set.input_sources).difference(available_sources):
1991
+ raise UnrequiredInputSources(unreq)
1918
1992
 
1919
1993
  # TODO: get available input sources from workflow imports
1920
1994
 
@@ -1929,11 +2003,8 @@ class WorkflowTask:
1929
2003
  for path_i, avail_i in available_sources.items():
1930
2004
  # for each sub-path in available sources, if the "root-path" source is
1931
2005
  # required, then add the sub-path source to `req_types` as well:
1932
- path_i_split = path_i.split(".")
1933
- is_path_i_sub = len(path_i_split) > 1
1934
- if is_path_i_sub:
1935
- path_i_root = path_i_split[0]
1936
- if path_i_root in req_types:
2006
+ if len(path_i_split := path_i.split(".")) > 1:
2007
+ if path_i_split[0] in req_types:
1937
2008
  req_types.add(path_i)
1938
2009
 
1939
2010
  for s_idx, specified_source in enumerate(
@@ -1943,35 +2014,36 @@ class WorkflowTask:
1943
2014
  specified_source, self.unique_name
1944
2015
  )
1945
2016
  avail_idx = specified_source.is_in(avail_i)
2017
+ if avail_idx is None:
2018
+ raise UnavailableInputSource(specified_source, path_i, avail_i)
2019
+ available_source: InputSource
1946
2020
  try:
1947
2021
  available_source = avail_i[avail_idx]
1948
2022
  except TypeError:
1949
2023
  raise UnavailableInputSource(
1950
- f"The input source {specified_source.to_string()!r} is not "
1951
- f"available for input path {path_i!r}. Available "
1952
- f"input sources are: {[i.to_string() for i in avail_i]}."
2024
+ specified_source, path_i, avail_i
1953
2025
  ) from None
1954
2026
 
1955
2027
  elem_iters_IDs = available_source.element_iters
1956
2028
  if specified_source.element_iters:
1957
2029
  # user-specified iter IDs; these must be a subset of available
1958
2030
  # element_iters:
1959
- if not set(specified_source.element_iters).issubset(elem_iters_IDs):
2031
+ if not set(specified_source.element_iters).issubset(
2032
+ elem_iters_IDs or ()
2033
+ ):
1960
2034
  raise InapplicableInputSourceElementIters(
1961
- f"The specified `element_iters` for input source "
1962
- f"{specified_source.to_string()!r} are not all applicable. "
1963
- f"Applicable element iteration IDs for this input source "
1964
- f"are: {elem_iters_IDs!r}."
2035
+ specified_source, elem_iters_IDs
1965
2036
  )
1966
2037
  elem_iters_IDs = specified_source.element_iters
1967
2038
 
1968
2039
  if specified_source.where:
1969
2040
  # filter iter IDs by user-specified rules, maintaining order:
1970
2041
  elem_iters = self.workflow.get_element_iterations_from_IDs(
1971
- elem_iters_IDs
2042
+ elem_iters_IDs or ()
1972
2043
  )
1973
- filtered = specified_source.where.filter(elem_iters)
1974
- elem_iters_IDs = [i.id_ for i in filtered]
2044
+ elem_iters_IDs = [
2045
+ ei.id_ for ei in specified_source.where.filter(elem_iters)
2046
+ ]
1975
2047
 
1976
2048
  available_source.element_iters = elem_iters_IDs
1977
2049
  element_set.input_sources[path_i][s_idx] = available_source
@@ -1979,24 +2051,16 @@ class WorkflowTask:
1979
2051
  # sorting ensures that root parameters come before sub-parameters, which is
1980
2052
  # necessary when considering if we want to include a sub-parameter, when setting
1981
2053
  # missing sources below:
1982
- unsourced_inputs = sorted(req_types - set(element_set.input_sources.keys()))
1983
-
1984
- extra_types = set(k for k, v in all_stats.items() if v.is_extra)
1985
- if extra_types:
1986
- extra_str = ", ".join(f"{i!r}" for i in extra_types)
1987
- raise ExtraInputs(
1988
- message=(
1989
- f"The following inputs are not required, but have been passed: "
1990
- f"{extra_str}."
1991
- ),
1992
- extra_inputs=extra_types,
1993
- )
2054
+ unsourced_inputs = sorted(req_types.difference(element_set.input_sources))
2055
+
2056
+ if extra_types := {k for k, v in all_stats.items() if v.is_extra}:
2057
+ raise ExtraInputs(extra_types)
1994
2058
 
1995
2059
  # set source for any unsourced inputs:
1996
- missing = []
2060
+ missing: list[str] = []
1997
2061
  # track which root params we have set according to default behaviour (not
1998
2062
  # specified by user):
1999
- set_root_params = []
2063
+ set_root_params: set[str] = set()
2000
2064
  for input_type in unsourced_inputs:
2001
2065
  input_split = input_type.split(".")
2002
2066
  has_root_param = input_split[0] if len(input_split) > 1 else None
@@ -2028,89 +2092,85 @@ class WorkflowTask:
2028
2092
  ):
2029
2093
  continue
2030
2094
 
2031
- element_set.input_sources.update({input_type: [source]})
2095
+ element_set.input_sources[input_type] = [source]
2032
2096
  if not has_root_param:
2033
- set_root_params.append(input_type)
2097
+ set_root_params.add(input_type)
2034
2098
 
2035
2099
  # for task sources that span multiple element sets, pad out sub-parameter
2036
2100
  # `element_iters` to include the element iterations from other element sets in
2037
2101
  # which the "root" parameter is defined:
2038
- sources_by_task = defaultdict(dict)
2039
- elem_iter_by_task = defaultdict(dict)
2040
- all_elem_iters = set()
2102
+ sources_by_task: dict[int, dict[str, InputSource]] = defaultdict(dict)
2103
+ elem_iter_by_task: dict[int, dict[str, list[int]]] = defaultdict(dict)
2104
+ all_elem_iters: set[int] = set()
2041
2105
  for inp_type, sources in element_set.input_sources.items():
2042
2106
  source = sources[0]
2043
2107
  if source.source_type is InputSourceType.TASK:
2108
+ assert source.task_ref is not None
2109
+ assert source.element_iters is not None
2044
2110
  sources_by_task[source.task_ref][inp_type] = source
2045
2111
  all_elem_iters.update(source.element_iters)
2046
2112
  elem_iter_by_task[source.task_ref][inp_type] = source.element_iters
2047
2113
 
2048
- all_elem_iter_objs = self.workflow.get_element_iterations_from_IDs(all_elem_iters)
2049
- all_elem_iters_by_ID = {i.id_: i for i in all_elem_iter_objs}
2114
+ all_elem_iters_by_ID = {
2115
+ el_iter.id_: el_iter
2116
+ for el_iter in self.workflow.get_element_iterations_from_IDs(all_elem_iters)
2117
+ }
2050
2118
 
2051
2119
  # element set indices:
2052
2120
  padded_elem_iters = defaultdict(list)
2053
- es_idx_by_task = defaultdict(dict)
2121
+ es_idx_by_task: dict[int, dict[str, _ESIdx]] = defaultdict(dict)
2054
2122
  for task_ref, task_iters in elem_iter_by_task.items():
2055
2123
  for inp_type, inp_iters in task_iters.items():
2056
2124
  es_indices = [
2057
- all_elem_iters_by_ID[i].element.element_set_idx for i in inp_iters
2125
+ all_elem_iters_by_ID[id_].element.element_set_idx for id_ in inp_iters
2058
2126
  ]
2059
- es_idx_by_task[task_ref][inp_type] = (es_indices, set(es_indices))
2060
- root_params = {k for k in task_iters if "." not in k}
2061
- root_param_nesting = {
2062
- k: element_set.nesting_order.get(f"inputs.{k}", None) for k in root_params
2063
- }
2064
- for root_param_i in root_params:
2065
- sub_params = {
2066
- k
2067
- for k in task_iters
2068
- if k.split(".")[0] == root_param_i and k != root_param_i
2069
- }
2070
- rp_elem_sets = es_idx_by_task[task_ref][root_param_i][0]
2071
- rp_elem_sets_uniq = es_idx_by_task[task_ref][root_param_i][1]
2072
-
2073
- for sub_param_j in sub_params:
2127
+ es_idx_by_task[task_ref][inp_type] = _ESIdx(
2128
+ es_indices, frozenset(es_indices)
2129
+ )
2130
+ for root_param in {k for k in task_iters if "." not in k}:
2131
+ rp_nesting = element_set.nesting_order.get(f"inputs.{root_param}", None)
2132
+ rp_elem_sets, rp_elem_sets_uniq = es_idx_by_task[task_ref][root_param]
2133
+
2134
+ root_param_prefix = f"{root_param}."
2135
+ for sub_param_j in {
2136
+ k for k in task_iters if k.startswith(root_param_prefix)
2137
+ }:
2074
2138
  sub_param_nesting = element_set.nesting_order.get(
2075
2139
  f"inputs.{sub_param_j}", None
2076
2140
  )
2077
- if sub_param_nesting == root_param_nesting[root_param_i]:
2078
-
2079
- sp_elem_sets_uniq = es_idx_by_task[task_ref][sub_param_j][1]
2141
+ if sub_param_nesting == rp_nesting:
2142
+ sp_elem_sets_uniq = es_idx_by_task[task_ref][sub_param_j].uniq
2080
2143
 
2081
2144
  if sp_elem_sets_uniq != rp_elem_sets_uniq:
2082
-
2083
2145
  # replace elem_iters in sub-param sequence with those from the
2084
2146
  # root parameter, but re-order the elem iters to match their
2085
2147
  # original order:
2086
- iters_copy = elem_iter_by_task[task_ref][root_param_i][:]
2148
+ iters = elem_iter_by_task[task_ref][root_param]
2087
2149
 
2088
2150
  # "mask" iter IDs corresponding to the sub-parameter's element
2089
2151
  # sets, and keep track of the extra indices so they can be
2090
2152
  # ignored later:
2091
- sp_iters_new = []
2092
- for idx, (i, j) in enumerate(zip(iters_copy, rp_elem_sets)):
2093
- if j in sp_elem_sets_uniq:
2153
+ sp_iters_new: list[int | None] = []
2154
+ for idx, (it_id, es_idx) in enumerate(
2155
+ zip(iters, rp_elem_sets)
2156
+ ):
2157
+ if es_idx in sp_elem_sets_uniq:
2094
2158
  sp_iters_new.append(None)
2095
2159
  else:
2096
- sp_iters_new.append(i)
2160
+ sp_iters_new.append(it_id)
2097
2161
  padded_elem_iters[sub_param_j].append(idx)
2098
2162
 
2099
- # fill in sub-param elem_iters in their specified order
2100
- sub_iters_it = iter(elem_iter_by_task[task_ref][sub_param_j])
2101
- sp_iters_new = [
2102
- i if i is not None else next(sub_iters_it)
2103
- for i in sp_iters_new
2104
- ]
2105
-
2106
2163
  # update sub-parameter element iters:
2107
- for src_idx, src in enumerate(
2108
- element_set.input_sources[sub_param_j]
2109
- ):
2164
+ for src in element_set.input_sources[sub_param_j]:
2110
2165
  if src.source_type is InputSourceType.TASK:
2111
- element_set.input_sources[sub_param_j][
2112
- src_idx
2113
- ].element_iters = sp_iters_new
2166
+ # fill in sub-param elem_iters in their specified order
2167
+ sub_iters_it = iter(
2168
+ elem_iter_by_task[task_ref][sub_param_j]
2169
+ )
2170
+ src.element_iters = [
2171
+ it_id if it_id is not None else next(sub_iters_it)
2172
+ for it_id in sp_iters_new
2173
+ ]
2114
2174
  # assumes only a single task-type source for this
2115
2175
  # parameter
2116
2176
  break
@@ -2120,131 +2180,125 @@ class WorkflowTask:
2120
2180
  # results in no available elements due to `allow_non_coincident_task_sources`.
2121
2181
 
2122
2182
  if not element_set.allow_non_coincident_task_sources:
2123
- # if multiple parameters are sourced from the same upstream task, only use
2124
- # element iterations for which all parameters are available (the set
2125
- # intersection):
2126
- for task_ref, sources in sources_by_task.items():
2127
- # if a parameter has multiple labels, disregard from this by removing all
2128
- # parameters:
2129
- seen_labelled = {}
2130
- for src_i in sources.keys():
2131
- if "[" in src_i:
2132
- unlabelled, _ = split_param_label(src_i)
2133
- if unlabelled not in seen_labelled:
2134
- seen_labelled[unlabelled] = 1
2135
- else:
2136
- seen_labelled[unlabelled] += 1
2137
-
2138
- for unlabelled, count in seen_labelled.items():
2139
- if count > 1:
2140
- # remove:
2141
- sources = {
2142
- k: v
2143
- for k, v in sources.items()
2144
- if not k.startswith(unlabelled)
2145
- }
2183
+ self.__enforce_some_sanity(sources_by_task, element_set)
2146
2184
 
2147
- if len(sources) < 2:
2148
- continue
2149
-
2150
- first_src = next(iter(sources.values()))
2151
- intersect_task_i = set(first_src.element_iters)
2152
- for src_i in sources.values():
2153
- intersect_task_i.intersection_update(src_i.element_iters)
2154
- if not intersect_task_i:
2155
- raise NoCoincidentInputSources(
2156
- f"Task {self.name!r}: input sources from task {task_ref!r} have "
2157
- f"no coincident applicable element iterations. Consider setting "
2158
- f"the element set (or task) argument "
2159
- f"`allow_non_coincident_task_sources` to `True`, which will "
2160
- f"allow for input sources from the same task to use different "
2161
- f"(non-coinciding) subsets of element iterations from the "
2162
- f"source task."
2163
- )
2185
+ if missing:
2186
+ raise MissingInputs(missing)
2187
+ return padded_elem_iters
2164
2188
 
2165
- # now change elements for the affected input sources.
2166
- # sort by original order of first_src.element_iters
2167
- int_task_i_lst = [
2168
- i for i in first_src.element_iters if i in intersect_task_i
2169
- ]
2170
- for inp_type in sources.keys():
2171
- element_set.input_sources[inp_type][0].element_iters = int_task_i_lst
2189
+ def __enforce_some_sanity(
2190
+ self, sources_by_task: dict[int, dict[str, InputSource]], element_set: ElementSet
2191
+ ) -> None:
2192
+ """
2193
+ if multiple parameters are sourced from the same upstream task, only use
2194
+ element iterations for which all parameters are available (the set
2195
+ intersection)
2196
+ """
2197
+ for task_ref, sources in sources_by_task.items():
2198
+ # if a parameter has multiple labels, disregard from this by removing all
2199
+ # parameters:
2200
+ seen_labelled: dict[str, int] = defaultdict(int)
2201
+ for src_i in sources:
2202
+ if "[" in src_i:
2203
+ unlabelled, _ = split_param_label(src_i)
2204
+ assert unlabelled is not None
2205
+ seen_labelled[unlabelled] += 1
2206
+
2207
+ for prefix, count in seen_labelled.items():
2208
+ if count > 1:
2209
+ # remove:
2210
+ sources = {
2211
+ k: v for k, v in sources.items() if not k.startswith(prefix)
2212
+ }
2172
2213
 
2173
- if missing:
2174
- missing_str = ", ".join(f"{i!r}" for i in missing)
2175
- raise MissingInputs(
2176
- message=f"The following inputs have no sources: {missing_str}.",
2177
- missing_inputs=missing,
2178
- )
2214
+ if len(sources) < 2:
2215
+ continue
2179
2216
 
2180
- return padded_elem_iters
2217
+ first_src = next(iter(sources.values()))
2218
+ intersect_task_i = set(first_src.element_iters or ())
2219
+ for inp_src in sources.values():
2220
+ intersect_task_i.intersection_update(inp_src.element_iters or ())
2221
+ if not intersect_task_i:
2222
+ raise NoCoincidentInputSources(self.name, task_ref)
2223
+
2224
+ # now change elements for the affected input sources.
2225
+ # sort by original order of first_src.element_iters
2226
+ int_task_i_lst = [
2227
+ i for i in first_src.element_iters or () if i in intersect_task_i
2228
+ ]
2229
+ for inp_type in sources:
2230
+ element_set.input_sources[inp_type][0].element_iters = int_task_i_lst
2181
2231
 
2182
2232
  def generate_new_elements(
2183
2233
  self,
2184
- input_data_indices,
2185
- output_data_indices,
2186
- element_data_indices,
2187
- sequence_indices,
2188
- source_indices,
2189
- ):
2234
+ input_data_indices: Mapping[str, Sequence[int | list[int]]],
2235
+ output_data_indices: Mapping[str, Sequence[int]],
2236
+ element_data_indices: Sequence[Mapping[str, int]],
2237
+ sequence_indices: Mapping[str, Sequence[int]],
2238
+ source_indices: Mapping[str, Sequence[int]],
2239
+ ) -> tuple[
2240
+ Sequence[DataIndex], Mapping[str, Sequence[int]], Mapping[str, Sequence[int]]
2241
+ ]:
2190
2242
  """
2191
2243
  Create information about new elements in this task.
2192
2244
  """
2193
- new_elements = []
2194
- element_sequence_indices = {}
2195
- element_src_indices = {}
2196
- for i_idx, i in enumerate(element_data_indices):
2245
+ new_elements: list[DataIndex] = []
2246
+ element_sequence_indices: dict[str, list[int]] = {}
2247
+ element_src_indices: dict[str, list[int]] = {}
2248
+ for i_idx, data_idx in enumerate(element_data_indices):
2197
2249
  elem_i = {
2198
2250
  k: input_data_indices[k][v]
2199
- for k, v in i.items()
2251
+ for k, v in data_idx.items()
2200
2252
  if input_data_indices[k][v] != -1
2201
2253
  }
2202
- elem_i.update({k: v[i_idx] for k, v in output_data_indices.items()})
2254
+ elem_i.update((k, v2[i_idx]) for k, v2 in output_data_indices.items())
2203
2255
  new_elements.append(elem_i)
2204
2256
 
2205
- for k, v in i.items():
2257
+ for k, v3 in data_idx.items():
2206
2258
  # track which sequence value indices (if any) are used for each new
2207
2259
  # element:
2208
2260
  if k in sequence_indices:
2209
- if k not in element_sequence_indices:
2210
- element_sequence_indices[k] = []
2211
- element_sequence_indices[k].append(sequence_indices[k][v])
2261
+ element_sequence_indices.setdefault(k, []).append(
2262
+ sequence_indices[k][v3]
2263
+ )
2212
2264
 
2213
2265
  # track original InputSource associated with each new element:
2214
2266
  if k in source_indices:
2215
- if k not in element_src_indices:
2216
- element_src_indices[k] = []
2217
- if input_data_indices[k][v] != -1:
2218
- src_idx_k = source_indices[k][v]
2267
+ if input_data_indices[k][v3] != -1:
2268
+ src_idx_k = source_indices[k][v3]
2219
2269
  else:
2220
2270
  src_idx_k = -1
2221
- element_src_indices[k].append(src_idx_k)
2271
+ element_src_indices.setdefault(k, []).append(src_idx_k)
2222
2272
 
2223
2273
  return new_elements, element_sequence_indices, element_src_indices
2224
2274
 
2225
2275
  @property
2226
- def upstream_tasks(self):
2276
+ def upstream_tasks(self) -> Iterator[WorkflowTask]:
2227
2277
  """All workflow tasks that are upstream from this task."""
2228
- return [task for task in self.workflow.tasks[: self.index]]
2278
+ tasks = self.workflow.tasks
2279
+ for idx in range(0, self.index):
2280
+ yield tasks[idx]
2229
2281
 
2230
2282
  @property
2231
- def downstream_tasks(self):
2283
+ def downstream_tasks(self) -> Iterator[WorkflowTask]:
2232
2284
  """All workflow tasks that are downstream from this task."""
2233
- return [task for task in self.workflow.tasks[self.index + 1 :]]
2285
+ tasks = self.workflow.tasks
2286
+ for idx in range(self.index + 1, len(tasks)):
2287
+ yield tasks[idx]
2234
2288
 
2235
2289
  @staticmethod
2236
- def resolve_element_data_indices(multiplicities):
2237
- """Find the index of the Zarr parameter group index list corresponding to each
2290
+ def resolve_element_data_indices(
2291
+ multiplicities: list[MultiplicityDescriptor],
2292
+ ) -> Sequence[Mapping[str, int]]:
2293
+ """Find the index of the parameter group index list corresponding to each
2238
2294
  input data for all elements.
2239
2295
 
2240
- # TODO: update docstring; shouldn't reference Zarr.
2241
-
2242
2296
  Parameters
2243
2297
  ----------
2244
- multiplicities : list of dict
2298
+ multiplicities : list of MultiplicityDescriptor
2245
2299
  Each list item represents a sequence of values with keys:
2246
2300
  multiplicity: int
2247
- nesting_order: int
2301
+ nesting_order: float
2248
2302
  path : str
2249
2303
 
2250
2304
  Returns
@@ -2254,37 +2308,40 @@ class WorkflowTask:
2254
2308
  input data paths and whose values are indices that index the values of the
2255
2309
  dict returned by the `task.make_persistent` method.
2256
2310
 
2311
+ Note
2312
+ ----
2313
+ Non-integer nesting orders result in doing the dot product of that sequence with
2314
+ all the current sequences instead of just with the other sequences at the same
2315
+ nesting order (or as a cross product for other nesting orders entire).
2257
2316
  """
2258
2317
 
2259
2318
  # order by nesting order (lower nesting orders will be slowest-varying):
2260
2319
  multi_srt = sorted(multiplicities, key=lambda x: x["nesting_order"])
2261
- multi_srt_grp = group_by_dict_key_values(
2262
- multi_srt, "nesting_order"
2263
- ) # TODO: is tested?
2320
+ multi_srt_grp = group_by_dict_key_values(multi_srt, "nesting_order")
2264
2321
 
2265
- element_dat_idx = [{}]
2266
- last_nest_ord = None
2322
+ element_dat_idx: list[dict[str, int]] = [{}]
2323
+ last_nest_ord: int | None = None
2267
2324
  for para_sequences in multi_srt_grp:
2268
2325
  # check all equivalent nesting_orders have equivalent multiplicities
2269
- all_multis = {i["multiplicity"] for i in para_sequences}
2326
+ all_multis = {md["multiplicity"] for md in para_sequences}
2270
2327
  if len(all_multis) > 1:
2271
2328
  raise ValueError(
2272
2329
  f"All inputs with the same `nesting_order` must have the same "
2273
2330
  f"multiplicity, but for paths "
2274
- f"{[i['path'] for i in para_sequences]} with "
2331
+ f"{[md['path'] for md in para_sequences]} with "
2275
2332
  f"`nesting_order` {para_sequences[0]['nesting_order']} found "
2276
- f"multiplicities {[i['multiplicity'] for i in para_sequences]}."
2333
+ f"multiplicities {[md['multiplicity'] for md in para_sequences]}."
2277
2334
  )
2278
2335
 
2279
- cur_nest_ord = para_sequences[0]["nesting_order"]
2280
- new_elements = []
2336
+ cur_nest_ord = int(para_sequences[0]["nesting_order"])
2337
+ new_elements: list[dict[str, int]] = []
2281
2338
  for elem_idx, element in enumerate(element_dat_idx):
2282
- if last_nest_ord is not None and int(cur_nest_ord) == int(last_nest_ord):
2339
+ if last_nest_ord is not None and cur_nest_ord == last_nest_ord:
2283
2340
  # merge in parallel with existing elements:
2284
2341
  new_elements.append(
2285
2342
  {
2286
2343
  **element,
2287
- **{i["path"]: elem_idx for i in para_sequences},
2344
+ **{md["path"]: elem_idx for md in para_sequences},
2288
2345
  }
2289
2346
  )
2290
2347
  else:
@@ -2293,7 +2350,7 @@ class WorkflowTask:
2293
2350
  new_elements.append(
2294
2351
  {
2295
2352
  **element,
2296
- **{i["path"]: val_idx for i in para_sequences},
2353
+ **{md["path"]: val_idx for md in para_sequences},
2297
2354
  }
2298
2355
  )
2299
2356
  element_dat_idx = new_elements
@@ -2302,7 +2359,7 @@ class WorkflowTask:
2302
2359
  return element_dat_idx
2303
2360
 
2304
2361
  @TimeIt.decorator
2305
- def initialise_EARs(self, iter_IDs: Optional[List[int]] = None) -> List[int]:
2362
+ def initialise_EARs(self, iter_IDs: list[int] | None = None) -> Sequence[int]:
2306
2363
  """Try to initialise any uninitialised EARs of this task."""
2307
2364
  if iter_IDs:
2308
2365
  iters = self.workflow.get_element_iterations_from_IDs(iter_IDs)
@@ -2317,16 +2374,16 @@ class WorkflowTask:
2317
2374
  # objects.
2318
2375
  iters.extend(element.iterations)
2319
2376
 
2320
- initialised = []
2377
+ initialised: list[int] = []
2321
2378
  for iter_i in iters:
2322
2379
  if not iter_i.EARs_initialised:
2323
2380
  try:
2324
- self._initialise_element_iter_EARs(iter_i)
2381
+ self.__initialise_element_iter_EARs(iter_i)
2325
2382
  initialised.append(iter_i.id_)
2326
2383
  except UnsetParameterDataError:
2327
2384
  # raised by `Action.test_rules`; cannot yet initialise EARs
2328
- self.app.logger.debug(
2329
- f"UnsetParameterDataError raised: cannot yet initialise runs."
2385
+ self._app.logger.debug(
2386
+ "UnsetParameterDataError raised: cannot yet initialise runs."
2330
2387
  )
2331
2388
  pass
2332
2389
  else:
@@ -2335,13 +2392,13 @@ class WorkflowTask:
2335
2392
  return initialised
2336
2393
 
2337
2394
  @TimeIt.decorator
2338
- def _initialise_element_iter_EARs(self, element_iter: app.ElementIteration) -> None:
2395
+ def __initialise_element_iter_EARs(self, element_iter: ElementIteration) -> None:
2339
2396
  # keys are (act_idx, EAR_idx):
2340
- all_data_idx = {}
2341
- action_runs = {}
2397
+ all_data_idx: dict[tuple[int, int], DataIndex] = {}
2398
+ action_runs: dict[tuple[int, int], dict[str, Any]] = {}
2342
2399
 
2343
2400
  # keys are parameter indices, values are EAR_IDs to update those sources to
2344
- param_src_updates = {}
2401
+ param_src_updates: dict[int, ParamSource] = {}
2345
2402
 
2346
2403
  count = 0
2347
2404
  for act_idx, action in self.template.all_schema_actions():
@@ -2355,9 +2412,9 @@ class WorkflowTask:
2355
2412
  # it previously failed)
2356
2413
  act_valid, cmds_idx = action.test_rules(element_iter=element_iter)
2357
2414
  if act_valid:
2358
- self.app.logger.info(f"All action rules evaluated to true {log_common}")
2415
+ self._app.logger.info(f"All action rules evaluated to true {log_common}")
2359
2416
  EAR_ID = self.workflow.num_EARs + count
2360
- param_source = {
2417
+ param_source: ParamSource = {
2361
2418
  "type": "EAR_output",
2362
2419
  "EAR_ID": EAR_ID,
2363
2420
  }
@@ -2374,17 +2431,19 @@ class WorkflowTask:
2374
2431
  # with EARs initialised, we can update the pre-allocated schema-level
2375
2432
  # parameters with the correct EAR reference:
2376
2433
  for i in psrc_update:
2377
- param_src_updates[i] = {"EAR_ID": EAR_ID}
2434
+ param_src_updates[cast("int", i)] = {"EAR_ID": EAR_ID}
2378
2435
  run_0 = {
2379
2436
  "elem_iter_ID": element_iter.id_,
2380
2437
  "action_idx": act_idx,
2381
2438
  "commands_idx": cmds_idx,
2382
2439
  "metadata": {},
2383
2440
  }
2384
- action_runs[(act_idx, EAR_ID)] = run_0
2441
+ action_runs[act_idx, EAR_ID] = run_0
2385
2442
  count += 1
2386
2443
  else:
2387
- self.app.logger.info(f"Some action rules evaluated to false {log_common}")
2444
+ self._app.logger.info(
2445
+ f"Some action rules evaluated to false {log_common}"
2446
+ )
2388
2447
 
2389
2448
  # `generate_data_index` can modify data index for previous actions, so only assign
2390
2449
  # this at the end:
@@ -2393,14 +2452,13 @@ class WorkflowTask:
2393
2452
  elem_iter_ID=element_iter.id_,
2394
2453
  action_idx=act_idx,
2395
2454
  commands_idx=run["commands_idx"],
2396
- data_idx=all_data_idx[(act_idx, EAR_ID_i)],
2397
- metadata={},
2455
+ data_idx=all_data_idx[act_idx, EAR_ID_i],
2398
2456
  )
2399
2457
 
2400
2458
  self.workflow._store.update_param_source(param_src_updates)
2401
2459
 
2402
2460
  @TimeIt.decorator
2403
- def _add_element_set(self, element_set):
2461
+ def _add_element_set(self, element_set: ElementSet) -> list[int]:
2404
2462
  """
2405
2463
  Returns
2406
2464
  -------
@@ -2414,7 +2472,7 @@ class WorkflowTask:
2414
2472
  # may modify element_set.input_sources:
2415
2473
  padded_elem_iters = self.ensure_input_sources(element_set)
2416
2474
 
2417
- (input_data_idx, seq_idx, src_idx) = self._make_new_elements_persistent(
2475
+ (input_data_idx, seq_idx, src_idx) = self.__make_new_elements_persistent(
2418
2476
  element_set=element_set,
2419
2477
  element_set_idx=self.num_element_sets,
2420
2478
  padded_elem_iters=padded_elem_iters,
@@ -2448,10 +2506,10 @@ class WorkflowTask:
2448
2506
  src_idx,
2449
2507
  )
2450
2508
 
2451
- iter_IDs = []
2452
- elem_IDs = []
2509
+ iter_IDs: list[int] = []
2510
+ elem_IDs: list[int] = []
2453
2511
  for elem_idx, data_idx in enumerate(element_data_idx):
2454
- schema_params = set(i for i in data_idx.keys() if len(i.split(".")) == 2)
2512
+ schema_params = set(i for i in data_idx if len(i.split(".")) == 2)
2455
2513
  elem_ID_i = self.workflow._store.add_element(
2456
2514
  task_ID=self.insert_ID,
2457
2515
  es_idx=self.num_element_sets - 1,
@@ -2471,21 +2529,72 @@ class WorkflowTask:
2471
2529
 
2472
2530
  return iter_IDs
2473
2531
 
2532
+ @overload
2533
+ def add_elements(
2534
+ self,
2535
+ *,
2536
+ base_element: Element | None = None,
2537
+ inputs: list[InputValue] | dict[str, Any] | None = None,
2538
+ input_files: list[InputFile] | None = None,
2539
+ sequences: list[ValueSequence] | None = None,
2540
+ resources: Resources = None,
2541
+ repeats: list[RepeatsDescriptor] | int | None = None,
2542
+ input_sources: dict[str, list[InputSource]] | None = None,
2543
+ nesting_order: dict[str, float] | None = None,
2544
+ element_sets: list[ElementSet] | None = None,
2545
+ sourceable_elem_iters: list[int] | None = None,
2546
+ propagate_to: (
2547
+ list[ElementPropagation]
2548
+ | Mapping[str, ElementPropagation | Mapping[str, Any]]
2549
+ | None
2550
+ ) = None,
2551
+ return_indices: Literal[True],
2552
+ ) -> list[int]:
2553
+ ...
2554
+
2555
+ @overload
2556
+ def add_elements(
2557
+ self,
2558
+ *,
2559
+ base_element: Element | None = None,
2560
+ inputs: list[InputValue] | dict[str, Any] | None = None,
2561
+ input_files: list[InputFile] | None = None,
2562
+ sequences: list[ValueSequence] | None = None,
2563
+ resources: Resources = None,
2564
+ repeats: list[RepeatsDescriptor] | int | None = None,
2565
+ input_sources: dict[str, list[InputSource]] | None = None,
2566
+ nesting_order: dict[str, float] | None = None,
2567
+ element_sets: list[ElementSet] | None = None,
2568
+ sourceable_elem_iters: list[int] | None = None,
2569
+ propagate_to: (
2570
+ list[ElementPropagation]
2571
+ | Mapping[str, ElementPropagation | Mapping[str, Any]]
2572
+ | None
2573
+ ) = None,
2574
+ return_indices: Literal[False] = False,
2575
+ ) -> None:
2576
+ ...
2577
+
2474
2578
  def add_elements(
2475
2579
  self,
2476
- base_element=None,
2477
- inputs=None,
2478
- input_files=None,
2479
- sequences=None,
2480
- resources=None,
2481
- repeats=None,
2482
- input_sources=None,
2483
- nesting_order=None,
2484
- element_sets=None,
2485
- sourceable_elem_iters=None,
2486
- propagate_to=None,
2580
+ *,
2581
+ base_element: Element | None = None,
2582
+ inputs: list[InputValue] | dict[str, Any] | None = None,
2583
+ input_files: list[InputFile] | None = None,
2584
+ sequences: list[ValueSequence] | None = None,
2585
+ resources: Resources = None,
2586
+ repeats: list[RepeatsDescriptor] | int | None = None,
2587
+ input_sources: dict[str, list[InputSource]] | None = None,
2588
+ nesting_order: dict[str, float] | None = None,
2589
+ element_sets: list[ElementSet] | None = None,
2590
+ sourceable_elem_iters: list[int] | None = None,
2591
+ propagate_to: (
2592
+ list[ElementPropagation]
2593
+ | Mapping[str, ElementPropagation | Mapping[str, Any]]
2594
+ | None
2595
+ ) = None,
2487
2596
  return_indices=False,
2488
- ):
2597
+ ) -> list[int] | None:
2489
2598
  """
2490
2599
  Add elements to this task.
2491
2600
 
@@ -2502,11 +2611,11 @@ class WorkflowTask:
2502
2611
  default.
2503
2612
 
2504
2613
  """
2505
- propagate_to = self.app.ElementPropagation._prepare_propagate_to_dict(
2614
+ real_propagate_to = self._app.ElementPropagation._prepare_propagate_to_dict(
2506
2615
  propagate_to, self.workflow
2507
2616
  )
2508
2617
  with self.workflow.batch_update():
2509
- return self._add_elements(
2618
+ indices = self._add_elements(
2510
2619
  base_element=base_element,
2511
2620
  inputs=inputs,
2512
2621
  input_files=input_files,
@@ -2517,27 +2626,37 @@ class WorkflowTask:
2517
2626
  nesting_order=nesting_order,
2518
2627
  element_sets=element_sets,
2519
2628
  sourceable_elem_iters=sourceable_elem_iters,
2520
- propagate_to=propagate_to,
2521
- return_indices=return_indices,
2629
+ propagate_to=real_propagate_to,
2522
2630
  )
2631
+ return indices if return_indices else None
2523
2632
 
2524
2633
  @TimeIt.decorator
2525
2634
  def _add_elements(
2526
2635
  self,
2527
- base_element=None,
2528
- inputs=None,
2529
- input_files=None,
2530
- sequences=None,
2531
- resources=None,
2532
- repeats=None,
2533
- input_sources=None,
2534
- nesting_order=None,
2535
- element_sets=None,
2536
- sourceable_elem_iters=None,
2537
- propagate_to: Dict[str, app.ElementPropagation] = None,
2538
- return_indices: bool = False,
2539
- ):
2540
- """Add more elements to this task."""
2636
+ *,
2637
+ base_element: Element | None = None,
2638
+ inputs: list[InputValue] | dict[str, Any] | None = None,
2639
+ input_files: list[InputFile] | None = None,
2640
+ sequences: list[ValueSequence] | None = None,
2641
+ resources: Resources = None,
2642
+ repeats: list[RepeatsDescriptor] | int | None = None,
2643
+ input_sources: dict[str, list[InputSource]] | None = None,
2644
+ nesting_order: dict[str, float] | None = None,
2645
+ element_sets: list[ElementSet] | None = None,
2646
+ sourceable_elem_iters: list[int] | None = None,
2647
+ propagate_to: dict[str, ElementPropagation],
2648
+ ) -> list[int] | None:
2649
+ """Add more elements to this task.
2650
+
2651
+ Parameters
2652
+ ----------
2653
+ sourceable_elem_iters : list[int]
2654
+ If specified, a list of global element iteration indices from which inputs
2655
+ may be sourced. If not specified, all workflow element iterations are
2656
+ considered sourceable.
2657
+ propagate_to : dict[str, ElementPropagation]
2658
+ Propagate the new elements downstream to the specified tasks.
2659
+ """
2541
2660
 
2542
2661
  if base_element is not None:
2543
2662
  if base_element.task is not self:
@@ -2546,7 +2665,7 @@ class WorkflowTask:
2546
2665
  inputs = inputs or b_inputs
2547
2666
  resources = resources or b_resources
2548
2667
 
2549
- element_sets = self.app.ElementSet.ensure_element_sets(
2668
+ element_sets = self._app.ElementSet.ensure_element_sets(
2550
2669
  inputs=inputs,
2551
2670
  input_files=input_files,
2552
2671
  sequences=sequences,
@@ -2558,24 +2677,22 @@ class WorkflowTask:
2558
2677
  sourceable_elem_iters=sourceable_elem_iters,
2559
2678
  )
2560
2679
 
2561
- elem_idx = []
2680
+ elem_idx: list[int] = []
2562
2681
  for elem_set_i in element_sets:
2563
- # copy:
2564
- elem_set_i = elem_set_i.prepare_persistent_copy()
2682
+ # copy and add the new element set:
2683
+ elem_idx.extend(self._add_element_set(elem_set_i.prepare_persistent_copy()))
2565
2684
 
2566
- # add the new element set:
2567
- elem_idx += self._add_element_set(elem_set_i)
2685
+ if not propagate_to:
2686
+ return elem_idx
2568
2687
 
2569
2688
  for task in self.get_dependent_tasks(as_objects=True):
2570
- elem_prop = propagate_to.get(task.unique_name)
2571
- if elem_prop is None:
2689
+ if (elem_prop := propagate_to.get(task.unique_name)) is None:
2572
2690
  continue
2573
2691
 
2574
- task_dep_names = [
2575
- i.unique_name
2576
- for i in elem_prop.element_set.get_task_dependencies(as_objects=True)
2577
- ]
2578
- if self.unique_name not in task_dep_names:
2692
+ if all(
2693
+ self.unique_name != task.unique_name
2694
+ for task in elem_prop.element_set.get_task_dependencies(as_objects=True)
2695
+ ):
2579
2696
  # TODO: why can't we just do
2580
2697
  # `if self in not elem_propagate.element_set.task_dependencies:`?
2581
2698
  continue
@@ -2584,11 +2701,11 @@ class WorkflowTask:
2584
2701
  # Assume for now we use a single base element set.
2585
2702
  # Later, allow combining multiple element sets.
2586
2703
  src_elem_iters = elem_idx + [
2587
- j for i in element_sets for j in i.sourceable_elem_iters or []
2704
+ j for el_set in element_sets for j in el_set.sourceable_elem_iters or ()
2588
2705
  ]
2589
2706
 
2590
2707
  # note we must pass `resources` as a list since it is already persistent:
2591
- elem_set_i = self.app.ElementSet(
2708
+ elem_set_i = self._app.ElementSet(
2592
2709
  inputs=elem_prop.element_set.inputs,
2593
2710
  input_files=elem_prop.element_set.input_files,
2594
2711
  sequences=elem_prop.element_set.sequences,
@@ -2602,40 +2719,57 @@ class WorkflowTask:
2602
2719
  del propagate_to[task.unique_name]
2603
2720
  prop_elem_idx = task._add_elements(
2604
2721
  element_sets=[elem_set_i],
2605
- return_indices=True,
2606
2722
  propagate_to=propagate_to,
2607
2723
  )
2608
- elem_idx.extend(prop_elem_idx)
2724
+ elem_idx.extend(prop_elem_idx or ())
2609
2725
 
2610
- if return_indices:
2611
- return elem_idx
2726
+ return elem_idx
2727
+
2728
+ @overload
2729
+ def get_element_dependencies(
2730
+ self,
2731
+ as_objects: Literal[False] = False,
2732
+ ) -> set[int]:
2733
+ ...
2734
+
2735
+ @overload
2736
+ def get_element_dependencies(
2737
+ self,
2738
+ as_objects: Literal[True],
2739
+ ) -> list[Element]:
2740
+ ...
2612
2741
 
2613
2742
  def get_element_dependencies(
2614
2743
  self,
2615
2744
  as_objects: bool = False,
2616
- ) -> List[Union[int, app.Element]]:
2745
+ ) -> set[int] | list[Element]:
2617
2746
  """Get elements from upstream tasks that this task depends on."""
2618
2747
 
2619
- deps = []
2620
- for element in self.elements[:]:
2748
+ deps: set[int] = set()
2749
+ for element in self.elements:
2621
2750
  for iter_i in element.iterations:
2622
- for dep_elem_i in iter_i.get_element_dependencies(as_objects=True):
2623
- if (
2624
- dep_elem_i.task.insert_ID != self.insert_ID
2625
- and dep_elem_i not in deps
2626
- ):
2627
- deps.append(dep_elem_i.id_)
2751
+ deps.update(
2752
+ dep_elem_i.id_
2753
+ for dep_elem_i in iter_i.get_element_dependencies(as_objects=True)
2754
+ if dep_elem_i.task.insert_ID != self.insert_ID
2755
+ )
2628
2756
 
2629
- deps = sorted(deps)
2630
2757
  if as_objects:
2631
- deps = self.workflow.get_elements_from_IDs(deps)
2632
-
2758
+ return self.workflow.get_elements_from_IDs(sorted(deps))
2633
2759
  return deps
2634
2760
 
2761
+ @overload
2762
+ def get_task_dependencies(self, as_objects: Literal[False] = False) -> set[int]:
2763
+ ...
2764
+
2765
+ @overload
2766
+ def get_task_dependencies(self, as_objects: Literal[True]) -> list[WorkflowTask]:
2767
+ ...
2768
+
2635
2769
  def get_task_dependencies(
2636
2770
  self,
2637
2771
  as_objects: bool = False,
2638
- ) -> List[Union[int, app.WorkflowTask]]:
2772
+ ) -> set[int] | list[WorkflowTask]:
2639
2773
  """Get tasks (insert ID or WorkflowTask objects) that this task depends on.
2640
2774
 
2641
2775
  Dependencies may come from either elements from upstream tasks, or from locally
@@ -2645,87 +2779,106 @@ class WorkflowTask:
2645
2779
  # new "task_iteration" input source type, which may take precedence over any
2646
2780
  # other input source types.
2647
2781
 
2648
- deps = []
2782
+ deps: set[int] = set()
2649
2783
  for element_set in self.template.element_sets:
2650
2784
  for sources in element_set.input_sources.values():
2651
- for src in sources:
2785
+ deps.update(
2786
+ src.task_ref
2787
+ for src in sources
2652
2788
  if (
2653
2789
  src.source_type is InputSourceType.TASK
2654
- and src.task_ref not in deps
2655
- ):
2656
- deps.append(src.task_ref)
2790
+ and src.task_ref is not None
2791
+ )
2792
+ )
2657
2793
 
2658
- deps = sorted(deps)
2659
2794
  if as_objects:
2660
- deps = [self.workflow.tasks.get(insert_ID=i) for i in deps]
2661
-
2795
+ return [self.workflow.tasks.get(insert_ID=id_) for id_ in sorted(deps)]
2662
2796
  return deps
2663
2797
 
2798
+ @overload
2799
+ def get_dependent_elements(
2800
+ self,
2801
+ as_objects: Literal[False] = False,
2802
+ ) -> set[int]:
2803
+ ...
2804
+
2805
+ @overload
2806
+ def get_dependent_elements(self, as_objects: Literal[True]) -> list[Element]:
2807
+ ...
2808
+
2664
2809
  def get_dependent_elements(
2665
2810
  self,
2666
2811
  as_objects: bool = False,
2667
- ) -> List[Union[int, app.Element]]:
2812
+ ) -> set[int] | list[Element]:
2668
2813
  """Get elements from downstream tasks that depend on this task."""
2669
- deps = []
2814
+ deps: set[int] = set()
2670
2815
  for task in self.downstream_tasks:
2671
- for element in task.elements[:]:
2672
- for iter_i in element.iterations:
2673
- for dep_i in iter_i.get_task_dependencies(as_objects=False):
2674
- if dep_i == self.insert_ID and element.id_ not in deps:
2675
- deps.append(element.id_)
2816
+ deps.update(
2817
+ element.id_
2818
+ for element in task.elements
2819
+ if any(
2820
+ self.insert_ID in iter_i.get_task_dependencies()
2821
+ for iter_i in element.iterations
2822
+ )
2823
+ )
2676
2824
 
2677
- deps = sorted(deps)
2678
2825
  if as_objects:
2679
- deps = self.workflow.get_elements_from_IDs(deps)
2680
-
2826
+ return self.workflow.get_elements_from_IDs(sorted(deps))
2681
2827
  return deps
2682
2828
 
2829
+ @overload
2830
+ def get_dependent_tasks(self, as_objects: Literal[False] = False) -> set[int]:
2831
+ ...
2832
+
2833
+ @overload
2834
+ def get_dependent_tasks(self, as_objects: Literal[True]) -> list[WorkflowTask]:
2835
+ ...
2836
+
2683
2837
  @TimeIt.decorator
2684
2838
  def get_dependent_tasks(
2685
2839
  self,
2686
2840
  as_objects: bool = False,
2687
- ) -> List[Union[int, app.WorkflowTask]]:
2841
+ ) -> set[int] | list[WorkflowTask]:
2688
2842
  """Get tasks (insert ID or WorkflowTask objects) that depends on this task."""
2689
2843
 
2690
2844
  # TODO: this method might become insufficient if/when we start considering a
2691
2845
  # new "task_iteration" input source type, which may take precedence over any
2692
2846
  # other input source types.
2693
2847
 
2694
- deps = []
2848
+ deps: set[int] = set()
2695
2849
  for task in self.downstream_tasks:
2696
- for element_set in task.template.element_sets:
2697
- for sources in element_set.input_sources.values():
2698
- for src in sources:
2699
- if (
2700
- src.source_type is InputSourceType.TASK
2701
- and src.task_ref == self.insert_ID
2702
- and task.insert_ID not in deps
2703
- ):
2704
- deps.append(task.insert_ID)
2705
- deps = sorted(deps)
2850
+ if task.insert_ID not in deps and any(
2851
+ src.source_type is InputSourceType.TASK and src.task_ref == self.insert_ID
2852
+ for element_set in task.template.element_sets
2853
+ for sources in element_set.input_sources.values()
2854
+ for src in sources
2855
+ ):
2856
+ deps.add(task.insert_ID)
2706
2857
  if as_objects:
2707
- deps = [self.workflow.tasks.get(insert_ID=i) for i in deps]
2858
+ return [self.workflow.tasks.get(insert_ID=id_) for id_ in sorted(deps)]
2708
2859
  return deps
2709
2860
 
2710
2861
  @property
2711
- def inputs(self):
2862
+ def inputs(self) -> TaskInputParameters:
2712
2863
  """
2713
2864
  Inputs to this task.
2714
2865
  """
2715
- return self.app.TaskInputParameters(self)
2866
+ return self._app.TaskInputParameters(self)
2716
2867
 
2717
2868
  @property
2718
- def outputs(self):
2869
+ def outputs(self) -> TaskOutputParameters:
2719
2870
  """
2720
2871
  Outputs from this task.
2721
2872
  """
2722
- return self.app.TaskOutputParameters(self)
2873
+ return self._app.TaskOutputParameters(self)
2723
2874
 
2724
- def get(self, path, raise_on_missing=False, default=None):
2875
+ def get(
2876
+ self, path: str, *, raise_on_missing=False, default: Any | None = None
2877
+ ) -> Parameters:
2725
2878
  """
2726
2879
  Get a parameter known to this task by its path.
2727
2880
  """
2728
- return self.app.Parameters(
2881
+ return self._app.Parameters(
2729
2882
  self,
2730
2883
  path=path,
2731
2884
  return_element_parameters=False,
@@ -2733,12 +2886,15 @@ class WorkflowTask:
2733
2886
  default=default,
2734
2887
  )
2735
2888
 
2736
- def _paths_to_PV_classes(self, paths: Iterable[str]) -> Dict:
2889
+ def _paths_to_PV_classes(self, *paths: str | None) -> dict[str, type[ParameterValue]]:
2737
2890
  """Return a dict mapping dot-delimited string input paths to `ParameterValue`
2738
2891
  classes."""
2739
2892
 
2740
- params = {}
2893
+ params: dict[str, type[ParameterValue]] = {}
2741
2894
  for path in paths:
2895
+ if not path:
2896
+ # Skip None/empty
2897
+ continue
2742
2898
  path_split = path.split(".")
2743
2899
  if len(path_split) == 1 or path_split[0] not in ("inputs", "outputs"):
2744
2900
  continue
@@ -2751,22 +2907,22 @@ class WorkflowTask:
2751
2907
  path_1, _ = split_param_label(
2752
2908
  path_split[1]
2753
2909
  ) # remove label if present
2754
- for i in self.template.schemas:
2755
- for j in i.inputs:
2756
- if j.parameter.typ == path_1 and j.parameter._value_class:
2757
- params[key_0] = j.parameter._value_class
2910
+ for schema in self.template.schemas:
2911
+ for inp in schema.inputs:
2912
+ if inp.parameter.typ == path_1 and inp.parameter._value_class:
2913
+ params[key_0] = inp.parameter._value_class
2758
2914
 
2759
2915
  elif path_split[0] == "outputs":
2760
- for i in self.template.schemas:
2761
- for j in i.outputs:
2916
+ for schema in self.template.schemas:
2917
+ for out in schema.outputs:
2762
2918
  if (
2763
- j.parameter.typ == path_split[1]
2764
- and j.parameter._value_class
2919
+ out.parameter.typ == path_split[1]
2920
+ and out.parameter._value_class
2765
2921
  ):
2766
- params[key_0] = j.parameter._value_class
2922
+ params[key_0] = out.parameter._value_class
2767
2923
 
2768
2924
  if path_split[2:]:
2769
- pv_classes = ParameterValue.__subclasses__()
2925
+ pv_classes = {cls._typ: cls for cls in ParameterValue.__subclasses__()}
2770
2926
 
2771
2927
  # now proceed by searching for sub-parameters in each ParameterValue
2772
2928
  # sub-class:
@@ -2776,306 +2932,332 @@ class WorkflowTask:
2776
2932
  key_i = ".".join(child)
2777
2933
  if key_i in params:
2778
2934
  continue
2779
- parent_param = params.get(".".join(parent))
2780
- if parent_param:
2935
+ if parent_param := params.get(".".join(parent)):
2781
2936
  for attr_name, sub_type in parent_param._sub_parameters.items():
2782
2937
  if part_i == attr_name:
2783
2938
  # find the class with this `typ` attribute:
2784
- for cls_i in pv_classes:
2785
- if cls_i._typ == sub_type:
2786
- params[key_i] = cls_i
2787
- break
2939
+ if cls := pv_classes.get(sub_type):
2940
+ params[key_i] = cls
2788
2941
 
2789
2942
  return params
2790
2943
 
2791
- @TimeIt.decorator
2792
- def _get_merged_parameter_data(
2793
- self,
2794
- data_index,
2795
- path=None,
2796
- raise_on_missing=False,
2797
- raise_on_unset=False,
2798
- default: Any = None,
2799
- ):
2800
- """Get element data from the persistent store."""
2801
-
2802
- def _get_relevant_paths(
2803
- data_index: Dict, path: List[str], children_of: str = None
2804
- ):
2805
- relevant_paths = {}
2806
- # first extract out relevant paths in `data_index`:
2807
- for path_i in data_index:
2808
- path_i_split = path_i.split(".")
2944
+ @staticmethod
2945
+ def __get_relevant_paths(
2946
+ data_index: Mapping[str, Any], path: list[str], children_of: str | None = None
2947
+ ) -> Mapping[str, RelevantPath]:
2948
+ relevant_paths: dict[str, RelevantPath] = {}
2949
+ # first extract out relevant paths in `data_index`:
2950
+ for path_i in data_index:
2951
+ path_i_split = path_i.split(".")
2952
+ try:
2953
+ rel_path = get_relative_path(path, path_i_split)
2954
+ relevant_paths[path_i] = {"type": "parent", "relative_path": rel_path}
2955
+ except ValueError:
2809
2956
  try:
2810
- rel_path = get_relative_path(path, path_i_split)
2811
- relevant_paths[path_i] = {"type": "parent", "relative_path": rel_path}
2957
+ update_path = get_relative_path(path_i_split, path)
2958
+ relevant_paths[path_i] = {
2959
+ "type": "update",
2960
+ "update_path": update_path,
2961
+ }
2812
2962
  except ValueError:
2813
- try:
2814
- update_path = get_relative_path(path_i_split, path)
2815
- relevant_paths[path_i] = {
2816
- "type": "update",
2817
- "update_path": update_path,
2818
- }
2819
- except ValueError:
2820
- # no intersection between paths
2821
- if children_of and path_i.startswith(children_of):
2822
- relevant_paths[path_i] = {"type": "sibling"}
2823
- continue
2824
-
2825
- return relevant_paths
2826
-
2827
- def _get_relevant_data(relevant_data_idx: Dict, raise_on_unset: bool, path: str):
2828
- relevant_data = {}
2829
- for path_i, data_idx_i in relevant_data_idx.items():
2830
- is_multi = isinstance(data_idx_i, list)
2831
- if not is_multi:
2832
- data_idx_i = [data_idx_i]
2833
-
2834
- data_i = []
2835
- methods_i = []
2836
- is_param_set_i = []
2837
- for data_idx_ij in data_idx_i:
2838
- meth_i = None
2839
- is_set_i = True
2840
- if path_i[0] == "repeats":
2841
- # data is an integer repeats index, rather than a parameter ID:
2842
- data_j = data_idx_ij
2843
- else:
2844
- param_j = self.workflow.get_parameter(data_idx_ij)
2845
- is_set_i = param_j.is_set
2846
- if param_j.file:
2847
- if param_j.file["store_contents"]:
2848
- data_j = Path(self.workflow.path) / param_j.file["path"]
2849
- else:
2850
- data_j = Path(param_j.file["path"])
2851
- data_j = data_j.as_posix()
2852
- else:
2853
- meth_i = param_j.source.get("value_class_method")
2854
- if param_j.is_pending:
2855
- # if pending, we need to convert `ParameterValue` objects
2856
- # to their dict representation, so they can be merged with
2857
- # other data:
2858
- try:
2859
- data_j = param_j.data.to_dict()
2860
- except AttributeError:
2861
- data_j = param_j.data
2862
- else:
2863
- # if not pending, data will be the result of an encode-
2864
- # decode cycle, and it will not be initialised as an
2865
- # object if the parameter is associated with a
2866
- # `ParameterValue` class.
2867
- data_j = param_j.data
2868
- if raise_on_unset and not is_set_i:
2869
- raise UnsetParameterDataError(
2870
- f"Element data path {path!r} resolves to unset data for "
2871
- f"(at least) data-index path: {path_i!r}."
2872
- )
2873
- if is_multi:
2874
- data_i.append(data_j)
2875
- methods_i.append(meth_i)
2876
- is_param_set_i.append(is_set_i)
2877
- else:
2878
- data_i = data_j
2879
- methods_i = meth_i
2880
- is_param_set_i = is_set_i
2963
+ # no intersection between paths
2964
+ if children_of and path_i.startswith(children_of):
2965
+ relevant_paths[path_i] = {"type": "sibling"}
2966
+ continue
2881
2967
 
2968
+ return relevant_paths
2969
+
2970
+ def __get_relevant_data_item(
2971
+ self, path: str | None, path_i: str, data_idx_ij: int, raise_on_unset: bool
2972
+ ) -> tuple[Any, bool, str | None]:
2973
+ if path_i.startswith("repeats."):
2974
+ # data is an integer repeats index, rather than a parameter ID:
2975
+ return data_idx_ij, True, None
2976
+
2977
+ meth_i: str | None = None
2978
+ data_j: Any
2979
+ param_j = self.workflow.get_parameter(data_idx_ij)
2980
+ is_set_i = param_j.is_set
2981
+ if param_j.file:
2982
+ if param_j.file["store_contents"]:
2983
+ file_j = Path(self.workflow.path) / param_j.file["path"]
2984
+ else:
2985
+ file_j = Path(param_j.file["path"])
2986
+ data_j = file_j.as_posix()
2987
+ else:
2988
+ meth_i = param_j.source.get("value_class_method")
2989
+ if param_j.is_pending:
2990
+ # if pending, we need to convert `ParameterValue` objects
2991
+ # to their dict representation, so they can be merged with
2992
+ # other data:
2993
+ try:
2994
+ data_j = cast("ParameterValue", param_j.data).to_dict()
2995
+ except AttributeError:
2996
+ data_j = param_j.data
2997
+ else:
2998
+ # if not pending, data will be the result of an encode-
2999
+ # decode cycle, and it will not be initialised as an
3000
+ # object if the parameter is associated with a
3001
+ # `ParameterValue` class.
3002
+ data_j = param_j.data
3003
+ if raise_on_unset and not is_set_i:
3004
+ raise UnsetParameterDataError(path, path_i)
3005
+ return data_j, is_set_i, meth_i
3006
+
3007
+ def __get_relevant_data(
3008
+ self,
3009
+ relevant_data_idx: Mapping[str, list[int] | int],
3010
+ raise_on_unset: bool,
3011
+ path: str | None,
3012
+ ) -> Mapping[str, RelevantData]:
3013
+ relevant_data: dict[str, RelevantData] = {}
3014
+ for path_i, data_idx_i in relevant_data_idx.items():
3015
+ if not isinstance(data_idx_i, list):
3016
+ data, is_set, meth = self.__get_relevant_data_item(
3017
+ path, path_i, data_idx_i, raise_on_unset
3018
+ )
2882
3019
  relevant_data[path_i] = {
2883
- "data": data_i,
2884
- "value_class_method": methods_i,
2885
- "is_set": is_param_set_i,
2886
- "is_multi": is_multi,
2887
- }
2888
- if not raise_on_unset:
2889
- to_remove = []
2890
- for key, dat_info in relevant_data.items():
2891
- if not dat_info["is_set"] and ((path and path in key) or not path):
2892
- # remove sub-paths, as they cannot be merged with this parent
2893
- to_remove.extend(
2894
- k for k in relevant_data if k != key and k.startswith(key)
2895
- )
2896
- relevant_data = {
2897
- k: v for k, v in relevant_data.items() if k not in to_remove
3020
+ "data": data,
3021
+ "value_class_method": meth,
3022
+ "is_set": is_set,
3023
+ "is_multi": False,
2898
3024
  }
3025
+ continue
2899
3026
 
2900
- return relevant_data
3027
+ data_i: list[Any] = []
3028
+ methods_i: list[str | None] = []
3029
+ is_param_set_i: list[bool] = []
3030
+ for data_idx_ij in data_idx_i:
3031
+ data_j, is_set_i, meth_i = self.__get_relevant_data_item(
3032
+ path, path_i, data_idx_ij, raise_on_unset
3033
+ )
3034
+ data_i.append(data_j)
3035
+ methods_i.append(meth_i)
3036
+ is_param_set_i.append(is_set_i)
3037
+
3038
+ relevant_data[path_i] = {
3039
+ "data": data_i,
3040
+ "value_class_method": methods_i,
3041
+ "is_set": is_param_set_i,
3042
+ "is_multi": True,
3043
+ }
3044
+ if not raise_on_unset:
3045
+ to_remove: set[str] = set()
3046
+ for key, dat_info in relevant_data.items():
3047
+ if not dat_info["is_set"] and (not path or path in key):
3048
+ # remove sub-paths, as they cannot be merged with this parent
3049
+ prefix = f"{key}."
3050
+ to_remove.update(k for k in relevant_data if k.startswith(prefix))
3051
+ for key in to_remove:
3052
+ relevant_data.pop(key, None)
3053
+
3054
+ return relevant_data
2901
3055
 
2902
- def _merge_relevant_data(
2903
- relevant_data, relevant_paths, PV_classes, path: str, raise_on_missing: bool
2904
- ):
2905
- current_val = None
2906
- assigned_from_parent = False
2907
- val_cls_method = None
2908
- path_is_multi = False
2909
- path_is_set = False
2910
- all_multi_len = None
2911
- for path_i, data_info_i in relevant_data.items():
2912
- data_i = data_info_i["data"]
2913
- if path_i == path:
2914
- val_cls_method = data_info_i["value_class_method"]
2915
- path_is_multi = data_info_i["is_multi"]
2916
- path_is_set = data_info_i["is_set"]
2917
-
2918
- if data_info_i["is_multi"]:
2919
- if all_multi_len:
2920
- if len(data_i) != all_multi_len:
2921
- raise RuntimeError(
2922
- f"Cannot merge group values of different lengths."
2923
- )
2924
- else:
2925
- # keep track of group lengths, only merge equal-length groups;
2926
- all_multi_len = len(data_i)
2927
-
2928
- path_info = relevant_paths[path_i]
2929
- path_type = path_info["type"]
2930
- if path_type == "parent":
2931
- try:
2932
- if data_info_i["is_multi"]:
2933
- current_val = [
2934
- get_in_container(
2935
- i,
2936
- path_info["relative_path"],
2937
- cast_indices=True,
2938
- )
2939
- for i in data_i
2940
- ]
2941
- path_is_multi = True
2942
- path_is_set = data_info_i["is_set"]
2943
- val_cls_method = data_info_i["value_class_method"]
2944
- else:
2945
- current_val = get_in_container(
2946
- data_i,
3056
+ @classmethod
3057
+ def __merge_relevant_data(
3058
+ cls,
3059
+ relevant_data: Mapping[str, RelevantData],
3060
+ relevant_paths: Mapping[str, RelevantPath],
3061
+ PV_classes,
3062
+ path: str | None,
3063
+ raise_on_missing: bool,
3064
+ ):
3065
+ current_val: list | dict | Any | None = None
3066
+ assigned_from_parent = False
3067
+ val_cls_method: str | None | list[str | None] = None
3068
+ path_is_multi = False
3069
+ path_is_set: bool | list[bool] = False
3070
+ all_multi_len: int | None = None
3071
+ for path_i, data_info_i in relevant_data.items():
3072
+ data_i = data_info_i["data"]
3073
+ if path_i == path:
3074
+ val_cls_method = data_info_i["value_class_method"]
3075
+ path_is_multi = data_info_i["is_multi"]
3076
+ path_is_set = data_info_i["is_set"]
3077
+
3078
+ if data_info_i["is_multi"]:
3079
+ if all_multi_len:
3080
+ if len(data_i) != all_multi_len:
3081
+ raise RuntimeError(
3082
+ "Cannot merge group values of different lengths."
3083
+ )
3084
+ else:
3085
+ # keep track of group lengths, only merge equal-length groups;
3086
+ all_multi_len = len(data_i)
3087
+
3088
+ path_info = relevant_paths[path_i]
3089
+ if path_info["type"] == "parent":
3090
+ try:
3091
+ if data_info_i["is_multi"]:
3092
+ current_val = [
3093
+ get_in_container(
3094
+ item,
2947
3095
  path_info["relative_path"],
2948
3096
  cast_indices=True,
2949
3097
  )
2950
- except ContainerKeyError as err:
2951
- if path_i in PV_classes:
2952
- err_path = ".".join([path_i] + err.path[:-1])
2953
- raise MayNeedObjectError(path=err_path)
2954
- continue
2955
- except (IndexError, ValueError) as err:
2956
- if raise_on_missing:
2957
- raise err
2958
- continue
2959
- else:
2960
- assigned_from_parent = True
2961
- elif path_type == "update":
2962
- current_val = current_val or {}
2963
- if all_multi_len:
2964
- if len(path_i.split(".")) == 2:
2965
- # groups can only be "created" at the parameter level
2966
- set_in_container(
2967
- cont=current_val,
2968
- path=path_info["update_path"],
2969
- value=data_i,
2970
- ensure_path=True,
2971
- cast_indices=True,
2972
- )
2973
- else:
2974
- # update group
2975
- update_path = path_info["update_path"]
2976
- if len(update_path) > 1:
2977
- for idx, j in enumerate(data_i):
2978
- path_ij = update_path[0:1] + [idx] + update_path[1:]
2979
- set_in_container(
2980
- cont=current_val,
2981
- path=path_ij,
2982
- value=j,
2983
- ensure_path=True,
2984
- cast_indices=True,
2985
- )
2986
- else:
2987
- for idx, (i, j) in enumerate(zip(current_val, data_i)):
2988
- set_in_container(
2989
- cont=i,
2990
- path=update_path,
2991
- value=j,
2992
- ensure_path=True,
2993
- cast_indices=True,
2994
- )
2995
-
3098
+ for item in data_i
3099
+ ]
3100
+ path_is_multi = True
3101
+ path_is_set = data_info_i["is_set"]
3102
+ val_cls_method = data_info_i["value_class_method"]
2996
3103
  else:
2997
- set_in_container(
2998
- current_val,
2999
- path_info["update_path"],
3104
+ current_val = get_in_container(
3000
3105
  data_i,
3106
+ path_info["relative_path"],
3107
+ cast_indices=True,
3108
+ )
3109
+ except ContainerKeyError as err:
3110
+ if path_i in PV_classes:
3111
+ raise MayNeedObjectError(path=".".join([path_i, *err.path[:-1]]))
3112
+ continue
3113
+ except (IndexError, ValueError) as err:
3114
+ if raise_on_missing:
3115
+ raise err
3116
+ continue
3117
+ else:
3118
+ assigned_from_parent = True
3119
+ elif path_info["type"] == "update":
3120
+ current_val = current_val or {}
3121
+ if all_multi_len:
3122
+ if len(path_i.split(".")) == 2:
3123
+ # groups can only be "created" at the parameter level
3124
+ set_in_container(
3125
+ cont=current_val,
3126
+ path=path_info["update_path"],
3127
+ value=data_i,
3001
3128
  ensure_path=True,
3002
3129
  cast_indices=True,
3003
3130
  )
3004
- if path in PV_classes:
3005
- if path not in relevant_data:
3006
- # requested data must be a sub-path of relevant data, so we can assume
3007
- # path is set (if the parent was not set the sub-paths would be
3008
- # removed in `_get_relevant_data`):
3009
- path_is_set = path_is_set or True
3010
-
3011
- if not assigned_from_parent:
3012
- # search for unset parents in `relevant_data`:
3013
- path_split = path.split(".")
3014
- for parent_i_span in range(len(path_split) - 1, 1, -1):
3015
- parent_path_i = ".".join(path_split[0:parent_i_span])
3016
- relevant_par = relevant_data.get(parent_path_i)
3017
- par_is_set = relevant_par["is_set"]
3018
- if not par_is_set or any(not i for i in par_is_set):
3019
- val_cls_method = relevant_par["value_class_method"]
3020
- path_is_multi = relevant_par["is_multi"]
3021
- path_is_set = relevant_par["is_set"]
3022
- current_val = relevant_par["data"]
3023
- break
3024
-
3025
- # initialise objects
3026
- PV_cls = PV_classes[path]
3027
- if path_is_multi:
3028
- _current_val = []
3029
- for set_i, meth_i, val_i in zip(
3030
- path_is_set, val_cls_method, current_val
3031
- ):
3032
- if set_i and isinstance(val_i, dict):
3033
- method_i = getattr(PV_cls, meth_i) if meth_i else PV_cls
3034
- _cur_val_i = method_i(**val_i)
3131
+ else:
3132
+ # update group
3133
+ update_path = path_info["update_path"]
3134
+ if len(update_path) > 1:
3135
+ for idx, j in enumerate(data_i):
3136
+ set_in_container(
3137
+ cont=current_val,
3138
+ path=[*update_path[:1], idx, *update_path[1:]],
3139
+ value=j,
3140
+ ensure_path=True,
3141
+ cast_indices=True,
3142
+ )
3035
3143
  else:
3036
- _cur_val_i = None
3037
- _current_val.append(_cur_val_i)
3038
- current_val = _current_val
3039
- elif path_is_set and isinstance(current_val, dict):
3040
- method = getattr(PV_cls, val_cls_method) if val_cls_method else PV_cls
3041
- current_val = method(**current_val)
3144
+ for i, j in zip(current_val, data_i):
3145
+ set_in_container(
3146
+ cont=i,
3147
+ path=update_path,
3148
+ value=j,
3149
+ ensure_path=True,
3150
+ cast_indices=True,
3151
+ )
3042
3152
 
3043
- return current_val, all_multi_len
3153
+ else:
3154
+ set_in_container(
3155
+ current_val,
3156
+ path_info["update_path"],
3157
+ data_i,
3158
+ ensure_path=True,
3159
+ cast_indices=True,
3160
+ )
3161
+ if path in PV_classes:
3162
+ if path not in relevant_data:
3163
+ # requested data must be a sub-path of relevant data, so we can assume
3164
+ # path is set (if the parent was not set the sub-paths would be
3165
+ # removed in `__get_relevant_data`):
3166
+ path_is_set = path_is_set or True
3167
+
3168
+ if not assigned_from_parent:
3169
+ # search for unset parents in `relevant_data`:
3170
+ assert path is not None
3171
+ for parent_i_span in range(
3172
+ len(path_split := path.split(".")) - 1, 1, -1
3173
+ ):
3174
+ parent_path_i = ".".join(path_split[:parent_i_span])
3175
+ if not (relevant_par := relevant_data.get(parent_path_i)):
3176
+ continue
3177
+ if not (par_is_set := relevant_par["is_set"]) or not all(
3178
+ cast("list", par_is_set)
3179
+ ):
3180
+ val_cls_method = relevant_par["value_class_method"]
3181
+ path_is_multi = relevant_par["is_multi"]
3182
+ path_is_set = relevant_par["is_set"]
3183
+ current_val = relevant_par["data"]
3184
+ break
3185
+
3186
+ # initialise objects
3187
+ PV_cls = PV_classes[path]
3188
+ if path_is_multi:
3189
+ current_val = [
3190
+ (
3191
+ cls.__map_parameter_value(PV_cls, meth_i, val_i)
3192
+ if set_i and isinstance(val_i, dict)
3193
+ else None
3194
+ )
3195
+ for set_i, meth_i, val_i in zip(
3196
+ cast("list[bool]", path_is_set),
3197
+ cast("list[str|None]", val_cls_method),
3198
+ cast("list[Any]", current_val),
3199
+ )
3200
+ ]
3201
+ elif path_is_set and isinstance(current_val, dict):
3202
+ assert not isinstance(val_cls_method, list)
3203
+ current_val = cls.__map_parameter_value(
3204
+ PV_cls, val_cls_method, current_val
3205
+ )
3206
+
3207
+ return current_val, all_multi_len
3044
3208
 
3045
- # TODO: custom exception?
3046
- missing_err = ValueError(f"Path {path!r} does not exist in the element data.")
3209
+ @staticmethod
3210
+ def __map_parameter_value(
3211
+ PV_cls: type[ParameterValue], meth: str | None, val: dict
3212
+ ) -> Any | ParameterValue:
3213
+ if meth:
3214
+ method: Callable = getattr(PV_cls, meth)
3215
+ return method(**val)
3216
+ else:
3217
+ return PV_cls(**val)
3047
3218
 
3219
+ @TimeIt.decorator
3220
+ def _get_merged_parameter_data(
3221
+ self,
3222
+ data_index: Mapping[str, list[int] | int],
3223
+ path: str | None = None,
3224
+ *,
3225
+ raise_on_missing: bool = False,
3226
+ raise_on_unset: bool = False,
3227
+ default: Any | None = None,
3228
+ ):
3229
+ """Get element data from the persistent store."""
3048
3230
  path_split = [] if not path else path.split(".")
3049
3231
 
3050
- relevant_paths = _get_relevant_paths(data_index, path_split)
3051
- if not relevant_paths:
3232
+ if not (relevant_paths := self.__get_relevant_paths(data_index, path_split)):
3052
3233
  if raise_on_missing:
3053
- raise missing_err
3234
+ # TODO: custom exception?
3235
+ raise ValueError(f"Path {path!r} does not exist in the element data.")
3054
3236
  return default
3055
3237
 
3056
3238
  relevant_data_idx = {k: v for k, v in data_index.items() if k in relevant_paths}
3057
- PV_cls_paths = list(relevant_paths.keys()) + ([path] if path else [])
3058
- PV_classes = self._paths_to_PV_classes(PV_cls_paths)
3059
- relevant_data = _get_relevant_data(relevant_data_idx, raise_on_unset, path)
3239
+ PV_classes = self._paths_to_PV_classes(*relevant_paths, path)
3240
+ relevant_data = self.__get_relevant_data(relevant_data_idx, raise_on_unset, path)
3060
3241
 
3061
3242
  current_val = None
3062
3243
  is_assigned = False
3063
3244
  try:
3064
- current_val, _ = _merge_relevant_data(
3245
+ current_val, _ = self.__merge_relevant_data(
3065
3246
  relevant_data, relevant_paths, PV_classes, path, raise_on_missing
3066
3247
  )
3067
3248
  except MayNeedObjectError as err:
3068
3249
  path_to_init = err.path
3069
3250
  path_to_init_split = path_to_init.split(".")
3070
- relevant_paths = _get_relevant_paths(data_index, path_to_init_split)
3071
- PV_cls_paths = list(relevant_paths.keys()) + [path_to_init]
3072
- PV_classes = self._paths_to_PV_classes(PV_cls_paths)
3251
+ relevant_paths = self.__get_relevant_paths(data_index, path_to_init_split)
3252
+ PV_classes = self._paths_to_PV_classes(*relevant_paths, path_to_init)
3073
3253
  relevant_data_idx = {
3074
3254
  k: v for k, v in data_index.items() if k in relevant_paths
3075
3255
  }
3076
- relevant_data = _get_relevant_data(relevant_data_idx, raise_on_unset, path)
3256
+ relevant_data = self.__get_relevant_data(
3257
+ relevant_data_idx, raise_on_unset, path
3258
+ )
3077
3259
  # merge the parent data
3078
- current_val, group_len = _merge_relevant_data(
3260
+ current_val, group_len = self.__merge_relevant_data(
3079
3261
  relevant_data, relevant_paths, PV_classes, path_to_init, raise_on_missing
3080
3262
  )
3081
3263
  # try to retrieve attributes via the initialised object:
@@ -3084,12 +3266,12 @@ class WorkflowTask:
3084
3266
  if group_len:
3085
3267
  current_val = [
3086
3268
  get_in_container(
3087
- cont=i,
3269
+ cont=item,
3088
3270
  path=rel_path_split,
3089
3271
  cast_indices=True,
3090
3272
  allow_getattr=True,
3091
3273
  )
3092
- for i in current_val
3274
+ for item in current_val
3093
3275
  ]
3094
3276
  else:
3095
3277
  current_val = get_in_container(
@@ -3110,7 +3292,8 @@ class WorkflowTask:
3110
3292
 
3111
3293
  if not is_assigned:
3112
3294
  if raise_on_missing:
3113
- raise missing_err
3295
+ # TODO: custom exception?
3296
+ raise ValueError(f"Path {path!r} does not exist in the element data.")
3114
3297
  current_val = default
3115
3298
 
3116
3299
  return current_val
@@ -3128,7 +3311,7 @@ class Elements:
3128
3311
 
3129
3312
  __slots__ = ("_task",)
3130
3313
 
3131
- def __init__(self, task: app.WorkflowTask):
3314
+ def __init__(self, task: WorkflowTask):
3132
3315
  self._task = task
3133
3316
 
3134
3317
  # TODO: cache Element objects
@@ -3140,44 +3323,57 @@ class Elements:
3140
3323
  )
3141
3324
 
3142
3325
  @property
3143
- def task(self):
3326
+ def task(self) -> WorkflowTask:
3144
3327
  """
3145
3328
  The task this is the elements of.
3146
3329
  """
3147
3330
  return self._task
3148
3331
 
3149
3332
  @TimeIt.decorator
3150
- def _get_selection(self, selection: Union[int, slice, List[int]]) -> List[int]:
3333
+ def __get_selection(self, selection: int | slice | list[int]) -> list[int]:
3151
3334
  """Normalise an element selection into a list of element indices."""
3152
3335
  if isinstance(selection, int):
3153
- lst = [selection]
3336
+ return [selection]
3154
3337
 
3155
3338
  elif isinstance(selection, slice):
3156
- lst = list(range(*selection.indices(self.task.num_elements)))
3339
+ return list(range(*selection.indices(self.task.num_elements)))
3157
3340
 
3158
3341
  elif isinstance(selection, list):
3159
- lst = selection
3342
+ return selection
3160
3343
  else:
3161
3344
  raise RuntimeError(
3162
3345
  f"{self.__class__.__name__} selection must be an `int`, `slice` object, "
3163
3346
  f"or list of `int`s, but received type {type(selection)}."
3164
3347
  )
3165
- return lst
3166
3348
 
3167
- def __len__(self):
3349
+ def __len__(self) -> int:
3168
3350
  return self.task.num_elements
3169
3351
 
3170
- def __iter__(self):
3171
- for i in self.task.workflow.get_task_elements(self.task):
3172
- yield i
3352
+ def __iter__(self) -> Iterator[Element]:
3353
+ yield from self.task.workflow.get_task_elements(self.task)
3354
+
3355
+ @overload
3356
+ def __getitem__(
3357
+ self,
3358
+ selection: int,
3359
+ ) -> Element:
3360
+ ...
3361
+
3362
+ @overload
3363
+ def __getitem__(
3364
+ self,
3365
+ selection: slice | list[int],
3366
+ ) -> list[Element]:
3367
+ ...
3173
3368
 
3174
3369
  @TimeIt.decorator
3175
3370
  def __getitem__(
3176
3371
  self,
3177
- selection: Union[int, slice, List[int]],
3178
- ) -> Union[app.Element, List[app.Element]]:
3179
- idx_lst = self._get_selection(selection)
3180
- elements = self.task.workflow.get_task_elements(self.task, idx_lst)
3372
+ selection: int | slice | list[int],
3373
+ ) -> Element | list[Element]:
3374
+ elements = self.task.workflow.get_task_elements(
3375
+ self.task, self.__get_selection(selection)
3376
+ )
3181
3377
 
3182
3378
  if isinstance(selection, int):
3183
3379
  return elements[0]
@@ -3186,7 +3382,8 @@ class Elements:
3186
3382
 
3187
3383
 
3188
3384
  @dataclass
3189
- class Parameters:
3385
+ @hydrate
3386
+ class Parameters(AppAware):
3190
3387
  """
3191
3388
  The parameters of a (workflow-bound) task. Iterable.
3192
3389
 
@@ -3206,78 +3403,85 @@ class Parameters:
3206
3403
  A default value to use when the parameter is absent.
3207
3404
  """
3208
3405
 
3209
- _app_attr = "_app"
3210
-
3211
3406
  #: The task these are the parameters of.
3212
- task: app.WorkflowTask
3407
+ task: WorkflowTask
3213
3408
  #: The path to the parameter or parameters.
3214
3409
  path: str
3215
3410
  #: Whether to return element parameters.
3216
3411
  return_element_parameters: bool
3217
3412
  #: Whether to raise an exception on a missing parameter.
3218
- raise_on_missing: Optional[bool] = False
3413
+ raise_on_missing: bool = False
3219
3414
  #: Whether to raise an exception on an unset parameter.
3220
- raise_on_unset: Optional[bool] = False
3415
+ raise_on_unset: bool = False
3221
3416
  #: A default value to use when the parameter is absent.
3222
- default: Optional[Any] = None
3417
+ default: Any | None = None
3223
3418
 
3224
3419
  @TimeIt.decorator
3225
- def _get_selection(self, selection: Union[int, slice, List[int]]) -> List[int]:
3420
+ def __get_selection(
3421
+ self, selection: int | slice | list[int] | tuple[int, ...]
3422
+ ) -> list[int]:
3226
3423
  """Normalise an element selection into a list of element indices."""
3227
3424
  if isinstance(selection, int):
3228
- lst = [selection]
3229
-
3425
+ return [selection]
3230
3426
  elif isinstance(selection, slice):
3231
- lst = list(range(*selection.indices(self.task.num_elements)))
3232
-
3427
+ return list(range(*selection.indices(self.task.num_elements)))
3233
3428
  elif isinstance(selection, list):
3234
- lst = selection
3429
+ return selection
3430
+ elif isinstance(selection, tuple):
3431
+ return list(selection)
3235
3432
  else:
3236
3433
  raise RuntimeError(
3237
3434
  f"{self.__class__.__name__} selection must be an `int`, `slice` object, "
3238
3435
  f"or list of `int`s, but received type {type(selection)}."
3239
3436
  )
3240
- return lst
3241
3437
 
3242
- def __iter__(self):
3243
- for i in self.__getitem__(slice(None)):
3244
- yield i
3438
+ def __iter__(self) -> Iterator[Any | ElementParameter]:
3439
+ yield from self.__getitem__(slice(None))
3440
+
3441
+ @overload
3442
+ def __getitem__(self, selection: int) -> Any | ElementParameter:
3443
+ ...
3444
+
3445
+ @overload
3446
+ def __getitem__(self, selection: slice | list[int]) -> list[Any | ElementParameter]:
3447
+ ...
3245
3448
 
3246
3449
  def __getitem__(
3247
3450
  self,
3248
- selection: Union[int, slice, List[int]],
3249
- ) -> Union[Any, List[Any]]:
3250
- idx_lst = self._get_selection(selection)
3451
+ selection: int | slice | list[int],
3452
+ ) -> Any | ElementParameter | list[Any | ElementParameter]:
3453
+ idx_lst = self.__get_selection(selection)
3251
3454
  elements = self.task.workflow.get_task_elements(self.task, idx_lst)
3252
3455
  if self.return_element_parameters:
3253
- params = [
3456
+ params = (
3254
3457
  self._app.ElementParameter(
3255
3458
  task=self.task,
3256
3459
  path=self.path,
3257
- parent=i,
3258
- element=i,
3460
+ parent=elem,
3461
+ element=elem,
3259
3462
  )
3260
- for i in elements
3261
- ]
3463
+ for elem in elements
3464
+ )
3262
3465
  else:
3263
- params = [
3264
- i.get(
3466
+ params = (
3467
+ elem.get(
3265
3468
  path=self.path,
3266
3469
  raise_on_missing=self.raise_on_missing,
3267
3470
  raise_on_unset=self.raise_on_unset,
3268
3471
  default=self.default,
3269
3472
  )
3270
- for i in elements
3271
- ]
3473
+ for elem in elements
3474
+ )
3272
3475
 
3273
3476
  if isinstance(selection, int):
3274
- return params[0]
3477
+ return next(iter(params))
3275
3478
  else:
3276
- return params
3479
+ return list(params)
3277
3480
 
3278
3481
 
3279
3482
  @dataclass
3280
- class TaskInputParameters:
3483
+ @hydrate
3484
+ class TaskInputParameters(AppAware):
3281
3485
  """
3282
3486
  For retrieving schema input parameters across all elements.
3283
3487
  Treat as an unmodifiable namespace.
@@ -3288,31 +3492,34 @@ class TaskInputParameters:
3288
3492
  The task that this represents the input parameters of.
3289
3493
  """
3290
3494
 
3291
- _app_attr = "_app"
3292
-
3293
3495
  #: The task that this represents the input parameters of.
3294
- task: app.WorkflowTask
3496
+ task: WorkflowTask
3497
+ __input_names: frozenset[str] | None = field(default=None, init=False, compare=False)
3295
3498
 
3296
- def __getattr__(self, name):
3297
- if name not in self._get_input_names():
3499
+ def __getattr__(self, name: str) -> Parameters:
3500
+ if name not in self.__get_input_names():
3298
3501
  raise ValueError(f"No input named {name!r}.")
3299
3502
  return self._app.Parameters(self.task, f"inputs.{name}", True)
3300
3503
 
3301
- def __repr__(self):
3504
+ def __repr__(self) -> str:
3302
3505
  return (
3303
3506
  f"{self.__class__.__name__}("
3304
- f"{', '.join(f'{i!r}' for i in self._get_input_names())})"
3507
+ f"{', '.join(f'{name!r}' for name in sorted(self.__get_input_names()))})"
3305
3508
  )
3306
3509
 
3307
- def __dir__(self):
3308
- return super().__dir__() + self._get_input_names()
3510
+ def __dir__(self) -> Iterator[str]:
3511
+ yield from super().__dir__()
3512
+ yield from sorted(self.__get_input_names())
3309
3513
 
3310
- def _get_input_names(self):
3311
- return sorted(self.task.template.all_schema_input_types)
3514
+ def __get_input_names(self) -> frozenset[str]:
3515
+ if self.__input_names is None:
3516
+ self.__input_names = frozenset(self.task.template.all_schema_input_types)
3517
+ return self.__input_names
3312
3518
 
3313
3519
 
3314
3520
  @dataclass
3315
- class TaskOutputParameters:
3521
+ @hydrate
3522
+ class TaskOutputParameters(AppAware):
3316
3523
  """
3317
3524
  For retrieving schema output parameters across all elements.
3318
3525
  Treat as an unmodifiable namespace.
@@ -3323,31 +3530,34 @@ class TaskOutputParameters:
3323
3530
  The task that this represents the output parameters of.
3324
3531
  """
3325
3532
 
3326
- _app_attr = "_app"
3327
-
3328
3533
  #: The task that this represents the output parameters of.
3329
- task: app.WorkflowTask
3534
+ task: WorkflowTask
3535
+ __output_names: frozenset[str] | None = field(default=None, init=False, compare=False)
3330
3536
 
3331
- def __getattr__(self, name):
3332
- if name not in self._get_output_names():
3537
+ def __getattr__(self, name: str) -> Parameters:
3538
+ if name not in self.__get_output_names():
3333
3539
  raise ValueError(f"No output named {name!r}.")
3334
3540
  return self._app.Parameters(self.task, f"outputs.{name}", True)
3335
3541
 
3336
- def __repr__(self):
3542
+ def __repr__(self) -> str:
3337
3543
  return (
3338
3544
  f"{self.__class__.__name__}("
3339
- f"{', '.join(f'{i!r}' for i in self._get_output_names())})"
3545
+ f"{', '.join(map(repr, sorted(self.__get_output_names())))})"
3340
3546
  )
3341
3547
 
3342
- def __dir__(self):
3343
- return super().__dir__() + self._get_output_names()
3548
+ def __dir__(self) -> Iterator[str]:
3549
+ yield from super().__dir__()
3550
+ yield from sorted(self.__get_output_names())
3344
3551
 
3345
- def _get_output_names(self):
3346
- return sorted(self.task.template.all_schema_output_types)
3552
+ def __get_output_names(self) -> frozenset[str]:
3553
+ if self.__output_names is None:
3554
+ self.__output_names = frozenset(self.task.template.all_schema_output_types)
3555
+ return self.__output_names
3347
3556
 
3348
3557
 
3349
3558
  @dataclass
3350
- class ElementPropagation:
3559
+ @hydrate
3560
+ class ElementPropagation(AppAware):
3351
3561
  """
3352
3562
  Class to represent how a newly added element set should propagate to a given
3353
3563
  downstream task.
@@ -3362,17 +3572,15 @@ class ElementPropagation:
3362
3572
  The input source information.
3363
3573
  """
3364
3574
 
3365
- _app_attr = "app"
3366
-
3367
3575
  #: The task this is propagating to.
3368
- task: app.Task
3576
+ task: WorkflowTask
3369
3577
  #: The nesting order information.
3370
- nesting_order: Optional[Dict] = None
3578
+ nesting_order: dict[str, float] | None = None
3371
3579
  #: The input source information.
3372
- input_sources: Optional[Dict] = None
3580
+ input_sources: dict[str, list[InputSource]] | None = None
3373
3581
 
3374
3582
  @property
3375
- def element_set(self):
3583
+ def element_set(self) -> ElementSet:
3376
3584
  """
3377
3585
  The element set that this propagates from.
3378
3586
 
@@ -3383,25 +3591,38 @@ class ElementPropagation:
3383
3591
  # TEMP property; for now just use the first element set as the base:
3384
3592
  return self.task.template.element_sets[0]
3385
3593
 
3386
- def __deepcopy__(self, memo):
3594
+ def __deepcopy__(self, memo: dict[int, Any] | None) -> Self:
3387
3595
  return self.__class__(
3388
3596
  task=self.task,
3389
- nesting_order=copy.deepcopy(self.nesting_order, memo),
3597
+ nesting_order=copy.copy(self.nesting_order),
3390
3598
  input_sources=copy.deepcopy(self.input_sources, memo),
3391
3599
  )
3392
3600
 
3393
3601
  @classmethod
3394
- def _prepare_propagate_to_dict(cls, propagate_to, workflow):
3395
- propagate_to = copy.deepcopy(propagate_to)
3602
+ def _prepare_propagate_to_dict(
3603
+ cls,
3604
+ propagate_to: (
3605
+ list[ElementPropagation]
3606
+ | Mapping[str, ElementPropagation | Mapping[str, Any]]
3607
+ | None
3608
+ ),
3609
+ workflow: Workflow,
3610
+ ) -> dict[str, ElementPropagation]:
3396
3611
  if not propagate_to:
3397
- propagate_to = {}
3398
- elif isinstance(propagate_to, list):
3399
- propagate_to = {i.task.unique_name: i for i in propagate_to}
3400
-
3401
- for k, v in propagate_to.items():
3402
- if not isinstance(v, ElementPropagation):
3403
- propagate_to[k] = cls.app.ElementPropagation(
3404
- task=workflow.tasks.get(unique_name=k),
3405
- **v,
3406
- )
3407
- return propagate_to
3612
+ return {}
3613
+ propagate_to = copy.deepcopy(propagate_to)
3614
+ if isinstance(propagate_to, list):
3615
+ return {prop.task.unique_name: prop for prop in propagate_to}
3616
+
3617
+ return {
3618
+ k: (
3619
+ v
3620
+ if isinstance(v, ElementPropagation)
3621
+ else cls(task=workflow.tasks.get(unique_name=k), **v)
3622
+ )
3623
+ for k, v in propagate_to.items()
3624
+ }
3625
+
3626
+
3627
+ #: A task used as a template for other tasks.
3628
+ TaskTemplate: TypeAlias = Task