hpcflow-new2 0.2.0a189__py3-none-any.whl → 0.2.0a190__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. hpcflow/__pyinstaller/hook-hpcflow.py +8 -6
  2. hpcflow/_version.py +1 -1
  3. hpcflow/app.py +1 -0
  4. hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +1 -1
  5. hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +1 -1
  6. hpcflow/sdk/__init__.py +21 -15
  7. hpcflow/sdk/app.py +2133 -770
  8. hpcflow/sdk/cli.py +281 -250
  9. hpcflow/sdk/cli_common.py +6 -2
  10. hpcflow/sdk/config/__init__.py +1 -1
  11. hpcflow/sdk/config/callbacks.py +77 -42
  12. hpcflow/sdk/config/cli.py +126 -103
  13. hpcflow/sdk/config/config.py +578 -311
  14. hpcflow/sdk/config/config_file.py +131 -95
  15. hpcflow/sdk/config/errors.py +112 -85
  16. hpcflow/sdk/config/types.py +145 -0
  17. hpcflow/sdk/core/actions.py +1054 -994
  18. hpcflow/sdk/core/app_aware.py +24 -0
  19. hpcflow/sdk/core/cache.py +81 -63
  20. hpcflow/sdk/core/command_files.py +275 -185
  21. hpcflow/sdk/core/commands.py +111 -107
  22. hpcflow/sdk/core/element.py +724 -503
  23. hpcflow/sdk/core/enums.py +192 -0
  24. hpcflow/sdk/core/environment.py +74 -93
  25. hpcflow/sdk/core/errors.py +398 -51
  26. hpcflow/sdk/core/json_like.py +540 -272
  27. hpcflow/sdk/core/loop.py +380 -334
  28. hpcflow/sdk/core/loop_cache.py +160 -43
  29. hpcflow/sdk/core/object_list.py +370 -207
  30. hpcflow/sdk/core/parameters.py +728 -600
  31. hpcflow/sdk/core/rule.py +59 -41
  32. hpcflow/sdk/core/run_dir_files.py +33 -22
  33. hpcflow/sdk/core/task.py +1546 -1325
  34. hpcflow/sdk/core/task_schema.py +240 -196
  35. hpcflow/sdk/core/test_utils.py +126 -88
  36. hpcflow/sdk/core/types.py +387 -0
  37. hpcflow/sdk/core/utils.py +410 -305
  38. hpcflow/sdk/core/validation.py +82 -9
  39. hpcflow/sdk/core/workflow.py +1192 -1028
  40. hpcflow/sdk/core/zarr_io.py +98 -137
  41. hpcflow/sdk/demo/cli.py +46 -33
  42. hpcflow/sdk/helper/cli.py +18 -16
  43. hpcflow/sdk/helper/helper.py +75 -63
  44. hpcflow/sdk/helper/watcher.py +61 -28
  45. hpcflow/sdk/log.py +83 -59
  46. hpcflow/sdk/persistence/__init__.py +8 -31
  47. hpcflow/sdk/persistence/base.py +988 -586
  48. hpcflow/sdk/persistence/defaults.py +6 -0
  49. hpcflow/sdk/persistence/discovery.py +38 -0
  50. hpcflow/sdk/persistence/json.py +408 -153
  51. hpcflow/sdk/persistence/pending.py +158 -123
  52. hpcflow/sdk/persistence/store_resource.py +37 -22
  53. hpcflow/sdk/persistence/types.py +307 -0
  54. hpcflow/sdk/persistence/utils.py +14 -11
  55. hpcflow/sdk/persistence/zarr.py +477 -420
  56. hpcflow/sdk/runtime.py +44 -41
  57. hpcflow/sdk/submission/{jobscript_info.py → enums.py} +39 -12
  58. hpcflow/sdk/submission/jobscript.py +444 -404
  59. hpcflow/sdk/submission/schedulers/__init__.py +133 -40
  60. hpcflow/sdk/submission/schedulers/direct.py +97 -71
  61. hpcflow/sdk/submission/schedulers/sge.py +132 -126
  62. hpcflow/sdk/submission/schedulers/slurm.py +263 -268
  63. hpcflow/sdk/submission/schedulers/utils.py +7 -2
  64. hpcflow/sdk/submission/shells/__init__.py +14 -15
  65. hpcflow/sdk/submission/shells/base.py +102 -29
  66. hpcflow/sdk/submission/shells/bash.py +72 -55
  67. hpcflow/sdk/submission/shells/os_version.py +31 -30
  68. hpcflow/sdk/submission/shells/powershell.py +37 -29
  69. hpcflow/sdk/submission/submission.py +203 -257
  70. hpcflow/sdk/submission/types.py +143 -0
  71. hpcflow/sdk/typing.py +163 -12
  72. hpcflow/tests/conftest.py +8 -6
  73. hpcflow/tests/schedulers/slurm/test_slurm_submission.py +5 -2
  74. hpcflow/tests/scripts/test_main_scripts.py +60 -30
  75. hpcflow/tests/shells/wsl/test_wsl_submission.py +6 -4
  76. hpcflow/tests/unit/test_action.py +86 -75
  77. hpcflow/tests/unit/test_action_rule.py +9 -4
  78. hpcflow/tests/unit/test_app.py +13 -6
  79. hpcflow/tests/unit/test_cli.py +1 -1
  80. hpcflow/tests/unit/test_command.py +71 -54
  81. hpcflow/tests/unit/test_config.py +20 -15
  82. hpcflow/tests/unit/test_config_file.py +21 -18
  83. hpcflow/tests/unit/test_element.py +58 -62
  84. hpcflow/tests/unit/test_element_iteration.py +3 -1
  85. hpcflow/tests/unit/test_element_set.py +29 -19
  86. hpcflow/tests/unit/test_group.py +4 -2
  87. hpcflow/tests/unit/test_input_source.py +116 -93
  88. hpcflow/tests/unit/test_input_value.py +29 -24
  89. hpcflow/tests/unit/test_json_like.py +44 -35
  90. hpcflow/tests/unit/test_loop.py +65 -58
  91. hpcflow/tests/unit/test_object_list.py +17 -12
  92. hpcflow/tests/unit/test_parameter.py +16 -7
  93. hpcflow/tests/unit/test_persistence.py +48 -35
  94. hpcflow/tests/unit/test_resources.py +20 -18
  95. hpcflow/tests/unit/test_run.py +8 -3
  96. hpcflow/tests/unit/test_runtime.py +2 -1
  97. hpcflow/tests/unit/test_schema_input.py +23 -15
  98. hpcflow/tests/unit/test_shell.py +3 -2
  99. hpcflow/tests/unit/test_slurm.py +8 -7
  100. hpcflow/tests/unit/test_submission.py +39 -19
  101. hpcflow/tests/unit/test_task.py +352 -247
  102. hpcflow/tests/unit/test_task_schema.py +33 -20
  103. hpcflow/tests/unit/test_utils.py +9 -11
  104. hpcflow/tests/unit/test_value_sequence.py +15 -12
  105. hpcflow/tests/unit/test_workflow.py +114 -83
  106. hpcflow/tests/unit/test_workflow_template.py +0 -1
  107. hpcflow/tests/workflows/test_jobscript.py +2 -1
  108. hpcflow/tests/workflows/test_workflows.py +18 -13
  109. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/METADATA +2 -1
  110. hpcflow_new2-0.2.0a190.dist-info/RECORD +165 -0
  111. hpcflow/sdk/core/parallel.py +0 -21
  112. hpcflow_new2-0.2.0a189.dist-info/RECORD +0 -158
  113. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/LICENSE +0 -0
  114. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/WHEEL +0 -0
  115. {hpcflow_new2-0.2.0a189.dist-info → hpcflow_new2-0.2.0a190.dist-info}/entry_points.txt +0 -0
@@ -4,35 +4,33 @@ Main workflow model.
4
4
 
5
5
  from __future__ import annotations
6
6
  from collections import defaultdict
7
- from contextlib import contextmanager
7
+ from contextlib import contextmanager, nullcontext
8
8
  import copy
9
9
  from dataclasses import dataclass, field
10
- from datetime import datetime, timezone
11
10
 
12
11
  from pathlib import Path
13
12
  import random
14
13
  import string
15
14
  from threading import Thread
16
15
  import time
17
- from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
16
+ from typing import overload, cast, TYPE_CHECKING
18
17
  from uuid import uuid4
19
18
  from warnings import warn
20
- from fsspec.implementations.local import LocalFileSystem
21
- from fsspec.implementations.zip import ZipFileSystem
19
+ from fsspec.implementations.local import LocalFileSystem # type: ignore
20
+ from fsspec.implementations.zip import ZipFileSystem # type: ignore
22
21
  import numpy as np
23
- from fsspec.core import url_to_fs
22
+ from fsspec.core import url_to_fs # type: ignore
24
23
  import rich.console
25
24
 
26
- from hpcflow.sdk import app
27
- from hpcflow.sdk.core import (
28
- ALL_TEMPLATE_FORMATS,
29
- ABORT_EXIT_CODE,
30
- )
31
- from hpcflow.sdk.core.actions import EARStatus
32
- from hpcflow.sdk.core.loop_cache import LoopCache
25
+ from hpcflow.sdk.typing import hydrate
26
+ from hpcflow.sdk.core import ALL_TEMPLATE_FORMATS, ABORT_EXIT_CODE
27
+ from hpcflow.sdk.core.app_aware import AppAware
28
+ from hpcflow.sdk.core.enums import EARStatus
29
+ from hpcflow.sdk.core.loop_cache import LoopCache, LoopIndex
33
30
  from hpcflow.sdk.log import TimeIt
34
- from hpcflow.sdk.persistence import store_cls_from_str, DEFAULT_STORE_FORMAT
35
- from hpcflow.sdk.persistence.base import TEMPLATE_COMP_TYPES, AnySEAR
31
+ from hpcflow.sdk.persistence import store_cls_from_str
32
+ from hpcflow.sdk.persistence.defaults import DEFAULT_STORE_FORMAT
33
+ from hpcflow.sdk.persistence.base import TEMPLATE_COMP_TYPES
36
34
  from hpcflow.sdk.persistence.utils import ask_pw_on_auth_exc, infer_store
37
35
  from hpcflow.sdk.submission.jobscript import (
38
36
  generate_EAR_resource_map,
@@ -41,17 +39,18 @@ from hpcflow.sdk.submission.jobscript import (
41
39
  merge_jobscripts_across_tasks,
42
40
  resolve_jobscript_dependencies,
43
41
  )
44
- from hpcflow.sdk.submission.jobscript_info import JobscriptElementState
42
+ from hpcflow.sdk.submission.enums import JobscriptElementState
45
43
  from hpcflow.sdk.submission.schedulers.direct import DirectScheduler
46
- from hpcflow.sdk.typing import PathLike
47
44
  from hpcflow.sdk.core.json_like import ChildObjectSpec, JSONLike
48
- from .utils import (
49
- nth_key,
45
+ from hpcflow.sdk.core.utils import (
50
46
  read_JSON_file,
51
47
  read_JSON_string,
52
48
  read_YAML_str,
53
49
  read_YAML_file,
54
50
  replace_items,
51
+ current_timestamp,
52
+ normalise_timestamp,
53
+ parse_timestamp,
55
54
  )
56
55
  from hpcflow.sdk.core.errors import (
57
56
  InvalidInputSourceTaskReference,
@@ -62,32 +61,79 @@ from hpcflow.sdk.core.errors import (
62
61
  WorkflowSubmissionFailure,
63
62
  )
64
63
 
64
+ if TYPE_CHECKING:
65
+ from collections.abc import Iterable, Iterator, Mapping, Sequence
66
+ from contextlib import AbstractContextManager
67
+ from typing import Any, ClassVar, Literal
68
+ from typing_extensions import Self, TypeAlias
69
+ from numpy.typing import NDArray
70
+ import psutil
71
+ from rich.status import Status
72
+ from ..typing import DataIndex, ParamSource, PathLike, TemplateComponents
73
+ from .actions import ElementActionRun
74
+ from .element import Element, ElementIteration
75
+ from .loop import Loop, WorkflowLoop
76
+ from .object_list import ObjectList, ResourceList, WorkflowLoopList, WorkflowTaskList
77
+ from .parameters import InputSource, ResourceSpec
78
+ from .task import Task, WorkflowTask
79
+ from .types import (
80
+ AbstractFileSystem,
81
+ CreationInfo,
82
+ Pending,
83
+ Resources,
84
+ WorkflowTemplateTaskData,
85
+ )
86
+ from ..submission.submission import Submission
87
+ from ..submission.jobscript import (
88
+ Jobscript,
89
+ JobScriptDescriptor,
90
+ JobScriptCreationArguments,
91
+ )
92
+ from ..persistence.base import (
93
+ StoreElement,
94
+ StoreElementIter,
95
+ StoreTask,
96
+ StoreParameter,
97
+ StoreEAR,
98
+ )
99
+ from ..persistence.types import TemplateMeta
65
100
 
66
- class _DummyPersistentWorkflow:
67
- """An object to pass to ResourceSpec.make_persistent that pretends to be a
68
- Workflow object, so we can pretend to make template-level inputs/resources
69
- persistent before the workflow exists."""
70
-
71
- def __init__(self):
72
- self._parameters = []
73
- self._sources = []
74
- self._data_ref = []
75
-
76
- def _add_parameter_data(self, data, source: Dict) -> int:
77
- self._parameters.append(data)
78
- self._sources.append(source)
79
- self._data_ref.append(len(self._data_ref))
80
- return self._data_ref[-1]
101
+ #: Convenience alias
102
+ _TemplateComponents: TypeAlias = "dict[str, ObjectList[JSONLike]]"
81
103
 
82
- def get_parameter_data(self, data_idx):
83
- return self._parameters[self._data_ref.index(data_idx)]
84
104
 
85
- def make_persistent(self, workflow: app.Workflow):
86
- for dat_i, source_i in zip(self._parameters, self._sources):
87
- workflow._add_parameter_data(dat_i, source_i)
105
+ @dataclass
106
+ class _Pathway:
107
+ id_: int
108
+ names: LoopIndex[str, int] = field(default_factory=LoopIndex)
109
+ iter_ids: list[int] = field(default_factory=list)
110
+ data_idx: list[DataIndex] = field(default_factory=list)
111
+
112
+ def as_tuple(
113
+ self, *, ret_iter_IDs: bool = False, ret_data_idx: bool = False
114
+ ) -> tuple:
115
+ if ret_iter_IDs:
116
+ if ret_data_idx:
117
+ return (self.id_, self.names, tuple(self.iter_ids), tuple(self.data_idx))
118
+ else:
119
+ return (self.id_, self.names, tuple(self.iter_ids))
120
+ else:
121
+ if ret_data_idx:
122
+ return (self.id_, self.names, tuple(self.data_idx))
123
+ else:
124
+ return (self.id_, self.names)
125
+
126
+ def __deepcopy__(self, memo) -> Self:
127
+ return self.__class__(
128
+ self.id_,
129
+ self.names,
130
+ copy.deepcopy(self.iter_ids, memo),
131
+ copy.deepcopy(self.data_idx, memo),
132
+ )
88
133
 
89
134
 
90
135
  @dataclass
136
+ @hydrate
91
137
  class WorkflowTemplate(JSONLike):
92
138
  """Class to represent initial parametrisation of a {app_name} workflow, with limited
93
139
  validation logic.
@@ -122,10 +168,9 @@ class WorkflowTemplate(JSONLike):
122
168
  Whether to merge the environemtns into task resources.
123
169
  """
124
170
 
125
- _app_attr = "app"
126
- _validation_schema = "workflow_spec_schema.yaml"
171
+ _validation_schema: ClassVar[str] = "workflow_spec_schema.yaml"
127
172
 
128
- _child_objects = (
173
+ _child_objects: ClassVar[tuple[ChildObjectSpec, ...]] = (
129
174
  ChildObjectSpec(
130
175
  name="tasks",
131
176
  class_name="Task",
@@ -148,30 +193,31 @@ class WorkflowTemplate(JSONLike):
148
193
  #: A string name for the workflow.
149
194
  name: str
150
195
  #: Documentation information.
151
- doc: Optional[Union[List[str], str]] = field(repr=False, default=None)
196
+ doc: list[str] | str | None = field(repr=False, default=None)
152
197
  #: A list of Task objects to include in the workflow.
153
- tasks: Optional[List[app.Task]] = field(default_factory=lambda: [])
198
+ tasks: list[Task] = field(default_factory=list)
154
199
  #: A list of Loop objects to include in the workflow.
155
- loops: Optional[List[app.Loop]] = field(default_factory=lambda: [])
200
+ loops: list[Loop] = field(default_factory=list)
156
201
  #: The associated concrete workflow.
157
- workflow: Optional[app.Workflow] = None
202
+ workflow: Workflow | None = None
158
203
  #: Template-level resources to apply to all tasks as default values.
159
- resources: Optional[Dict[str, Dict]] = None
204
+ resources: Resources = None
160
205
  #: The execution environments to use.
161
- environments: Optional[Dict[str, Dict[str, Any]]] = None
206
+ environments: Mapping[str, Mapping[str, Any]] | None = None
162
207
  #: The environment presets to use.
163
- env_presets: Optional[Union[str, List[str]]] = None
208
+ env_presets: str | list[str] | None = None
164
209
  #: The file this was derived from.
165
- source_file: Optional[str] = field(default=None, compare=False)
210
+ source_file: str | None = field(default=None, compare=False)
166
211
  #: Additional arguments to pass to the persistent data store constructor.
167
- store_kwargs: Optional[Dict] = field(default_factory=lambda: {})
212
+ store_kwargs: dict[str, Any] = field(default_factory=dict)
168
213
  #: Whether to merge template-level `resources` into element set resources.
169
- merge_resources: Optional[bool] = True
214
+ merge_resources: bool = True
170
215
  #: Whether to merge the environemtns into task resources.
171
- merge_envs: Optional[bool] = True
216
+ merge_envs: bool = True
172
217
 
173
- def __post_init__(self):
174
- self.resources = self.app.ResourceList.normalise(self.resources)
218
+ def __post_init__(self) -> None:
219
+ resources = self._app.ResourceList.normalise(self.resources)
220
+ self.resources = resources
175
221
  self._set_parent_refs()
176
222
 
177
223
  # merge template-level `resources` into task element set resources (this mutates
@@ -180,7 +226,7 @@ class WorkflowTemplate(JSONLike):
180
226
  if self.merge_resources:
181
227
  for task in self.tasks:
182
228
  for element_set in task.element_sets:
183
- element_set.resources.merge_other(self.resources)
229
+ element_set.resources.merge_other(resources)
184
230
  self.merge_resources = False
185
231
 
186
232
  if self.merge_envs:
@@ -189,8 +235,21 @@ class WorkflowTemplate(JSONLike):
189
235
  if self.doc and not isinstance(self.doc, list):
190
236
  self.doc = [self.doc]
191
237
 
192
- def _merge_envs_into_task_resources(self):
238
+ @property
239
+ def _resources(self) -> ResourceList:
240
+ res = self.resources
241
+ assert isinstance(res, self._app.ResourceList)
242
+ return res
243
+
244
+ def _get_resources_copy(self) -> Iterator[ResourceSpec]:
245
+ """
246
+ Get a deep copy of the list of resources.
247
+ """
248
+ memo: dict[int, Any] = {}
249
+ for spec in self._resources:
250
+ yield copy.deepcopy(spec, memo)
193
251
 
252
+ def _merge_envs_into_task_resources(self) -> None:
194
253
  self.merge_envs = False
195
254
 
196
255
  # disallow both `env_presets` and `environments` specifications:
@@ -204,7 +263,6 @@ class WorkflowTemplate(JSONLike):
204
263
  self.env_presets = [self.env_presets] if self.env_presets else []
205
264
 
206
265
  for task in self.tasks:
207
-
208
266
  # get applicable environments and environment preset names:
209
267
  try:
210
268
  schema = task.schema
@@ -217,7 +275,7 @@ class WorkflowTemplate(JSONLike):
217
275
  schema_presets = schema.environment_presets
218
276
  app_envs = {act.get_environment_name() for act in schema.actions}
219
277
  for es in task.element_sets:
220
- app_env_specs_i = None
278
+ app_env_specs_i: Mapping[str, Mapping[str, Any]] | None = None
221
279
  if not es.environments and not es.env_preset:
222
280
  # no task level envs/presets specified, so merge template-level:
223
281
  if self.environments:
@@ -225,33 +283,31 @@ class WorkflowTemplate(JSONLike):
225
283
  k: v for k, v in self.environments.items() if k in app_envs
226
284
  }
227
285
  if app_env_specs_i:
228
- self.app.logger.info(
286
+ self._app.logger.info(
229
287
  f"(task {task.name!r}, element set {es.index}): using "
230
288
  f"template-level requested `environment` specifiers: "
231
289
  f"{app_env_specs_i!r}."
232
290
  )
233
291
  es.environments = app_env_specs_i
234
292
 
235
- elif self.env_presets:
293
+ elif self.env_presets and schema_presets:
236
294
  # take only the first applicable preset:
237
- app_presets_i = [
238
- k for k in self.env_presets if k in schema_presets
239
- ]
240
- if app_presets_i:
241
- app_env_specs_i = schema_presets[app_presets_i[0]]
242
- self.app.logger.info(
243
- f"(task {task.name!r}, element set {es.index}): using "
244
- f"template-level requested {app_presets_i[0]!r} "
245
- f"`env_preset`: {app_env_specs_i!r}."
246
- )
247
- es.env_preset = app_presets_i[0]
295
+ for app_preset in self.env_presets:
296
+ if app_preset in schema_presets:
297
+ es.env_preset = app_preset
298
+ app_env_specs_i = schema_presets[app_preset]
299
+ self._app.logger.info(
300
+ f"(task {task.name!r}, element set {es.index}): using "
301
+ f"template-level requested {app_preset!r} "
302
+ f"`env_preset`: {app_env_specs_i!r}."
303
+ )
304
+ break
248
305
 
249
306
  else:
250
307
  # no env/preset applicable here (and no env/preset at task level),
251
308
  # so apply a default preset if available:
252
- app_env_specs_i = (schema_presets or {}).get("", None)
253
- if app_env_specs_i:
254
- self.app.logger.info(
309
+ if app_env_specs_i := (schema_presets or {}).get("", None):
310
+ self._app.logger.info(
255
311
  f"(task {task.name!r}, element set {es.index}): setting "
256
312
  f"to default (empty-string named) `env_preset`: "
257
313
  f"{app_env_specs_i}."
@@ -259,74 +315,68 @@ class WorkflowTemplate(JSONLike):
259
315
  es.env_preset = ""
260
316
 
261
317
  if app_env_specs_i:
262
- es.resources.merge_other(
263
- self.app.ResourceList(
264
- [
265
- self.app.ResourceSpec(
266
- scope="any", environments=app_env_specs_i
267
- )
268
- ]
318
+ es.resources.merge_one(
319
+ self._app.ResourceSpec(
320
+ scope="any", environments=app_env_specs_i
269
321
  )
270
322
  )
271
323
 
272
324
  @classmethod
273
325
  @TimeIt.decorator
274
- def _from_data(cls, data: Dict) -> app.WorkflowTemplate:
326
+ def _from_data(cls, data: dict[str, Any]) -> WorkflowTemplate:
327
+ task_dat: WorkflowTemplateTaskData
275
328
  # use element_sets if not already:
276
329
  for task_idx, task_dat in enumerate(data["tasks"]):
277
330
  schema = task_dat.pop("schema")
278
- schema = schema if isinstance(schema, list) else [schema]
331
+ schema_list: list = schema if isinstance(schema, list) else [schema]
279
332
  if "element_sets" in task_dat:
280
333
  # just update the schema to a list:
281
- data["tasks"][task_idx]["schema"] = schema
334
+ data["tasks"][task_idx]["schema"] = schema_list
282
335
  else:
283
336
  # add a single element set, and update the schema to a list:
284
337
  out_labels = task_dat.pop("output_labels", [])
285
338
  data["tasks"][task_idx] = {
286
- "schema": schema,
339
+ "schema": schema_list,
287
340
  "element_sets": [task_dat],
288
341
  "output_labels": out_labels,
289
342
  }
290
343
 
291
344
  # extract out any template components:
292
- tcs = data.pop("template_components", {})
293
- params_dat = tcs.pop("parameters", [])
294
- if params_dat:
295
- parameters = cls.app.ParametersList.from_json_like(
296
- params_dat, shared_data=cls.app.template_components
345
+ # TODO: TypedDict for data
346
+ tcs: dict[str, list] = data.pop("template_components", {})
347
+ if params_dat := tcs.pop("parameters", []):
348
+ parameters = cls._app.ParametersList.from_json_like(
349
+ params_dat, shared_data=cls._app._shared_data
297
350
  )
298
- cls.app.parameters.add_objects(parameters, skip_duplicates=True)
351
+ cls._app.parameters.add_objects(parameters, skip_duplicates=True)
299
352
 
300
- cmd_files_dat = tcs.pop("command_files", [])
301
- if cmd_files_dat:
302
- cmd_files = cls.app.CommandFilesList.from_json_like(
303
- cmd_files_dat, shared_data=cls.app.template_components
353
+ if cmd_files_dat := tcs.pop("command_files", []):
354
+ cmd_files = cls._app.CommandFilesList.from_json_like(
355
+ cmd_files_dat, shared_data=cls._app._shared_data
304
356
  )
305
- cls.app.command_files.add_objects(cmd_files, skip_duplicates=True)
357
+ cls._app.command_files.add_objects(cmd_files, skip_duplicates=True)
306
358
 
307
- envs_dat = tcs.pop("environments", [])
308
- if envs_dat:
309
- envs = cls.app.EnvironmentsList.from_json_like(
310
- envs_dat, shared_data=cls.app.template_components
359
+ if envs_dat := tcs.pop("environments", []):
360
+ envs = cls._app.EnvironmentsList.from_json_like(
361
+ envs_dat, shared_data=cls._app._shared_data
311
362
  )
312
- cls.app.envs.add_objects(envs, skip_duplicates=True)
363
+ cls._app.envs.add_objects(envs, skip_duplicates=True)
313
364
 
314
- ts_dat = tcs.pop("task_schemas", [])
315
- if ts_dat:
316
- task_schemas = cls.app.TaskSchemasList.from_json_like(
317
- ts_dat, shared_data=cls.app.template_components
365
+ if ts_dat := tcs.pop("task_schemas", []):
366
+ task_schemas = cls._app.TaskSchemasList.from_json_like(
367
+ ts_dat, shared_data=cls._app._shared_data
318
368
  )
319
- cls.app.task_schemas.add_objects(task_schemas, skip_duplicates=True)
369
+ cls._app.task_schemas.add_objects(task_schemas, skip_duplicates=True)
320
370
 
321
- return cls.from_json_like(data, shared_data=cls.app.template_components)
371
+ return cls.from_json_like(data, shared_data=cls._app._shared_data)
322
372
 
323
373
  @classmethod
324
374
  @TimeIt.decorator
325
375
  def from_YAML_string(
326
376
  cls,
327
377
  string: str,
328
- variables: Optional[Dict[str, str]] = None,
329
- ) -> app.WorkflowTemplate:
378
+ variables: dict[str, str] | None = None,
379
+ ) -> WorkflowTemplate:
330
380
  """Load from a YAML string.
331
381
 
332
382
  Parameters
@@ -339,16 +389,16 @@ class WorkflowTemplate(JSONLike):
339
389
  return cls._from_data(read_YAML_str(string, variables=variables))
340
390
 
341
391
  @classmethod
342
- def _check_name(cls, data: Dict, path: PathLike) -> str:
392
+ def _check_name(cls, data: dict[str, Any], path: PathLike) -> None:
343
393
  """Check the workflow template data has a "name" key. If not, add a "name" key,
344
394
  using the file path stem.
345
395
 
346
396
  Note: this method mutates `data`.
347
397
 
348
398
  """
349
- if "name" not in data:
399
+ if "name" not in data and path is not None:
350
400
  name = Path(path).stem
351
- cls.app.logger.info(
401
+ cls._app.logger.info(
352
402
  f"using file name stem ({name!r}) as the workflow template name."
353
403
  )
354
404
  data["name"] = name
@@ -358,8 +408,8 @@ class WorkflowTemplate(JSONLike):
358
408
  def from_YAML_file(
359
409
  cls,
360
410
  path: PathLike,
361
- variables: Optional[Dict[str, str]] = None,
362
- ) -> app.WorkflowTemplate:
411
+ variables: dict[str, str] | None = None,
412
+ ) -> WorkflowTemplate:
363
413
  """Load from a YAML file.
364
414
 
365
415
  Parameters
@@ -370,7 +420,7 @@ class WorkflowTemplate(JSONLike):
370
420
  String variables to substitute in the file given by `path`.
371
421
 
372
422
  """
373
- cls.app.logger.debug("parsing workflow template from a YAML file")
423
+ cls._app.logger.debug("parsing workflow template from a YAML file")
374
424
  data = read_YAML_file(path, variables=variables)
375
425
  cls._check_name(data, path)
376
426
  data["source_file"] = str(path)
@@ -381,8 +431,8 @@ class WorkflowTemplate(JSONLike):
381
431
  def from_JSON_string(
382
432
  cls,
383
433
  string: str,
384
- variables: Optional[Dict[str, str]] = None,
385
- ) -> app.WorkflowTemplate:
434
+ variables: dict[str, str] | None = None,
435
+ ) -> WorkflowTemplate:
386
436
  """Load from a JSON string.
387
437
 
388
438
  Parameters
@@ -399,8 +449,8 @@ class WorkflowTemplate(JSONLike):
399
449
  def from_JSON_file(
400
450
  cls,
401
451
  path: PathLike,
402
- variables: Optional[Dict[str, str]] = None,
403
- ) -> app.WorkflowTemplate:
452
+ variables: dict[str, str] | None = None,
453
+ ) -> WorkflowTemplate:
404
454
  """Load from a JSON file.
405
455
 
406
456
  Parameters
@@ -410,7 +460,7 @@ class WorkflowTemplate(JSONLike):
410
460
  variables
411
461
  String variables to substitute in the file given by `path`.
412
462
  """
413
- cls.app.logger.debug("parsing workflow template from a JSON file")
463
+ cls._app.logger.debug("parsing workflow template from a JSON file")
414
464
  data = read_JSON_file(path, variables=variables)
415
465
  cls._check_name(data, path)
416
466
  data["source_file"] = str(path)
@@ -421,9 +471,9 @@ class WorkflowTemplate(JSONLike):
421
471
  def from_file(
422
472
  cls,
423
473
  path: PathLike,
424
- template_format: Optional[str] = None,
425
- variables: Optional[Dict[str, str]] = None,
426
- ) -> app.WorkflowTemplate:
474
+ template_format: Literal["yaml", "json"] | None = None,
475
+ variables: dict[str, str] | None = None,
476
+ ) -> WorkflowTemplate:
427
477
  """Load from either a YAML or JSON file, depending on the file extension.
428
478
 
429
479
  Parameters
@@ -437,20 +487,21 @@ class WorkflowTemplate(JSONLike):
437
487
  String variables to substitute in the file given by `path`.
438
488
 
439
489
  """
440
- path = Path(path)
490
+ path_ = Path(path or ".")
441
491
  fmt = template_format.lower() if template_format else None
442
- if fmt == "yaml" or path.suffix in (".yaml", ".yml"):
443
- return cls.from_YAML_file(path, variables=variables)
444
- elif fmt == "json" or path.suffix in (".json", ".jsonc"):
445
- return cls.from_JSON_file(path, variables=variables)
492
+ if fmt == "yaml" or path_.suffix in (".yaml", ".yml"):
493
+ return cls.from_YAML_file(path_, variables=variables)
494
+ elif fmt == "json" or path_.suffix in (".json", ".jsonc"):
495
+ return cls.from_JSON_file(path_, variables=variables)
446
496
  else:
447
497
  raise ValueError(
448
- f"Unknown workflow template file extension {path.suffix!r}. Supported "
498
+ f"Unknown workflow template file extension {path_.suffix!r}. Supported "
449
499
  f"template formats are {ALL_TEMPLATE_FORMATS!r}."
450
500
  )
451
501
 
452
- def _add_empty_task(self, task: app.Task, new_index: int, insert_ID: int) -> None:
502
+ def _add_empty_task(self, task: Task, new_index: int, insert_ID: int) -> None:
453
503
  """Called by `Workflow._add_empty_task`."""
504
+ assert self.workflow
454
505
  new_task_name = self.workflow._get_new_task_unique_name(task, new_index)
455
506
 
456
507
  task._insert_ID = insert_ID
@@ -460,28 +511,26 @@ class WorkflowTemplate(JSONLike):
460
511
  task.workflow_template = self
461
512
  self.tasks.insert(new_index, task)
462
513
 
463
- def _add_empty_loop(self, loop: app.Loop) -> None:
514
+ def _add_empty_loop(self, loop: Loop) -> None:
464
515
  """Called by `Workflow._add_empty_loop`."""
465
516
 
517
+ assert self.workflow
466
518
  if not loop.name:
467
- existing = [i.name for i in self.loops]
519
+ existing = {loop.name for loop in self.loops}
468
520
  new_idx = len(self.loops)
469
- name = f"loop_{new_idx}"
470
- while name in existing:
521
+ while (name := f"loop_{new_idx}") in existing:
471
522
  new_idx += 1
472
- name = f"loop_{new_idx}"
473
523
  loop._name = name
474
524
  elif loop.name in self.workflow.loops.list_attrs():
475
- raise LoopAlreadyExistsError(
476
- f"A loop with the name {loop.name!r} already exists in the workflow: "
477
- f"{getattr(self.workflow.loops, loop.name)!r}."
478
- )
525
+ raise LoopAlreadyExistsError(loop.name, self.workflow.loops)
479
526
 
480
527
  loop._workflow_template = self
481
528
  self.loops.append(loop)
482
529
 
483
530
 
484
- def resolve_fsspec(path: PathLike, **kwargs) -> Tuple[Any, str, str]:
531
+ def resolve_fsspec(
532
+ path: PathLike, **kwargs
533
+ ) -> tuple[AbstractFileSystem, str, str | None]:
485
534
  """
486
535
  Decide how to handle a particular virtual path.
487
536
 
@@ -492,30 +541,37 @@ def resolve_fsspec(path: PathLike, **kwargs) -> Tuple[Any, str, str]:
492
541
 
493
542
  """
494
543
 
495
- path = str(path)
496
- if path.endswith(".zip"):
544
+ path_s = str(path)
545
+ fs: AbstractFileSystem
546
+ if path_s.endswith(".zip"):
497
547
  # `url_to_fs` does not seem to work for zip combos e.g. `zip::ssh://`, so we
498
548
  # construct a `ZipFileSystem` ourselves and assume it is signified only by the
499
549
  # file extension:
500
550
  fs, pw = ask_pw_on_auth_exc(
501
551
  ZipFileSystem,
502
- fo=path,
552
+ fo=path_s,
503
553
  mode="r",
504
554
  target_options=kwargs or {},
505
555
  add_pw_to="target_options",
506
556
  )
507
- path = ""
557
+ path_s = ""
508
558
 
509
559
  else:
510
- (fs, path), pw = ask_pw_on_auth_exc(url_to_fs, str(path), **kwargs)
511
- path = str(Path(path).as_posix())
560
+ (fs, path_s), pw = ask_pw_on_auth_exc(url_to_fs, path_s, **kwargs)
561
+ path_s = str(Path(path_s).as_posix())
512
562
  if isinstance(fs, LocalFileSystem):
513
- path = str(Path(path).resolve())
563
+ path_s = str(Path(path_s).resolve())
564
+
565
+ return fs, path_s, pw
514
566
 
515
- return fs, path, pw
516
567
 
568
+ @dataclass(frozen=True)
569
+ class _IterationData:
570
+ id_: int
571
+ idx: int
517
572
 
518
- class Workflow:
573
+
574
+ class Workflow(AppAware):
519
575
  """
520
576
  A concrete workflow.
521
577
 
@@ -533,55 +589,56 @@ class Workflow:
533
589
  For compatibility during pre-stable development phase.
534
590
  """
535
591
 
536
- _app_attr = "app"
537
- _default_ts_fmt = r"%Y-%m-%d %H:%M:%S.%f"
538
- _default_ts_name_fmt = r"%Y-%m-%d_%H%M%S"
539
- _input_files_dir_name = "input_files"
540
- _exec_dir_name = "execute"
592
+ _default_ts_fmt: ClassVar[str] = r"%Y-%m-%d %H:%M:%S.%f"
593
+ _default_ts_name_fmt: ClassVar[str] = r"%Y-%m-%d_%H%M%S"
594
+ _input_files_dir_name: ClassVar[str] = "input_files"
595
+ _exec_dir_name: ClassVar[str] = "execute"
541
596
 
542
597
  def __init__(
543
598
  self,
544
- workflow_ref: Union[str, Path, int],
545
- store_fmt: Optional[str] = None,
546
- fs_kwargs: Optional[Dict] = None,
599
+ workflow_ref: str | Path | int,
600
+ store_fmt: str | None = None,
601
+ fs_kwargs: dict[str, Any] | None = None,
547
602
  **kwargs,
548
603
  ):
549
604
  if isinstance(workflow_ref, int):
550
- path = self.app._get_workflow_path_from_local_ID(workflow_ref)
605
+ path = self._app._get_workflow_path_from_local_ID(workflow_ref)
606
+ elif isinstance(workflow_ref, str):
607
+ path = Path(workflow_ref)
551
608
  else:
552
609
  path = workflow_ref
553
610
 
554
- self.app.logger.info(f"loading workflow from path: {path}")
611
+ self._app.logger.info(f"loading workflow from path: {path}")
555
612
  fs_path = str(path)
556
- fs, path, _ = resolve_fsspec(fs_path or "", **(fs_kwargs or {}))
613
+ fs, path_s, _ = resolve_fsspec(path, **(fs_kwargs or {}))
557
614
  store_fmt = store_fmt or infer_store(fs_path, fs)
558
615
  store_cls = store_cls_from_str(store_fmt)
559
616
 
560
- self.path = path
617
+ self.path = path_s
561
618
 
562
619
  # assigned on first access:
563
- self._ts_fmt = None
564
- self._ts_name_fmt = None
565
- self._creation_info = None
566
- self._name = None
567
- self._template = None
568
- self._template_components = None
569
- self._tasks = None
570
- self._loops = None
571
- self._submissions = None
572
-
573
- self._store = store_cls(self.app, self, self.path, fs)
620
+ self._ts_fmt: str | None = None
621
+ self._ts_name_fmt: str | None = None
622
+ self._creation_info: CreationInfo | None = None
623
+ self._name: str | None = None
624
+ self._template: WorkflowTemplate | None = None
625
+ self._template_components: TemplateComponents | None = None
626
+ self._tasks: WorkflowTaskList | None = None
627
+ self._loops: WorkflowLoopList | None = None
628
+ self._submissions: list[Submission] | None = None
629
+
630
+ self._store = store_cls(self._app, self, self.path, fs)
574
631
  self._in_batch_mode = False # flag to track when processing batch updates
575
632
 
576
633
  # store indices of updates during batch update, so we can revert on failure:
577
634
  self._pending = self._get_empty_pending()
578
635
 
579
- def reload(self):
636
+ def reload(self) -> Self:
580
637
  """Reload the workflow from disk."""
581
638
  return self.__class__(self.url)
582
639
 
583
640
  @property
584
- def name(self):
641
+ def name(self) -> str:
585
642
  """
586
643
  The name of the workflow.
587
644
 
@@ -593,43 +650,36 @@ class Workflow:
593
650
  return self._name
594
651
 
595
652
  @property
596
- def url(self):
653
+ def url(self) -> str:
597
654
  """An fsspec URL for this workflow."""
598
- if self._store.fs.protocol == "zip":
599
- return self._store.fs.of.path
600
- elif self._store.fs.protocol == "file":
601
- return self.path
602
- else:
603
- raise NotImplementedError("Only (local) zip and local URLs provided for now.")
655
+ if self._store.fs:
656
+ if self._store.fs.protocol == "zip":
657
+ return self._store.fs.of.path
658
+ elif self._store.fs.protocol == "file":
659
+ return self.path
660
+ raise NotImplementedError("Only (local) zip and local URLs provided for now.")
604
661
 
605
662
  @property
606
- def store_format(self):
663
+ def store_format(self) -> str:
607
664
  """
608
665
  The format of the workflow's persistent store.
609
666
  """
610
667
  return self._store._name
611
668
 
612
- @property
613
- def num_tasks(self) -> int:
614
- """
615
- The number of tasks in the workflow.
616
- """
617
- return len(self.tasks)
618
-
619
669
  @classmethod
620
670
  @TimeIt.decorator
621
671
  def from_template(
622
672
  cls,
623
673
  template: WorkflowTemplate,
624
- path: Optional[PathLike] = None,
625
- name: Optional[str] = None,
626
- overwrite: Optional[bool] = False,
627
- store: Optional[str] = DEFAULT_STORE_FORMAT,
628
- ts_fmt: Optional[str] = None,
629
- ts_name_fmt: Optional[str] = None,
630
- store_kwargs: Optional[Dict] = None,
631
- status: Optional[Any] = None,
632
- ) -> app.Workflow:
674
+ path: PathLike = None,
675
+ name: str | None = None,
676
+ overwrite: bool = False,
677
+ store: str = DEFAULT_STORE_FORMAT,
678
+ ts_fmt: str | None = None,
679
+ ts_name_fmt: str | None = None,
680
+ store_kwargs: dict[str, Any] | None = None,
681
+ status: Status | None = None,
682
+ ) -> Workflow:
633
683
  """Generate from a `WorkflowTemplate` object.
634
684
 
635
685
  Parameters
@@ -671,30 +721,28 @@ class Workflow:
671
721
  ts_name_fmt=ts_name_fmt,
672
722
  store_kwargs=store_kwargs,
673
723
  )
674
- with wk._store.cached_load():
675
- with wk.batch_update(is_workflow_creation=True):
676
- with wk._store.cache_ctx():
677
- for idx, task in enumerate(template.tasks):
678
- if status:
679
- status.update(
680
- f"Adding task {idx + 1}/{len(template.tasks)} "
681
- f"({task.name!r})..."
682
- )
683
- wk._add_task(task)
724
+ with wk._store.cached_load(), wk.batch_update(
725
+ is_workflow_creation=True
726
+ ), wk._store.cache_ctx():
727
+ for idx, task in enumerate(template.tasks):
728
+ if status:
729
+ status.update(
730
+ f"Adding task {idx + 1}/{len(template.tasks)} "
731
+ f"({task.name!r})..."
732
+ )
733
+ wk._add_task(task)
734
+ if status:
735
+ status.update(f"Preparing to add {len(template.loops)} loops...")
736
+ if template.loops:
737
+ # TODO: if loop with non-initialisable actions, will fail
738
+ cache = LoopCache.build(workflow=wk, loops=template.loops)
739
+ for idx, loop in enumerate(template.loops):
684
740
  if status:
685
741
  status.update(
686
- f"Preparing to add {len(template.loops)} loops..."
742
+ f"Adding loop {idx + 1}/"
743
+ f"{len(template.loops)} ({loop.name!r})"
687
744
  )
688
- if template.loops:
689
- # TODO: if loop with non-initialisable actions, will fail
690
- cache = LoopCache.build(workflow=wk, loops=template.loops)
691
- for idx, loop in enumerate(template.loops):
692
- if status:
693
- status.update(
694
- f"Adding loop {idx + 1}/"
695
- f"{len(template.loops)} ({loop.name!r})"
696
- )
697
- wk._add_loop(loop, cache=cache, status=status)
745
+ wk._add_loop(loop, cache=cache, status=status)
698
746
  except Exception:
699
747
  if status:
700
748
  status.stop()
@@ -706,15 +754,15 @@ class Workflow:
706
754
  def from_YAML_file(
707
755
  cls,
708
756
  YAML_path: PathLike,
709
- path: Optional[str] = None,
710
- name: Optional[str] = None,
711
- overwrite: Optional[bool] = False,
712
- store: Optional[str] = DEFAULT_STORE_FORMAT,
713
- ts_fmt: Optional[str] = None,
714
- ts_name_fmt: Optional[str] = None,
715
- store_kwargs: Optional[Dict] = None,
716
- variables: Optional[Dict[str, str]] = None,
717
- ) -> app.Workflow:
757
+ path: PathLike = None,
758
+ name: str | None = None,
759
+ overwrite: bool = False,
760
+ store: str = DEFAULT_STORE_FORMAT,
761
+ ts_fmt: str | None = None,
762
+ ts_name_fmt: str | None = None,
763
+ store_kwargs: dict[str, Any] | None = None,
764
+ variables: dict[str, str] | None = None,
765
+ ) -> Workflow:
718
766
  """Generate from a YAML file.
719
767
 
720
768
  Parameters
@@ -745,7 +793,7 @@ class Workflow:
745
793
  variables:
746
794
  String variables to substitute in the file given by `YAML_path`.
747
795
  """
748
- template = cls.app.WorkflowTemplate.from_YAML_file(
796
+ template = cls._app.WorkflowTemplate.from_YAML_file(
749
797
  path=YAML_path,
750
798
  variables=variables,
751
799
  )
@@ -763,16 +811,16 @@ class Workflow:
763
811
  @classmethod
764
812
  def from_YAML_string(
765
813
  cls,
766
- YAML_str: PathLike,
767
- path: Optional[str] = None,
768
- name: Optional[str] = None,
769
- overwrite: Optional[bool] = False,
770
- store: Optional[str] = DEFAULT_STORE_FORMAT,
771
- ts_fmt: Optional[str] = None,
772
- ts_name_fmt: Optional[str] = None,
773
- store_kwargs: Optional[Dict] = None,
774
- variables: Optional[Dict[str, str]] = None,
775
- ) -> app.Workflow:
814
+ YAML_str: str,
815
+ path: PathLike = None,
816
+ name: str | None = None,
817
+ overwrite: bool = False,
818
+ store: str = DEFAULT_STORE_FORMAT,
819
+ ts_fmt: str | None = None,
820
+ ts_name_fmt: str | None = None,
821
+ store_kwargs: dict[str, Any] | None = None,
822
+ variables: dict[str, str] | None = None,
823
+ ) -> Workflow:
776
824
  """Generate from a YAML string.
777
825
 
778
826
  Parameters
@@ -803,7 +851,7 @@ class Workflow:
803
851
  variables:
804
852
  String variables to substitute in the string `YAML_str`.
805
853
  """
806
- template = cls.app.WorkflowTemplate.from_YAML_string(
854
+ template = cls._app.WorkflowTemplate.from_YAML_string(
807
855
  string=YAML_str,
808
856
  variables=variables,
809
857
  )
@@ -822,16 +870,16 @@ class Workflow:
822
870
  def from_JSON_file(
823
871
  cls,
824
872
  JSON_path: PathLike,
825
- path: Optional[str] = None,
826
- name: Optional[str] = None,
827
- overwrite: Optional[bool] = False,
828
- store: Optional[str] = DEFAULT_STORE_FORMAT,
829
- ts_fmt: Optional[str] = None,
830
- ts_name_fmt: Optional[str] = None,
831
- store_kwargs: Optional[Dict] = None,
832
- variables: Optional[Dict[str, str]] = None,
833
- status: Optional[Any] = None,
834
- ) -> app.Workflow:
873
+ path: PathLike = None,
874
+ name: str | None = None,
875
+ overwrite: bool = False,
876
+ store: str = DEFAULT_STORE_FORMAT,
877
+ ts_fmt: str | None = None,
878
+ ts_name_fmt: str | None = None,
879
+ store_kwargs: dict[str, Any] | None = None,
880
+ variables: dict[str, str] | None = None,
881
+ status: Status | None = None,
882
+ ) -> Workflow:
835
883
  """Generate from a JSON file.
836
884
 
837
885
  Parameters
@@ -862,7 +910,7 @@ class Workflow:
862
910
  variables:
863
911
  String variables to substitute in the file given by `JSON_path`.
864
912
  """
865
- template = cls.app.WorkflowTemplate.from_JSON_file(
913
+ template = cls._app.WorkflowTemplate.from_JSON_file(
866
914
  path=JSON_path,
867
915
  variables=variables,
868
916
  )
@@ -881,17 +929,17 @@ class Workflow:
881
929
  @classmethod
882
930
  def from_JSON_string(
883
931
  cls,
884
- JSON_str: PathLike,
885
- path: Optional[str] = None,
886
- name: Optional[str] = None,
887
- overwrite: Optional[bool] = False,
888
- store: Optional[str] = DEFAULT_STORE_FORMAT,
889
- ts_fmt: Optional[str] = None,
890
- ts_name_fmt: Optional[str] = None,
891
- store_kwargs: Optional[Dict] = None,
892
- variables: Optional[Dict[str, str]] = None,
893
- status: Optional[Any] = None,
894
- ) -> app.Workflow:
932
+ JSON_str: str,
933
+ path: PathLike = None,
934
+ name: str | None = None,
935
+ overwrite: bool = False,
936
+ store: str = DEFAULT_STORE_FORMAT,
937
+ ts_fmt: str | None = None,
938
+ ts_name_fmt: str | None = None,
939
+ store_kwargs: dict[str, Any] | None = None,
940
+ variables: dict[str, str] | None = None,
941
+ status: Status | None = None,
942
+ ) -> Workflow:
895
943
  """Generate from a JSON string.
896
944
 
897
945
  Parameters
@@ -922,7 +970,7 @@ class Workflow:
922
970
  variables:
923
971
  String variables to substitute in the string `JSON_str`.
924
972
  """
925
- template = cls.app.WorkflowTemplate.from_JSON_string(
973
+ template = cls._app.WorkflowTemplate.from_JSON_string(
926
974
  string=JSON_str,
927
975
  variables=variables,
928
976
  )
@@ -943,17 +991,17 @@ class Workflow:
943
991
  def from_file(
944
992
  cls,
945
993
  template_path: PathLike,
946
- template_format: Optional[str] = None,
947
- path: Optional[str] = None,
948
- name: Optional[str] = None,
949
- overwrite: Optional[bool] = False,
950
- store: Optional[str] = DEFAULT_STORE_FORMAT,
951
- ts_fmt: Optional[str] = None,
952
- ts_name_fmt: Optional[str] = None,
953
- store_kwargs: Optional[Dict] = None,
954
- variables: Optional[Dict[str, str]] = None,
955
- status: Optional[Any] = None,
956
- ) -> app.Workflow:
994
+ template_format: Literal["json", "yaml"] | None = None,
995
+ path: str | None = None,
996
+ name: str | None = None,
997
+ overwrite: bool = False,
998
+ store: str = DEFAULT_STORE_FORMAT,
999
+ ts_fmt: str | None = None,
1000
+ ts_name_fmt: str | None = None,
1001
+ store_kwargs: dict[str, Any] | None = None,
1002
+ variables: dict[str, str] | None = None,
1003
+ status: Status | None = None,
1004
+ ) -> Workflow:
957
1005
  """Generate from either a YAML or JSON file, depending on the file extension.
958
1006
 
959
1007
  Parameters
@@ -989,7 +1037,7 @@ class Workflow:
989
1037
  String variables to substitute in the file given by `template_path`.
990
1038
  """
991
1039
  try:
992
- template = cls.app.WorkflowTemplate.from_file(
1040
+ template = cls._app.WorkflowTemplate.from_file(
993
1041
  template_path,
994
1042
  template_format,
995
1043
  variables=variables,
@@ -1015,17 +1063,17 @@ class Workflow:
1015
1063
  def from_template_data(
1016
1064
  cls,
1017
1065
  template_name: str,
1018
- tasks: Optional[List[app.Task]] = None,
1019
- loops: Optional[List[app.Loop]] = None,
1020
- resources: Optional[Dict[str, Dict]] = None,
1021
- path: Optional[PathLike] = None,
1022
- workflow_name: Optional[str] = None,
1023
- overwrite: Optional[bool] = False,
1024
- store: Optional[str] = DEFAULT_STORE_FORMAT,
1025
- ts_fmt: Optional[str] = None,
1026
- ts_name_fmt: Optional[str] = None,
1027
- store_kwargs: Optional[Dict] = None,
1028
- ) -> app.Workflow:
1066
+ tasks: list[Task] | None = None,
1067
+ loops: list[Loop] | None = None,
1068
+ resources: Resources = None,
1069
+ path: PathLike | None = None,
1070
+ workflow_name: str | None = None,
1071
+ overwrite: bool = False,
1072
+ store: str = DEFAULT_STORE_FORMAT,
1073
+ ts_fmt: str | None = None,
1074
+ ts_name_fmt: str | None = None,
1075
+ store_kwargs: dict[str, Any] | None = None,
1076
+ ) -> Workflow:
1029
1077
  """Generate from the data associated with a WorkflowTemplate object.
1030
1078
 
1031
1079
  Parameters
@@ -1063,7 +1111,7 @@ class Workflow:
1063
1111
  store_kwargs:
1064
1112
  Keyword arguments to pass to the store's `write_empty_workflow` method.
1065
1113
  """
1066
- template = cls.app.WorkflowTemplate(
1114
+ template = cls._app.WorkflowTemplate(
1067
1115
  template_name,
1068
1116
  tasks=tasks or [],
1069
1117
  loops=loops or [],
@@ -1083,9 +1131,9 @@ class Workflow:
1083
1131
  @TimeIt.decorator
1084
1132
  def _add_empty_task(
1085
1133
  self,
1086
- task: app.Task,
1087
- new_index: Optional[int] = None,
1088
- ) -> app.WorkflowTask:
1134
+ task: Task,
1135
+ new_index: int | None = None,
1136
+ ) -> WorkflowTask:
1089
1137
  if new_index is None:
1090
1138
  new_index = self.num_tasks
1091
1139
 
@@ -1099,70 +1147,73 @@ class Workflow:
1099
1147
 
1100
1148
  # create and insert a new WorkflowTask:
1101
1149
  self.tasks.add_object(
1102
- self.app.WorkflowTask.new_empty_task(self, task_c, new_index),
1150
+ self._app.WorkflowTask.new_empty_task(self, task_c, new_index),
1103
1151
  index=new_index,
1104
1152
  )
1105
1153
 
1106
1154
  # update persistent store:
1107
1155
  task_js, temp_comps_js = task_c.to_json_like()
1156
+ assert temp_comps_js is not None
1108
1157
  self._store.add_template_components(temp_comps_js)
1109
- self._store.add_task(new_index, task_js)
1158
+ self._store.add_task(new_index, cast("Mapping", task_js))
1110
1159
 
1111
1160
  # update in-memory workflow template components:
1112
- temp_comps = self.app.template_components_from_json_like(temp_comps_js)
1161
+ temp_comps = cast(
1162
+ "_TemplateComponents",
1163
+ self._app.template_components_from_json_like(temp_comps_js),
1164
+ )
1113
1165
  for comp_type, comps in temp_comps.items():
1166
+ ol = self.__template_components[comp_type]
1114
1167
  for comp in comps:
1115
1168
  comp._set_hash()
1116
- if comp not in self.template_components[comp_type]:
1117
- idx = self.template_components[comp_type].add_object(comp)
1118
- self._pending["template_components"][comp_type].append(idx)
1169
+ if comp not in ol:
1170
+ self._pending["template_components"][comp_type].append(
1171
+ ol.add_object(comp, skip_duplicates=False)
1172
+ )
1119
1173
 
1120
1174
  self._pending["tasks"].append(new_index)
1121
1175
  return self.tasks[new_index]
1122
1176
 
1123
1177
  @TimeIt.decorator
1124
- def _add_task(self, task: app.Task, new_index: Optional[int] = None) -> None:
1178
+ def _add_task(self, task: Task, new_index: int | None = None) -> None:
1125
1179
  new_wk_task = self._add_empty_task(task=task, new_index=new_index)
1126
- new_wk_task._add_elements(element_sets=task.element_sets)
1180
+ new_wk_task._add_elements(element_sets=task.element_sets, propagate_to={})
1127
1181
 
1128
- def add_task(self, task: app.Task, new_index: Optional[int] = None) -> None:
1182
+ def add_task(self, task: Task, new_index: int | None = None) -> None:
1129
1183
  """
1130
1184
  Add a task to this workflow.
1131
1185
  """
1132
- with self._store.cached_load():
1133
- with self.batch_update():
1134
- self._add_task(task, new_index=new_index)
1186
+ with self._store.cached_load(), self.batch_update():
1187
+ self._add_task(task, new_index=new_index)
1135
1188
 
1136
- def add_task_after(self, new_task: app.Task, task_ref: app.Task = None) -> None:
1189
+ def add_task_after(self, new_task: Task, task_ref: Task | None = None) -> None:
1137
1190
  """Add a new task after the specified task.
1138
1191
 
1139
1192
  Parameters
1140
1193
  ----------
1141
1194
  task_ref
1142
1195
  If not given, the new task will be added at the end of the workflow.
1143
-
1144
1196
  """
1145
- new_index = task_ref.index + 1 if task_ref else None
1197
+ new_index = (
1198
+ task_ref.index + 1 if task_ref and task_ref.index is not None else None
1199
+ )
1146
1200
  self.add_task(new_task, new_index)
1147
1201
  # TODO: add new downstream elements?
1148
1202
 
1149
- def add_task_before(self, new_task: app.Task, task_ref: app.Task = None) -> None:
1203
+ def add_task_before(self, new_task: Task, task_ref: Task | None = None) -> None:
1150
1204
  """Add a new task before the specified task.
1151
1205
 
1152
1206
  Parameters
1153
1207
  ----------
1154
1208
  task_ref
1155
1209
  If not given, the new task will be added at the beginning of the workflow.
1156
-
1157
1210
  """
1158
1211
  new_index = task_ref.index if task_ref else 0
1159
1212
  self.add_task(new_task, new_index)
1160
1213
  # TODO: add new downstream elements?
1161
1214
 
1162
1215
  @TimeIt.decorator
1163
- def _add_empty_loop(
1164
- self, loop: app.Loop, cache: LoopCache
1165
- ) -> Tuple[app.WorkflowLoop, List[app.ElementIteration]]:
1216
+ def _add_empty_loop(self, loop: Loop, cache: LoopCache) -> WorkflowLoop:
1166
1217
  """Add a new loop (zeroth iterations only) to the workflow."""
1167
1218
 
1168
1219
  new_index = self.num_loops
@@ -1178,7 +1229,7 @@ class Workflow:
1178
1229
  iter_loop_idx = cache.get_iter_loop_indices(iter_IDs)
1179
1230
 
1180
1231
  # create and insert a new WorkflowLoop:
1181
- new_loop = self.app.WorkflowLoop.new_empty_loop(
1232
+ new_loop = self._app.WorkflowLoop.new_empty_loop(
1182
1233
  index=new_index,
1183
1234
  workflow=self,
1184
1235
  template=loop_c,
@@ -1195,7 +1246,7 @@ class Workflow:
1195
1246
 
1196
1247
  # update persistent store:
1197
1248
  self._store.add_loop(
1198
- loop_template=loop_js,
1249
+ loop_template=cast("Mapping", loop_js),
1199
1250
  iterable_parameters=wk_loop.iterable_parameters,
1200
1251
  parents=wk_loop.parents,
1201
1252
  num_added_iterations=wk_loop.num_added_iterations,
@@ -1205,17 +1256,16 @@ class Workflow:
1205
1256
  self._pending["loops"].append(new_index)
1206
1257
 
1207
1258
  # update cache loop indices:
1208
- cache.update_loop_indices(new_loop_name=loop_c.name, iter_IDs=iter_IDs)
1259
+ cache.update_loop_indices(new_loop_name=loop_c.name or "", iter_IDs=iter_IDs)
1209
1260
 
1210
1261
  return wk_loop
1211
1262
 
1212
1263
  @TimeIt.decorator
1213
1264
  def _add_loop(
1214
- self, loop: app.Loop, cache: Optional[Dict] = None, status: Optional[Any] = None
1265
+ self, loop: Loop, cache: LoopCache | None = None, status: Status | None = None
1215
1266
  ) -> None:
1216
- if not cache:
1217
- cache = LoopCache.build(workflow=self, loops=[loop])
1218
- new_wk_loop = self._add_empty_loop(loop, cache)
1267
+ cache_ = cache or LoopCache.build(workflow=self, loops=[loop])
1268
+ new_wk_loop = self._add_empty_loop(loop, cache_)
1219
1269
  if loop.num_iterations is not None:
1220
1270
  # fixed number of iterations, so add remaining N > 0 iterations:
1221
1271
  if status:
@@ -1225,38 +1275,36 @@ class Workflow:
1225
1275
  status.update(
1226
1276
  f"{status_prev}: iteration {iter_idx + 2}/{loop.num_iterations}."
1227
1277
  )
1228
- new_wk_loop.add_iteration(cache=cache)
1278
+ new_wk_loop.add_iteration(cache=cache_)
1229
1279
 
1230
- def add_loop(self, loop: app.Loop) -> None:
1280
+ def add_loop(self, loop: Loop) -> None:
1231
1281
  """Add a loop to a subset of workflow tasks."""
1232
- with self._store.cached_load():
1233
- with self.batch_update():
1234
- self._add_loop(loop)
1282
+ with self._store.cached_load(), self.batch_update():
1283
+ self._add_loop(loop)
1235
1284
 
1236
1285
  @property
1237
- def creation_info(self):
1286
+ def creation_info(self) -> CreationInfo:
1238
1287
  """
1239
1288
  The creation descriptor for the workflow.
1240
1289
  """
1241
1290
  if not self._creation_info:
1242
1291
  info = self._store.get_creation_info()
1243
- info["create_time"] = (
1244
- datetime.strptime(info["create_time"], self.ts_fmt)
1245
- .replace(tzinfo=timezone.utc)
1246
- .astimezone()
1247
- )
1248
- self._creation_info = info
1292
+ self._creation_info = {
1293
+ "app_info": info["app_info"],
1294
+ "create_time": parse_timestamp(info["create_time"], self.ts_fmt),
1295
+ "id": info["id"],
1296
+ }
1249
1297
  return self._creation_info
1250
1298
 
1251
1299
  @property
1252
- def id_(self):
1300
+ def id_(self) -> str:
1253
1301
  """
1254
1302
  The ID of this workflow.
1255
1303
  """
1256
1304
  return self.creation_info["id"]
1257
1305
 
1258
1306
  @property
1259
- def ts_fmt(self):
1307
+ def ts_fmt(self) -> str:
1260
1308
  """
1261
1309
  The timestamp format.
1262
1310
  """
@@ -1265,7 +1313,7 @@ class Workflow:
1265
1313
  return self._ts_fmt
1266
1314
 
1267
1315
  @property
1268
- def ts_name_fmt(self):
1316
+ def ts_name_fmt(self) -> str:
1269
1317
  """
1270
1318
  The timestamp format for names.
1271
1319
  """
@@ -1274,18 +1322,24 @@ class Workflow:
1274
1322
  return self._ts_name_fmt
1275
1323
 
1276
1324
  @property
1277
- def template_components(self) -> Dict:
1325
+ def template_components(self) -> TemplateComponents:
1278
1326
  """
1279
1327
  The template components used for this workflow.
1280
1328
  """
1281
1329
  if self._template_components is None:
1282
1330
  with self._store.cached_load():
1283
1331
  tc_js = self._store.get_template_components()
1284
- self._template_components = self.app.template_components_from_json_like(tc_js)
1332
+ self._template_components = self._app.template_components_from_json_like(
1333
+ tc_js
1334
+ )
1285
1335
  return self._template_components
1286
1336
 
1287
1337
  @property
1288
- def template(self) -> app.WorkflowTemplate:
1338
+ def __template_components(self) -> _TemplateComponents:
1339
+ return cast("_TemplateComponents", self.template_components)
1340
+
1341
+ @property
1342
+ def template(self) -> WorkflowTemplate:
1289
1343
  """
1290
1344
  The template that this workflow was made from.
1291
1345
  """
@@ -1294,11 +1348,11 @@ class Workflow:
1294
1348
  temp_js = self._store.get_template()
1295
1349
 
1296
1350
  # TODO: insert_ID and id_ are the same thing:
1297
- for task in temp_js["tasks"]:
1351
+ for task in cast("list[dict]", temp_js["tasks"]):
1298
1352
  task.pop("id_", None)
1299
1353
 
1300
- template = self.app.WorkflowTemplate.from_json_like(
1301
- temp_js, self.template_components
1354
+ template = self._app.WorkflowTemplate.from_json_like(
1355
+ temp_js, cast("dict", self.template_components)
1302
1356
  )
1303
1357
  template.workflow = self
1304
1358
  self._template = template
@@ -1306,62 +1360,74 @@ class Workflow:
1306
1360
  return self._template
1307
1361
 
1308
1362
  @property
1309
- def tasks(self) -> app.WorkflowTaskList:
1363
+ def tasks(self) -> WorkflowTaskList:
1310
1364
  """
1311
1365
  The tasks in this workflow.
1312
1366
  """
1313
1367
  if self._tasks is None:
1314
1368
  with self._store.cached_load():
1315
- all_tasks = self._store.get_tasks()
1316
- wk_tasks = []
1317
- for i in all_tasks:
1318
- wk_task = self.app.WorkflowTask(
1369
+ all_tasks: Iterable[StoreTask] = self._store.get_tasks()
1370
+ self._tasks = self._app.WorkflowTaskList(
1371
+ self._app.WorkflowTask(
1319
1372
  workflow=self,
1320
- template=self.template.tasks[i.index],
1321
- index=i.index,
1322
- element_IDs=i.element_IDs,
1373
+ template=self.template.tasks[task.index],
1374
+ index=task.index,
1375
+ element_IDs=task.element_IDs,
1323
1376
  )
1324
- wk_tasks.append(wk_task)
1325
- self._tasks = self.app.WorkflowTaskList(wk_tasks)
1377
+ for task in all_tasks
1378
+ )
1326
1379
 
1327
1380
  return self._tasks
1328
1381
 
1329
1382
  @property
1330
- def loops(self) -> app.WorkflowLoopList:
1383
+ def loops(self) -> WorkflowLoopList:
1331
1384
  """
1332
1385
  The loops in this workflow.
1333
1386
  """
1387
+
1388
+ def repack_iteration_tuples(
1389
+ num_added_iterations: list[list[list[int] | int]],
1390
+ ) -> Iterator[tuple[tuple[int, ...], int]]:
1391
+ """
1392
+ Unpacks a very ugly type from the persistence layer, turning it into
1393
+ something we can process into a dict more easily. This in turn is caused
1394
+ by JSON and Zarr not really understanding tuples as such.
1395
+ """
1396
+ for item in num_added_iterations:
1397
+ # Convert the outside to a tuple and narrow the inner types
1398
+ key_vec, count = item
1399
+ yield tuple(cast("list[int]", key_vec)), cast("int", count)
1400
+
1334
1401
  if self._loops is None:
1335
1402
  with self._store.cached_load():
1336
- wk_loops = []
1337
- for idx, loop_dat in self._store.get_loops().items():
1338
- num_add_iters = {
1339
- tuple(i[0]): i[1] for i in loop_dat["num_added_iterations"]
1340
- }
1341
- wk_loop = self.app.WorkflowLoop(
1403
+ self._loops = self._app.WorkflowLoopList(
1404
+ self._app.WorkflowLoop(
1342
1405
  index=idx,
1343
1406
  workflow=self,
1344
1407
  template=self.template.loops[idx],
1345
1408
  parents=loop_dat["parents"],
1346
- num_added_iterations=num_add_iters,
1409
+ num_added_iterations=dict(
1410
+ repack_iteration_tuples(loop_dat["num_added_iterations"])
1411
+ ),
1347
1412
  iterable_parameters=loop_dat["iterable_parameters"],
1348
1413
  )
1349
- wk_loops.append(wk_loop)
1350
- self._loops = self.app.WorkflowLoopList(wk_loops)
1414
+ for idx, loop_dat in self._store.get_loops().items()
1415
+ )
1351
1416
  return self._loops
1352
1417
 
1353
1418
  @property
1354
- def submissions(self) -> List[app.Submission]:
1419
+ def submissions(self) -> list[Submission]:
1355
1420
  """
1356
1421
  The job submissions done by this workflow.
1357
1422
  """
1358
1423
  if self._submissions is None:
1359
- self.app.persistence_logger.debug("loading workflow submissions")
1424
+ self._app.persistence_logger.debug("loading workflow submissions")
1360
1425
  with self._store.cached_load():
1361
- subs = []
1426
+ subs: list[Submission] = []
1362
1427
  for idx, sub_dat in self._store.get_submissions().items():
1363
- sub_js = {"index": idx, **sub_dat}
1364
- sub = self.app.Submission.from_json_like(sub_js)
1428
+ sub = self._app.Submission.from_json_like(
1429
+ {"index": idx, **cast("dict", sub_dat)}
1430
+ )
1365
1431
  sub.workflow = self
1366
1432
  subs.append(sub)
1367
1433
  self._submissions = subs
@@ -1375,7 +1441,7 @@ class Workflow:
1375
1441
  return self._store._get_num_total_added_tasks()
1376
1442
 
1377
1443
  @TimeIt.decorator
1378
- def get_store_EARs(self, id_lst: Iterable[int]) -> List[AnySEAR]:
1444
+ def get_store_EARs(self, id_lst: Iterable[int]) -> Sequence[StoreEAR]:
1379
1445
  """
1380
1446
  Get the persistent element action runs.
1381
1447
  """
@@ -1384,226 +1450,210 @@ class Workflow:
1384
1450
  @TimeIt.decorator
1385
1451
  def get_store_element_iterations(
1386
1452
  self, id_lst: Iterable[int]
1387
- ) -> List[AnySElementIter]:
1453
+ ) -> Sequence[StoreElementIter]:
1388
1454
  """
1389
1455
  Get the persistent element iterations.
1390
1456
  """
1391
1457
  return self._store.get_element_iterations(id_lst)
1392
1458
 
1393
1459
  @TimeIt.decorator
1394
- def get_store_elements(self, id_lst: Iterable[int]) -> List[AnySElement]:
1460
+ def get_store_elements(self, id_lst: Iterable[int]) -> Sequence[StoreElement]:
1395
1461
  """
1396
1462
  Get the persistent elements.
1397
1463
  """
1398
1464
  return self._store.get_elements(id_lst)
1399
1465
 
1400
1466
  @TimeIt.decorator
1401
- def get_store_tasks(self, id_lst: Iterable[int]) -> List[AnySTask]:
1467
+ def get_store_tasks(self, id_lst: Iterable[int]) -> Sequence[StoreTask]:
1402
1468
  """
1403
1469
  Get the persistent tasks.
1404
1470
  """
1405
1471
  return self._store.get_tasks_by_IDs(id_lst)
1406
1472
 
1407
- def get_element_iteration_IDs_from_EAR_IDs(self, id_lst: Iterable[int]) -> List[int]:
1473
+ def get_element_iteration_IDs_from_EAR_IDs(self, id_lst: Iterable[int]) -> list[int]:
1408
1474
  """
1409
1475
  Get the element iteration IDs of EARs.
1410
1476
  """
1411
- return [i.elem_iter_ID for i in self.get_store_EARs(id_lst)]
1477
+ return [ear.elem_iter_ID for ear in self.get_store_EARs(id_lst)]
1412
1478
 
1413
- def get_element_IDs_from_EAR_IDs(self, id_lst: Iterable[int]) -> List[int]:
1479
+ def get_element_IDs_from_EAR_IDs(self, id_lst: Iterable[int]) -> list[int]:
1414
1480
  """
1415
1481
  Get the element IDs of EARs.
1416
1482
  """
1417
1483
  iter_IDs = self.get_element_iteration_IDs_from_EAR_IDs(id_lst)
1418
- return [i.element_ID for i in self.get_store_element_iterations(iter_IDs)]
1484
+ return [itr.element_ID for itr in self.get_store_element_iterations(iter_IDs)]
1419
1485
 
1420
- def get_task_IDs_from_element_IDs(self, id_lst: Iterable[int]) -> List[int]:
1486
+ def get_task_IDs_from_element_IDs(self, id_lst: Iterable[int]) -> list[int]:
1421
1487
  """
1422
1488
  Get the task IDs of elements.
1423
1489
  """
1424
- return [i.task_ID for i in self.get_store_elements(id_lst)]
1490
+ return [elem.task_ID for elem in self.get_store_elements(id_lst)]
1425
1491
 
1426
- def get_EAR_IDs_of_tasks(self, id_lst: int) -> List[int]:
1492
+ def get_EAR_IDs_of_tasks(self, id_lst: Iterable[int]) -> list[int]:
1427
1493
  """Get EAR IDs belonging to multiple tasks."""
1428
- return [i.id_ for i in self.get_EARs_of_tasks(id_lst)]
1429
-
1430
- def get_EARs_of_tasks(self, id_lst: Iterable[int]) -> List[app.ElementActionRun]:
1431
- """Get EARs belonging to multiple task.s"""
1432
- EARs = []
1433
- for i in id_lst:
1434
- task = self.tasks.get(insert_ID=i)
1435
- for elem in task.elements[:]:
1494
+ return [ear.id_ for ear in self.get_EARs_of_tasks(id_lst)]
1495
+
1496
+ def get_EARs_of_tasks(self, id_lst: Iterable[int]) -> Iterator[ElementActionRun]:
1497
+ """Get EARs belonging to multiple tasks."""
1498
+ for id_ in id_lst:
1499
+ for elem in self.tasks.get(insert_ID=id_).elements[:]:
1436
1500
  for iter_ in elem.iterations:
1437
- for run in iter_.action_runs:
1438
- EARs.append(run)
1439
- return EARs
1501
+ yield from iter_.action_runs
1440
1502
 
1441
1503
  def get_element_iterations_of_tasks(
1442
1504
  self, id_lst: Iterable[int]
1443
- ) -> List[app.ElementIteration]:
1505
+ ) -> Iterator[ElementIteration]:
1444
1506
  """Get element iterations belonging to multiple tasks."""
1445
- iters = []
1446
- for i in id_lst:
1447
- task = self.tasks.get(insert_ID=i)
1448
- for elem in task.elements[:]:
1449
- for iter_i in elem.iterations:
1450
- iters.append(iter_i)
1451
- return iters
1507
+ for id_ in id_lst:
1508
+ for elem in self.tasks.get(insert_ID=id_).elements[:]:
1509
+ yield from elem.iterations
1510
+
1511
+ @dataclass
1512
+ class _IndexPath1:
1513
+ elem: int
1514
+ task: int
1452
1515
 
1453
1516
  @TimeIt.decorator
1454
- def get_elements_from_IDs(self, id_lst: Iterable[int]) -> List[app.Element]:
1517
+ def get_elements_from_IDs(self, id_lst: Iterable[int]) -> list[Element]:
1455
1518
  """Return element objects from a list of IDs."""
1456
1519
 
1457
- store_elems = self._store.get_elements(id_lst)
1458
-
1459
- task_IDs = [i.task_ID for i in store_elems]
1460
- store_tasks = self._store.get_tasks_by_IDs(task_IDs)
1520
+ store_elems = self.get_store_elements(id_lst)
1521
+ store_tasks = self.get_store_tasks(el.task_ID for el in store_elems)
1461
1522
 
1462
- element_idx_by_task = defaultdict(set)
1463
- index_paths = []
1464
- for el, tk in zip(store_elems, store_tasks):
1465
- elem_idx = tk.element_IDs.index(el.id_)
1466
- index_paths.append(
1467
- {
1468
- "elem_idx": elem_idx,
1469
- "task_idx": tk.index,
1470
- }
1471
- )
1472
- element_idx_by_task[tk.index].add(elem_idx)
1523
+ element_idx_by_task: dict[int, set[int]] = defaultdict(set)
1524
+ index_paths: list[Workflow._IndexPath1] = []
1525
+ for elem, task in zip(store_elems, store_tasks):
1526
+ elem_idx = task.element_IDs.index(elem.id_)
1527
+ index_paths.append(Workflow._IndexPath1(elem_idx, task.index))
1528
+ element_idx_by_task[task.index].add(elem_idx)
1473
1529
 
1474
- elements_by_task = {}
1475
- for task_idx, elem_idx in element_idx_by_task.items():
1476
- task = self.tasks[task_idx]
1477
- elements_by_task[task_idx] = dict(
1478
- zip(elem_idx, task.elements[list(elem_idx)])
1479
- )
1530
+ elements_by_task = {
1531
+ task_idx: {idx: self.tasks[task_idx].elements[idx] for idx in elem_idxes}
1532
+ for task_idx, elem_idxes in element_idx_by_task.items()
1533
+ }
1480
1534
 
1481
- objs = []
1482
- for idx_dat in index_paths:
1483
- elem = elements_by_task[idx_dat["task_idx"]][idx_dat["elem_idx"]]
1484
- objs.append(elem)
1535
+ return [elements_by_task[path.task][path.elem] for path in index_paths]
1485
1536
 
1486
- return objs
1537
+ @dataclass
1538
+ class _IndexPath2:
1539
+ iter: int
1540
+ elem: int
1541
+ task: int
1487
1542
 
1488
1543
  @TimeIt.decorator
1489
1544
  def get_element_iterations_from_IDs(
1490
1545
  self, id_lst: Iterable[int]
1491
- ) -> List[app.ElementIteration]:
1546
+ ) -> list[ElementIteration]:
1492
1547
  """Return element iteration objects from a list of IDs."""
1493
1548
 
1494
- store_iters = self._store.get_element_iterations(id_lst)
1549
+ store_iters = self.get_store_element_iterations(id_lst)
1550
+ store_elems = self.get_store_elements(it.element_ID for it in store_iters)
1551
+ store_tasks = self.get_store_tasks(el.task_ID for el in store_elems)
1495
1552
 
1496
- elem_IDs = [i.element_ID for i in store_iters]
1497
- store_elems = self._store.get_elements(elem_IDs)
1553
+ element_idx_by_task: dict[int, set[int]] = defaultdict(set)
1498
1554
 
1499
- task_IDs = [i.task_ID for i in store_elems]
1500
- store_tasks = self._store.get_tasks_by_IDs(task_IDs)
1555
+ index_paths: list[Workflow._IndexPath2] = []
1556
+ for itr, elem, task in zip(store_iters, store_elems, store_tasks):
1557
+ iter_idx = elem.iteration_IDs.index(itr.id_)
1558
+ elem_idx = task.element_IDs.index(elem.id_)
1559
+ index_paths.append(Workflow._IndexPath2(iter_idx, elem_idx, task.index))
1560
+ element_idx_by_task[task.index].add(elem_idx)
1501
1561
 
1502
- element_idx_by_task = defaultdict(set)
1562
+ elements_by_task = {
1563
+ task_idx: {idx: self.tasks[task_idx].elements[idx] for idx in elem_idx}
1564
+ for task_idx, elem_idx in element_idx_by_task.items()
1565
+ }
1503
1566
 
1504
- index_paths = []
1505
- for it, el, tk in zip(store_iters, store_elems, store_tasks):
1506
- iter_idx = el.iteration_IDs.index(it.id_)
1507
- elem_idx = tk.element_IDs.index(el.id_)
1508
- index_paths.append(
1509
- {
1510
- "iter_idx": iter_idx,
1511
- "elem_idx": elem_idx,
1512
- "task_idx": tk.index,
1513
- }
1514
- )
1515
- element_idx_by_task[tk.index].add(elem_idx)
1567
+ return [
1568
+ elements_by_task[path.task][path.elem].iterations[path.iter]
1569
+ for path in index_paths
1570
+ ]
1516
1571
 
1517
- elements_by_task = {}
1518
- for task_idx, elem_idx in element_idx_by_task.items():
1519
- task = self.tasks[task_idx]
1520
- elements_by_task[task_idx] = dict(
1521
- zip(elem_idx, task.elements[list(elem_idx)])
1522
- )
1572
+ @dataclass
1573
+ class _IndexPath3:
1574
+ run: int
1575
+ act: int
1576
+ iter: int
1577
+ elem: int
1578
+ task: int
1523
1579
 
1524
- objs = []
1525
- for idx_dat in index_paths:
1526
- elem = elements_by_task[idx_dat["task_idx"]][idx_dat["elem_idx"]]
1527
- iter_ = elem.iterations[idx_dat["iter_idx"]]
1528
- objs.append(iter_)
1580
+ @overload
1581
+ def get_EARs_from_IDs(self, ids: Iterable[int]) -> list[ElementActionRun]:
1582
+ ...
1529
1583
 
1530
- return objs
1584
+ @overload
1585
+ def get_EARs_from_IDs(self, ids: int) -> ElementActionRun:
1586
+ ...
1531
1587
 
1532
1588
  @TimeIt.decorator
1533
- def get_EARs_from_IDs(self, id_lst: Iterable[int]) -> List[app.ElementActionRun]:
1589
+ def get_EARs_from_IDs(
1590
+ self, ids: Iterable[int] | int
1591
+ ) -> list[ElementActionRun] | ElementActionRun:
1534
1592
  """Get element action run objects from a list of IDs."""
1535
- self.app.persistence_logger.debug(f"get_EARs_from_IDs: id_lst={id_lst!r}")
1536
-
1537
- store_EARs = self._store.get_EARs(id_lst)
1538
-
1539
- elem_iter_IDs = [i.elem_iter_ID for i in store_EARs]
1540
- store_iters = self._store.get_element_iterations(elem_iter_IDs)
1541
-
1542
- elem_IDs = [i.element_ID for i in store_iters]
1543
- store_elems = self._store.get_elements(elem_IDs)
1593
+ id_lst = [ids] if isinstance(ids, int) else list(ids)
1594
+ self._app.persistence_logger.debug(f"get_EARs_from_IDs: id_lst={id_lst!r}")
1544
1595
 
1545
- task_IDs = [i.task_ID for i in store_elems]
1546
- store_tasks = self._store.get_tasks_by_IDs(task_IDs)
1596
+ store_EARs = self.get_store_EARs(id_lst)
1597
+ store_iters = self.get_store_element_iterations(
1598
+ ear.elem_iter_ID for ear in store_EARs
1599
+ )
1600
+ store_elems = self.get_store_elements(it.element_ID for it in store_iters)
1601
+ store_tasks = self.get_store_tasks(el.task_ID for el in store_elems)
1547
1602
 
1548
1603
  # to allow for bulk retrieval of elements/iterations
1549
- element_idx_by_task = defaultdict(set)
1550
- iter_idx_by_task_elem = defaultdict(lambda: defaultdict(set))
1604
+ element_idx_by_task: dict[int, set[int]] = defaultdict(set)
1605
+ iter_idx_by_task_elem: dict[int, dict[int, set[int]]] = defaultdict(
1606
+ lambda: defaultdict(set)
1607
+ )
1551
1608
 
1552
- index_paths = []
1609
+ index_paths: list[Workflow._IndexPath3] = []
1553
1610
  for rn, it, el, tk in zip(store_EARs, store_iters, store_elems, store_tasks):
1554
1611
  act_idx = rn.action_idx
1555
- run_idx = it.EAR_IDs[act_idx].index(rn.id_)
1612
+ run_idx = it.EAR_IDs[act_idx].index(rn.id_) if it.EAR_IDs is not None else -1
1556
1613
  iter_idx = el.iteration_IDs.index(it.id_)
1557
1614
  elem_idx = tk.element_IDs.index(el.id_)
1558
1615
  index_paths.append(
1559
- {
1560
- "run_idx": run_idx,
1561
- "action_idx": act_idx,
1562
- "iter_idx": iter_idx,
1563
- "elem_idx": elem_idx,
1564
- "task_idx": tk.index,
1565
- }
1616
+ Workflow._IndexPath3(run_idx, act_idx, iter_idx, elem_idx, tk.index)
1566
1617
  )
1567
1618
  element_idx_by_task[tk.index].add(elem_idx)
1568
1619
  iter_idx_by_task_elem[tk.index][elem_idx].add(iter_idx)
1569
1620
 
1570
1621
  # retrieve elements/iterations:
1571
- iters_by_task_elem = defaultdict(lambda: defaultdict(dict))
1572
- for task_idx, elem_idx in element_idx_by_task.items():
1573
- elements = self.tasks[task_idx].elements[list(elem_idx)]
1574
- for elem_i in elements:
1575
- elem_i_iters_idx = iter_idx_by_task_elem[task_idx][elem_i.index]
1576
- elem_iters = [elem_i.iterations[j] for j in elem_i_iters_idx]
1577
- iters_by_task_elem[task_idx][elem_i.index].update(
1578
- dict(zip(elem_i_iters_idx, elem_iters))
1579
- )
1580
-
1581
- objs = []
1582
- for idx_dat in index_paths:
1583
- iter_ = iters_by_task_elem[idx_dat["task_idx"]][idx_dat["elem_idx"]][
1584
- idx_dat["iter_idx"]
1585
- ]
1586
- run = iter_.actions[idx_dat["action_idx"]].runs[idx_dat["run_idx"]]
1587
- objs.append(run)
1622
+ iters = {
1623
+ task_idx: {
1624
+ elem_i.index: {
1625
+ iter_idx: elem_i.iterations[iter_idx]
1626
+ for iter_idx in iter_idx_by_task_elem[task_idx][elem_i.index]
1627
+ }
1628
+ for elem_i in self.tasks[task_idx].elements[list(elem_idxes)]
1629
+ }
1630
+ for task_idx, elem_idxes in element_idx_by_task.items()
1631
+ }
1588
1632
 
1589
- return objs
1633
+ result = [
1634
+ iters[path.task][path.elem][path.iter].actions[path.act].runs[path.run]
1635
+ for path in index_paths
1636
+ ]
1637
+ if isinstance(ids, int):
1638
+ return result[0]
1639
+ return result
1590
1640
 
1591
1641
  @TimeIt.decorator
1592
- def get_all_elements(self) -> List[app.Element]:
1642
+ def get_all_elements(self) -> list[Element]:
1593
1643
  """
1594
1644
  Get all elements in the workflow.
1595
1645
  """
1596
1646
  return self.get_elements_from_IDs(range(self.num_elements))
1597
1647
 
1598
1648
  @TimeIt.decorator
1599
- def get_all_element_iterations(self) -> List[app.ElementIteration]:
1649
+ def get_all_element_iterations(self) -> list[ElementIteration]:
1600
1650
  """
1601
1651
  Get all iterations in the workflow.
1602
1652
  """
1603
1653
  return self.get_element_iterations_from_IDs(range(self.num_element_iterations))
1604
1654
 
1605
1655
  @TimeIt.decorator
1606
- def get_all_EARs(self) -> List[app.ElementActionRun]:
1656
+ def get_all_EARs(self) -> list[ElementActionRun]:
1607
1657
  """
1608
1658
  Get all runs in the workflow.
1609
1659
  """
@@ -1618,14 +1668,14 @@ class Workflow:
1618
1668
  yield
1619
1669
  else:
1620
1670
  try:
1621
- self.app.persistence_logger.info(
1671
+ self._app.persistence_logger.info(
1622
1672
  f"entering batch update (is_workflow_creation={is_workflow_creation!r})"
1623
1673
  )
1624
1674
  self._in_batch_mode = True
1625
1675
  yield
1626
1676
 
1627
- except Exception as err:
1628
- self.app.persistence_logger.error("batch update exception!")
1677
+ except Exception:
1678
+ self._app.persistence_logger.error("batch update exception!")
1629
1679
  self._in_batch_mode = False
1630
1680
  self._store._pending.reset()
1631
1681
 
@@ -1644,7 +1694,7 @@ class Workflow:
1644
1694
  self._store.delete_no_confirm()
1645
1695
  self._store.reinstate_replaced_dir()
1646
1696
 
1647
- raise err
1697
+ raise
1648
1698
 
1649
1699
  else:
1650
1700
  if self._store._pending:
@@ -1669,11 +1719,11 @@ class Workflow:
1669
1719
  if is_workflow_creation:
1670
1720
  self._store.remove_replaced_dir()
1671
1721
 
1672
- self.app.persistence_logger.info("exiting batch update")
1722
+ self._app.persistence_logger.info("exiting batch update")
1673
1723
  self._in_batch_mode = False
1674
1724
 
1675
1725
  @classmethod
1676
- def temporary_rename(cls, path: str, fs) -> List[str]:
1726
+ def temporary_rename(cls, path: str, fs: AbstractFileSystem) -> str:
1677
1727
  """Rename an existing same-path workflow (directory) so we can restore it if
1678
1728
  workflow creation fails.
1679
1729
 
@@ -1681,13 +1731,13 @@ class Workflow:
1681
1731
  paths may be created, where only the final path should be considered the
1682
1732
  successfully renamed workflow. Other paths will be deleted."""
1683
1733
 
1684
- all_replaced = []
1734
+ all_replaced: list[str] = []
1685
1735
 
1686
- @cls.app.perm_error_retry()
1687
- def _temp_rename(path: str, fs) -> str:
1736
+ @cls._app.perm_error_retry()
1737
+ def _temp_rename(path: str, fs: AbstractFileSystem) -> str:
1688
1738
  temp_ext = "".join(random.choices(string.ascii_letters, k=10))
1689
1739
  replaced = str(Path(f"{path}.{temp_ext}").as_posix())
1690
- cls.app.persistence_logger.debug(
1740
+ cls._app.persistence_logger.debug(
1691
1741
  f"temporary_rename: _temp_rename: {path!r} --> {replaced!r}."
1692
1742
  )
1693
1743
  all_replaced.append(replaced)
@@ -1698,17 +1748,19 @@ class Workflow:
1698
1748
  fs.rename(path, replaced)
1699
1749
  return replaced
1700
1750
 
1701
- @cls.app.perm_error_retry()
1702
- def _remove_path(path: str, fs) -> None:
1703
- cls.app.persistence_logger.debug(f"temporary_rename: _remove_path: {path!r}.")
1751
+ @cls._app.perm_error_retry()
1752
+ def _remove_path(path: str, fs: AbstractFileSystem) -> None:
1753
+ cls._app.persistence_logger.debug(
1754
+ f"temporary_rename: _remove_path: {path!r}."
1755
+ )
1704
1756
  while fs.exists(path):
1705
1757
  fs.rm(path, recursive=True)
1706
1758
  time.sleep(0.5)
1707
1759
 
1708
1760
  _temp_rename(path, fs)
1709
1761
 
1710
- for i in all_replaced[:-1]:
1711
- _remove_path(i, fs)
1762
+ for path in all_replaced[:-1]:
1763
+ _remove_path(path, fs)
1712
1764
 
1713
1765
  return all_replaced[-1]
1714
1766
 
@@ -1716,28 +1768,30 @@ class Workflow:
1716
1768
  @TimeIt.decorator
1717
1769
  def _write_empty_workflow(
1718
1770
  cls,
1719
- template: app.WorkflowTemplate,
1720
- path: Optional[PathLike] = None,
1721
- name: Optional[str] = None,
1722
- overwrite: Optional[bool] = False,
1723
- store: Optional[str] = DEFAULT_STORE_FORMAT,
1724
- ts_fmt: Optional[str] = None,
1725
- ts_name_fmt: Optional[str] = None,
1726
- fs_kwargs: Optional[Dict] = None,
1727
- store_kwargs: Optional[Dict] = None,
1728
- ) -> app.Workflow:
1771
+ template: WorkflowTemplate,
1772
+ *,
1773
+ path: PathLike | None = None,
1774
+ name: str | None = None,
1775
+ overwrite: bool | None = False,
1776
+ store: str = DEFAULT_STORE_FORMAT,
1777
+ ts_fmt: str | None = None,
1778
+ ts_name_fmt: str | None = None,
1779
+ fs_kwargs: dict[str, Any] | None = None,
1780
+ store_kwargs: dict[str, Any] | None = None,
1781
+ ) -> Workflow:
1729
1782
  """
1730
1783
  Parameters
1731
1784
  ----------
1785
+ template
1786
+ The workflow description to instantiate.
1732
1787
  path
1733
1788
  The directory in which the workflow will be generated. The current directory
1734
1789
  if not specified.
1735
-
1736
1790
  """
1737
- ts = datetime.now()
1738
1791
 
1739
1792
  # store all times in UTC, since NumPy doesn't support time zone info:
1740
- ts_utc = ts.astimezone(tz=timezone.utc)
1793
+ ts_utc = current_timestamp()
1794
+ ts = normalise_timestamp(ts_utc)
1741
1795
 
1742
1796
  ts_name_fmt = ts_name_fmt or cls._default_ts_name_fmt
1743
1797
  ts_fmt = ts_fmt or cls._default_ts_fmt
@@ -1751,42 +1805,70 @@ class Workflow:
1751
1805
 
1752
1806
  replaced_wk = None
1753
1807
  if fs.exists(wk_path):
1754
- cls.app.logger.debug("workflow path exists")
1808
+ cls._app.logger.debug("workflow path exists")
1755
1809
  if overwrite:
1756
- cls.app.logger.debug("renaming existing workflow path")
1810
+ cls._app.logger.debug("renaming existing workflow path")
1757
1811
  replaced_wk = cls.temporary_rename(wk_path, fs)
1758
1812
  else:
1759
1813
  raise ValueError(
1760
1814
  f"Path already exists: {wk_path} on file system " f"{fs!r}."
1761
1815
  )
1762
1816
 
1817
+ class PersistenceGrabber:
1818
+ """An object to pass to ResourceSpec.make_persistent that pretends to be a
1819
+ Workflow object, so we can pretend to make template-level inputs/resources
1820
+ persistent before the workflow exists."""
1821
+
1822
+ def __init__(self) -> None:
1823
+ self.__ps: list[tuple[Any, ParamSource]] = []
1824
+
1825
+ def _add_parameter_data(self, data: Any, source: ParamSource) -> int:
1826
+ ref = len(self.__ps)
1827
+ self.__ps.append((data, source))
1828
+ return ref
1829
+
1830
+ def get_parameter_data(self, data_idx: int) -> Any:
1831
+ return self.__ps[data_idx - 1][0]
1832
+
1833
+ def check_parameters_exist(self, id_lst: int | list[int]) -> bool:
1834
+ r = range(len(self.__ps))
1835
+ if isinstance(id_lst, int):
1836
+ return id_lst in r
1837
+ else:
1838
+ return all(id_ in r for id_ in id_lst)
1839
+
1840
+ def write_persistence_data_to_workflow(self, workflow: Workflow) -> None:
1841
+ for dat_i, source_i in self.__ps:
1842
+ workflow._add_parameter_data(dat_i, source_i)
1843
+
1763
1844
  # make template-level inputs/resources think they are persistent:
1764
- wk_dummy = _DummyPersistentWorkflow()
1765
- param_src = {"type": "workflow_resources"}
1766
- for res_i in copy.deepcopy(template.resources):
1767
- res_i.make_persistent(wk_dummy, param_src)
1768
-
1769
- template_js, template_sh = template.to_json_like(exclude=["tasks", "loops"])
1770
- template_js["tasks"] = []
1771
- template_js["loops"] = []
1772
-
1773
- creation_info = {
1774
- "app_info": cls.app.get_info(),
1775
- "create_time": ts_utc.strftime(ts_fmt),
1776
- "id": str(uuid4()),
1845
+ grabber = PersistenceGrabber()
1846
+ param_src: ParamSource = {"type": "workflow_resources"}
1847
+ for res_i_copy in template._get_resources_copy():
1848
+ res_i_copy.make_persistent(grabber, param_src)
1849
+
1850
+ template_js_, template_sh = template.to_json_like(exclude={"tasks", "loops"})
1851
+ template_js: TemplateMeta = {
1852
+ **cast("TemplateMeta", template_js_), # Trust me, bro!
1853
+ "tasks": [],
1854
+ "loops": [],
1777
1855
  }
1778
1856
 
1779
1857
  store_kwargs = store_kwargs if store_kwargs else template.store_kwargs
1780
1858
  store_cls = store_cls_from_str(store)
1781
1859
  store_cls.write_empty_workflow(
1782
- app=cls.app,
1860
+ app=cls._app,
1783
1861
  template_js=template_js,
1784
- template_components_js=template_sh,
1862
+ template_components_js=template_sh or {},
1785
1863
  wk_path=wk_path,
1786
1864
  fs=fs,
1787
1865
  name=name,
1788
1866
  replaced_wk=replaced_wk,
1789
- creation_info=creation_info,
1867
+ creation_info={
1868
+ "app_info": cls._app.get_info(),
1869
+ "create_time": ts_utc.strftime(ts_fmt),
1870
+ "id": str(uuid4()),
1871
+ },
1790
1872
  ts_fmt=ts_fmt,
1791
1873
  ts_name_fmt=ts_name_fmt,
1792
1874
  **store_kwargs,
@@ -1796,7 +1878,7 @@ class Workflow:
1796
1878
  wk = cls(fs_path, store_fmt=store, fs_kwargs=fs_kwargs)
1797
1879
 
1798
1880
  # actually make template inputs/resources persistent, now the workflow exists:
1799
- wk_dummy.make_persistent(wk)
1881
+ grabber.write_persistence_data_to_workflow(wk)
1800
1882
 
1801
1883
  if template.source_file:
1802
1884
  wk.artifacts_path.mkdir(exist_ok=False)
@@ -1807,11 +1889,12 @@ class Workflow:
1807
1889
 
1808
1890
  def zip(
1809
1891
  self,
1810
- path=".",
1811
- log=None,
1812
- overwrite=False,
1813
- include_execute=False,
1814
- include_rechunk_backups=False,
1892
+ path: str = ".",
1893
+ *,
1894
+ log: str | None = None,
1895
+ overwrite: bool = False,
1896
+ include_execute: bool = False,
1897
+ include_rechunk_backups: bool = False,
1815
1898
  ) -> str:
1816
1899
  """
1817
1900
  Convert the workflow to a zipped form.
@@ -1831,7 +1914,7 @@ class Workflow:
1831
1914
  include_rechunk_backups=include_rechunk_backups,
1832
1915
  )
1833
1916
 
1834
- def unzip(self, path=".", log=None) -> str:
1917
+ def unzip(self, path: str = ".", *, log: str | None = None) -> str:
1835
1918
  """
1836
1919
  Convert the workflow to an unzipped form.
1837
1920
 
@@ -1844,50 +1927,68 @@ class Workflow:
1844
1927
  """
1845
1928
  return self._store.unzip(path=path, log=log)
1846
1929
 
1847
- def copy(self, path=None) -> str:
1930
+ def copy(self, path: str | Path = ".") -> Path:
1848
1931
  """Copy the workflow to a new path and return the copied workflow path."""
1849
1932
  return self._store.copy(path)
1850
1933
 
1851
- def delete(self):
1934
+ def delete(self) -> None:
1852
1935
  """
1853
1936
  Delete the persistent data.
1854
1937
  """
1855
1938
  self._store.delete()
1856
1939
 
1857
- def _delete_no_confirm(self):
1940
+ def _delete_no_confirm(self) -> None:
1858
1941
  self._store.delete_no_confirm()
1859
1942
 
1860
- def get_parameters(
1861
- self, id_lst: Iterable[int], **kwargs: Dict
1862
- ) -> List[AnySParameter]:
1943
+ def get_parameters(self, id_lst: Iterable[int], **kwargs) -> Sequence[StoreParameter]:
1863
1944
  """
1864
1945
  Get parameters known to the workflow.
1946
+
1947
+ Parameter
1948
+ ---------
1949
+ id_lst:
1950
+ The indices of the parameters to retrieve.
1951
+
1952
+ Keyword Arguments
1953
+ -----------------
1954
+ dataset_copy: bool
1955
+ For Zarr stores only. If True, copy arrays as NumPy arrays.
1865
1956
  """
1866
1957
  return self._store.get_parameters(id_lst, **kwargs)
1867
1958
 
1868
1959
  @TimeIt.decorator
1869
- def get_parameter_sources(self, id_lst: Iterable[int]) -> List[Dict]:
1960
+ def get_parameter_sources(self, id_lst: Iterable[int]) -> list[ParamSource]:
1870
1961
  """
1871
1962
  Get parameter sources known to the workflow.
1872
1963
  """
1873
1964
  return self._store.get_parameter_sources(id_lst)
1874
1965
 
1875
1966
  @TimeIt.decorator
1876
- def get_parameter_set_statuses(self, id_lst: Iterable[int]) -> List[bool]:
1967
+ def get_parameter_set_statuses(self, id_lst: Iterable[int]) -> list[bool]:
1877
1968
  """
1878
1969
  Get whether some parameters are set.
1879
1970
  """
1880
1971
  return self._store.get_parameter_set_statuses(id_lst)
1881
1972
 
1882
1973
  @TimeIt.decorator
1883
- def get_parameter(self, index: int, **kwargs: Dict) -> AnySParameter:
1974
+ def get_parameter(self, index: int, **kwargs) -> StoreParameter:
1884
1975
  """
1885
1976
  Get a single parameter.
1977
+
1978
+ Parameter
1979
+ ---------
1980
+ index:
1981
+ The index of the parameter to retrieve.
1982
+
1983
+ Keyword Arguments
1984
+ -----------------
1985
+ dataset_copy: bool
1986
+ For Zarr stores only. If True, copy arrays as NumPy arrays.
1886
1987
  """
1887
- return self.get_parameters([index], **kwargs)[0]
1988
+ return self.get_parameters((index,), **kwargs)[0]
1888
1989
 
1889
1990
  @TimeIt.decorator
1890
- def get_parameter_data(self, index: int, **kwargs: Dict) -> Any:
1991
+ def get_parameter_data(self, index: int, **kwargs) -> Any:
1891
1992
  """
1892
1993
  Get the data relating to a parameter.
1893
1994
  """
@@ -1898,69 +1999,77 @@ class Workflow:
1898
1999
  return param.file
1899
2000
 
1900
2001
  @TimeIt.decorator
1901
- def get_parameter_source(self, index: int) -> Dict:
2002
+ def get_parameter_source(self, index: int) -> ParamSource:
1902
2003
  """
1903
2004
  Get the source of a particular parameter.
1904
2005
  """
1905
- return self.get_parameter_sources([index])[0]
2006
+ return self.get_parameter_sources((index,))[0]
1906
2007
 
1907
2008
  @TimeIt.decorator
1908
2009
  def is_parameter_set(self, index: int) -> bool:
1909
2010
  """
1910
2011
  Test if a particular parameter is set.
1911
2012
  """
1912
- return self.get_parameter_set_statuses([index])[0]
2013
+ return self.get_parameter_set_statuses((index,))[0]
1913
2014
 
1914
2015
  @TimeIt.decorator
1915
- def get_all_parameters(self, **kwargs: Dict) -> List[AnySParameter]:
1916
- """Retrieve all persistent parameters."""
2016
+ def get_all_parameters(self, **kwargs) -> list[StoreParameter]:
2017
+ """
2018
+ Retrieve all persistent parameters.
2019
+
2020
+ Keyword Arguments
2021
+ -----------------
2022
+ dataset_copy: bool
2023
+ For Zarr stores only. If True, copy arrays as NumPy arrays.
2024
+ """
1917
2025
  num_params = self._store._get_num_total_parameters()
1918
- id_lst = list(range(num_params))
1919
- return self._store.get_parameters(id_lst, **kwargs)
2026
+ return self._store.get_parameters(range(num_params), **kwargs)
1920
2027
 
1921
2028
  @TimeIt.decorator
1922
- def get_all_parameter_sources(self, **kwargs: Dict) -> List[Dict]:
2029
+ def get_all_parameter_sources(self, **kwargs) -> list[ParamSource]:
1923
2030
  """Retrieve all persistent parameters sources."""
1924
2031
  num_params = self._store._get_num_total_parameters()
1925
- id_lst = list(range(num_params))
1926
- return self._store.get_parameter_sources(id_lst, **kwargs)
2032
+ return self._store.get_parameter_sources(range(num_params), **kwargs)
1927
2033
 
1928
2034
  @TimeIt.decorator
1929
- def get_all_parameter_data(self, **kwargs: Dict) -> Dict[int, Any]:
1930
- """Retrieve all workflow parameter data."""
1931
- params = self.get_all_parameters(**kwargs)
1932
- return {i.id_: (i.data if i.data is not None else i.file) for i in params}
2035
+ def get_all_parameter_data(self, **kwargs) -> dict[int, Any]:
2036
+ """
2037
+ Retrieve all workflow parameter data.
1933
2038
 
1934
- def check_parameters_exist(
1935
- self, id_lst: Union[int, List[int]]
1936
- ) -> Union[bool, List[bool]]:
2039
+ Keyword Arguments
2040
+ -----------------
2041
+ dataset_copy: bool
2042
+ For Zarr stores only. If True, copy arrays as NumPy arrays.
1937
2043
  """
1938
- Check if parameters exist.
2044
+ return {
2045
+ param.id_: (param.data if param.data is not None else param.file)
2046
+ for param in self.get_all_parameters(**kwargs)
2047
+ }
2048
+
2049
+ def check_parameters_exist(self, id_lst: int | list[int]) -> bool:
2050
+ """
2051
+ Check if all the parameters exist.
1939
2052
  """
1940
- is_multi = True
1941
2053
  if isinstance(id_lst, int):
1942
- is_multi = False
1943
- id_lst = [id_lst]
1944
- exists = self._store.check_parameters_exist(id_lst)
1945
- if not is_multi:
1946
- exists = exists[0]
1947
- return exists
1948
-
1949
- def _add_unset_parameter_data(self, source: Dict) -> int:
2054
+ return next(iter(self._store.check_parameters_exist((id_lst,))))
2055
+ return all(self._store.check_parameters_exist(id_lst))
2056
+
2057
+ def _add_unset_parameter_data(self, source: ParamSource) -> int:
1950
2058
  # TODO: use this for unset files as well
1951
2059
  return self._store.add_unset_parameter(source)
1952
2060
 
1953
- def _add_parameter_data(self, data, source: Dict) -> int:
2061
+ def _add_parameter_data(self, data, source: ParamSource) -> int:
1954
2062
  return self._store.add_set_parameter(data, source)
1955
2063
 
1956
2064
  def _add_file(
1957
2065
  self,
2066
+ *,
1958
2067
  store_contents: bool,
1959
2068
  is_input: bool,
1960
- source: Dict,
2069
+ source: ParamSource,
1961
2070
  path=None,
1962
2071
  contents=None,
1963
- filename: str = None,
2072
+ filename: str,
1964
2073
  ) -> int:
1965
2074
  return self._store.add_file(
1966
2075
  store_contents=store_contents,
@@ -1973,16 +2082,16 @@ class Workflow:
1973
2082
 
1974
2083
  def _set_file(
1975
2084
  self,
1976
- param_id: int,
2085
+ param_id: int | list[int] | None,
1977
2086
  store_contents: bool,
1978
2087
  is_input: bool,
1979
- path=None,
2088
+ path: Path | str,
1980
2089
  contents=None,
1981
- filename: str = None,
2090
+ filename: str | None = None,
1982
2091
  clean_up: bool = False,
1983
- ) -> int:
2092
+ ) -> None:
1984
2093
  self._store.set_file(
1985
- param_id=param_id,
2094
+ param_id=cast("int", param_id),
1986
2095
  store_contents=store_contents,
1987
2096
  is_input=is_input,
1988
2097
  path=path,
@@ -1991,9 +2100,19 @@ class Workflow:
1991
2100
  clean_up=clean_up,
1992
2101
  )
1993
2102
 
2103
+ @overload
2104
+ def get_task_unique_names(
2105
+ self, map_to_insert_ID: Literal[False] = False
2106
+ ) -> Sequence[str]:
2107
+ ...
2108
+
2109
+ @overload
2110
+ def get_task_unique_names(self, map_to_insert_ID: Literal[True]) -> Mapping[str, int]:
2111
+ ...
2112
+
1994
2113
  def get_task_unique_names(
1995
2114
  self, map_to_insert_ID: bool = False
1996
- ) -> Union[List[str], Dict[str, int]]:
2115
+ ) -> Sequence[str] | Mapping[str, int]:
1997
2116
  """Return the unique names of all workflow tasks.
1998
2117
 
1999
2118
  Parameters
@@ -2003,21 +2122,20 @@ class Workflow:
2003
2122
  list.
2004
2123
 
2005
2124
  """
2006
- names = self.app.Task.get_task_unique_names(self.template.tasks)
2125
+ names = self._app.Task.get_task_unique_names(self.template.tasks)
2007
2126
  if map_to_insert_ID:
2008
- insert_IDs = (i.insert_ID for i in self.template.tasks)
2009
- return dict(zip(names, insert_IDs))
2127
+ return dict(zip(names, (task.insert_ID for task in self.template.tasks)))
2010
2128
  else:
2011
2129
  return names
2012
2130
 
2013
- def _get_new_task_unique_name(self, new_task: app.Task, new_index: int) -> str:
2131
+ def _get_new_task_unique_name(self, new_task: Task, new_index: int) -> str:
2014
2132
  task_templates = list(self.template.tasks)
2015
2133
  task_templates.insert(new_index, new_task)
2016
- uniq_names = self.app.Task.get_task_unique_names(task_templates)
2134
+ uniq_names = self._app.Task.get_task_unique_names(task_templates)
2017
2135
 
2018
2136
  return uniq_names[new_index]
2019
2137
 
2020
- def _get_empty_pending(self) -> Dict:
2138
+ def _get_empty_pending(self) -> Pending:
2021
2139
  return {
2022
2140
  "template_components": {k: [] for k in TEMPLATE_COMP_TYPES},
2023
2141
  "tasks": [], # list of int
@@ -2047,7 +2165,9 @@ class Workflow:
2047
2165
  for comp_type, comp_indices in self._pending["template_components"].items():
2048
2166
  for comp_idx in comp_indices[::-1]:
2049
2167
  # iterate in reverse so the index references are correct
2050
- self.template_components[comp_type]._remove_object(comp_idx)
2168
+ tc = self.__template_components[comp_type]
2169
+ assert hasattr(tc, "_remove_object")
2170
+ tc._remove_object(comp_idx)
2051
2171
 
2052
2172
  for loop_idx in self._pending["loops"][::-1]:
2053
2173
  # iterate in reverse so the index references are correct
@@ -2056,33 +2176,34 @@ class Workflow:
2056
2176
 
2057
2177
  for sub_idx in self._pending["submissions"][::-1]:
2058
2178
  # iterate in reverse so the index references are correct
2179
+ assert self._submissions is not None
2059
2180
  self._submissions.pop(sub_idx)
2060
2181
 
2061
2182
  self._reset_pending()
2062
2183
 
2063
2184
  @property
2064
- def num_tasks(self):
2185
+ def num_tasks(self) -> int:
2065
2186
  """
2066
2187
  The total number of tasks.
2067
2188
  """
2068
2189
  return self._store._get_num_total_tasks()
2069
2190
 
2070
2191
  @property
2071
- def num_submissions(self):
2192
+ def num_submissions(self) -> int:
2072
2193
  """
2073
2194
  The total number of job submissions.
2074
2195
  """
2075
2196
  return self._store._get_num_total_submissions()
2076
2197
 
2077
2198
  @property
2078
- def num_elements(self):
2199
+ def num_elements(self) -> int:
2079
2200
  """
2080
2201
  The total number of elements.
2081
2202
  """
2082
2203
  return self._store._get_num_total_elements()
2083
2204
 
2084
2205
  @property
2085
- def num_element_iterations(self):
2206
+ def num_element_iterations(self) -> int:
2086
2207
  """
2087
2208
  The total number of element iterations.
2088
2209
  """
@@ -2090,7 +2211,7 @@ class Workflow:
2090
2211
 
2091
2212
  @property
2092
2213
  @TimeIt.decorator
2093
- def num_EARs(self):
2214
+ def num_EARs(self) -> int:
2094
2215
  """
2095
2216
  The total number of element action runs.
2096
2217
  """
@@ -2104,7 +2225,7 @@ class Workflow:
2104
2225
  return self._store._get_num_total_loops()
2105
2226
 
2106
2227
  @property
2107
- def artifacts_path(self):
2228
+ def artifacts_path(self) -> Path:
2108
2229
  """
2109
2230
  Path to artifacts of the workflow (temporary files, etc).
2110
2231
  """
@@ -2112,28 +2233,28 @@ class Workflow:
2112
2233
  return Path(self.path) / "artifacts"
2113
2234
 
2114
2235
  @property
2115
- def input_files_path(self):
2236
+ def input_files_path(self) -> Path:
2116
2237
  """
2117
2238
  Path to input files for the workflow.
2118
2239
  """
2119
2240
  return self.artifacts_path / self._input_files_dir_name
2120
2241
 
2121
2242
  @property
2122
- def submissions_path(self):
2243
+ def submissions_path(self) -> Path:
2123
2244
  """
2124
2245
  Path to submission data for ths workflow.
2125
2246
  """
2126
2247
  return self.artifacts_path / "submissions"
2127
2248
 
2128
2249
  @property
2129
- def task_artifacts_path(self):
2250
+ def task_artifacts_path(self) -> Path:
2130
2251
  """
2131
2252
  Path to artifacts of tasks.
2132
2253
  """
2133
2254
  return self.artifacts_path / "tasks"
2134
2255
 
2135
2256
  @property
2136
- def execution_path(self):
2257
+ def execution_path(self) -> Path:
2137
2258
  """
2138
2259
  Path to working directory path for executing.
2139
2260
  """
@@ -2142,29 +2263,29 @@ class Workflow:
2142
2263
  @TimeIt.decorator
2143
2264
  def get_task_elements(
2144
2265
  self,
2145
- task: app.Task,
2146
- idx_lst: Optional[List[int]] = None,
2147
- ) -> List[app.Element]:
2266
+ task: WorkflowTask,
2267
+ idx_lst: list[int] | None = None,
2268
+ ) -> list[Element]:
2148
2269
  """
2149
2270
  Get the elements of a task.
2150
2271
  """
2151
2272
  return [
2152
- self.app.Element(task=task, **{k: v for k, v in i.items() if k != "task_ID"})
2153
- for i in self._store.get_task_elements(task.insert_ID, idx_lst)
2273
+ self._app.Element(
2274
+ task=task, **{k: v for k, v in te.items() if k != "task_ID"}
2275
+ )
2276
+ for te in self._store.get_task_elements(task.insert_ID, idx_lst)
2154
2277
  ]
2155
2278
 
2156
2279
  def set_EAR_submission_index(self, EAR_ID: int, sub_idx: int) -> None:
2157
2280
  """Set the submission index of an EAR."""
2158
- with self._store.cached_load():
2159
- with self.batch_update():
2160
- self._store.set_EAR_submission_index(EAR_ID, sub_idx)
2281
+ with self._store.cached_load(), self.batch_update():
2282
+ self._store.set_EAR_submission_index(EAR_ID, sub_idx)
2161
2283
 
2162
2284
  def set_EAR_start(self, EAR_ID: int) -> None:
2163
2285
  """Set the start time on an EAR."""
2164
- self.app.logger.debug(f"Setting start for EAR ID {EAR_ID!r}")
2165
- with self._store.cached_load():
2166
- with self.batch_update():
2167
- self._store.set_EAR_start(EAR_ID)
2286
+ self._app.logger.debug(f"Setting start for EAR ID {EAR_ID!r}")
2287
+ with self._store.cached_load(), self.batch_update():
2288
+ self._store.set_EAR_start(EAR_ID)
2168
2289
 
2169
2290
  def set_EAR_end(
2170
2291
  self,
@@ -2179,11 +2300,11 @@ class Workflow:
2179
2300
  skipped. Also save any generated input/output files.
2180
2301
 
2181
2302
  """
2182
- self.app.logger.debug(
2303
+ self._app.logger.debug(
2183
2304
  f"Setting end for EAR ID {EAR_ID!r} with exit code {exit_code!r}."
2184
2305
  )
2185
2306
  with self._store.cached_load():
2186
- EAR = self.get_EARs_from_IDs([EAR_ID])[0]
2307
+ EAR = self.get_EARs_from_IDs(EAR_ID)
2187
2308
  with self.batch_update():
2188
2309
  success = exit_code == 0 # TODO more sophisticated success heuristics
2189
2310
  if EAR.action.abortable and exit_code == ABORT_EXIT_CODE:
@@ -2192,17 +2313,16 @@ class Workflow:
2192
2313
 
2193
2314
  for IFG_i in EAR.action.input_file_generators:
2194
2315
  inp_file = IFG_i.input_file
2195
- self.app.logger.debug(
2316
+ self._app.logger.debug(
2196
2317
  f"Saving EAR input file: {inp_file.label!r} for EAR ID "
2197
2318
  f"{EAR_ID!r}."
2198
2319
  )
2199
2320
  param_id = EAR.data_idx[f"input_files.{inp_file.label}"]
2200
2321
 
2201
2322
  file_paths = inp_file.value()
2202
- if not isinstance(file_paths, list):
2203
- file_paths = [file_paths]
2204
-
2205
- for path_i in file_paths:
2323
+ for path_i in (
2324
+ file_paths if isinstance(file_paths, list) else [file_paths]
2325
+ ):
2206
2326
  self._set_file(
2207
2327
  param_id=param_id,
2208
2328
  store_contents=True, # TODO: make optional according to IFG
@@ -2215,25 +2335,21 @@ class Workflow:
2215
2335
 
2216
2336
  # Save action-level files: (TODO: refactor with below for OFPs)
2217
2337
  for save_file_j in EAR.action.save_files:
2218
- self.app.logger.debug(
2338
+ self._app.logger.debug(
2219
2339
  f"Saving file: {save_file_j.label!r} for EAR ID " f"{EAR_ID!r}."
2220
2340
  )
2221
- try:
2222
- param_id = EAR.data_idx[f"output_files.{save_file_j.label}"]
2223
- except KeyError:
2224
- # We might be saving a file that is not a defined
2225
- # "output file"; this will avoid saving a reference in the
2226
- # parameter data:
2227
- param_id = None
2341
+ # We might be saving a file that is not a defined
2342
+ # "output file"; this will avoid saving a reference in the
2343
+ # parameter data in that case
2344
+ param_id_j = EAR.data_idx.get(f"output_files.{save_file_j.label}")
2228
2345
 
2229
2346
  file_paths = save_file_j.value()
2230
- self.app.logger.debug(f"Saving output file paths: {file_paths!r}")
2231
- if not isinstance(file_paths, list):
2232
- file_paths = [file_paths]
2233
-
2234
- for path_i in file_paths:
2347
+ self._app.logger.debug(f"Saving output file paths: {file_paths!r}")
2348
+ for path_i in (
2349
+ file_paths if isinstance(file_paths, list) else [file_paths]
2350
+ ):
2235
2351
  self._set_file(
2236
- param_id=param_id,
2352
+ param_id=param_id_j,
2237
2353
  store_contents=True,
2238
2354
  is_input=False,
2239
2355
  path=Path(path_i).resolve(),
@@ -2241,29 +2357,25 @@ class Workflow:
2241
2357
  )
2242
2358
 
2243
2359
  for OFP_i in EAR.action.output_file_parsers:
2244
- for save_file_j in OFP_i.save_files:
2245
- self.app.logger.debug(
2360
+ for save_file_j in OFP_i._save_files:
2361
+ self._app.logger.debug(
2246
2362
  f"Saving EAR output file: {save_file_j.label!r} for EAR ID "
2247
2363
  f"{EAR_ID!r}."
2248
2364
  )
2249
- try:
2250
- param_id = EAR.data_idx[f"output_files.{save_file_j.label}"]
2251
- except KeyError:
2252
- # We might be saving a file that is not a defined
2253
- # "output file"; this will avoid saving a reference in the
2254
- # parameter data:
2255
- param_id = None
2365
+ # We might be saving a file that is not a defined
2366
+ # "output file"; this will avoid saving a reference in the
2367
+ # parameter data in that case
2368
+ param_id_j = EAR.data_idx.get(f"output_files.{save_file_j.label}")
2256
2369
 
2257
2370
  file_paths = save_file_j.value()
2258
- self.app.logger.debug(
2371
+ self._app.logger.debug(
2259
2372
  f"Saving EAR output file paths: {file_paths!r}"
2260
2373
  )
2261
- if not isinstance(file_paths, list):
2262
- file_paths = [file_paths]
2263
-
2264
- for path_i in file_paths:
2374
+ for path_i in (
2375
+ file_paths if isinstance(file_paths, list) else [file_paths]
2376
+ ):
2265
2377
  self._set_file(
2266
- param_id=param_id,
2378
+ param_id=param_id_j,
2267
2379
  store_contents=True, # TODO: make optional according to OFP
2268
2380
  is_input=False,
2269
2381
  path=Path(path_i).resolve(),
@@ -2271,9 +2383,9 @@ class Workflow:
2271
2383
  )
2272
2384
 
2273
2385
  if not success:
2274
- for EAR_dep_ID in EAR.get_dependent_EARs(as_objects=False):
2386
+ for EAR_dep_ID in EAR.get_dependent_EARs():
2275
2387
  # TODO: this needs to be recursive?
2276
- self.app.logger.debug(
2388
+ self._app.logger.debug(
2277
2389
  f"Setting EAR ID {EAR_dep_ID!r} to skip because it depends on"
2278
2390
  f" EAR ID {EAR_ID!r}, which exited with a non-zero exit code:"
2279
2391
  f" {exit_code!r}."
@@ -2287,40 +2399,37 @@ class Workflow:
2287
2399
  Record that an EAR is to be skipped due to an upstream failure or loop
2288
2400
  termination condition being met.
2289
2401
  """
2290
- with self._store.cached_load():
2291
- with self.batch_update():
2292
- self._store.set_EAR_skip(EAR_ID)
2402
+ with self._store.cached_load(), self.batch_update():
2403
+ self._store.set_EAR_skip(EAR_ID)
2293
2404
 
2294
- def get_EAR_skipped(self, EAR_ID: int) -> None:
2405
+ def get_EAR_skipped(self, EAR_ID: int) -> bool:
2295
2406
  """Check if an EAR is to be skipped."""
2296
2407
  with self._store.cached_load():
2297
2408
  return self._store.get_EAR_skipped(EAR_ID)
2298
2409
 
2299
2410
  @TimeIt.decorator
2300
2411
  def set_parameter_value(
2301
- self, param_id: int, value: Any, commit: bool = False
2412
+ self, param_id: int | list[int], value: Any, commit: bool = False
2302
2413
  ) -> None:
2303
2414
  """
2304
2415
  Set the value of a parameter.
2305
2416
  """
2306
- with self._store.cached_load():
2307
- with self.batch_update():
2308
- self._store.set_parameter_value(param_id, value)
2417
+ with self._store.cached_load(), self.batch_update():
2418
+ self._store.set_parameter_value(cast("int", param_id), value)
2309
2419
 
2310
2420
  if commit:
2311
2421
  # force commit now:
2312
2422
  self._store._pending.commit_all()
2313
2423
 
2314
- def set_EARs_initialised(self, iter_ID: int):
2424
+ def set_EARs_initialised(self, iter_ID: int) -> None:
2315
2425
  """
2316
2426
  Set :py:attr:`~hpcflow.app.ElementIteration.EARs_initialised` to True for the
2317
2427
  specified iteration.
2318
2428
  """
2319
- with self._store.cached_load():
2320
- with self.batch_update():
2321
- self._store.set_EARs_initialised(iter_ID)
2429
+ with self._store.cached_load(), self.batch_update():
2430
+ self._store.set_EARs_initialised(iter_ID)
2322
2431
 
2323
- def elements(self) -> Iterator[app.Element]:
2432
+ def elements(self) -> Iterator[Element]:
2324
2433
  """
2325
2434
  Get the elements of the workflow's tasks.
2326
2435
  """
@@ -2328,19 +2437,48 @@ class Workflow:
2328
2437
  for element in task.elements[:]:
2329
2438
  yield element
2330
2439
 
2440
+ @overload
2441
+ def get_iteration_task_pathway(
2442
+ self,
2443
+ *,
2444
+ ret_iter_IDs: Literal[False] = False,
2445
+ ret_data_idx: Literal[False] = False,
2446
+ ) -> Sequence[tuple[int, LoopIndex[str, int]]]:
2447
+ ...
2448
+
2449
+ @overload
2450
+ def get_iteration_task_pathway(
2451
+ self, *, ret_iter_IDs: Literal[False] = False, ret_data_idx: Literal[True]
2452
+ ) -> Sequence[tuple[int, LoopIndex[str, int], tuple[Mapping[str, int], ...]]]:
2453
+ ...
2454
+
2455
+ @overload
2456
+ def get_iteration_task_pathway(
2457
+ self, *, ret_iter_IDs: Literal[True], ret_data_idx: Literal[False] = False
2458
+ ) -> Sequence[tuple[int, LoopIndex[str, int], tuple[int, ...]]]:
2459
+ ...
2460
+
2461
+ @overload
2462
+ def get_iteration_task_pathway(
2463
+ self, *, ret_iter_IDs: Literal[True], ret_data_idx: Literal[True]
2464
+ ) -> Sequence[
2465
+ tuple[int, LoopIndex[str, int], tuple[int, ...], tuple[Mapping[str, int], ...]]
2466
+ ]:
2467
+ ...
2468
+
2331
2469
  @TimeIt.decorator
2332
- def get_iteration_task_pathway(self, ret_iter_IDs=False, ret_data_idx=False):
2470
+ def get_iteration_task_pathway(
2471
+ self, ret_iter_IDs: bool = False, ret_data_idx: bool = False
2472
+ ) -> Sequence[tuple]:
2333
2473
  """
2334
2474
  Get the iteration task pathway.
2335
2475
  """
2336
- # FIXME: I don't understand this concept, alas.
2337
- pathway = []
2476
+ pathway: list[_Pathway] = []
2338
2477
  for task in self.tasks:
2339
- pathway.append((task.insert_ID, {}))
2478
+ pathway.append(_Pathway(task.insert_ID))
2340
2479
 
2341
- added_loop_names = set()
2480
+ added_loop_names: set[str] = set()
2342
2481
  for _ in range(self.num_loops):
2343
- to_add = None
2344
2482
  for loop in self.loops:
2345
2483
  if loop.name in added_loop_names:
2346
2484
  continue
@@ -2348,87 +2486,81 @@ class Workflow:
2348
2486
  # add a loop only once their parents have been added:
2349
2487
  to_add = loop
2350
2488
  break
2351
-
2352
- if to_add is None:
2489
+ else:
2353
2490
  raise RuntimeError(
2354
2491
  "Failed to find a loop whose parents have already been added to the "
2355
2492
  "iteration task pathway."
2356
2493
  )
2357
2494
 
2358
2495
  iIDs = to_add.task_insert_IDs
2359
- relevant_idx = [idx for idx, i in enumerate(pathway) if i[0] in iIDs]
2496
+ relevant_idx = (
2497
+ idx for idx, path_i in enumerate(pathway) if path_i.id_ in iIDs
2498
+ )
2360
2499
 
2361
2500
  for num_add_k, num_add in to_add.num_added_iterations.items():
2362
- parent_loop_idx = {
2363
- to_add.parents[idx]: i for idx, i in enumerate(num_add_k)
2364
- }
2365
-
2366
- repl = []
2367
- repl_idx = []
2501
+ parent_loop_idx = list(zip(to_add.parents, num_add_k))
2502
+ replacement: list[_Pathway] = []
2503
+ repl_idx: list[int] = []
2368
2504
  for i in range(num_add):
2369
- for p_idx, p in enumerate(pathway):
2370
- skip = False
2371
- if p[0] not in iIDs:
2372
- continue
2373
- for k, v in parent_loop_idx.items():
2374
- if p[1][k] != v:
2375
- skip = True
2376
- break
2377
- if skip:
2505
+ for p_idx, path in enumerate(pathway):
2506
+ if path.id_ not in iIDs:
2378
2507
  continue
2379
- p = copy.deepcopy(p)
2380
- p[1].update({to_add.name: i})
2381
- repl_idx.append(p_idx)
2382
- repl.append(p)
2383
-
2384
- if repl:
2385
- repl_start, repl_stop = min(repl_idx), max(repl_idx)
2386
- pathway = replace_items(pathway, repl_start, repl_stop + 1, repl)
2508
+ if all(path.names[k] == v for k, v in parent_loop_idx):
2509
+ new_path = copy.deepcopy(path)
2510
+ new_path.names += {to_add.name: i}
2511
+ repl_idx.append(p_idx)
2512
+ replacement.append(new_path)
2513
+
2514
+ if replacement:
2515
+ pathway = replace_items(
2516
+ pathway, min(repl_idx), max(repl_idx) + 1, replacement
2517
+ )
2387
2518
 
2388
2519
  added_loop_names.add(to_add.name)
2389
2520
 
2390
- if added_loop_names != set(i.name for i in self.loops):
2521
+ if added_loop_names != set(loop.name for loop in self.loops):
2391
2522
  raise RuntimeError(
2392
2523
  "Not all loops have been considered in the iteration task pathway."
2393
2524
  )
2394
2525
 
2395
2526
  if ret_iter_IDs or ret_data_idx:
2396
2527
  all_iters = self.get_all_element_iterations()
2397
- for idx, i in enumerate(pathway):
2398
- i_iters = []
2399
- for iter_j in all_iters:
2400
- if iter_j.task.insert_ID == i[0] and iter_j.loop_idx == i[1]:
2401
- i_iters.append(iter_j)
2402
- new = list(i)
2528
+ for path_i in pathway:
2529
+ i_iters = [
2530
+ iter_j
2531
+ for iter_j in all_iters
2532
+ if (
2533
+ iter_j.task.insert_ID == path_i.id_
2534
+ and iter_j.loop_idx == path_i.names
2535
+ )
2536
+ ]
2403
2537
  if ret_iter_IDs:
2404
- new += [tuple([j.id_ for j in i_iters])]
2538
+ path_i.iter_ids.extend(elit.id_ for elit in i_iters)
2405
2539
  if ret_data_idx:
2406
- new += [tuple(j.get_data_idx() for j in i_iters)]
2407
- pathway[idx] = tuple(new)
2540
+ path_i.data_idx.extend(elit.get_data_idx() for elit in i_iters)
2408
2541
 
2409
- return pathway
2542
+ return [
2543
+ path.as_tuple(ret_iter_IDs=ret_iter_IDs, ret_data_idx=ret_data_idx)
2544
+ for path in pathway
2545
+ ]
2410
2546
 
2411
2547
  @TimeIt.decorator
2412
2548
  def _submit(
2413
2549
  self,
2414
- status: Optional[Any] = None,
2415
- ignore_errors: Optional[bool] = False,
2416
- JS_parallelism: Optional[bool] = None,
2417
- print_stdout: Optional[bool] = False,
2418
- add_to_known: Optional[bool] = True,
2419
- tasks: Optional[List[int]] = None,
2420
- ) -> Tuple[List[Exception], Dict[int, int]]:
2550
+ status: Status | None = None,
2551
+ ignore_errors: bool = False,
2552
+ JS_parallelism: bool | None = None,
2553
+ print_stdout: bool = False,
2554
+ add_to_known: bool = True,
2555
+ tasks: Sequence[int] | None = None,
2556
+ ) -> tuple[Sequence[SubmissionFailure], Mapping[int, Sequence[int]]]:
2421
2557
  """Submit outstanding EARs for execution."""
2422
2558
 
2423
2559
  # generate a new submission if there are no pending submissions:
2424
- pending = [i for i in self.submissions if i.needs_submit]
2425
- if not pending:
2560
+ if not (pending := [sub for sub in self.submissions if sub.needs_submit]):
2426
2561
  if status:
2427
2562
  status.update("Adding new submission...")
2428
- new_sub = self._add_submission(tasks=tasks, JS_parallelism=JS_parallelism)
2429
- if not new_sub:
2430
- if status:
2431
- status.stop()
2563
+ if not (new_sub := self._add_submission(tasks, JS_parallelism)):
2432
2564
  raise ValueError("No pending element action runs to submit!")
2433
2565
  pending = [new_sub]
2434
2566
 
@@ -2443,8 +2575,8 @@ class Workflow:
2443
2575
  self._store._pending.commit_all()
2444
2576
 
2445
2577
  # submit all pending submissions:
2446
- exceptions = []
2447
- submitted_js = {}
2578
+ exceptions: list[SubmissionFailure] = []
2579
+ submitted_js: dict[int, list[int]] = {}
2448
2580
  for sub in pending:
2449
2581
  try:
2450
2582
  if status:
@@ -2461,18 +2593,51 @@ class Workflow:
2461
2593
 
2462
2594
  return exceptions, submitted_js
2463
2595
 
2596
+ @overload
2597
+ def submit(
2598
+ self,
2599
+ *,
2600
+ ignore_errors: bool = False,
2601
+ JS_parallelism: bool | None = None,
2602
+ print_stdout: bool = False,
2603
+ wait: bool = False,
2604
+ add_to_known: bool = True,
2605
+ return_idx: Literal[True],
2606
+ tasks: list[int] | None = None,
2607
+ cancel: bool = False,
2608
+ status: bool = True,
2609
+ ) -> Mapping[int, Sequence[int]]:
2610
+ ...
2611
+
2612
+ @overload
2464
2613
  def submit(
2465
2614
  self,
2466
- ignore_errors: Optional[bool] = False,
2467
- JS_parallelism: Optional[bool] = None,
2468
- print_stdout: Optional[bool] = False,
2469
- wait: Optional[bool] = False,
2470
- add_to_known: Optional[bool] = True,
2471
- return_idx: Optional[bool] = False,
2472
- tasks: Optional[List[int]] = None,
2473
- cancel: Optional[bool] = False,
2474
- status: Optional[bool] = True,
2475
- ) -> Dict[int, int]:
2615
+ *,
2616
+ ignore_errors: bool = False,
2617
+ JS_parallelism: bool | None = None,
2618
+ print_stdout: bool = False,
2619
+ wait: bool = False,
2620
+ add_to_known: bool = True,
2621
+ return_idx: Literal[False] = False,
2622
+ tasks: list[int] | None = None,
2623
+ cancel: bool = False,
2624
+ status: bool = True,
2625
+ ) -> None:
2626
+ ...
2627
+
2628
+ def submit(
2629
+ self,
2630
+ *,
2631
+ ignore_errors: bool = False,
2632
+ JS_parallelism: bool | None = None,
2633
+ print_stdout: bool = False,
2634
+ wait: bool = False,
2635
+ add_to_known: bool = True,
2636
+ return_idx: bool = False,
2637
+ tasks: list[int] | None = None,
2638
+ cancel: bool = False,
2639
+ status: bool = True,
2640
+ ) -> Mapping[int, Sequence[int]] | None:
2476
2641
  """Submit the workflow for execution.
2477
2642
 
2478
2643
  Parameters
@@ -2504,41 +2669,28 @@ class Workflow:
2504
2669
  If True, display a live status to track submission progress.
2505
2670
  """
2506
2671
 
2507
- if status:
2508
- console = rich.console.Console()
2509
- status = console.status("Submitting workflow...")
2510
- status.start()
2511
-
2512
- with self._store.cached_load():
2672
+ # Type hint for mypy
2673
+ status_context: AbstractContextManager[Status] | AbstractContextManager[None] = (
2674
+ rich.console.Console().status("Submitting workflow...")
2675
+ if status
2676
+ else nullcontext()
2677
+ )
2678
+ with status_context as status_, self._store.cached_load():
2513
2679
  if not self._store.is_submittable:
2514
- if status:
2515
- status.stop()
2516
2680
  raise NotImplementedError("The workflow is not submittable.")
2517
- with self.batch_update():
2518
- # commit updates before raising exception:
2519
- try:
2520
- with self._store.cache_ctx():
2521
- exceptions, submitted_js = self._submit(
2522
- ignore_errors=ignore_errors,
2523
- JS_parallelism=JS_parallelism,
2524
- print_stdout=print_stdout,
2525
- status=status,
2526
- add_to_known=add_to_known,
2527
- tasks=tasks,
2528
- )
2529
- except Exception:
2530
- if status:
2531
- status.stop()
2532
- raise
2681
+ # commit updates before raising exception:
2682
+ with self.batch_update(), self._store.cache_ctx():
2683
+ exceptions, submitted_js = self._submit(
2684
+ ignore_errors=ignore_errors,
2685
+ JS_parallelism=JS_parallelism,
2686
+ print_stdout=print_stdout,
2687
+ status=status_,
2688
+ add_to_known=add_to_known,
2689
+ tasks=tasks,
2690
+ )
2533
2691
 
2534
2692
  if exceptions:
2535
- msg = "\n" + "\n\n".join([i.message for i in exceptions])
2536
- if status:
2537
- status.stop()
2538
- raise WorkflowSubmissionFailure(msg)
2539
-
2540
- if status:
2541
- status.stop()
2693
+ raise WorkflowSubmissionFailure(exceptions)
2542
2694
 
2543
2695
  if cancel:
2544
2696
  self.cancel()
@@ -2548,58 +2700,69 @@ class Workflow:
2548
2700
 
2549
2701
  if return_idx:
2550
2702
  return submitted_js
2703
+ return None
2551
2704
 
2552
- def wait(self, sub_js: Optional[Dict] = None):
2553
- """Wait for the completion of specified/all submitted jobscripts."""
2705
+ @staticmethod
2706
+ def __wait_for_direct_jobscripts(jobscripts: list[Jobscript]):
2707
+ """Wait for the passed direct (i.e. non-scheduled) jobscripts to finish."""
2554
2708
 
2555
- # TODO: think about how this might work with remote workflow submission (via SSH)
2709
+ def callback(proc: psutil.Process) -> None:
2710
+ js = js_pids[proc.pid]
2711
+ assert hasattr(proc, "returncode")
2712
+ # TODO sometimes proc.returncode is None; maybe because multiple wait
2713
+ # calls?
2714
+ print(
2715
+ f"Jobscript {js.index} from submission {js.submission.index} "
2716
+ f"finished with exit code {proc.returncode}."
2717
+ )
2556
2718
 
2557
- def wait_for_direct_jobscripts(jobscripts: List[app.Jobscript]):
2558
- """Wait for the passed direct (i.e. non-scheduled) jobscripts to finish."""
2719
+ js_pids = {js.process_ID: js for js in jobscripts}
2720
+ process_refs = [
2721
+ (js.process_ID, js.submit_cmdline)
2722
+ for js in jobscripts
2723
+ if js.process_ID and js.submit_cmdline
2724
+ ]
2725
+ DirectScheduler.wait_for_jobscripts(process_refs, callback=callback)
2726
+
2727
+ def __wait_for_scheduled_jobscripts(self, jobscripts: list[Jobscript]):
2728
+ """Wait for the passed scheduled jobscripts to finish."""
2729
+ schedulers = self._app.Submission.get_unique_schedulers_of_jobscripts(jobscripts)
2730
+ threads: list[Thread] = []
2731
+ for js_indices, sched in schedulers:
2732
+ jobscripts_gen = (
2733
+ self.submissions[sub_idx].jobscripts[js_idx]
2734
+ for sub_idx, js_idx in js_indices
2735
+ )
2736
+ job_IDs = [
2737
+ js.scheduler_job_ID
2738
+ for js in jobscripts_gen
2739
+ if js.scheduler_job_ID is not None
2740
+ ]
2741
+ threads.append(Thread(target=sched.wait_for_jobscripts, args=(job_IDs,)))
2559
2742
 
2560
- def callback(proc):
2561
- js = js_pids[proc.pid]
2562
- # TODO sometimes proc.returncode is None; maybe because multiple wait
2563
- # calls?
2564
- print(
2565
- f"Jobscript {js.index} from submission {js.submission.index} "
2566
- f"finished with exit code {proc.returncode}."
2567
- )
2743
+ for thr in threads:
2744
+ thr.start()
2568
2745
 
2569
- js_pids = {i.process_ID: i for i in jobscripts}
2570
- process_refs = [(i.process_ID, i.submit_cmdline) for i in jobscripts]
2571
- DirectScheduler.wait_for_jobscripts(js_refs=process_refs, callback=callback)
2572
-
2573
- def wait_for_scheduled_jobscripts(jobscripts: List[app.Jobscript]):
2574
- """Wait for the passed scheduled jobscripts to finish."""
2575
- schedulers = app.Submission.get_unique_schedulers_of_jobscripts(jobscripts)
2576
- threads = []
2577
- for js_indices, sched in schedulers.items():
2578
- jobscripts = [
2579
- self.submissions[sub_idx].jobscripts[js_idx]
2580
- for sub_idx, js_idx in js_indices
2581
- ]
2582
- job_IDs = [i.scheduler_job_ID for i in jobscripts]
2583
- threads.append(Thread(target=sched.wait_for_jobscripts, args=(job_IDs,)))
2746
+ for thr in threads:
2747
+ thr.join()
2584
2748
 
2585
- for i in threads:
2586
- i.start()
2749
+ def wait(self, sub_js: Mapping[int, Sequence[int]] | None = None):
2750
+ """Wait for the completion of specified/all submitted jobscripts."""
2587
2751
 
2588
- for i in threads:
2589
- i.join()
2752
+ # TODO: think about how this might work with remote workflow submission (via SSH)
2590
2753
 
2591
2754
  # TODO: add a log file to the submission dir where we can log stuff (e.g starting
2592
2755
  # a thread...)
2593
2756
 
2594
2757
  if not sub_js:
2595
2758
  # find any active jobscripts first:
2596
- sub_js = defaultdict(list)
2759
+ sub_js_: dict[int, list[int]] = defaultdict(list)
2597
2760
  for sub in self.submissions:
2598
- for js_idx in sub.get_active_jobscripts():
2599
- sub_js[sub.index].append(js_idx)
2761
+ sub_js_[sub.index].extend(sub.get_active_jobscripts())
2762
+ sub_js = sub_js_
2600
2763
 
2601
- js_direct = []
2602
- js_sched = []
2764
+ js_direct: list[Jobscript] = []
2765
+ js_sched: list[Jobscript] = []
2603
2766
  for sub_idx, all_js_idx in sub_js.items():
2604
2767
  for js_idx in all_js_idx:
2605
2768
  try:
@@ -2626,8 +2789,10 @@ class Workflow:
2626
2789
  return
2627
2790
 
2628
2791
  try:
2629
- t_direct = Thread(target=wait_for_direct_jobscripts, args=(js_direct,))
2630
- t_sched = Thread(target=wait_for_scheduled_jobscripts, args=(js_sched,))
2792
+ t_direct = Thread(target=self.__wait_for_direct_jobscripts, args=(js_direct,))
2793
+ t_sched = Thread(
2794
+ target=self.__wait_for_scheduled_jobscripts, args=(js_sched,)
2795
+ )
2631
2796
  t_direct.start()
2632
2797
  t_sched.start()
2633
2798
 
@@ -2646,16 +2811,16 @@ class Workflow:
2646
2811
  def get_running_elements(
2647
2812
  self,
2648
2813
  submission_idx: int = -1,
2649
- task_idx: Optional[int] = None,
2650
- task_insert_ID: Optional[int] = None,
2651
- ) -> List[app.Element]:
2814
+ task_idx: int | None = None,
2815
+ task_insert_ID: int | None = None,
2816
+ ) -> list[Element]:
2652
2817
  """Retrieve elements that are running according to the scheduler."""
2653
2818
 
2654
2819
  if task_idx is not None and task_insert_ID is not None:
2655
2820
  raise ValueError("Specify at most one of `task_insert_ID` and `task_idx`.")
2656
2821
 
2657
2822
  # keys are task_insert_IDs, values are element indices:
2658
- active_elems = defaultdict(set)
2823
+ active_elems: dict[int, set[int]] = defaultdict(set)
2659
2824
  sub = self.submissions[submission_idx]
2660
2825
  for js_idx, states in sub.get_active_jobscripts().items():
2661
2826
  js = sub.jobscripts[js_idx]
@@ -2667,14 +2832,14 @@ class Workflow:
2667
2832
  active_elems[task_iID].add(elem_idx)
2668
2833
 
2669
2834
  # retrieve Element objects:
2670
- out = []
2671
- for task_iID, elem_idx in active_elems.items():
2835
+ out: list[Element] = []
2836
+ for task_iID, elem_idxes in active_elems.items():
2672
2837
  if task_insert_ID is not None and task_iID != task_insert_ID:
2673
2838
  continue
2674
2839
  task = self.tasks.get(insert_ID=task_iID)
2675
2840
  if task_idx is not None and task_idx != task.index:
2676
2841
  continue
2677
- for idx_i in elem_idx:
2842
+ for idx_i in elem_idxes:
2678
2843
  out.append(task.elements[idx_i])
2679
2844
 
2680
2845
  return out
@@ -2682,10 +2847,10 @@ class Workflow:
2682
2847
  def get_running_runs(
2683
2848
  self,
2684
2849
  submission_idx: int = -1,
2685
- task_idx: Optional[int] = None,
2686
- task_insert_ID: Optional[int] = None,
2687
- element_idx: int = None,
2688
- ) -> List[app.ElementActionRun]:
2850
+ task_idx: int | None = None,
2851
+ task_insert_ID: int | None = None,
2852
+ element_idx: int | None = None,
2853
+ ) -> list[ElementActionRun]:
2689
2854
  """Retrieve runs that are running according to the scheduler."""
2690
2855
 
2691
2856
  elems = self.get_running_elements(
@@ -2705,7 +2870,7 @@ class Workflow:
2705
2870
  break # only one element action may be running at a time
2706
2871
  return out
2707
2872
 
2708
- def _abort_run_ID(self, submission_idx, run_ID: int):
2873
+ def _abort_run_ID(self, submission_idx: int, run_ID: int):
2709
2874
  """Modify the submission abort runs text file to signal that a run should be
2710
2875
  aborted."""
2711
2876
  self.submissions[submission_idx]._set_run_abort(run_ID)
@@ -2713,9 +2878,9 @@ class Workflow:
2713
2878
  def abort_run(
2714
2879
  self,
2715
2880
  submission_idx: int = -1,
2716
- task_idx: Optional[int] = None,
2717
- task_insert_ID: Optional[int] = None,
2718
- element_idx: int = None,
2881
+ task_idx: int | None = None,
2882
+ task_insert_ID: int | None = None,
2883
+ element_idx: int | None = None,
2719
2884
  ):
2720
2885
  """Abort the currently running action-run of the specified task/element.
2721
2886
 
@@ -2740,45 +2905,42 @@ class Workflow:
2740
2905
 
2741
2906
  elif len(running) > 1:
2742
2907
  if element_idx is None:
2743
- elem_idx = tuple(i.element.index for i in running)
2908
+ elem_idx = tuple(ear.element.index for ear in running)
2744
2909
  raise ValueError(
2745
2910
  f"Multiple elements are running (indices: {elem_idx!r}). Specify "
2746
- f"which element index you want to abort."
2911
+ "which element index you want to abort."
2747
2912
  )
2748
2913
  else:
2749
- raise RuntimeError(f"Multiple running runs.")
2914
+ raise RuntimeError("Multiple running runs.")
2750
2915
 
2751
2916
  run = running[0]
2752
2917
  if not run.action.abortable:
2753
- raise RunNotAbortableError(
2754
- "The run is not defined as abortable in the task schema, so it cannot "
2755
- "be aborted."
2756
- )
2918
+ raise RunNotAbortableError()
2757
2919
  self._abort_run_ID(submission_idx, run.id_)
2758
2920
 
2759
2921
  @TimeIt.decorator
2760
- def cancel(self, hard=False):
2922
+ def cancel(self, hard: bool = False):
2761
2923
  """Cancel any running jobscripts."""
2762
2924
  for sub in self.submissions:
2763
2925
  sub.cancel()
2764
2926
 
2765
2927
  def add_submission(
2766
- self, tasks: Optional[List[int]] = None, JS_parallelism: Optional[bool] = None
2767
- ) -> app.Submission:
2928
+ self, tasks: list[int] | None = None, JS_parallelism: bool | None = None
2929
+ ) -> Submission | None:
2768
2930
  """
2769
2931
  Add a job submission to this workflow.
2770
2932
  """
2771
- with self._store.cached_load():
2772
- with self.batch_update():
2773
- return self._add_submission(tasks, JS_parallelism)
2933
+ # JS_parallelism=None means guess
2934
+ with self._store.cached_load(), self.batch_update():
2935
+ return self._add_submission(tasks, JS_parallelism)
2774
2936
 
2775
2937
  @TimeIt.decorator
2776
2938
  def _add_submission(
2777
- self, tasks: Optional[List[int]] = None, JS_parallelism: Optional[bool] = None
2778
- ) -> app.Submission:
2939
+ self, tasks: Sequence[int] | None = None, JS_parallelism: bool | None = None
2940
+ ) -> Submission | None:
2779
2941
  new_idx = self.num_submissions
2780
2942
  _ = self.submissions # TODO: just to ensure `submissions` is loaded
2781
- sub_obj = self.app.Submission(
2943
+ sub_obj: Submission = self._app.Submission(
2782
2944
  index=new_idx,
2783
2945
  workflow=self,
2784
2946
  jobscripts=self.resolve_jobscripts(tasks),
@@ -2788,49 +2950,65 @@ class Workflow:
2788
2950
  all_EAR_ID = [i for js in sub_obj.jobscripts for i in js.EAR_ID.flatten()]
2789
2951
  if not all_EAR_ID:
2790
2952
  print(
2791
- f"There are no pending element action runs, so a new submission was not "
2792
- f"added."
2953
+ "There are no pending element action runs, so a new submission was not "
2954
+ "added."
2793
2955
  )
2794
- return
2956
+ return None
2795
2957
 
2796
- with self._store.cached_load():
2797
- with self.batch_update():
2798
- for i in all_EAR_ID:
2799
- self._store.set_EAR_submission_index(EAR_ID=i, sub_idx=new_idx)
2958
+ with self._store.cached_load(), self.batch_update():
2959
+ for id_ in all_EAR_ID:
2960
+ self._store.set_EAR_submission_index(EAR_ID=id_, sub_idx=new_idx)
2800
2961
 
2801
2962
  sub_obj_js, _ = sub_obj.to_json_like()
2963
+ assert self._submissions is not None
2802
2964
  self._submissions.append(sub_obj)
2803
2965
  self._pending["submissions"].append(new_idx)
2804
- with self._store.cached_load():
2805
- with self.batch_update():
2806
- self._store.add_submission(new_idx, sub_obj_js)
2966
+ with self._store.cached_load(), self.batch_update():
2967
+ self._store.add_submission(new_idx, sub_obj_js)
2807
2968
 
2808
2969
  return self.submissions[new_idx]
2809
2970
 
2810
2971
  @TimeIt.decorator
2811
- def resolve_jobscripts(
2812
- self, tasks: Optional[List[int]] = None
2813
- ) -> List[app.Jobscript]:
2972
+ def resolve_jobscripts(self, tasks: Sequence[int] | None = None) -> list[Jobscript]:
2814
2973
  """
2815
2974
  Resolve this workflow to a set of job scripts to run.
2816
2975
  """
2817
2976
  js, element_deps = self._resolve_singular_jobscripts(tasks)
2818
2977
  js_deps = resolve_jobscript_dependencies(js, element_deps)
2819
2978
 
2820
- for js_idx in js:
2979
+ for js_idx, jsca in js.items():
2821
2980
  if js_idx in js_deps:
2822
- js[js_idx]["dependencies"] = js_deps[js_idx]
2981
+ jsca["dependencies"] = js_deps[js_idx]
2823
2982
 
2824
2983
  js = merge_jobscripts_across_tasks(js)
2825
- js = jobscripts_to_list(js)
2826
- js_objs = [self.app.Jobscript(**i) for i in js]
2984
+ return [self._app.Jobscript(**jsca) for jsca in jobscripts_to_list(js)]
2827
2985
 
2828
- return js_objs
2986
+ def __EAR_obj_map(
2987
+ self,
2988
+ js_desc: JobScriptDescriptor,
2989
+ jsca: JobScriptCreationArguments,
2990
+ task: WorkflowTask,
2991
+ task_actions: Sequence[tuple[int, int, int]],
2992
+ EAR_map: NDArray,
2993
+ ) -> Mapping[int, ElementActionRun]:
2994
+ all_EAR_IDs: list[int] = []
2995
+ for js_elem_idx, (elem_idx, act_indices) in enumerate(
2996
+ js_desc["elements"].items()
2997
+ ):
2998
+ for act_idx in act_indices:
2999
+ EAR_ID_i: int = EAR_map[act_idx, elem_idx].item()
3000
+ all_EAR_IDs.append(EAR_ID_i)
3001
+ js_act_idx = task_actions.index((task.insert_ID, act_idx, 0))
3002
+ jsca["EAR_ID"][js_act_idx][js_elem_idx] = EAR_ID_i
3003
+ return dict(zip(all_EAR_IDs, self.get_EARs_from_IDs(all_EAR_IDs)))
2829
3004
 
2830
3005
  @TimeIt.decorator
2831
3006
  def _resolve_singular_jobscripts(
2832
- self, tasks: Optional[List[int]] = None
2833
- ) -> Tuple[Dict[int, Dict], Dict]:
3007
+ self, tasks: Sequence[int] | None = None
3008
+ ) -> tuple[
3009
+ Mapping[int, JobScriptCreationArguments],
3010
+ Mapping[int, Mapping[int, Sequence[int]]],
3011
+ ]:
2834
3012
  """
2835
3013
  We arrange EARs into `EARs` and `elements` so we can quickly look up membership
2836
3014
  by EAR idx in the `EARs` dict.
@@ -2838,44 +3016,43 @@ class Workflow:
2838
3016
  Returns
2839
3017
  -------
2840
3018
  submission_jobscripts
3019
+ Information for making each jobscript.
2841
3020
  all_element_deps
2842
3021
  For a given jobscript index, for a given jobscript element index within that
2843
3022
  jobscript, this is a list of EAR IDs dependencies of that element.
2844
-
2845
3023
  """
2846
- if not tasks:
2847
- tasks = list(range(self.num_tasks))
3024
+ task_set = frozenset(tasks if tasks else range(self.num_tasks))
2848
3025
 
2849
3026
  if self._store.use_cache:
2850
3027
  # pre-cache parameter sources (used in `EAR.get_EAR_dependencies`):
2851
3028
  self.get_all_parameter_sources()
2852
3029
 
2853
- submission_jobscripts = {}
2854
- all_element_deps = {}
3030
+ submission_jobscripts: dict[int, JobScriptCreationArguments] = {}
3031
+ all_element_deps: dict[int, dict[int, list[int]]] = {}
2855
3032
 
2856
3033
  for task_iID, loop_idx_i in self.get_iteration_task_pathway():
2857
3034
  task = self.tasks.get(insert_ID=task_iID)
2858
- if task.index not in tasks:
3035
+ if task.index not in task_set:
2859
3036
  continue
2860
3037
  res, res_hash, res_map, EAR_map = generate_EAR_resource_map(task, loop_idx_i)
2861
3038
  jobscripts, _ = group_resource_map_into_jobscripts(res_map)
2862
3039
 
2863
3040
  for js_dat in jobscripts:
2864
3041
  # (insert ID, action_idx, index into task_loop_idx):
2865
- task_actions = [
2866
- [task.insert_ID, i, 0]
2867
- for i in sorted(
2868
- set(
2869
- act_idx_i
2870
- for act_idx in js_dat["elements"].values()
2871
- for act_idx_i in act_idx
2872
- )
2873
- )
2874
- ]
3042
+ task_actions = sorted(
3043
+ set(
3044
+ (task.insert_ID, act_idx_i, 0)
3045
+ for act_idx in js_dat["elements"].values()
3046
+ for act_idx_i in act_idx
3047
+ ),
3048
+ key=lambda x: x[1],
3049
+ )
3050
+ # Invert the mapping
3051
+ task_actions_inv = {k: idx for idx, k in enumerate(task_actions)}
2875
3052
  # task_elements: { JS_ELEM_IDX: [TASK_ELEM_IDX for each task insert ID]}
2876
3053
  task_elements = {
2877
3054
  js_elem_idx: [task_elem_idx]
2878
- for js_elem_idx, task_elem_idx in enumerate(js_dat["elements"].keys())
3055
+ for js_elem_idx, task_elem_idx in enumerate(js_dat["elements"])
2879
3056
  }
2880
3057
  EAR_idx_arr_shape = (
2881
3058
  len(task_actions),
@@ -2886,7 +3063,7 @@ class Workflow:
2886
3063
 
2887
3064
  new_js_idx = len(submission_jobscripts)
2888
3065
 
2889
- js_i = {
3066
+ js_i: JobScriptCreationArguments = {
2890
3067
  "task_insert_IDs": [task.insert_ID],
2891
3068
  "task_loop_idx": [loop_idx_i],
2892
3069
  "task_actions": task_actions, # map jobscript actions to task actions
@@ -2897,45 +3074,82 @@ class Workflow:
2897
3074
  "dependencies": {},
2898
3075
  }
2899
3076
 
2900
- all_EAR_IDs = []
2901
- for js_elem_idx, (elem_idx, act_indices) in enumerate(
2902
- js_dat["elements"].items()
2903
- ):
2904
- for act_idx in act_indices:
2905
- EAR_ID_i = EAR_map[act_idx, elem_idx].item()
2906
- all_EAR_IDs.append(EAR_ID_i)
2907
- js_act_idx = task_actions.index([task.insert_ID, act_idx, 0])
2908
- js_i["EAR_ID"][js_act_idx][js_elem_idx] = EAR_ID_i
2909
-
2910
- all_EAR_objs = dict(zip(all_EAR_IDs, self.get_EARs_from_IDs(all_EAR_IDs)))
3077
+ all_EAR_objs = self.__EAR_obj_map(
3078
+ js_dat, js_i, task, task_actions, EAR_map
3079
+ )
2911
3080
 
2912
3081
  for js_elem_idx, (elem_idx, act_indices) in enumerate(
2913
3082
  js_dat["elements"].items()
2914
3083
  ):
2915
- all_EAR_IDs = []
3084
+ all_EAR_IDs: list[int] = []
2916
3085
  for act_idx in act_indices:
2917
- EAR_ID_i = EAR_map[act_idx, elem_idx].item()
3086
+ EAR_ID_i: int = EAR_map[act_idx, elem_idx].item()
2918
3087
  all_EAR_IDs.append(EAR_ID_i)
2919
- js_act_idx = task_actions.index([task.insert_ID, act_idx, 0])
2920
- js_i["EAR_ID"][js_act_idx][js_elem_idx] = EAR_ID_i
3088
+ js_act_idx = task_actions_inv[task.insert_ID, act_idx, 0]
3089
+ EAR_ID_arr[js_act_idx][js_elem_idx] = EAR_ID_i
2921
3090
 
2922
3091
  # get indices of EARs that this element depends on:
2923
- EAR_objs = [all_EAR_objs[k] for k in all_EAR_IDs]
2924
- EAR_deps = [i.get_EAR_dependencies() for i in EAR_objs]
2925
- EAR_deps_flat = [j for i in EAR_deps for j in i]
2926
3092
  EAR_deps_EAR_idx = [
2927
- i for i in EAR_deps_flat if i not in js_i["EAR_ID"]
3093
+ dep_ear_id
3094
+ for main_ear_id in all_EAR_IDs
3095
+ for dep_ear_id in all_EAR_objs[main_ear_id].get_EAR_dependencies()
3096
+ if dep_ear_id not in EAR_ID_arr
2928
3097
  ]
2929
3098
  if EAR_deps_EAR_idx:
2930
- if new_js_idx not in all_element_deps:
2931
- all_element_deps[new_js_idx] = {}
2932
-
2933
- all_element_deps[new_js_idx][js_elem_idx] = EAR_deps_EAR_idx
3099
+ all_element_deps.setdefault(new_js_idx, {})[
3100
+ js_elem_idx
3101
+ ] = EAR_deps_EAR_idx
2934
3102
 
2935
3103
  submission_jobscripts[new_js_idx] = js_i
2936
3104
 
2937
3105
  return submission_jobscripts, all_element_deps
2938
3106
 
3107
+ def __get_commands(
3108
+ self, jobscript: Jobscript, JS_action_idx: int, ear: ElementActionRun
3109
+ ):
3110
+ try:
3111
+ commands, shell_vars = ear.compose_commands(jobscript, JS_action_idx)
3112
+ except OutputFileParserNoOutputError:
3113
+ # no commands to write but still need to write the file,
3114
+ # the jobscript is expecting it.
3115
+ return ""
3116
+
3117
+ self._app.persistence_logger.debug("need to write commands")
3118
+ pieces = [commands]
3119
+ for cmd_idx, var_dat in shell_vars.items():
3120
+ for param_name, shell_var_name, st_typ in var_dat:
3121
+ pieces.append(
3122
+ jobscript.shell.format_save_parameter(
3123
+ workflow_app_alias=jobscript.workflow_app_alias,
3124
+ param_name=param_name,
3125
+ shell_var_name=shell_var_name,
3126
+ EAR_ID=ear.id_,
3127
+ cmd_idx=cmd_idx,
3128
+ stderr=(st_typ == "stderr"),
3129
+ )
3130
+ )
3131
+ commands = jobscript.shell.wrap_in_subshell("".join(pieces), ear.action.abortable)
3132
+
3133
+ # add loop-check command if this is the last action of this loop iteration
3134
+ # for this element:
3135
+ if self.loops:
3136
+ final_runs = (
3137
+ # TODO: excessive reads here
3138
+ self.get_iteration_final_run_IDs(id_lst=jobscript.all_EAR_IDs)
3139
+ )
3140
+ self._app.persistence_logger.debug(f"final_runs: {final_runs!r}")
3141
+ pieces = []
3142
+ for loop_name, run_IDs in final_runs.items():
3143
+ if ear.id_ in run_IDs:
3144
+ loop_cmd = jobscript.shell.format_loop_check(
3145
+ workflow_app_alias=jobscript.workflow_app_alias,
3146
+ loop_name=loop_name,
3147
+ run_ID=ear.id_,
3148
+ )
3149
+ pieces.append(jobscript.shell.wrap_in_subshell(loop_cmd, False))
3150
+ commands += "".join(pieces)
3151
+ return commands
3152
+
2939
3153
  def write_commands(
2940
3154
  self,
2941
3155
  submission_idx: int,
@@ -2945,60 +3159,17 @@ class Workflow:
2945
3159
  ) -> None:
2946
3160
  """Write run-time commands for a given EAR."""
2947
3161
  with self._store.cached_load():
2948
- self.app.persistence_logger.debug("Workflow.write_commands")
2949
- self.app.persistence_logger.debug(
3162
+ self._app.persistence_logger.debug("Workflow.write_commands")
3163
+ self._app.persistence_logger.debug(
2950
3164
  f"loading jobscript (submission index: {submission_idx}; jobscript "
2951
3165
  f"index: {jobscript_idx})"
2952
3166
  )
2953
3167
  jobscript = self.submissions[submission_idx].jobscripts[jobscript_idx]
2954
- self.app.persistence_logger.debug(f"loading run {EAR_ID!r}")
2955
- EAR = self.get_EARs_from_IDs([EAR_ID])[0]
2956
- self.app.persistence_logger.debug(f"run {EAR_ID!r} loaded: {EAR!r}")
2957
- write_commands = True
2958
- try:
2959
- commands, shell_vars = EAR.compose_commands(jobscript, JS_action_idx)
2960
- except OutputFileParserNoOutputError:
2961
- # no commands to write
2962
- write_commands = False
2963
-
2964
- if write_commands:
2965
- self.app.persistence_logger.debug("need to write commands")
2966
- for cmd_idx, var_dat in shell_vars.items():
2967
- for param_name, shell_var_name, st_typ in var_dat:
2968
- commands += jobscript.shell.format_save_parameter(
2969
- workflow_app_alias=jobscript.workflow_app_alias,
2970
- param_name=param_name,
2971
- shell_var_name=shell_var_name,
2972
- EAR_ID=EAR_ID,
2973
- cmd_idx=cmd_idx,
2974
- stderr=(st_typ == "stderr"),
2975
- )
2976
- commands = jobscript.shell.wrap_in_subshell(
2977
- commands, EAR.action.abortable
2978
- )
2979
-
2980
- # add loop-check command if this is the last action of this loop iteration
2981
- # for this element:
2982
- if self.loops:
2983
- final_runs = (
2984
- self.get_iteration_final_run_IDs( # TODO: excessive reads here
2985
- id_lst=jobscript.all_EAR_IDs
2986
- )
2987
- )
2988
- self.app.persistence_logger.debug(f"final_runs: {final_runs!r}")
2989
- for loop_name, run_IDs in final_runs.items():
2990
- if EAR.id_ in run_IDs:
2991
- loop_cmd = jobscript.shell.format_loop_check(
2992
- workflow_app_alias=jobscript.workflow_app_alias,
2993
- loop_name=loop_name,
2994
- run_ID=EAR.id_,
2995
- )
2996
- commands += jobscript.shell.wrap_in_subshell(loop_cmd, False)
2997
- else:
2998
- # still need to write the file, the jobscript is expecting it.
2999
- commands = ""
3000
-
3001
- self.app.persistence_logger.debug(f"commands to write: {commands!r}")
3168
+ self._app.persistence_logger.debug(f"loading run {EAR_ID!r}")
3169
+ EAR = self.get_EARs_from_IDs(EAR_ID)
3170
+ self._app.persistence_logger.debug(f"run {EAR_ID!r} loaded: {EAR!r}")
3171
+ commands = self.__get_commands(jobscript, JS_action_idx, EAR)
3172
+ self._app.persistence_logger.debug(f"commands to write: {commands!r}")
3002
3173
  cmd_file_name = jobscript.get_commands_file_name(JS_action_idx)
3003
3174
  with Path(cmd_file_name).open("wt", newline="\n") as fp:
3004
3175
  # (assuming we have CD'd correctly to the element run directory)
@@ -3009,11 +3180,10 @@ class Workflow:
3009
3180
  ) -> Any:
3010
3181
  """Process the shell stdout/stderr stream according to the associated Command
3011
3182
  object."""
3012
- with self._store.cached_load():
3013
- with self.batch_update():
3014
- EAR = self.get_EARs_from_IDs([EAR_ID])[0]
3015
- command = EAR.action.commands[cmd_idx]
3016
- return command.process_std_stream(name, value, stderr)
3183
+ with self._store.cached_load(), self.batch_update():
3184
+ EAR = self.get_EARs_from_IDs(EAR_ID)
3185
+ command = EAR.action.commands[cmd_idx]
3186
+ return command.process_std_stream(name, value, stderr)
3017
3187
 
3018
3188
  def save_parameter(
3019
3189
  self,
@@ -3024,15 +3194,14 @@ class Workflow:
3024
3194
  """
3025
3195
  Save a parameter where an EAR can find it.
3026
3196
  """
3027
- self.app.logger.info(f"save parameter {name!r} for EAR_ID {EAR_ID}.")
3028
- self.app.logger.debug(f"save parameter {name!r} value is {value!r}.")
3029
- with self._store.cached_load():
3030
- with self.batch_update():
3031
- EAR = self.get_EARs_from_IDs([EAR_ID])[0]
3032
- param_id = EAR.data_idx[name]
3033
- self.set_parameter_value(param_id, value)
3197
+ self._app.logger.info(f"save parameter {name!r} for EAR_ID {EAR_ID}.")
3198
+ self._app.logger.debug(f"save parameter {name!r} value is {value!r}.")
3199
+ with self._store.cached_load(), self.batch_update():
3200
+ EAR = self.get_EARs_from_IDs(EAR_ID)
3201
+ param_id = EAR.data_idx[name]
3202
+ self.set_parameter_value(param_id, value)
3034
3203
 
3035
- def show_all_EAR_statuses(self):
3204
+ def show_all_EAR_statuses(self) -> None:
3036
3205
  """
3037
3206
  Print a description of the status of every element action run in
3038
3207
  the workflow.
@@ -3061,7 +3230,7 @@ class Workflow:
3061
3230
  )
3062
3231
 
3063
3232
  def _resolve_input_source_task_reference(
3064
- self, input_source: app.InputSource, new_task_name: str
3233
+ self, input_source: InputSource, new_task_name: str
3065
3234
  ) -> None:
3066
3235
  """Normalise the input source task reference and convert a source to a local type
3067
3236
  if required."""
@@ -3070,107 +3239,102 @@ class Workflow:
3070
3239
 
3071
3240
  if isinstance(input_source.task_ref, str):
3072
3241
  if input_source.task_ref == new_task_name:
3073
- if input_source.task_source_type is self.app.TaskSourceType.OUTPUT:
3074
- raise InvalidInputSourceTaskReference(
3075
- f"Input source {input_source.to_string()!r} cannot refer to the "
3076
- f"outputs of its own task!"
3077
- )
3078
- else:
3079
- warn(
3080
- f"Changing input source {input_source.to_string()!r} to a local "
3081
- f"type, since the input source task reference refers to its own "
3082
- f"task."
3083
- )
3084
- # TODO: add an InputSource source_type setter to reset
3085
- # task_ref/source_type?
3086
- input_source.source_type = self.app.InputSourceType.LOCAL
3087
- input_source.task_ref = None
3088
- input_source.task_source_type = None
3242
+ if input_source.task_source_type is self._app.TaskSourceType.OUTPUT:
3243
+ raise InvalidInputSourceTaskReference(input_source)
3244
+ warn(
3245
+ f"Changing input source {input_source.to_string()!r} to a local "
3246
+ f"type, since the input source task reference refers to its own "
3247
+ f"task."
3248
+ )
3249
+ # TODO: add an InputSource source_type setter to reset
3250
+ # task_ref/source_type?
3251
+ input_source.source_type = self._app.InputSourceType.LOCAL
3252
+ input_source.task_ref = None
3253
+ input_source.task_source_type = None
3089
3254
  else:
3090
3255
  try:
3091
3256
  uniq_names_cur = self.get_task_unique_names(map_to_insert_ID=True)
3092
3257
  input_source.task_ref = uniq_names_cur[input_source.task_ref]
3093
3258
  except KeyError:
3094
3259
  raise InvalidInputSourceTaskReference(
3095
- f"Input source {input_source.to_string()!r} refers to a missing "
3096
- f"or inaccessible task: {input_source.task_ref!r}."
3260
+ input_source, input_source.task_ref
3097
3261
  )
3098
3262
 
3099
- def get_all_submission_run_IDs(self) -> List[int]:
3263
+ def get_all_submission_run_IDs(self) -> Iterable[int]:
3100
3264
  """
3101
3265
  Get the run IDs of all submissions.
3102
3266
  """
3103
- self.app.persistence_logger.debug("Workflow.get_all_submission_run_IDs")
3104
- id_lst = []
3267
+ self._app.persistence_logger.debug("Workflow.get_all_submission_run_IDs")
3105
3268
  for sub in self.submissions:
3106
- id_lst.extend(list(sub.all_EAR_IDs))
3107
- return id_lst
3269
+ yield from sub.all_EAR_IDs
3108
3270
 
3109
- def check_loop_termination(self, loop_name: str, run_ID: int) -> bool:
3271
+ def check_loop_termination(self, loop_name: str, run_ID: int) -> None:
3110
3272
  """Check if a loop should terminate, given the specified completed run, and if so,
3111
3273
  set downstream iteration runs to be skipped."""
3112
3274
  loop = self.loops.get(loop_name)
3113
- elem_iter = self.get_EARs_from_IDs([run_ID])[0].element_iteration
3275
+ elem_iter = self.get_EARs_from_IDs(run_ID).element_iteration
3114
3276
  if loop.test_termination(elem_iter):
3115
- to_skip = [] # run IDs of downstream iterations that can be skipped
3277
+ # run IDs of downstream iterations that can be skipped
3278
+ to_skip: set[int] = set()
3116
3279
  elem_id = elem_iter.element.id_
3117
3280
  loop_map = self.get_loop_map() # over all jobscripts
3118
3281
  for iter_idx, iter_dat in loop_map[loop_name][elem_id].items():
3119
3282
  if iter_idx > elem_iter.index:
3120
- to_skip.extend([i[0] for i in iter_dat])
3121
- self.app.logger.info(
3283
+ to_skip.update(itr_d.id_ for itr_d in iter_dat)
3284
+ self._app.logger.info(
3122
3285
  f"Loop {loop_name!r} termination condition met for run_ID {run_ID!r}."
3123
3286
  )
3124
3287
  for run_ID in to_skip:
3125
3288
  self.set_EAR_skip(run_ID)
3126
3289
 
3127
- def get_loop_map(self, id_lst: Optional[List[int]] = None):
3290
+ def get_loop_map(
3291
+ self, id_lst: Iterable[int] | None = None
3292
+ ) -> Mapping[str, Mapping[int, Mapping[int, Sequence[_IterationData]]]]:
3128
3293
  """
3129
3294
  Get a description of what is going on with looping.
3130
3295
  """
3131
3296
  # TODO: test this works across multiple jobscripts
3132
- self.app.persistence_logger.debug("Workflow.get_loop_map")
3297
+ self._app.persistence_logger.debug("Workflow.get_loop_map")
3133
3298
  if id_lst is None:
3134
3299
  id_lst = self.get_all_submission_run_IDs()
3135
- loop_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
3136
- runs = self.get_EARs_from_IDs(id_lst)
3137
- for i in runs:
3138
- for loop_name, iter_idx in i.element_iteration.loop_idx.items():
3139
- act_idx = i.element_action.action_idx
3140
- loop_map[loop_name][i.element.id_][iter_idx].append((i.id_, act_idx))
3300
+ loop_map: dict[str, dict[int, dict[int, list[_IterationData]]]] = defaultdict(
3301
+ lambda: defaultdict(lambda: defaultdict(list))
3302
+ )
3303
+ for EAR in self.get_EARs_from_IDs(id_lst):
3304
+ for loop_name, iter_idx in EAR.element_iteration.loop_idx.items():
3305
+ act_idx = EAR.element_action.action_idx
3306
+ loop_map[loop_name][EAR.element.id_][iter_idx].append(
3307
+ _IterationData(EAR.id_, act_idx)
3308
+ )
3141
3309
  return loop_map
3142
3310
 
3143
3311
  def get_iteration_final_run_IDs(
3144
3312
  self,
3145
- loop_map: Optional[Dict] = None,
3146
- id_lst: Optional[List[int]] = None,
3147
- ) -> Dict[str, List[int]]:
3313
+ id_lst: Iterable[int] | None = None,
3314
+ ) -> Mapping[str, Sequence[int]]:
3148
3315
  """Retrieve the run IDs of those runs that correspond to the final action within
3149
3316
  a named loop iteration.
3150
3317
 
3151
3318
  These runs represent the final action of a given element-iteration; this is used to
3152
3319
  identify which commands file to append a loop-termination check to.
3153
-
3154
3320
  """
3155
- self.app.persistence_logger.debug("Workflow.get_iteration_final_run_IDs")
3321
+ self._app.persistence_logger.debug("Workflow.get_iteration_final_run_IDs")
3156
3322
 
3157
- loop_map = loop_map or self.get_loop_map(id_lst)
3323
+ loop_map = self.get_loop_map(id_lst)
3158
3324
 
3159
3325
  # find final EARs for each loop:
3160
- final_runs = defaultdict(list)
3326
+ final_runs: dict[str, list[int]] = defaultdict(list)
3161
3327
  for loop_name, dat in loop_map.items():
3162
- for _, elem_dat in dat.items():
3163
- for _, iter_dat in elem_dat.items():
3164
- # sort by largest action index first, so we get save only the final EAR
3165
- final = sorted(iter_dat, key=lambda x: x[1], reverse=True)[0]
3166
- final_runs[loop_name].append(final[0])
3167
- return dict(final_runs)
3328
+ for elem_dat in dat.values():
3329
+ for iter_dat in elem_dat.values():
3330
+ final_runs[loop_name].append(max(iter_dat, key=lambda x: x.idx).id_)
3331
+ return final_runs
3168
3332
 
3169
3333
  def rechunk_runs(
3170
3334
  self,
3171
- chunk_size: Optional[int] = None,
3172
- backup: Optional[bool] = True,
3173
- status: Optional[bool] = True,
3335
+ chunk_size: int | None = None,
3336
+ backup: bool = True,
3337
+ status: bool = True,
3174
3338
  ):
3175
3339
  """
3176
3340
  Reorganise the stored data chunks for EARs to be more efficient.
@@ -3179,9 +3343,9 @@ class Workflow:
3179
3343
 
3180
3344
  def rechunk_parameter_base(
3181
3345
  self,
3182
- chunk_size: Optional[int] = None,
3183
- backup: Optional[bool] = True,
3184
- status: Optional[bool] = True,
3346
+ chunk_size: int | None = None,
3347
+ backup: bool = True,
3348
+ status: bool = True,
3185
3349
  ):
3186
3350
  """
3187
3351
  Reorganise the stored data chunks for parameterss to be more efficient.
@@ -3192,9 +3356,9 @@ class Workflow:
3192
3356
 
3193
3357
  def rechunk(
3194
3358
  self,
3195
- chunk_size: Optional[int] = None,
3196
- backup: Optional[bool] = True,
3197
- status: Optional[bool] = True,
3359
+ chunk_size: int | None = None,
3360
+ backup: bool = True,
3361
+ status: bool = True,
3198
3362
  ):
3199
3363
  """
3200
3364
  Rechunk metadata/runs and parameters/base arrays, making them more efficient.