hpcflow-new2 0.2.0a189__py3-none-any.whl → 0.2.0a190__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. hpcflow/__pyinstaller/hook-hpcflow.py +8 -6
  2. hpcflow/_version.py +1 -1
  3. hpcflow/app.py +1 -0
  4. hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +1 -1
  5. hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +1 -1
  6. hpcflow/sdk/__init__.py +21 -15
  7. hpcflow/sdk/app.py +2133 -770
  8. hpcflow/sdk/cli.py +281 -250
  9. hpcflow/sdk/cli_common.py +6 -2
  10. hpcflow/sdk/config/__init__.py +1 -1
  11. hpcflow/sdk/config/callbacks.py +77 -42
  12. hpcflow/sdk/config/cli.py +126 -103
  13. hpcflow/sdk/config/config.py +578 -311
  14. hpcflow/sdk/config/config_file.py +131 -95
  15. hpcflow/sdk/config/errors.py +112 -85
  16. hpcflow/sdk/config/types.py +145 -0
  17. hpcflow/sdk/core/actions.py +1054 -994
  18. hpcflow/sdk/core/app_aware.py +24 -0
  19. hpcflow/sdk/core/cache.py +81 -63
  20. hpcflow/sdk/core/command_files.py +275 -185
  21. hpcflow/sdk/core/commands.py +111 -107
  22. hpcflow/sdk/core/element.py +724 -503
  23. hpcflow/sdk/core/enums.py +192 -0
  24. hpcflow/sdk/core/environment.py +74 -93
  25. hpcflow/sdk/core/errors.py +398 -51
  26. hpcflow/sdk/core/json_like.py +540 -272
  27. hpcflow/sdk/core/loop.py +380 -334
  28. hpcflow/sdk/core/loop_cache.py +160 -43
  29. hpcflow/sdk/core/object_list.py +370 -207
  30. hpcflow/sdk/core/parameters.py +728 -600
  31. hpcflow/sdk/core/rule.py +59 -41
  32. hpcflow/sdk/core/run_dir_files.py +33 -22
  33. hpcflow/sdk/core/task.py +1546 -1325
  34. hpcflow/sdk/core/task_schema.py +240 -196
  35. hpcflow/sdk/core/test_utils.py +126 -88
  36. hpcflow/sdk/core/types.py +387 -0
  37. hpcflow/sdk/core/utils.py +410 -305
  38. hpcflow/sdk/core/validation.py +82 -9
  39. hpcflow/sdk/core/workflow.py +1192 -1028
  40. hpcflow/sdk/core/zarr_io.py +98 -137
  41. hpcflow/sdk/demo/cli.py +46 -33
  42. hpcflow/sdk/helper/cli.py +18 -16
  43. hpcflow/sdk/helper/helper.py +75 -63
  44. hpcflow/sdk/helper/watcher.py +61 -28
  45. hpcflow/sdk/log.py +83 -59
  46. hpcflow/sdk/persistence/__init__.py +8 -31
  47. hpcflow/sdk/persistence/base.py +988 -586
  48. hpcflow/sdk/persistence/defaults.py +6 -0
  49. hpcflow/sdk/persistence/discovery.py +38 -0
  50. hpcflow/sdk/persistence/json.py +408 -153
  51. hpcflow/sdk/persistence/pending.py +158 -123
  52. hpcflow/sdk/persistence/store_resource.py +37 -22
  53. hpcflow/sdk/persistence/types.py +307 -0
  54. hpcflow/sdk/persistence/utils.py +14 -11
  55. hpcflow/sdk/persistence/zarr.py +477 -420
  56. hpcflow/sdk/runtime.py +44 -41
  57. hpcflow/sdk/submission/{jobscript_info.py → enums.py} +39 -12
  58. hpcflow/sdk/submission/jobscript.py +444 -404
  59. hpcflow/sdk/submission/schedulers/__init__.py +133 -40
  60. hpcflow/sdk/submission/schedulers/direct.py +97 -71
  61. hpcflow/sdk/submission/schedulers/sge.py +132 -126
  62. hpcflow/sdk/submission/schedulers/slurm.py +263 -268
  63. hpcflow/sdk/submission/schedulers/utils.py +7 -2
  64. hpcflow/sdk/submission/shells/__init__.py +14 -15
  65. hpcflow/sdk/submission/shells/base.py +102 -29
  66. hpcflow/sdk/submission/shells/bash.py +72 -55
  67. hpcflow/sdk/submission/shells/os_version.py +31 -30
  68. hpcflow/sdk/submission/shells/powershell.py +37 -29
  69. hpcflow/sdk/submission/submission.py +203 -257
  70. hpcflow/sdk/submission/types.py +143 -0
  71. hpcflow/sdk/typing.py +163 -12
  72. hpcflow/tests/conftest.py +8 -6
  73. hpcflow/tests/schedulers/slurm/test_slurm_submission.py +5 -2
  74. hpcflow/tests/scripts/test_main_scripts.py +60 -30
  75. hpcflow/tests/shells/wsl/test_wsl_submission.py +6 -4
  76. hpcflow/tests/unit/test_action.py +86 -75
  77. hpcflow/tests/unit/test_action_rule.py +9 -4
  78. hpcflow/tests/unit/test_app.py +13 -6
  79. hpcflow/tests/unit/test_cli.py +1 -1
  80. hpcflow/tests/unit/test_command.py +71 -54
  81. hpcflow/tests/unit/test_config.py +20 -15
  82. hpcflow/tests/unit/test_config_file.py +21 -18
  83. hpcflow/tests/unit/test_element.py +58 -62
  84. hpcflow/tests/unit/test_element_iteration.py +3 -1
  85. hpcflow/tests/unit/test_element_set.py +29 -19
  86. hpcflow/tests/unit/test_group.py +4 -2
  87. hpcflow/tests/unit/test_input_source.py +116 -93
  88. hpcflow/tests/unit/test_input_value.py +29 -24
  89. hpcflow/tests/unit/test_json_like.py +44 -35
  90. hpcflow/tests/unit/test_loop.py +65 -58
  91. hpcflow/tests/unit/test_object_list.py +17 -12
  92. hpcflow/tests/unit/test_parameter.py +16 -7
  93. hpcflow/tests/unit/test_persistence.py +48 -35
  94. hpcflow/tests/unit/test_resources.py +20 -18
  95. hpcflow/tests/unit/test_run.py +8 -3
  96. hpcflow/tests/unit/test_runtime.py +2 -1
  97. hpcflow/tests/unit/test_schema_input.py +23 -15
  98. hpcflow/tests/unit/test_shell.py +3 -2
  99. hpcflow/tests/unit/test_slurm.py +8 -7
  100. hpcflow/tests/unit/test_submission.py +39 -19
  101. hpcflow/tests/unit/test_task.py +352 -247
  102. hpcflow/tests/unit/test_task_schema.py +33 -20
  103. hpcflow/tests/unit/test_utils.py +9 -11
  104. hpcflow/tests/unit/test_value_sequence.py +15 -12
  105. hpcflow/tests/unit/test_workflow.py +114 -83
  106. hpcflow/tests/unit/test_workflow_template.py +0 -1
  107. hpcflow/tests/workflows/test_jobscript.py +2 -1
  108. hpcflow/tests/workflows/test_workflows.py +18 -13
  109. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/METADATA +2 -1
  110. hpcflow_new2-0.2.0a190.dist-info/RECORD +165 -0
  111. hpcflow/sdk/core/parallel.py +0 -21
  112. hpcflow_new2-0.2.0a189.dist-info/RECORD +0 -158
  113. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/LICENSE +0 -0
  114. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/WHEEL +0 -0
  115. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/entry_points.txt +0 -0
hpcflow/sdk/core/loop.py CHANGED
@@ -7,20 +7,29 @@ notably looping over a set of values or until a condition holds.
7
7
  from __future__ import annotations
8
8
 
9
9
  import copy
10
- from typing import Dict, List, Optional, Tuple, Union
10
+ from collections import defaultdict
11
+ from itertools import chain
12
+ from typing import TYPE_CHECKING
13
+ from typing_extensions import override
11
14
 
12
- from hpcflow.sdk import app
15
+ from hpcflow.sdk.core.app_aware import AppAware
13
16
  from hpcflow.sdk.core.errors import LoopTaskSubsetError
14
17
  from hpcflow.sdk.core.json_like import ChildObjectSpec, JSONLike
15
- from hpcflow.sdk.core.loop_cache import LoopCache
16
- from hpcflow.sdk.core.parameters import InputSourceType
17
- from hpcflow.sdk.core.task import WorkflowTask
18
+ from hpcflow.sdk.core.loop_cache import LoopCache, LoopIndex
19
+ from hpcflow.sdk.core.enums import InputSourceType
18
20
  from hpcflow.sdk.core.utils import check_valid_py_identifier, nth_key, nth_value
19
21
  from hpcflow.sdk.log import TimeIt
20
22
 
21
- # from .parameters import Parameter
22
-
23
- # from valida.conditions import ConditionLike
23
+ if TYPE_CHECKING:
24
+ from collections.abc import Iterable, Iterator, Mapping, Sequence
25
+ from typing import Any, ClassVar
26
+ from typing_extensions import Self, TypeIs
27
+ from ..typing import DataIndex, ParamSource
28
+ from .parameters import SchemaInput, InputSource
29
+ from .rule import Rule
30
+ from .task import WorkflowTask
31
+ from .types import IterableParam
32
+ from .workflow import Workflow, WorkflowTemplate
24
33
 
25
34
 
26
35
  # @dataclass
@@ -55,20 +64,25 @@ class Loop(JSONLike):
55
64
  Stopping criterion, expressed as a rule.
56
65
  """
57
66
 
58
- _app_attr = "app"
59
- _child_objects = (ChildObjectSpec(name="termination", class_name="Rule"),)
67
+ _child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
68
+ ChildObjectSpec(name="termination", class_name="Rule"),
69
+ )
70
+
71
+ @classmethod
72
+ def __is_WorkflowTask(cls, value) -> TypeIs[WorkflowTask]:
73
+ return isinstance(value, cls._app.WorkflowTask)
60
74
 
61
75
  def __init__(
62
76
  self,
63
- tasks: List[Union[int, app.WorkflowTask]],
77
+ tasks: Iterable[int | WorkflowTask],
64
78
  num_iterations: int,
65
- name: Optional[str] = None,
66
- non_iterable_parameters: Optional[List[str]] = None,
67
- termination: Optional[app.Rule] = None,
79
+ name: str | None = None,
80
+ non_iterable_parameters: list[str] | None = None,
81
+ termination: Rule | None = None,
68
82
  ) -> None:
69
- _task_insert_IDs = []
83
+ _task_insert_IDs: list[int] = []
70
84
  for task in tasks:
71
- if isinstance(task, WorkflowTask):
85
+ if self.__is_WorkflowTask(task):
72
86
  _task_insert_IDs.append(task.insert_ID)
73
87
  elif isinstance(task, int):
74
88
  _task_insert_IDs.append(task)
@@ -84,95 +98,103 @@ class Loop(JSONLike):
84
98
  self._non_iterable_parameters = non_iterable_parameters or []
85
99
  self._termination = termination
86
100
 
87
- self._workflow_template = None # assigned by parent WorkflowTemplate
101
+ self._workflow_template: WorkflowTemplate | None = (
102
+ None # assigned by parent WorkflowTemplate
103
+ )
88
104
 
89
- def to_dict(self):
90
- out = super().to_dict()
105
+ @override
106
+ def _postprocess_to_dict(self, d: dict[str, Any]) -> dict[str, Any]:
107
+ out = super()._postprocess_to_dict(d)
91
108
  return {k.lstrip("_"): v for k, v in out.items()}
92
109
 
93
110
  @classmethod
94
- def _json_like_constructor(cls, json_like):
111
+ def _json_like_constructor(cls, json_like: dict) -> Self:
95
112
  """Invoked by `JSONLike.from_json_like` instead of `__init__`."""
96
113
  if "task_insert_IDs" in json_like:
97
114
  insert_IDs = json_like.pop("task_insert_IDs")
98
115
  else:
99
116
  insert_IDs = json_like.pop("tasks")
100
- obj = cls(tasks=insert_IDs, **json_like)
101
- return obj
117
+ return cls(tasks=insert_IDs, **json_like)
102
118
 
103
119
  @property
104
- def task_insert_IDs(self) -> Tuple[int]:
120
+ def task_insert_IDs(self) -> tuple[int, ...]:
105
121
  """Get the list of task insert_IDs that define the extent of the loop."""
106
122
  return tuple(self._task_insert_IDs)
107
123
 
108
124
  @property
109
- def name(self):
125
+ def name(self) -> str | None:
110
126
  """
111
127
  The name of the loop, if one was provided.
112
128
  """
113
129
  return self._name
114
130
 
115
131
  @property
116
- def num_iterations(self):
132
+ def num_iterations(self) -> int:
117
133
  """
118
134
  The number of loop iterations to do.
119
135
  """
120
136
  return self._num_iterations
121
137
 
122
138
  @property
123
- def non_iterable_parameters(self):
139
+ def non_iterable_parameters(self) -> Sequence[str]:
124
140
  """
125
141
  Which parameters are not iterable.
126
142
  """
127
143
  return self._non_iterable_parameters
128
144
 
129
145
  @property
130
- def termination(self):
146
+ def termination(self) -> Rule | None:
131
147
  """
132
148
  A termination rule for the loop, if one is provided.
133
149
  """
134
150
  return self._termination
135
151
 
136
152
  @property
137
- def workflow_template(self):
153
+ def workflow_template(self) -> WorkflowTemplate | None:
138
154
  """
139
155
  The workflow template that contains this loop.
140
156
  """
141
157
  return self._workflow_template
142
158
 
143
159
  @workflow_template.setter
144
- def workflow_template(self, template: app.WorkflowTemplate):
160
+ def workflow_template(self, template: WorkflowTemplate):
145
161
  self._workflow_template = template
146
- self._validate_against_template()
162
+ self.__validate_against_template()
163
+
164
+ def __workflow(self) -> None | Workflow:
165
+ if (wt := self.workflow_template) is None:
166
+ return None
167
+ return wt.workflow
147
168
 
148
169
  @property
149
- def task_objects(self) -> Tuple[app.WorkflowTask]:
170
+ def task_objects(self) -> tuple[WorkflowTask, ...]:
150
171
  """
151
172
  The tasks in the loop.
152
173
  """
153
- if not self.workflow_template:
174
+ if not (wf := self.__workflow()):
154
175
  raise RuntimeError(
155
176
  "Workflow template must be assigned to retrieve task objects of the loop."
156
177
  )
157
- return tuple(
158
- self.workflow_template.workflow.tasks.get(insert_ID=i)
159
- for i in self.task_insert_IDs
160
- )
178
+ return tuple(wf.tasks.get(insert_ID=t_id) for t_id in self.task_insert_IDs)
161
179
 
162
- def _validate_against_template(self):
180
+ def __validate_against_template(self) -> None:
163
181
  """Validate the loop parameters against the associated workflow."""
164
182
 
165
183
  # insert IDs must exist:
184
+ if not (wf := self.__workflow()):
185
+ raise RuntimeError(
186
+ "workflow cannot be validated against as it is not assigned"
187
+ )
166
188
  for insert_ID in self.task_insert_IDs:
167
189
  try:
168
- self.workflow_template.workflow.tasks.get(insert_ID=insert_ID)
190
+ wf.tasks.get(insert_ID=insert_ID)
169
191
  except ValueError:
170
192
  raise ValueError(
171
193
  f"Loop {self.name!r} has an invalid task insert ID {insert_ID!r}. "
172
194
  f"Such as task does not exist in the associated workflow."
173
195
  )
174
196
 
175
- def __repr__(self):
197
+ def __repr__(self) -> str:
176
198
  num_iterations_str = ""
177
199
  if self.num_iterations is not None:
178
200
  num_iterations_str = f", num_iterations={self.num_iterations!r}"
@@ -187,7 +209,7 @@ class Loop(JSONLike):
187
209
  f")"
188
210
  )
189
211
 
190
- def __deepcopy__(self, memo):
212
+ def __deepcopy__(self, memo: dict[int, Any]) -> Self:
191
213
  kwargs = self.to_dict()
192
214
  kwargs["tasks"] = kwargs.pop("task_insert_IDs")
193
215
  obj = self.__class__(**copy.deepcopy(kwargs, memo))
@@ -195,7 +217,7 @@ class Loop(JSONLike):
195
217
  return obj
196
218
 
197
219
 
198
- class WorkflowLoop:
220
+ class WorkflowLoop(AppAware):
199
221
  """
200
222
  Class to represent a :py:class:`.Loop` that is bound to a
201
223
  :py:class:`~hpcflow.app.Workflow`.
@@ -216,17 +238,15 @@ class WorkflowLoop:
216
238
  The paths to the parent entities of this loop.
217
239
  """
218
240
 
219
- _app_attr = "app"
220
-
221
241
  def __init__(
222
242
  self,
223
243
  index: int,
224
- workflow: app.Workflow,
225
- template: app.Loop,
226
- num_added_iterations: Dict[Tuple[int], int],
227
- iterable_parameters: Dict[int : List[int, List[int]]],
228
- parents: List[str],
229
- ):
244
+ workflow: Workflow,
245
+ template: Loop,
246
+ num_added_iterations: dict[tuple[int, ...], int],
247
+ iterable_parameters: dict[str, IterableParam],
248
+ parents: list[str],
249
+ ) -> None:
230
250
  self._index = index
231
251
  self._workflow = workflow
232
252
  self._template = template
@@ -236,25 +256,22 @@ class WorkflowLoop:
236
256
 
237
257
  # appended to on adding a empty loop to the workflow that's a parent of this loop,
238
258
  # reset and added to `self._parents` on dump to disk:
239
- self._pending_parents = []
259
+ self._pending_parents: list[str] = []
240
260
 
241
261
  # used for `num_added_iterations` when a new loop iteration is added, or when
242
262
  # parents are append to; reset to None on dump to disk. Each key is a tuple of
243
263
  # parent loop indices and each value is the number of pending new iterations:
244
- self._pending_num_added_iterations = None
264
+ self._pending_num_added_iterations: dict[tuple[int, ...], int] | None = None
245
265
 
246
266
  self._validate()
247
267
 
248
268
  @TimeIt.decorator
249
- def _validate(self):
269
+ def _validate(self) -> None:
250
270
  # task subset must be a contiguous range of task indices:
251
271
  task_indices = self.task_indices
252
272
  task_min, task_max = task_indices[0], task_indices[-1]
253
273
  if task_indices != tuple(range(task_min, task_max + 1)):
254
- raise LoopTaskSubsetError(
255
- f"Loop {self.name!r}: task subset must be an ascending contiguous range, "
256
- f"but specified task indices were: {self.task_indices!r}."
257
- )
274
+ raise LoopTaskSubsetError(self.name, self.task_indices)
258
275
 
259
276
  for task in self.downstream_tasks:
260
277
  for param in self.iterable_parameters:
@@ -273,7 +290,7 @@ class WorkflowLoop:
273
290
  )
274
291
 
275
292
  @property
276
- def num_added_iterations(self):
293
+ def num_added_iterations(self) -> Mapping[tuple[int, ...], int]:
277
294
  """
278
295
  The number of added iterations.
279
296
  """
@@ -282,27 +299,30 @@ class WorkflowLoop:
282
299
  else:
283
300
  return self._num_added_iterations
284
301
 
285
- def _initialise_pending_added_iters(self, added_iters_key):
302
+ @property
303
+ def __pending(self) -> dict[tuple[int, ...], int]:
286
304
  if not self._pending_num_added_iterations:
287
- self._pending_num_added_iterations = copy.deepcopy(self._num_added_iterations)
288
-
289
- if added_iters_key not in self._pending_num_added_iterations:
290
- self._pending_num_added_iterations[added_iters_key] = 1
305
+ self._pending_num_added_iterations = dict(self._num_added_iterations)
306
+ return self._pending_num_added_iterations
291
307
 
292
- def _increment_pending_added_iters(self, added_iters_key):
308
+ def _initialise_pending_added_iters(self, added_iters: Iterable[int]):
293
309
  if not self._pending_num_added_iterations:
294
- self._pending_num_added_iterations = copy.deepcopy(self._num_added_iterations)
310
+ self._pending_num_added_iterations = dict(self._num_added_iterations)
311
+ if (added_iters_key := tuple(added_iters)) not in (pending := self.__pending):
312
+ pending[added_iters_key] = 1
295
313
 
296
- self._pending_num_added_iterations[added_iters_key] += 1
314
+ def _increment_pending_added_iters(self, added_iters_key: Iterable[int]):
315
+ self.__pending[tuple(added_iters_key)] += 1
297
316
 
298
- def _update_parents(self, parent: app.WorkflowLoop):
317
+ def _update_parents(self, parent: WorkflowLoop):
318
+ assert parent.name
299
319
  self._pending_parents.append(parent.name)
300
320
 
301
- if not self._pending_num_added_iterations:
302
- self._pending_num_added_iterations = copy.deepcopy(self._num_added_iterations)
303
-
304
321
  self._pending_num_added_iterations = {
305
- tuple(list(k) + [0]): v for k, v in self._pending_num_added_iterations.items()
322
+ (*k, 0): v
323
+ for k, v in (
324
+ self._pending_num_added_iterations or self._num_added_iterations
325
+ ).items()
306
326
  }
307
327
 
308
328
  self.workflow._store.update_loop_parents(
@@ -311,116 +331,118 @@ class WorkflowLoop:
311
331
  parents=self.parents,
312
332
  )
313
333
 
314
- def _reset_pending_num_added_iters(self):
334
+ def _reset_pending_num_added_iters(self) -> None:
315
335
  self._pending_num_added_iterations = None
316
336
 
317
- def _accept_pending_num_added_iters(self):
337
+ def _accept_pending_num_added_iters(self) -> None:
318
338
  if self._pending_num_added_iterations:
319
- self._num_added_iterations = copy.deepcopy(self._pending_num_added_iterations)
339
+ self._num_added_iterations = dict(self._pending_num_added_iterations)
320
340
  self._reset_pending_num_added_iters()
321
341
 
322
- def _reset_pending_parents(self):
342
+ def _reset_pending_parents(self) -> None:
323
343
  self._pending_parents = []
324
344
 
325
- def _accept_pending_parents(self):
345
+ def _accept_pending_parents(self) -> None:
326
346
  self._parents += self._pending_parents
327
347
  self._reset_pending_parents()
328
348
 
329
349
  @property
330
- def index(self):
350
+ def index(self) -> int:
331
351
  """
332
352
  The index of this loop within its workflow.
333
353
  """
334
354
  return self._index
335
355
 
336
356
  @property
337
- def task_insert_IDs(self):
357
+ def task_insert_IDs(self) -> tuple[int, ...]:
338
358
  """
339
359
  The insertion IDs of the tasks inside this loop.
340
360
  """
341
361
  return self.template.task_insert_IDs
342
362
 
343
363
  @property
344
- def task_objects(self):
364
+ def task_objects(self) -> tuple[WorkflowTask, ...]:
345
365
  """
346
366
  The tasks in this loop.
347
367
  """
348
368
  return self.template.task_objects
349
369
 
350
370
  @property
351
- def task_indices(self) -> Tuple[int]:
371
+ def task_indices(self) -> tuple[int, ...]:
352
372
  """
353
373
  The list of task indices that define the extent of the loop.
354
374
  """
355
- return tuple(i.index for i in self.task_objects)
375
+ return tuple(task.index for task in self.task_objects)
356
376
 
357
377
  @property
358
- def workflow(self):
378
+ def workflow(self) -> Workflow:
359
379
  """
360
380
  The workflow containing this loop.
361
381
  """
362
382
  return self._workflow
363
383
 
364
384
  @property
365
- def template(self):
385
+ def template(self) -> Loop:
366
386
  """
367
387
  The loop template for this loop.
368
388
  """
369
389
  return self._template
370
390
 
371
391
  @property
372
- def parents(self) -> List[str]:
392
+ def parents(self) -> Sequence[str]:
373
393
  """
374
394
  The parents of this loop.
375
395
  """
376
396
  return self._parents + self._pending_parents
377
397
 
378
398
  @property
379
- def name(self):
399
+ def name(self) -> str:
380
400
  """
381
401
  The name of this loop, if one is defined.
382
402
  """
403
+ assert self.template.name
383
404
  return self.template.name
384
405
 
385
406
  @property
386
- def iterable_parameters(self):
407
+ def iterable_parameters(self) -> Mapping[str, IterableParam]:
387
408
  """
388
409
  The parameters that are being iterated over.
389
410
  """
390
411
  return self._iterable_parameters
391
412
 
392
413
  @property
393
- def num_iterations(self):
414
+ def num_iterations(self) -> int:
394
415
  """
395
416
  The number of iterations.
396
417
  """
397
418
  return self.template.num_iterations
398
419
 
399
420
  @property
400
- def downstream_tasks(self) -> List[app.WorkflowLoop]:
421
+ def downstream_tasks(self) -> Iterator[WorkflowTask]:
401
422
  """Tasks that are not part of the loop, and downstream from this loop."""
402
- return self.workflow.tasks[self.task_objects[-1].index + 1 :]
423
+ tasks = self.workflow.tasks
424
+ for idx in range(self.task_objects[-1].index + 1, len(tasks)):
425
+ yield tasks[idx]
403
426
 
404
427
  @property
405
- def upstream_tasks(self) -> List[app.WorkflowLoop]:
428
+ def upstream_tasks(self) -> Iterator[WorkflowTask]:
406
429
  """Tasks that are not part of the loop, and upstream from this loop."""
407
- return self.workflow.tasks[: self.task_objects[0].index]
430
+ tasks = self.workflow.tasks
431
+ for idx in range(0, self.task_objects[0].index):
432
+ yield tasks[idx]
408
433
 
409
434
  @staticmethod
410
435
  @TimeIt.decorator
411
- def _find_iterable_parameters(loop_template: app.Loop):
412
- all_inputs_first_idx = {}
413
- all_outputs_idx = {}
436
+ def _find_iterable_parameters(loop_template: Loop) -> dict[str, IterableParam]:
437
+ all_inputs_first_idx: dict[str, int] = {}
438
+ all_outputs_idx: dict[str, list[int]] = defaultdict(list)
414
439
  for task in loop_template.task_objects:
415
440
  for typ in task.template.all_schema_input_types:
416
- if typ not in all_inputs_first_idx:
417
- all_inputs_first_idx[typ] = task.insert_ID
441
+ all_inputs_first_idx.setdefault(typ, task.insert_ID)
418
442
  for typ in task.template.all_schema_output_types:
419
- if typ not in all_outputs_idx:
420
- all_outputs_idx[typ] = []
421
443
  all_outputs_idx[typ].append(task.insert_ID)
422
444
 
423
- iterable_params = {}
445
+ iterable_params: dict[str, IterableParam] = {}
424
446
  for typ, first_idx in all_inputs_first_idx.items():
425
447
  if typ in all_outputs_idx and first_idx <= all_outputs_idx[typ][0]:
426
448
  iterable_params[typ] = {
@@ -429,8 +451,7 @@ class WorkflowLoop:
429
451
  }
430
452
 
431
453
  for non_iter in loop_template.non_iterable_parameters:
432
- if non_iter in iterable_params:
433
- del iterable_params[non_iter]
454
+ iterable_params.pop(non_iter, None)
434
455
 
435
456
  return iterable_params
436
457
 
@@ -439,10 +460,10 @@ class WorkflowLoop:
439
460
  def new_empty_loop(
440
461
  cls,
441
462
  index: int,
442
- workflow: app.Workflow,
443
- template: app.Loop,
444
- iter_loop_idx: List[Dict],
445
- ) -> Tuple[app.WorkflowLoop, List[Dict[str, int]]]:
463
+ workflow: Workflow,
464
+ template: Loop,
465
+ iter_loop_idx: Sequence[Mapping[str, int]],
466
+ ) -> WorkflowLoop:
446
467
  """
447
468
  Make a new empty loop.
448
469
 
@@ -457,13 +478,16 @@ class WorkflowLoop:
457
478
  iter_loop_idx: list[dict]
458
479
  Iteration information from parent loops.
459
480
  """
460
- parent_loops = cls._get_parent_loops(index, workflow, template)
461
- parent_names = [i.name for i in parent_loops]
462
- num_added_iters = {}
463
- for i in iter_loop_idx:
464
- num_added_iters[tuple([i[j] for j in parent_names])] = 1
481
+ parent_names = [
482
+ loop.name
483
+ for loop in cls._get_parent_loops(index, workflow, template)
484
+ if loop.name
485
+ ]
486
+ num_added_iters = {
487
+ tuple(l_idx[nm] for nm in parent_names): 1 for l_idx in iter_loop_idx
488
+ }
465
489
 
466
- obj = cls(
490
+ return cls(
467
491
  index=index,
468
492
  workflow=workflow,
469
493
  template=template,
@@ -471,17 +495,16 @@ class WorkflowLoop:
471
495
  iterable_parameters=cls._find_iterable_parameters(template),
472
496
  parents=parent_names,
473
497
  )
474
- return obj
475
498
 
476
499
  @classmethod
477
500
  @TimeIt.decorator
478
501
  def _get_parent_loops(
479
502
  cls,
480
503
  index: int,
481
- workflow: app.Workflow,
482
- template: app.Loop,
483
- ) -> List[app.WorkflowLoop]:
484
- parents = []
504
+ workflow: Workflow,
505
+ template: Loop,
506
+ ) -> list[WorkflowLoop]:
507
+ parents: list[WorkflowLoop] = []
485
508
  passed_self = False
486
509
  self_tasks = set(template.task_insert_IDs)
487
510
  for loop_i in workflow.loops:
@@ -496,18 +519,18 @@ class WorkflowLoop:
496
519
  return parents
497
520
 
498
521
  @TimeIt.decorator
499
- def get_parent_loops(self) -> List[app.WorkflowLoop]:
522
+ def get_parent_loops(self) -> list[WorkflowLoop]:
500
523
  """Get loops whose task subset is a superset of this loop's task subset. If two
501
524
  loops have identical task subsets, the first loop in the workflow loop list is
502
525
  considered the child."""
503
526
  return self._get_parent_loops(self.index, self.workflow, self.template)
504
527
 
505
528
  @TimeIt.decorator
506
- def get_child_loops(self) -> List[app.WorkflowLoop]:
529
+ def get_child_loops(self) -> list[WorkflowLoop]:
507
530
  """Get loops whose task subset is a subset of this loop's task subset. If two
508
531
  loops have identical task subsets, the first loop in the workflow loop list is
509
532
  considered the child."""
510
- children = []
533
+ children: list[WorkflowLoop] = []
511
534
  passed_self = False
512
535
  self_tasks = set(self.task_insert_IDs)
513
536
  for loop_i in self.workflow.loops:
@@ -521,11 +544,14 @@ class WorkflowLoop:
521
544
  children.append(loop_i)
522
545
 
523
546
  # order by depth, so direct child is first:
524
- children = sorted(children, key=lambda x: len(next(iter(x.num_added_iterations))))
525
- return children
547
+ return sorted(children, key=lambda x: len(next(iter(x.num_added_iterations))))
526
548
 
527
549
  @TimeIt.decorator
528
- def add_iteration(self, parent_loop_indices=None, cache: Optional[LoopCache] = None):
550
+ def add_iteration(
551
+ self,
552
+ parent_loop_indices: Mapping[str, int] | None = None,
553
+ cache: LoopCache | None = None,
554
+ ) -> None:
529
555
  """
530
556
  Add an iteration to this loop.
531
557
 
@@ -539,42 +565,40 @@ class WorkflowLoop:
539
565
  """
540
566
  if not cache:
541
567
  cache = LoopCache.build(self.workflow)
568
+ assert cache is not None
542
569
  parent_loops = self.get_parent_loops()
543
570
  child_loops = self.get_child_loops()
544
- parent_loop_indices = parent_loop_indices or {}
545
- if parent_loops and not parent_loop_indices:
546
- parent_loop_indices = {i.name: 0 for i in parent_loops}
571
+ parent_loop_indices_ = parent_loop_indices or {
572
+ loop.name: 0 for loop in parent_loops
573
+ }
547
574
 
548
- iters_key = tuple([parent_loop_indices[k] for k in self.parents])
575
+ iters_key = tuple(parent_loop_indices_[p_nm] for p_nm in self.parents)
549
576
  cur_loop_idx = self.num_added_iterations[iters_key] - 1
550
- all_new_data_idx = {} # keys are (task.insert_ID and element.index)
577
+
578
+ # keys are (task.insert_ID and element.index)
579
+ all_new_data_idx: dict[tuple[int, int], DataIndex] = {}
551
580
 
552
581
  # initialise a new `num_added_iterations` key on each child loop:
582
+ iters_key_dct = {
583
+ **parent_loop_indices_,
584
+ self.name: cur_loop_idx + 1,
585
+ }
553
586
  for child in child_loops:
554
- iters_key_dct = {
555
- **parent_loop_indices,
556
- self.name: cur_loop_idx + 1,
557
- }
558
- added_iters_key_chd = tuple([iters_key_dct.get(j, 0) for j in child.parents])
559
- child._initialise_pending_added_iters(added_iters_key_chd)
587
+ child._initialise_pending_added_iters(
588
+ iters_key_dct.get(j, 0) for j in child.parents
589
+ )
560
590
 
561
591
  for task in self.task_objects:
562
-
563
- new_loop_idx = {
564
- **parent_loop_indices,
565
- self.name: cur_loop_idx + 1,
566
- **{
567
- child.name: 0
568
- for child in child_loops
569
- if task.insert_ID in child.task_insert_IDs
570
- },
592
+ new_loop_idx = LoopIndex(iters_key_dct) + {
593
+ child.name: 0
594
+ for child in child_loops
595
+ if task.insert_ID in child.task_insert_IDs
571
596
  }
572
- added_iter_IDs = []
597
+ added_iter_IDs: list[int] = []
573
598
  for elem_idx in range(task.num_elements):
574
-
575
599
  elem_ID = task.element_IDs[elem_idx]
576
600
 
577
- new_data_idx = {}
601
+ new_data_idx: DataIndex = {}
578
602
 
579
603
  # copy resources from zeroth iteration:
580
604
  zeroth_iter_ID, zi_iter_data_idx = cache.zeroth_iters[elem_ID]
@@ -587,109 +611,26 @@ class WorkflowLoop:
587
611
 
588
612
  for inp in task.template.all_schema_inputs:
589
613
  is_inp_task = False
590
- iter_dat = self.iterable_parameters.get(inp.typ)
591
- if iter_dat:
614
+ if iter_dat := self.iterable_parameters.get(inp.typ):
592
615
  is_inp_task = task.insert_ID == iter_dat["input_task"]
593
616
 
594
- if is_inp_task:
595
- # source from final output task of previous iteration, with all parent
596
- # loop indices the same as previous iteration, and all child loop indices
597
- # maximised:
598
-
599
- # identify element(s) from which this iterable input should be
600
- # parametrised:
601
- if task.insert_ID == iter_dat["output_tasks"][-1]:
602
- src_elem_ID = elem_ID
603
- grouped_elems = None
604
- else:
605
- src_elem_IDs_all = cache.element_dependents[elem_ID]
606
- src_elem_IDs = {
607
- k: v
608
- for k, v in src_elem_IDs_all.items()
609
- if cache.elements[k]["task_insert_ID"]
610
- == iter_dat["output_tasks"][-1]
611
- }
612
- # consider groups
613
- inp_group_name = inp.single_labelled_data.get("group")
614
- grouped_elems = []
615
- for src_elem_j_ID, src_elem_j_dat in src_elem_IDs.items():
616
- i_in_group = any(
617
- k == inp_group_name
618
- for k in src_elem_j_dat["group_names"]
619
- )
620
- if i_in_group:
621
- grouped_elems.append(src_elem_j_ID)
622
-
623
- if not grouped_elems and len(src_elem_IDs) > 1:
624
- raise NotImplementedError(
625
- f"Multiple elements found in the iterable parameter "
626
- f"{inp!r}'s latest output task (insert ID: "
627
- f"{iter_dat['output_tasks'][-1]}) that can be used "
628
- f"to parametrise the next iteration: "
629
- f"{list(src_elem_IDs.keys())!r}."
630
- )
631
-
632
- elif not src_elem_IDs:
633
- # TODO: maybe OK?
634
- raise NotImplementedError(
635
- f"No elements found in the iterable parameter "
636
- f"{inp!r}'s latest output task (insert ID: "
637
- f"{iter_dat['output_tasks'][-1]}) that can be used "
638
- f"to parametrise the next iteration."
639
- )
640
-
641
- else:
642
- src_elem_ID = nth_key(src_elem_IDs, 0)
643
-
644
- child_loop_max_iters = {}
645
- parent_loop_same_iters = {
646
- i.name: parent_loop_indices[i.name] for i in parent_loops
647
- }
648
- child_iter_parents = {
649
- **parent_loop_same_iters,
650
- self.name: cur_loop_idx,
651
- }
652
- for i in child_loops:
653
- i_num_iters = i.num_added_iterations[
654
- tuple(child_iter_parents[j] for j in i.parents)
655
- ]
656
- i_max = i_num_iters - 1
657
- child_iter_parents[i.name] = i_max
658
- child_loop_max_iters[i.name] = i_max
659
-
660
- source_iter_loop_idx = {
661
- **child_loop_max_iters,
662
- **parent_loop_same_iters,
663
- self.name: cur_loop_idx,
664
- }
665
-
666
- # identify the ElementIteration from which this input should be
667
- # parametrised:
668
- loop_idx_key = tuple(sorted(source_iter_loop_idx.items()))
669
- if grouped_elems:
670
- src_data_idx = []
671
- for src_elem_ID in grouped_elems:
672
- src_data_idx.append(
673
- cache.data_idx[src_elem_ID][loop_idx_key]
674
- )
675
- else:
676
- src_data_idx = cache.data_idx[src_elem_ID][loop_idx_key]
677
-
678
- if not src_data_idx:
679
- raise RuntimeError(
680
- f"Could not find a source iteration with loop_idx: "
681
- f"{source_iter_loop_idx!r}."
682
- )
683
-
684
- if grouped_elems:
685
- inp_dat_idx = [i[f"outputs.{inp.typ}"] for i in src_data_idx]
686
- else:
687
- inp_dat_idx = src_data_idx[f"outputs.{inp.typ}"]
688
- new_data_idx[f"inputs.{inp.typ}"] = inp_dat_idx
617
+ inp_key = f"inputs.{inp.typ}"
689
618
 
619
+ if is_inp_task:
620
+ assert iter_dat is not None
621
+ inp_dat_idx = self.__get_looped_index(
622
+ task,
623
+ elem_ID,
624
+ cache,
625
+ iter_dat,
626
+ inp,
627
+ parent_loops,
628
+ parent_loop_indices_,
629
+ child_loops,
630
+ cur_loop_idx,
631
+ )
632
+ new_data_idx[inp_key] = inp_dat_idx
690
633
  else:
691
- inp_key = f"inputs.{inp.typ}"
692
-
693
634
  orig_inp_src = cache.elements[elem_ID]["input_sources"][inp_key]
694
635
  inp_dat_idx = None
695
636
 
@@ -709,77 +650,16 @@ class WorkflowLoop:
709
650
  inp_dat_idx = zi_iter_data_idx[inp_key]
710
651
 
711
652
  elif orig_inp_src.source_type is InputSourceType.TASK:
712
- if orig_inp_src.task_ref not in self.task_insert_IDs:
713
- # source the data_idx from the iteration with same parent
714
- # loop indices as the new iteration to add:
715
- # src_iters = []
716
- src_data_idx = []
717
- for li_k, di_k in cache.data_idx[elem_ID].items():
718
- skip_iter = False
719
- li_k_dct = dict(li_k)
720
- for p_k, p_v in parent_loop_indices.items():
721
- if li_k_dct.get(p_k) != p_v:
722
- skip_iter = True
723
- break
724
- if not skip_iter:
725
- src_data_idx.append(di_k)
726
-
727
- # could be multiple, but they should all have the same
728
- # data index for this parameter:
729
- src_data_idx = src_data_idx[0]
730
- inp_dat_idx = src_data_idx[inp_key]
731
- else:
732
- is_group = False
733
- if (
734
- not inp.multiple
735
- and "group" in inp.single_labelled_data
736
- ):
737
- # this input is a group, assume for now all elements:
738
- is_group = True
739
-
740
- # same task/element, but update iteration to the just-added
741
- # iteration:
742
- key_prefix = orig_inp_src.task_source_type.name.lower()
743
- prev_dat_idx_key = f"{key_prefix}s.{inp.typ}"
744
- new_sources = []
745
- for (
746
- tiID,
747
- e_idx,
748
- ), prev_dat_idx in all_new_data_idx.items():
749
- if tiID == orig_inp_src.task_ref:
750
- # find which element in that task `element`
751
- # depends on:
752
- task_i = self.workflow.tasks.get(insert_ID=tiID)
753
- elem_i_ID = task_i.element_IDs[e_idx]
754
- src_elem_IDs_all = cache.element_dependents[
755
- elem_i_ID
756
- ]
757
- src_elem_IDs_i = {
758
- k: v
759
- for k, v in src_elem_IDs_all.items()
760
- if cache.elements[k]["task_insert_ID"]
761
- == task.insert_ID
762
- }
763
-
764
- # filter src_elem_IDs_i for matching element IDs:
765
- src_elem_IDs_i = [
766
- i for i in src_elem_IDs_i if i == elem_ID
767
- ]
768
- if (
769
- len(src_elem_IDs_i) == 1
770
- and src_elem_IDs_i[0] == elem_ID
771
- ):
772
- new_sources.append((tiID, e_idx))
773
-
774
- if is_group:
775
- inp_dat_idx = [
776
- all_new_data_idx[i][prev_dat_idx_key]
777
- for i in new_sources
778
- ]
779
- else:
780
- assert len(new_sources) == 1
781
- prev_dat_idx = all_new_data_idx[new_sources[0]]
782
- inp_dat_idx = prev_dat_idx[prev_dat_idx_key]
653
+ inp_dat_idx = self.__get_task_index(
654
+ task,
655
+ orig_inp_src,
656
+ cache,
657
+ elem_ID,
658
+ inp,
659
+ inp_key,
660
+ parent_loop_indices_,
661
+ all_new_data_idx,
662
+ )
783
663
 
784
664
  if inp_dat_idx is None:
785
665
  raise RuntimeError(
@@ -791,9 +671,8 @@ class WorkflowLoop:
791
671
 
792
672
  # add any locally defined sub-parameters:
793
673
  inp_statuses = cache.elements[elem_ID]["input_statuses"]
794
- inp_status_inps = set([f"inputs.{i}" for i in inp_statuses])
795
- sub_params = inp_status_inps - set(new_data_idx.keys())
796
- for sub_param_i in sub_params:
674
+ inp_status_inps = set(f"inputs.{inp}" for inp in inp_statuses)
675
+ for sub_param_i in inp_status_inps.difference(new_data_idx):
797
676
  sub_param_data_idx_iter_0 = zi_data_idx
798
677
  try:
799
678
  sub_param_data_idx = sub_param_data_idx_iter_0[sub_param_i]
@@ -808,13 +687,11 @@ class WorkflowLoop:
808
687
 
809
688
  for out in task.template.all_schema_outputs:
810
689
  path_i = f"outputs.{out.typ}"
811
- p_src = {"type": "EAR_output"}
690
+ p_src: ParamSource = {"type": "EAR_output"}
812
691
  new_data_idx[path_i] = self.workflow._add_unset_parameter_data(p_src)
813
692
 
814
- schema_params = set(
815
- i for i in new_data_idx.keys() if len(i.split(".")) == 2
816
- )
817
- all_new_data_idx[(task.insert_ID, elem_idx)] = new_data_idx
693
+ schema_params = set(i for i in new_data_idx if len(i.split(".")) == 2)
694
+ all_new_data_idx[task.insert_ID, elem_idx] = new_data_idx
818
695
 
819
696
  iter_ID_i = self.workflow._store.add_element_iteration(
820
697
  element_ID=elem_ID,
@@ -835,8 +712,9 @@ class WorkflowLoop:
835
712
 
836
713
  task.initialise_EARs(iter_IDs=added_iter_IDs)
837
714
 
838
- added_iters_key = tuple(parent_loop_indices[k] for k in self.parents)
839
- self._increment_pending_added_iters(added_iters_key)
715
+ self._increment_pending_added_iters(
716
+ parent_loop_indices_[p_nm] for p_nm in self.parents
717
+ )
840
718
  self.workflow._store.update_loop_num_iters(
841
719
  index=self.index,
842
720
  num_added_iters=self.num_added_iterations,
@@ -846,17 +724,185 @@ class WorkflowLoop:
846
724
  for child in child_loops[::-1]:
847
725
  if child.num_iterations is not None:
848
726
  for _ in range(child.num_iterations - 1):
849
- par_idx = {k: 0 for k in child.parents}
850
- child.add_iteration(
851
- parent_loop_indices={
852
- **par_idx,
853
- **parent_loop_indices,
854
- self.name: cur_loop_idx + 1,
855
- },
856
- cache=cache,
857
- )
727
+ par_idx = {parent_name: 0 for parent_name in child.parents}
728
+ if parent_loop_indices:
729
+ par_idx.update(parent_loop_indices)
730
+ par_idx[self.name] = cur_loop_idx + 1
731
+ child.add_iteration(parent_loop_indices=par_idx, cache=cache)
732
+
733
+ def __get_src_ID_and_groups(
734
+ self, elem_ID: int, iter_dat: IterableParam, inp: SchemaInput, cache: LoopCache
735
+ ) -> tuple[int, Sequence[int]]:
736
+ src_elem_IDs = {
737
+ k: v
738
+ for k, v in cache.element_dependents[elem_ID].items()
739
+ if cache.elements[k]["task_insert_ID"] == iter_dat["output_tasks"][-1]
740
+ }
741
+ # consider groups
742
+ single_data = inp.single_labelled_data
743
+ assert single_data is not None
744
+ inp_group_name = single_data.get("group")
745
+ grouped_elems = [
746
+ src_elem_j_ID
747
+ for src_elem_j_ID, src_elem_j_dat in src_elem_IDs.items()
748
+ if any(nm == inp_group_name for nm in src_elem_j_dat["group_names"])
749
+ ]
750
+
751
+ if not grouped_elems and len(src_elem_IDs) > 1:
752
+ raise NotImplementedError(
753
+ f"Multiple elements found in the iterable parameter "
754
+ f"{inp!r}'s latest output task (insert ID: "
755
+ f"{iter_dat['output_tasks'][-1]}) that can be used "
756
+ f"to parametrise the next iteration: "
757
+ f"{list(src_elem_IDs)!r}."
758
+ )
759
+
760
+ elif not src_elem_IDs:
761
+ # TODO: maybe OK?
762
+ raise NotImplementedError(
763
+ f"No elements found in the iterable parameter "
764
+ f"{inp!r}'s latest output task (insert ID: "
765
+ f"{iter_dat['output_tasks'][-1]}) that can be used "
766
+ f"to parametrise the next iteration."
767
+ )
768
+
769
+ return nth_key(src_elem_IDs, 0), grouped_elems
770
+
771
+ def __get_looped_index(
772
+ self,
773
+ task: WorkflowTask,
774
+ elem_ID: int,
775
+ cache: LoopCache,
776
+ iter_dat: IterableParam,
777
+ inp: SchemaInput,
778
+ parent_loops: list[WorkflowLoop],
779
+ parent_loop_indices: Mapping[str, int],
780
+ child_loops: list[WorkflowLoop],
781
+ cur_loop_idx: int,
782
+ ):
783
+ # source from final output task of previous iteration, with all parent
784
+ # loop indices the same as previous iteration, and all child loop indices
785
+ # maximised:
786
+
787
+ # identify element(s) from which this iterable input should be
788
+ # parametrised:
789
+ if task.insert_ID == iter_dat["output_tasks"][-1]:
790
+ src_elem_ID = elem_ID
791
+ grouped_elems: Sequence[int] = []
792
+ else:
793
+ src_elem_ID, grouped_elems = self.__get_src_ID_and_groups(
794
+ elem_ID, iter_dat, inp, cache
795
+ )
796
+
797
+ child_loop_max_iters: dict[str, int] = {}
798
+ parent_loop_same_iters = {
799
+ loop.name: parent_loop_indices[loop.name] for loop in parent_loops
800
+ }
801
+ child_iter_parents = {
802
+ **parent_loop_same_iters,
803
+ self.name: cur_loop_idx,
804
+ }
805
+ for loop in child_loops:
806
+ i_num_iters = loop.num_added_iterations[
807
+ tuple(child_iter_parents[j] for j in loop.parents)
808
+ ]
809
+ i_max = i_num_iters - 1
810
+ child_iter_parents[loop.name] = i_max
811
+ child_loop_max_iters[loop.name] = i_max
812
+
813
+ loop_idx_key = LoopIndex(child_loop_max_iters)
814
+ loop_idx_key.update(parent_loop_same_iters)
815
+ loop_idx_key[self.name] = cur_loop_idx
816
+
817
+ # identify the ElementIteration from which this input should be
818
+ # parametrised:
819
+ if grouped_elems:
820
+ src_data_idx = [
821
+ cache.data_idx[src_elem_ID][loop_idx_key] for src_elem_ID in grouped_elems
822
+ ]
823
+ if not src_data_idx:
824
+ raise RuntimeError(
825
+ f"Could not find a source iteration with loop_idx: "
826
+ f"{loop_idx_key!r}."
827
+ )
828
+ return [i[f"outputs.{inp.typ}"] for i in src_data_idx]
829
+ else:
830
+ return cache.data_idx[src_elem_ID][loop_idx_key][f"outputs.{inp.typ}"]
831
+
832
+ def __get_task_index(
833
+ self,
834
+ task: WorkflowTask,
835
+ orig_inp_src: InputSource,
836
+ cache: LoopCache,
837
+ elem_ID: int,
838
+ inp: SchemaInput,
839
+ inp_key: str,
840
+ parent_loop_indices: Mapping[str, int],
841
+ all_new_data_idx: Mapping[tuple[int, int], DataIndex],
842
+ ) -> int | list[int]:
843
+ if orig_inp_src.task_ref not in self.task_insert_IDs:
844
+ # source the data_idx from the iteration with same parent
845
+ # loop indices as the new iteration to add:
846
+ src_data_idx = next(
847
+ di_k
848
+ for li_k, di_k in cache.data_idx[elem_ID].items()
849
+ if all(li_k.get(p_k) == p_v for p_k, p_v in parent_loop_indices.items())
850
+ )
851
+
852
+ # could be multiple, but they should all have the same
853
+ # data index for this parameter:
854
+ return src_data_idx[inp_key]
855
+
856
+ is_group = (
857
+ inp.single_labelled_data is not None
858
+ and "group" in inp.single_labelled_data
859
+ # this input is a group, assume for now all elements
860
+ )
861
+
862
+ # same task/element, but update iteration to the just-added
863
+ # iteration:
864
+ assert orig_inp_src.task_source_type is not None
865
+ key_prefix = orig_inp_src.task_source_type.name.lower()
866
+ prev_dat_idx_key = f"{key_prefix}s.{inp.typ}"
867
+ new_sources: list[tuple[int, int]] = []
868
+ for (tiID, e_idx), _ in all_new_data_idx.items():
869
+ if tiID == orig_inp_src.task_ref:
870
+ # find which element in that task `element`
871
+ # depends on:
872
+ src_elem_IDs = cache.element_dependents[
873
+ self.workflow.tasks.get(insert_ID=tiID).element_IDs[e_idx]
874
+ ]
875
+ # filter src_elem_IDs_i for matching element IDs:
876
+ src_elem_IDs_i = [
877
+ k
878
+ for k, _v in src_elem_IDs.items()
879
+ if cache.elements[k]["task_insert_ID"] == task.insert_ID
880
+ and k == elem_ID
881
+ ]
882
+
883
+ if len(src_elem_IDs_i) == 1:
884
+ new_sources.append((tiID, e_idx))
885
+
886
+ if is_group:
887
+ # Convert into simple list of indices
888
+ return list(
889
+ chain.from_iterable(
890
+ self.__as_sequence(all_new_data_idx[src][prev_dat_idx_key])
891
+ for src in new_sources
892
+ )
893
+ )
894
+ else:
895
+ assert len(new_sources) == 1
896
+ return all_new_data_idx[new_sources[0]][prev_dat_idx_key]
897
+
898
+ @staticmethod
899
+ def __as_sequence(seq: int | Iterable[int]) -> Iterable[int]:
900
+ if isinstance(seq, int):
901
+ yield seq
902
+ else:
903
+ yield from seq
858
904
 
859
- def test_termination(self, element_iter):
905
+ def test_termination(self, element_iter) -> bool:
860
906
  """Check if a loop should terminate, given the specified completed element
861
907
  iteration."""
862
908
  if self.template.termination: