hpcflow-new2 0.2.0a190__py3-none-any.whl → 0.2.0a200__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. hpcflow/__pyinstaller/hook-hpcflow.py +1 -0
  2. hpcflow/_version.py +1 -1
  3. hpcflow/data/scripts/bad_script.py +2 -0
  4. hpcflow/data/scripts/do_nothing.py +2 -0
  5. hpcflow/data/scripts/env_specifier_test/input_file_generator_pass_env_spec.py +4 -0
  6. hpcflow/data/scripts/env_specifier_test/main_script_test_pass_env_spec.py +8 -0
  7. hpcflow/data/scripts/env_specifier_test/output_file_parser_pass_env_spec.py +4 -0
  8. hpcflow/data/scripts/env_specifier_test/v1/input_file_generator_basic.py +4 -0
  9. hpcflow/data/scripts/env_specifier_test/v1/main_script_test_direct_in_direct_out.py +7 -0
  10. hpcflow/data/scripts/env_specifier_test/v1/output_file_parser_basic.py +4 -0
  11. hpcflow/data/scripts/env_specifier_test/v2/main_script_test_direct_in_direct_out.py +7 -0
  12. hpcflow/data/scripts/input_file_generator_basic.py +3 -0
  13. hpcflow/data/scripts/input_file_generator_basic_FAIL.py +3 -0
  14. hpcflow/data/scripts/input_file_generator_test_stdout_stderr.py +8 -0
  15. hpcflow/data/scripts/main_script_test_direct_in.py +3 -0
  16. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2.py +6 -0
  17. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed.py +6 -0
  18. hpcflow/data/scripts/main_script_test_direct_in_direct_out_2_fail_allowed_group.py +7 -0
  19. hpcflow/data/scripts/main_script_test_direct_in_direct_out_3.py +6 -0
  20. hpcflow/data/scripts/main_script_test_direct_in_group_direct_out_3.py +6 -0
  21. hpcflow/data/scripts/main_script_test_direct_in_group_one_fail_direct_out_3.py +6 -0
  22. hpcflow/data/scripts/main_script_test_hdf5_in_obj_2.py +12 -0
  23. hpcflow/data/scripts/main_script_test_json_out_FAIL.py +3 -0
  24. hpcflow/data/scripts/main_script_test_shell_env_vars.py +12 -0
  25. hpcflow/data/scripts/main_script_test_std_out_std_err.py +6 -0
  26. hpcflow/data/scripts/output_file_parser_basic.py +3 -0
  27. hpcflow/data/scripts/output_file_parser_basic_FAIL.py +7 -0
  28. hpcflow/data/scripts/output_file_parser_test_stdout_stderr.py +8 -0
  29. hpcflow/data/scripts/script_exit_test.py +5 -0
  30. hpcflow/data/template_components/environments.yaml +1 -1
  31. hpcflow/sdk/__init__.py +5 -0
  32. hpcflow/sdk/app.py +166 -92
  33. hpcflow/sdk/cli.py +263 -84
  34. hpcflow/sdk/cli_common.py +99 -5
  35. hpcflow/sdk/config/callbacks.py +38 -1
  36. hpcflow/sdk/config/config.py +102 -13
  37. hpcflow/sdk/config/errors.py +19 -5
  38. hpcflow/sdk/config/types.py +3 -0
  39. hpcflow/sdk/core/__init__.py +25 -1
  40. hpcflow/sdk/core/actions.py +914 -262
  41. hpcflow/sdk/core/cache.py +76 -34
  42. hpcflow/sdk/core/command_files.py +14 -128
  43. hpcflow/sdk/core/commands.py +35 -6
  44. hpcflow/sdk/core/element.py +122 -50
  45. hpcflow/sdk/core/errors.py +58 -2
  46. hpcflow/sdk/core/execute.py +207 -0
  47. hpcflow/sdk/core/loop.py +408 -50
  48. hpcflow/sdk/core/loop_cache.py +4 -4
  49. hpcflow/sdk/core/parameters.py +382 -37
  50. hpcflow/sdk/core/run_dir_files.py +13 -40
  51. hpcflow/sdk/core/skip_reason.py +7 -0
  52. hpcflow/sdk/core/task.py +119 -30
  53. hpcflow/sdk/core/task_schema.py +68 -0
  54. hpcflow/sdk/core/test_utils.py +66 -27
  55. hpcflow/sdk/core/types.py +54 -1
  56. hpcflow/sdk/core/utils.py +136 -19
  57. hpcflow/sdk/core/workflow.py +1587 -356
  58. hpcflow/sdk/data/workflow_spec_schema.yaml +2 -0
  59. hpcflow/sdk/demo/cli.py +7 -0
  60. hpcflow/sdk/helper/cli.py +1 -0
  61. hpcflow/sdk/log.py +42 -15
  62. hpcflow/sdk/persistence/base.py +405 -53
  63. hpcflow/sdk/persistence/json.py +177 -52
  64. hpcflow/sdk/persistence/pending.py +237 -69
  65. hpcflow/sdk/persistence/store_resource.py +3 -2
  66. hpcflow/sdk/persistence/types.py +15 -4
  67. hpcflow/sdk/persistence/zarr.py +928 -81
  68. hpcflow/sdk/submission/jobscript.py +1408 -489
  69. hpcflow/sdk/submission/schedulers/__init__.py +40 -5
  70. hpcflow/sdk/submission/schedulers/direct.py +33 -19
  71. hpcflow/sdk/submission/schedulers/sge.py +51 -16
  72. hpcflow/sdk/submission/schedulers/slurm.py +44 -16
  73. hpcflow/sdk/submission/schedulers/utils.py +7 -2
  74. hpcflow/sdk/submission/shells/base.py +68 -20
  75. hpcflow/sdk/submission/shells/bash.py +222 -129
  76. hpcflow/sdk/submission/shells/powershell.py +200 -150
  77. hpcflow/sdk/submission/submission.py +852 -119
  78. hpcflow/sdk/submission/types.py +18 -21
  79. hpcflow/sdk/typing.py +24 -5
  80. hpcflow/sdk/utils/arrays.py +71 -0
  81. hpcflow/sdk/utils/deferred_file.py +55 -0
  82. hpcflow/sdk/utils/hashing.py +16 -0
  83. hpcflow/sdk/utils/patches.py +12 -0
  84. hpcflow/sdk/utils/strings.py +33 -0
  85. hpcflow/tests/api/test_api.py +32 -0
  86. hpcflow/tests/conftest.py +19 -0
  87. hpcflow/tests/data/benchmark_script_runner.yaml +26 -0
  88. hpcflow/tests/data/multi_path_sequences.yaml +29 -0
  89. hpcflow/tests/data/workflow_test_run_abort.yaml +34 -35
  90. hpcflow/tests/schedulers/sge/test_sge_submission.py +36 -0
  91. hpcflow/tests/scripts/test_input_file_generators.py +282 -0
  92. hpcflow/tests/scripts/test_main_scripts.py +821 -70
  93. hpcflow/tests/scripts/test_non_snippet_script.py +46 -0
  94. hpcflow/tests/scripts/test_ouput_file_parsers.py +353 -0
  95. hpcflow/tests/shells/wsl/test_wsl_submission.py +6 -0
  96. hpcflow/tests/unit/test_action.py +176 -0
  97. hpcflow/tests/unit/test_app.py +20 -0
  98. hpcflow/tests/unit/test_cache.py +46 -0
  99. hpcflow/tests/unit/test_cli.py +133 -0
  100. hpcflow/tests/unit/test_config.py +122 -1
  101. hpcflow/tests/unit/test_element_iteration.py +47 -0
  102. hpcflow/tests/unit/test_jobscript_unit.py +757 -0
  103. hpcflow/tests/unit/test_loop.py +1332 -27
  104. hpcflow/tests/unit/test_meta_task.py +325 -0
  105. hpcflow/tests/unit/test_multi_path_sequences.py +229 -0
  106. hpcflow/tests/unit/test_parameter.py +13 -0
  107. hpcflow/tests/unit/test_persistence.py +190 -8
  108. hpcflow/tests/unit/test_run.py +109 -3
  109. hpcflow/tests/unit/test_run_directories.py +29 -0
  110. hpcflow/tests/unit/test_shell.py +20 -0
  111. hpcflow/tests/unit/test_submission.py +5 -76
  112. hpcflow/tests/unit/test_workflow_template.py +31 -0
  113. hpcflow/tests/unit/utils/test_arrays.py +40 -0
  114. hpcflow/tests/unit/utils/test_deferred_file_writer.py +34 -0
  115. hpcflow/tests/unit/utils/test_hashing.py +65 -0
  116. hpcflow/tests/unit/utils/test_patches.py +5 -0
  117. hpcflow/tests/unit/utils/test_redirect_std.py +50 -0
  118. hpcflow/tests/workflows/__init__.py +0 -0
  119. hpcflow/tests/workflows/test_directory_structure.py +31 -0
  120. hpcflow/tests/workflows/test_jobscript.py +332 -0
  121. hpcflow/tests/workflows/test_run_status.py +198 -0
  122. hpcflow/tests/workflows/test_skip_downstream.py +696 -0
  123. hpcflow/tests/workflows/test_submission.py +140 -0
  124. hpcflow/tests/workflows/test_workflows.py +142 -2
  125. hpcflow/tests/workflows/test_zip.py +18 -0
  126. hpcflow/viz_demo.ipynb +6587 -3
  127. {hpcflow_new2-0.2.0a190.dist-info → hpcflow_new2-0.2.0a200.dist-info}/METADATA +7 -4
  128. hpcflow_new2-0.2.0a200.dist-info/RECORD +222 -0
  129. hpcflow_new2-0.2.0a190.dist-info/RECORD +0 -165
  130. {hpcflow_new2-0.2.0a190.dist-info → hpcflow_new2-0.2.0a200.dist-info}/LICENSE +0 -0
  131. {hpcflow_new2-0.2.0a190.dist-info → hpcflow_new2-0.2.0a200.dist-info}/WHEEL +0 -0
  132. {hpcflow_new2-0.2.0a190.dist-info → hpcflow_new2-0.2.0a200.dist-info}/entry_points.txt +0 -0
hpcflow/sdk/core/loop.py CHANGED
@@ -6,24 +6,33 @@ notably looping over a set of values or until a condition holds.
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
+ from collections import defaultdict
9
10
  import copy
11
+ from pprint import pp
12
+ import pprint
13
+ from typing import Dict, List, Optional, Tuple, Union, Any
14
+ from warnings import warn
10
15
  from collections import defaultdict
11
16
  from itertools import chain
12
- from typing import TYPE_CHECKING
17
+ from typing import cast, TYPE_CHECKING
13
18
  from typing_extensions import override
14
19
 
15
20
  from hpcflow.sdk.core.app_aware import AppAware
21
+ from hpcflow.sdk.core.actions import EARStatus
22
+ from hpcflow.sdk.core.skip_reason import SkipReason
16
23
  from hpcflow.sdk.core.errors import LoopTaskSubsetError
17
24
  from hpcflow.sdk.core.json_like import ChildObjectSpec, JSONLike
18
25
  from hpcflow.sdk.core.loop_cache import LoopCache, LoopIndex
19
- from hpcflow.sdk.core.enums import InputSourceType
26
+ from hpcflow.sdk.core.enums import InputSourceType, TaskSourceType
20
27
  from hpcflow.sdk.core.utils import check_valid_py_identifier, nth_key, nth_value
28
+ from hpcflow.sdk.utils.strings import shorten_list_str
21
29
  from hpcflow.sdk.log import TimeIt
22
30
 
23
31
  if TYPE_CHECKING:
24
32
  from collections.abc import Iterable, Iterator, Mapping, Sequence
25
33
  from typing import Any, ClassVar
26
34
  from typing_extensions import Self, TypeIs
35
+ from rich.status import Status
27
36
  from ..typing import DataIndex, ParamSource
28
37
  from .parameters import SchemaInput, InputSource
29
38
  from .rule import Rule
@@ -62,6 +71,8 @@ class Loop(JSONLike):
62
71
  Specify input parameters that should not iterate.
63
72
  termination: v~hpcflow.app.Rule
64
73
  Stopping criterion, expressed as a rule.
74
+ termination_task: int | ~hpcflow.app.WorkflowTask
75
+ Task at which to evaluate the termination condition.
65
76
  """
66
77
 
67
78
  _child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
@@ -79,6 +90,7 @@ class Loop(JSONLike):
79
90
  name: str | None = None,
80
91
  non_iterable_parameters: list[str] | None = None,
81
92
  termination: Rule | None = None,
93
+ termination_task: int | WorkflowTask | None = None,
82
94
  ) -> None:
83
95
  _task_insert_IDs: list[int] = []
84
96
  for task in tasks:
@@ -89,14 +101,34 @@ class Loop(JSONLike):
89
101
  else:
90
102
  raise TypeError(
91
103
  f"`tasks` must be a list whose elements are either task insert IDs "
92
- f"or WorkflowTask objects, but received the following: {tasks!r}."
104
+ f"or `WorkflowTask` objects, but received the following: {tasks!r}."
93
105
  )
94
106
 
107
+ if termination_task is None:
108
+ _term_task_iID = _task_insert_IDs[-1] # terminate on final task by default
109
+ elif self.__is_WorkflowTask(termination_task):
110
+ _term_task_iID = termination_task.insert_ID
111
+ elif isinstance(task, int):
112
+ _term_task_iID = termination_task
113
+ else:
114
+ raise TypeError(
115
+ f"`termination_task` must be a task insert ID or a `WorkflowTask` "
116
+ f"object, but received the following: {termination_task!r}."
117
+ )
118
+
119
+ if _term_task_iID not in _task_insert_IDs:
120
+ raise ValueError(
121
+ f"If specified, `termination_task` (provided: {termination_task!r}) must "
122
+ f"refer to a task that is part of the loop. Available task insert IDs "
123
+ f"are: {_task_insert_IDs!r}."
124
+ )
125
+
95
126
  self._task_insert_IDs = _task_insert_IDs
96
127
  self._num_iterations = num_iterations
97
128
  self._name = check_valid_py_identifier(name) if name else name
98
129
  self._non_iterable_parameters = non_iterable_parameters or []
99
130
  self._termination = termination
131
+ self._termination_task_insert_ID = _term_task_iID
100
132
 
101
133
  self._workflow_template: WorkflowTemplate | None = (
102
134
  None # assigned by parent WorkflowTemplate
@@ -114,7 +146,15 @@ class Loop(JSONLike):
114
146
  insert_IDs = json_like.pop("task_insert_IDs")
115
147
  else:
116
148
  insert_IDs = json_like.pop("tasks")
117
- return cls(tasks=insert_IDs, **json_like)
149
+
150
+ if "termination_task_insert_ID" in json_like:
151
+ tt_iID = json_like.pop("termination_task_insert_ID")
152
+ elif "termination_task" in json_like:
153
+ tt_iID = json_like.pop("termination_task")
154
+ else:
155
+ tt_iID = None
156
+
157
+ return cls(tasks=insert_IDs, termination_task=tt_iID, **json_like)
118
158
 
119
159
  @property
120
160
  def task_insert_IDs(self) -> tuple[int, ...]:
@@ -149,6 +189,25 @@ class Loop(JSONLike):
149
189
  """
150
190
  return self._termination
151
191
 
192
+ @property
193
+ def termination_task_insert_ID(self) -> int:
194
+ """
195
+ The insert ID of the task at which the loop will terminate.
196
+ """
197
+ return self._termination_task_insert_ID
198
+
199
+ @property
200
+ def termination_task(self) -> WorkflowTask:
201
+ """
202
+ The task at which the loop will terminate.
203
+ """
204
+ if (wt := self.workflow_template) is None:
205
+ raise RuntimeError(
206
+ "Workflow template must be assigned to retrieve task objects of the loop."
207
+ )
208
+ assert wt.workflow
209
+ return wt.workflow.tasks.get(insert_ID=self.termination_task_insert_ID)
210
+
152
211
  @property
153
212
  def workflow_template(self) -> WorkflowTemplate | None:
154
213
  """
@@ -212,6 +271,7 @@ class Loop(JSONLike):
212
271
  def __deepcopy__(self, memo: dict[int, Any]) -> Self:
213
272
  kwargs = self.to_dict()
214
273
  kwargs["tasks"] = kwargs.pop("task_insert_IDs")
274
+ kwargs["termination_task"] = kwargs.pop("termination_task_insert_ID")
215
275
  obj = self.__class__(**copy.deepcopy(kwargs, memo))
216
276
  obj._workflow_template = self._workflow_template
217
277
  return obj
@@ -234,6 +294,9 @@ class WorkflowLoop(AppAware):
234
294
  Description of what iterations have been added.
235
295
  iterable_parameters:
236
296
  Description of what parameters are being iterated over.
297
+ output_parameters:
298
+ Decription of what parameter are output from this loop, and the final task insert
299
+ ID from which they are output.
237
300
  parents: list[str]
238
301
  The paths to the parent entities of this loop.
239
302
  """
@@ -245,6 +308,7 @@ class WorkflowLoop(AppAware):
245
308
  template: Loop,
246
309
  num_added_iterations: dict[tuple[int, ...], int],
247
310
  iterable_parameters: dict[str, IterableParam],
311
+ output_parameters: dict[str, int],
248
312
  parents: list[str],
249
313
  ) -> None:
250
314
  self._index = index
@@ -252,10 +316,11 @@ class WorkflowLoop(AppAware):
252
316
  self._template = template
253
317
  self._num_added_iterations = num_added_iterations
254
318
  self._iterable_parameters = iterable_parameters
319
+ self._output_parameters = output_parameters
255
320
  self._parents = parents
256
321
 
257
- # appended to on adding a empty loop to the workflow that's a parent of this loop,
258
- # reset and added to `self._parents` on dump to disk:
322
+ # appended to when adding an empty loop to the workflow that is a parent of this
323
+ # loop; reset and added to `self._parents` on dump to disk:
259
324
  self._pending_parents: list[str] = []
260
325
 
261
326
  # used for `num_added_iterations` when a new loop iteration is added, or when
@@ -273,16 +338,6 @@ class WorkflowLoop(AppAware):
273
338
  if task_indices != tuple(range(task_min, task_max + 1)):
274
339
  raise LoopTaskSubsetError(self.name, self.task_indices)
275
340
 
276
- for task in self.downstream_tasks:
277
- for param in self.iterable_parameters:
278
- if param in task.template.all_schema_input_types:
279
- raise NotImplementedError(
280
- f"Downstream task {task.unique_name!r} of loop {self.name!r} "
281
- f"has as one of its input parameters this loop's iterable "
282
- f"parameter {param!r}. This parameter cannot be sourced "
283
- f"correctly."
284
- )
285
-
286
341
  def __repr__(self) -> str:
287
342
  return (
288
343
  f"{self.__class__.__name__}(template={self.template!r}, "
@@ -404,12 +459,20 @@ class WorkflowLoop(AppAware):
404
459
  return self.template.name
405
460
 
406
461
  @property
407
- def iterable_parameters(self) -> Mapping[str, IterableParam]:
462
+ def iterable_parameters(self) -> dict[str, IterableParam]:
408
463
  """
409
464
  The parameters that are being iterated over.
410
465
  """
411
466
  return self._iterable_parameters
412
467
 
468
+ @property
469
+ def output_parameters(self) -> dict[str, int]:
470
+ """
471
+ The parameters that are outputs of this loop, and the final task insert ID from
472
+ which each parameter is output.
473
+ """
474
+ return self._output_parameters
475
+
413
476
  @property
414
477
  def num_iterations(self) -> int:
415
478
  """
@@ -433,7 +496,9 @@ class WorkflowLoop(AppAware):
433
496
 
434
497
  @staticmethod
435
498
  @TimeIt.decorator
436
- def _find_iterable_parameters(loop_template: Loop) -> dict[str, IterableParam]:
499
+ def _find_iterable_and_output_parameters(
500
+ loop_template: Loop,
501
+ ) -> tuple[dict[str, IterableParam], dict[str, int]]:
437
502
  all_inputs_first_idx: dict[str, int] = {}
438
503
  all_outputs_idx: dict[str, list[int]] = defaultdict(list)
439
504
  for task in loop_template.task_objects:
@@ -442,6 +507,7 @@ class WorkflowLoop(AppAware):
442
507
  for typ in task.template.all_schema_output_types:
443
508
  all_outputs_idx[typ].append(task.insert_ID)
444
509
 
510
+ # find input parameters that are also output parameters at a later/same task:
445
511
  iterable_params: dict[str, IterableParam] = {}
446
512
  for typ, first_idx in all_inputs_first_idx.items():
447
513
  if typ in all_outputs_idx and first_idx <= all_outputs_idx[typ][0]:
@@ -453,7 +519,9 @@ class WorkflowLoop(AppAware):
453
519
  for non_iter in loop_template.non_iterable_parameters:
454
520
  iterable_params.pop(non_iter, None)
455
521
 
456
- return iterable_params
522
+ final_out_tasks = {k: v[-1] for k, v in all_outputs_idx.items()}
523
+
524
+ return iterable_params, final_out_tasks
457
525
 
458
526
  @classmethod
459
527
  @TimeIt.decorator
@@ -478,21 +546,20 @@ class WorkflowLoop(AppAware):
478
546
  iter_loop_idx: list[dict]
479
547
  Iteration information from parent loops.
480
548
  """
481
- parent_names = [
482
- loop.name
483
- for loop in cls._get_parent_loops(index, workflow, template)
484
- if loop.name
485
- ]
486
- num_added_iters = {
487
- tuple(l_idx[nm] for nm in parent_names): 1 for l_idx in iter_loop_idx
488
- }
549
+ parent_loops = cls._get_parent_loops(index, workflow, template)
550
+ parent_names = [i.name for i in parent_loops]
551
+ num_added_iters: dict[tuple[int, ...], int] = {}
552
+ for i in iter_loop_idx:
553
+ num_added_iters[tuple([i[j] for j in parent_names])] = 1
489
554
 
555
+ iter_params, out_params = cls._find_iterable_and_output_parameters(template)
490
556
  return cls(
491
557
  index=index,
492
558
  workflow=workflow,
493
559
  template=template,
494
560
  num_added_iterations=num_added_iters,
495
- iterable_parameters=cls._find_iterable_parameters(template),
561
+ iterable_parameters=iter_params,
562
+ output_parameters=out_params,
496
563
  parents=parent_names,
497
564
  )
498
565
 
@@ -551,6 +618,7 @@ class WorkflowLoop(AppAware):
551
618
  self,
552
619
  parent_loop_indices: Mapping[str, int] | None = None,
553
620
  cache: LoopCache | None = None,
621
+ status: Status | None = None,
554
622
  ) -> None:
555
623
  """
556
624
  Add an iteration to this loop.
@@ -588,6 +656,13 @@ class WorkflowLoop(AppAware):
588
656
  iters_key_dct.get(j, 0) for j in child.parents
589
657
  )
590
658
 
659
+ # needed for the case where an inner loop has only one iteration, meaning
660
+ # `add_iteration` will not be called recursively on it:
661
+ self.workflow._store.update_loop_num_iters(
662
+ index=child.index,
663
+ num_added_iters=child.num_added_iterations,
664
+ )
665
+
591
666
  for task in self.task_objects:
592
667
  new_loop_idx = LoopIndex(iters_key_dct) + {
593
668
  child.name: 0
@@ -723,21 +798,40 @@ class WorkflowLoop(AppAware):
723
798
  # add iterations to fixed-number-iteration children only:
724
799
  for child in child_loops[::-1]:
725
800
  if child.num_iterations is not None:
726
- for _ in range(child.num_iterations - 1):
801
+ if status:
802
+ status_prev = str(status.status).rstrip(".")
803
+ for iter_idx in range(child.num_iterations - 1):
804
+ if status:
805
+ status.update(
806
+ f"{status_prev} --> ({child.name!r}): iteration "
807
+ f"{iter_idx + 2}/{child.num_iterations}."
808
+ )
727
809
  par_idx = {parent_name: 0 for parent_name in child.parents}
728
810
  if parent_loop_indices:
729
811
  par_idx.update(parent_loop_indices)
730
812
  par_idx[self.name] = cur_loop_idx + 1
731
813
  child.add_iteration(parent_loop_indices=par_idx, cache=cache)
732
814
 
815
+ self.__update_loop_downstream_data_idx(parent_loop_indices_)
816
+
733
817
  def __get_src_ID_and_groups(
734
- self, elem_ID: int, iter_dat: IterableParam, inp: SchemaInput, cache: LoopCache
818
+ self,
819
+ elem_ID: int,
820
+ iter_dat: IterableParam,
821
+ inp: SchemaInput,
822
+ cache: LoopCache,
823
+ task: WorkflowTask,
735
824
  ) -> tuple[int, Sequence[int]]:
736
- src_elem_IDs = {
737
- k: v
738
- for k, v in cache.element_dependents[elem_ID].items()
739
- if cache.elements[k]["task_insert_ID"] == iter_dat["output_tasks"][-1]
740
- }
825
+ # `cache.elements` contains only elements that are part of the
826
+ # loop, so indexing a dependent element may raise:
827
+ src_elem_IDs = {}
828
+ for k, v in cache.element_dependents[elem_ID].items():
829
+ try:
830
+ if cache.elements[k]["task_insert_ID"] == iter_dat["output_tasks"][-1]:
831
+ src_elem_IDs[k] = v
832
+ except KeyError:
833
+ continue
834
+
741
835
  # consider groups
742
836
  single_data = inp.single_labelled_data
743
837
  assert single_data is not None
@@ -753,7 +847,8 @@ class WorkflowLoop(AppAware):
753
847
  f"Multiple elements found in the iterable parameter "
754
848
  f"{inp!r}'s latest output task (insert ID: "
755
849
  f"{iter_dat['output_tasks'][-1]}) that can be used "
756
- f"to parametrise the next iteration: "
850
+ f"to parametrise the next iteration of task "
851
+ f"{task.unique_name!r}: "
757
852
  f"{list(src_elem_IDs)!r}."
758
853
  )
759
854
 
@@ -787,11 +882,13 @@ class WorkflowLoop(AppAware):
787
882
  # identify element(s) from which this iterable input should be
788
883
  # parametrised:
789
884
  if task.insert_ID == iter_dat["output_tasks"][-1]:
885
+ # single-task loop
790
886
  src_elem_ID = elem_ID
791
887
  grouped_elems: Sequence[int] = []
792
888
  else:
889
+ # multi-task loop
793
890
  src_elem_ID, grouped_elems = self.__get_src_ID_and_groups(
794
- elem_ID, iter_dat, inp, cache
891
+ elem_ID, iter_dat, inp, cache, task
795
892
  )
796
893
 
797
894
  child_loop_max_iters: dict[str, int] = {}
@@ -803,12 +900,13 @@ class WorkflowLoop(AppAware):
803
900
  self.name: cur_loop_idx,
804
901
  }
805
902
  for loop in child_loops:
806
- i_num_iters = loop.num_added_iterations[
807
- tuple(child_iter_parents[j] for j in loop.parents)
808
- ]
809
- i_max = i_num_iters - 1
810
- child_iter_parents[loop.name] = i_max
811
- child_loop_max_iters[loop.name] = i_max
903
+ if iter_dat["output_tasks"][-1] in loop.task_insert_IDs:
904
+ i_num_iters = loop.num_added_iterations[
905
+ tuple(child_iter_parents[j] for j in loop.parents)
906
+ ]
907
+ i_max = i_num_iters - 1
908
+ child_iter_parents[loop.name] = i_max
909
+ child_loop_max_iters[loop.name] = i_max
812
910
 
813
911
  loop_idx_key = LoopIndex(child_loop_max_iters)
814
912
  loop_idx_key.update(parent_loop_same_iters)
@@ -872,13 +970,20 @@ class WorkflowLoop(AppAware):
872
970
  src_elem_IDs = cache.element_dependents[
873
971
  self.workflow.tasks.get(insert_ID=tiID).element_IDs[e_idx]
874
972
  ]
875
- # filter src_elem_IDs_i for matching element IDs:
876
- src_elem_IDs_i = [
877
- k
878
- for k, _v in src_elem_IDs.items()
879
- if cache.elements[k]["task_insert_ID"] == task.insert_ID
880
- and k == elem_ID
881
- ]
973
+ # `cache.elements` contains only elements that are part of the loop, so
974
+ # indexing a dependent element may raise:
975
+ src_elem_IDs_i = []
976
+ for k, _v in src_elem_IDs.items():
977
+ try:
978
+ if (
979
+ cache.elements[k]["task_insert_ID"] == task.insert_ID
980
+ and k == elem_ID
981
+ # filter src_elem_IDs_i for matching element IDs
982
+ ):
983
+
984
+ src_elem_IDs_i.append(k)
985
+ except KeyError:
986
+ continue
882
987
 
883
988
  if len(src_elem_IDs_i) == 1:
884
989
  new_sources.append((tiID, e_idx))
@@ -902,9 +1007,262 @@ class WorkflowLoop(AppAware):
902
1007
  else:
903
1008
  yield from seq
904
1009
 
1010
+ def __update_loop_downstream_data_idx(
1011
+ self,
1012
+ parent_loop_indices: Mapping[str, int],
1013
+ ):
1014
+ # update data indices of loop-downstream tasks that depend on task outputs from
1015
+ # this loop:
1016
+
1017
+ # keys: iter or run ID, values: dict of param type and new parameter index
1018
+ iter_new_data_idx: dict[int, DataIndex] = defaultdict(dict)
1019
+ run_new_data_idx: dict[int, DataIndex] = defaultdict(dict)
1020
+
1021
+ param_sources = self.workflow.get_all_parameter_sources()
1022
+
1023
+ # keys are parameter type, then task insert ID, then data index keys mapping to
1024
+ # their updated values:
1025
+ all_updates: dict[str, dict[int, dict[int, int]]] = defaultdict(
1026
+ lambda: defaultdict(dict)
1027
+ )
1028
+
1029
+ for task in self.downstream_tasks:
1030
+ for elem in task.elements:
1031
+ for param_typ, param_out_task_iID in self.output_parameters.items():
1032
+ if param_typ in task.template.all_schema_input_types:
1033
+ # this element's input *might* need updating, only if it has a
1034
+ # task input source type that is this loop's output task for this
1035
+ # parameter:
1036
+ elem_src = elem.input_sources[f"inputs.{param_typ}"]
1037
+ if (
1038
+ elem_src.source_type is InputSourceType.TASK
1039
+ and elem_src.task_source_type is TaskSourceType.OUTPUT
1040
+ and elem_src.task_ref == param_out_task_iID
1041
+ ):
1042
+ for iter_i in elem.iterations:
1043
+
1044
+ # do not modify element-iterations of previous iterations
1045
+ # of the current loop:
1046
+ skip_iter = False
1047
+ for k, v in parent_loop_indices.items():
1048
+ if iter_i.loop_idx.get(k) != v:
1049
+ skip_iter = True
1050
+ break
1051
+
1052
+ if skip_iter:
1053
+ continue
1054
+
1055
+ # update the iteration data index and any pending runs:
1056
+ iter_old_di = iter_i.data_idx[f"inputs.{param_typ}"]
1057
+
1058
+ is_group = True
1059
+ if not isinstance(iter_old_di, list):
1060
+ is_group = False
1061
+ iter_old_di = [iter_old_di]
1062
+
1063
+ iter_old_run_source = [
1064
+ param_sources[i]["EAR_ID"] for i in iter_old_di
1065
+ ]
1066
+ iter_old_run_objs = self.workflow.get_EARs_from_IDs(
1067
+ iter_old_run_source
1068
+ ) # TODO: use cache
1069
+
1070
+ # need to check the run source is actually from the loop
1071
+ # output task (it could be from a previous iteration of a
1072
+ # separate loop in this task):
1073
+ if any(
1074
+ i.task.insert_ID != param_out_task_iID
1075
+ for i in iter_old_run_objs
1076
+ ):
1077
+ continue
1078
+
1079
+ iter_new_iters = [
1080
+ i.element.iterations[-1] for i in iter_old_run_objs
1081
+ ]
1082
+
1083
+ # note: we can cast to int, because output keys never
1084
+ # have multiple data indices (unlike input keys):
1085
+ iter_new_dis = [
1086
+ cast("int", i.get_data_idx()[f"outputs.{param_typ}"])
1087
+ for i in iter_new_iters
1088
+ ]
1089
+
1090
+ # keep track of updates so we can also update task-input
1091
+ # type sources:
1092
+ all_updates[param_typ][task.insert_ID].update(
1093
+ dict(zip(iter_old_di, iter_new_dis))
1094
+ )
1095
+
1096
+ iter_new_data_idx[iter_i.id_][f"inputs.{param_typ}"] = (
1097
+ iter_new_dis if is_group else iter_new_dis[0]
1098
+ )
1099
+
1100
+ for run_j in iter_i.action_runs:
1101
+ if run_j.status is EARStatus.pending:
1102
+ try:
1103
+ old_di = run_j.data_idx[f"inputs.{param_typ}"]
1104
+ except KeyError:
1105
+ # not all actions will include this input
1106
+ continue
1107
+
1108
+ is_group = True
1109
+ if not isinstance(old_di, list):
1110
+ is_group = False
1111
+ old_di = [old_di]
1112
+
1113
+ old_run_source = [
1114
+ param_sources[i]["EAR_ID"] for i in old_di
1115
+ ]
1116
+ old_run_objs = self.workflow.get_EARs_from_IDs(
1117
+ old_run_source
1118
+ ) # TODO: use cache
1119
+
1120
+ # need to check the run source is actually from the loop
1121
+ # output task (it could be from a previous action in this
1122
+ # element-iteration):
1123
+ if any(
1124
+ i.task.insert_ID != param_out_task_iID
1125
+ for i in old_run_objs
1126
+ ):
1127
+ continue
1128
+
1129
+ new_iters = [
1130
+ i.element.iterations[-1] for i in old_run_objs
1131
+ ]
1132
+
1133
+ # note: we can cast to int, because output keys
1134
+ # never have multiple data indices (unlike input
1135
+ # keys):
1136
+ new_dis = [
1137
+ cast(
1138
+ "int",
1139
+ i.get_data_idx()[f"outputs.{param_typ}"],
1140
+ )
1141
+ for i in new_iters
1142
+ ]
1143
+
1144
+ run_new_data_idx[run_j.id_][
1145
+ f"inputs.{param_typ}"
1146
+ ] = (new_dis if is_group else new_dis[0])
1147
+
1148
+ elif (
1149
+ elem_src.source_type is InputSourceType.TASK
1150
+ and elem_src.task_source_type is TaskSourceType.INPUT
1151
+ ):
1152
+ # parameters are that sourced from inputs of other tasks,
1153
+ # might need to be updated if those other tasks have
1154
+ # themselves had their data indices updated:
1155
+ assert elem_src.task_ref
1156
+ ups_i = all_updates.get(param_typ, {}).get(elem_src.task_ref)
1157
+
1158
+ if ups_i:
1159
+ # if a further-downstream task has a task-input source
1160
+ # that points to this task, this will also need updating:
1161
+ all_updates[param_typ][task.insert_ID].update(ups_i)
1162
+
1163
+ else:
1164
+ continue
1165
+
1166
+ for iter_i in elem.iterations:
1167
+
1168
+ # update the iteration data index and any pending runs:
1169
+ iter_old_di = iter_i.data_idx[f"inputs.{param_typ}"]
1170
+
1171
+ is_group = True
1172
+ if not isinstance(iter_old_di, list):
1173
+ is_group = False
1174
+ iter_old_di = [iter_old_di]
1175
+
1176
+ iter_new_dis = [ups_i.get(i, i) for i in iter_old_di]
1177
+
1178
+ if iter_new_dis != iter_old_di:
1179
+ iter_new_data_idx[iter_i.id_][
1180
+ f"inputs.{param_typ}"
1181
+ ] = (iter_new_dis if is_group else iter_new_dis[0])
1182
+
1183
+ for run_j in iter_i.action_runs:
1184
+ if run_j.status is EARStatus.pending:
1185
+ try:
1186
+ old_di = run_j.data_idx[f"inputs.{param_typ}"]
1187
+ except KeyError:
1188
+ # not all actions will include this input
1189
+ continue
1190
+
1191
+ is_group = True
1192
+ if not isinstance(old_di, list):
1193
+ is_group = False
1194
+ old_di = [old_di]
1195
+
1196
+ new_dis = [ups_i.get(i, i) for i in old_di]
1197
+
1198
+ if new_dis != old_di:
1199
+ run_new_data_idx[run_j.id_][
1200
+ f"inputs.{param_typ}"
1201
+ ] = (new_dis if is_group else new_dis[0])
1202
+
1203
+ # now update data indices (TODO: including in cache!)
1204
+ if iter_new_data_idx:
1205
+ self.workflow._store.update_iter_data_indices(iter_new_data_idx)
1206
+
1207
+ if run_new_data_idx:
1208
+ self.workflow._store.update_run_data_indices(run_new_data_idx)
1209
+
905
1210
  def test_termination(self, element_iter) -> bool:
906
1211
  """Check if a loop should terminate, given the specified completed element
907
1212
  iteration."""
908
1213
  if self.template.termination:
909
1214
  return self.template.termination.test(element_iter)
910
1215
  return False
1216
+
1217
+ @TimeIt.decorator
1218
+ def get_element_IDs(self):
1219
+ elem_IDs = [
1220
+ j
1221
+ for i in self.task_insert_IDs
1222
+ for j in self.workflow.tasks.get(insert_ID=i).element_IDs
1223
+ ]
1224
+ return elem_IDs
1225
+
1226
+ @TimeIt.decorator
1227
+ def get_elements(self):
1228
+ return self.workflow.get_elements_from_IDs(self.get_element_IDs())
1229
+
1230
+ @TimeIt.decorator
1231
+ def skip_downstream_iterations(self, elem_iter) -> list[int]:
1232
+ """
1233
+ Parameters
1234
+ ----------
1235
+ elem_iter
1236
+ The element iteration whose subsequent iterations should be skipped.
1237
+ dep_element_IDs
1238
+ List of elements that are dependent (recursively) on the element
1239
+ of `elem_iter`.
1240
+ """
1241
+ current_iter_idx = elem_iter.loop_idx[self.name]
1242
+ current_task_iID = elem_iter.task.insert_ID
1243
+ self._app.logger.info(
1244
+ f"setting loop {self.name!r} iterations downstream of current iteration "
1245
+ f"index {current_iter_idx} to skip"
1246
+ )
1247
+ elements = self.get_elements()
1248
+
1249
+ # TODO: fix for multiple loop cycles
1250
+ warn(
1251
+ "skip downstream iterations does not work correctly for multiple loop cycles!"
1252
+ )
1253
+
1254
+ to_skip = []
1255
+ for elem in elements:
1256
+ for iter_i in elem.iterations:
1257
+ if iter_i.loop_idx[self.name] > current_iter_idx or (
1258
+ iter_i.loop_idx[self.name] == current_iter_idx
1259
+ and iter_i.task.insert_ID > current_task_iID
1260
+ ):
1261
+ to_skip.extend(iter_i.EAR_IDs_flat)
1262
+
1263
+ self._app.logger.info(
1264
+ f"{len(to_skip)} runs will be set to skip: {shorten_list_str(to_skip)}"
1265
+ )
1266
+ self.workflow.set_EAR_skip({k: SkipReason.LOOP_TERMINATION for k in to_skip})
1267
+
1268
+ return to_skip