fractal-server 2.13.1__py3-none-any.whl → 2.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. fractal_server/__init__.py +1 -1
  2. fractal_server/__main__.py +3 -1
  3. fractal_server/app/models/linkusergroup.py +6 -2
  4. fractal_server/app/models/v2/__init__.py +7 -1
  5. fractal_server/app/models/v2/dataset.py +1 -11
  6. fractal_server/app/models/v2/history.py +78 -0
  7. fractal_server/app/models/v2/job.py +10 -3
  8. fractal_server/app/models/v2/task_group.py +2 -2
  9. fractal_server/app/models/v2/workflow.py +1 -1
  10. fractal_server/app/models/v2/workflowtask.py +1 -1
  11. fractal_server/app/routes/admin/v2/accounting.py +18 -28
  12. fractal_server/app/routes/admin/v2/task.py +1 -1
  13. fractal_server/app/routes/admin/v2/task_group.py +0 -17
  14. fractal_server/app/routes/api/__init__.py +1 -1
  15. fractal_server/app/routes/api/v2/__init__.py +8 -2
  16. fractal_server/app/routes/api/v2/_aux_functions.py +66 -0
  17. fractal_server/app/routes/api/v2/_aux_functions_history.py +166 -0
  18. fractal_server/app/routes/api/v2/dataset.py +0 -17
  19. fractal_server/app/routes/api/v2/history.py +544 -0
  20. fractal_server/app/routes/api/v2/images.py +31 -43
  21. fractal_server/app/routes/api/v2/job.py +30 -0
  22. fractal_server/app/routes/api/v2/project.py +1 -53
  23. fractal_server/app/routes/api/v2/{status.py → status_legacy.py} +6 -6
  24. fractal_server/app/routes/api/v2/submit.py +16 -14
  25. fractal_server/app/routes/api/v2/task.py +3 -10
  26. fractal_server/app/routes/api/v2/task_collection_custom.py +4 -9
  27. fractal_server/app/routes/api/v2/task_group.py +0 -17
  28. fractal_server/app/routes/api/v2/verify_image_types.py +61 -0
  29. fractal_server/app/routes/api/v2/workflow.py +28 -69
  30. fractal_server/app/routes/api/v2/workflowtask.py +53 -50
  31. fractal_server/app/routes/auth/group.py +0 -16
  32. fractal_server/app/routes/auth/oauth.py +5 -3
  33. fractal_server/app/routes/pagination.py +47 -0
  34. fractal_server/app/runner/components.py +0 -3
  35. fractal_server/app/runner/compress_folder.py +57 -29
  36. fractal_server/app/runner/exceptions.py +4 -0
  37. fractal_server/app/runner/executors/base_runner.py +157 -0
  38. fractal_server/app/runner/{v2/_local/_local_config.py → executors/local/get_local_config.py} +7 -9
  39. fractal_server/app/runner/executors/local/runner.py +248 -0
  40. fractal_server/app/runner/executors/{slurm → slurm_common}/_batching.py +1 -1
  41. fractal_server/app/runner/executors/{slurm → slurm_common}/_slurm_config.py +9 -7
  42. fractal_server/app/runner/executors/slurm_common/base_slurm_runner.py +868 -0
  43. fractal_server/app/runner/{v2/_slurm_common → executors/slurm_common}/get_slurm_config.py +48 -17
  44. fractal_server/app/runner/executors/{slurm → slurm_common}/remote.py +36 -47
  45. fractal_server/app/runner/executors/slurm_common/slurm_job_task_models.py +134 -0
  46. fractal_server/app/runner/executors/slurm_ssh/runner.py +268 -0
  47. fractal_server/app/runner/executors/slurm_sudo/__init__.py +0 -0
  48. fractal_server/app/runner/executors/{slurm/sudo → slurm_sudo}/_subprocess_run_as_user.py +2 -83
  49. fractal_server/app/runner/executors/slurm_sudo/runner.py +193 -0
  50. fractal_server/app/runner/extract_archive.py +1 -3
  51. fractal_server/app/runner/task_files.py +134 -87
  52. fractal_server/app/runner/v2/__init__.py +0 -399
  53. fractal_server/app/runner/v2/_local.py +88 -0
  54. fractal_server/app/runner/v2/{_slurm_ssh/__init__.py → _slurm_ssh.py} +20 -19
  55. fractal_server/app/runner/v2/{_slurm_sudo/__init__.py → _slurm_sudo.py} +17 -15
  56. fractal_server/app/runner/v2/db_tools.py +119 -0
  57. fractal_server/app/runner/v2/runner.py +206 -95
  58. fractal_server/app/runner/v2/runner_functions.py +488 -187
  59. fractal_server/app/runner/v2/runner_functions_low_level.py +40 -43
  60. fractal_server/app/runner/v2/submit_workflow.py +358 -0
  61. fractal_server/app/runner/v2/task_interface.py +31 -0
  62. fractal_server/app/schemas/_validators.py +13 -24
  63. fractal_server/app/schemas/user.py +10 -7
  64. fractal_server/app/schemas/user_settings.py +9 -21
  65. fractal_server/app/schemas/v2/__init__.py +9 -1
  66. fractal_server/app/schemas/v2/dataset.py +12 -94
  67. fractal_server/app/schemas/v2/dumps.py +26 -9
  68. fractal_server/app/schemas/v2/history.py +80 -0
  69. fractal_server/app/schemas/v2/job.py +15 -8
  70. fractal_server/app/schemas/v2/manifest.py +14 -7
  71. fractal_server/app/schemas/v2/project.py +9 -7
  72. fractal_server/app/schemas/v2/status_legacy.py +35 -0
  73. fractal_server/app/schemas/v2/task.py +72 -77
  74. fractal_server/app/schemas/v2/task_collection.py +14 -32
  75. fractal_server/app/schemas/v2/task_group.py +10 -9
  76. fractal_server/app/schemas/v2/workflow.py +10 -11
  77. fractal_server/app/schemas/v2/workflowtask.py +2 -21
  78. fractal_server/app/security/__init__.py +3 -3
  79. fractal_server/app/security/signup_email.py +2 -2
  80. fractal_server/config.py +41 -46
  81. fractal_server/images/tools.py +23 -0
  82. fractal_server/migrations/versions/47351f8c7ebc_drop_dataset_filters.py +50 -0
  83. fractal_server/migrations/versions/9db60297b8b2_set_ondelete.py +250 -0
  84. fractal_server/migrations/versions/c90a7c76e996_job_id_in_history_run.py +41 -0
  85. fractal_server/migrations/versions/e81103413827_add_job_type_filters.py +36 -0
  86. fractal_server/migrations/versions/f37aceb45062_make_historyunit_logfile_required.py +39 -0
  87. fractal_server/migrations/versions/fbce16ff4e47_new_history_items.py +120 -0
  88. fractal_server/ssh/_fabric.py +28 -14
  89. fractal_server/tasks/v2/local/collect.py +2 -2
  90. fractal_server/tasks/v2/ssh/collect.py +2 -2
  91. fractal_server/tasks/v2/templates/2_pip_install.sh +1 -1
  92. fractal_server/tasks/v2/templates/4_pip_show.sh +1 -1
  93. fractal_server/tasks/v2/utils_background.py +0 -19
  94. fractal_server/tasks/v2/utils_database.py +30 -17
  95. fractal_server/tasks/v2/utils_templates.py +6 -0
  96. {fractal_server-2.13.1.dist-info → fractal_server-2.14.0.dist-info}/METADATA +4 -4
  97. {fractal_server-2.13.1.dist-info → fractal_server-2.14.0.dist-info}/RECORD +106 -96
  98. {fractal_server-2.13.1.dist-info → fractal_server-2.14.0.dist-info}/WHEEL +1 -1
  99. fractal_server/app/runner/executors/slurm/ssh/_executor_wait_thread.py +0 -126
  100. fractal_server/app/runner/executors/slurm/ssh/_slurm_job.py +0 -116
  101. fractal_server/app/runner/executors/slurm/ssh/executor.py +0 -1386
  102. fractal_server/app/runner/executors/slurm/sudo/_check_jobs_status.py +0 -71
  103. fractal_server/app/runner/executors/slurm/sudo/_executor_wait_thread.py +0 -130
  104. fractal_server/app/runner/executors/slurm/sudo/executor.py +0 -1281
  105. fractal_server/app/runner/v2/_local/__init__.py +0 -132
  106. fractal_server/app/runner/v2/_local/_submit_setup.py +0 -52
  107. fractal_server/app/runner/v2/_local/executor.py +0 -100
  108. fractal_server/app/runner/v2/_slurm_ssh/_submit_setup.py +0 -83
  109. fractal_server/app/runner/v2/_slurm_sudo/_submit_setup.py +0 -83
  110. fractal_server/app/runner/v2/handle_failed_job.py +0 -59
  111. fractal_server/app/schemas/v2/status.py +0 -16
  112. /fractal_server/app/{runner/executors/slurm → history}/__init__.py +0 -0
  113. /fractal_server/app/runner/executors/{slurm/ssh → local}/__init__.py +0 -0
  114. /fractal_server/app/runner/executors/{slurm/sudo → slurm_common}/__init__.py +0 -0
  115. /fractal_server/app/runner/executors/{_job_states.py → slurm_common/_job_states.py} +0 -0
  116. /fractal_server/app/runner/executors/{slurm → slurm_common}/utils_executors.py +0 -0
  117. /fractal_server/app/runner/{v2/_slurm_common → executors/slurm_ssh}/__init__.py +0 -0
  118. {fractal_server-2.13.1.dist-info → fractal_server-2.14.0.dist-info}/LICENSE +0 -0
  119. {fractal_server-2.13.1.dist-info → fractal_server-2.14.0.dist-info}/entry_points.txt +0 -0
@@ -1,100 +1,112 @@
1
1
  import functools
2
- import logging
3
- import traceback
4
- from concurrent.futures import Executor
5
2
  from pathlib import Path
6
3
  from typing import Any
7
4
  from typing import Callable
8
5
  from typing import Literal
9
6
  from typing import Optional
10
7
 
11
- from pydantic import ValidationError
8
+ from pydantic import BaseModel
9
+ from pydantic import ConfigDict
12
10
 
13
11
  from ..exceptions import JobExecutionError
12
+ from ..exceptions import TaskOutputValidationError
13
+ from .db_tools import update_status_of_history_unit
14
14
  from .deduplicate_list import deduplicate_list
15
- from .merge_outputs import merge_outputs
16
15
  from .runner_functions_low_level import run_single_task
17
16
  from .task_interface import InitTaskOutput
18
17
  from .task_interface import TaskOutput
18
+ from fractal_server.app.db import get_sync_db
19
+ from fractal_server.app.models.v2 import HistoryUnit
19
20
  from fractal_server.app.models.v2 import TaskV2
20
21
  from fractal_server.app.models.v2 import WorkflowTaskV2
21
- from fractal_server.app.runner.components import _COMPONENT_KEY_
22
- from fractal_server.app.runner.components import _index_to_component
23
-
22
+ from fractal_server.app.runner.executors.base_runner import BaseRunner
23
+ from fractal_server.app.runner.task_files import enrich_task_files_multisubmit
24
+ from fractal_server.app.runner.task_files import SUBMIT_PREFIX
25
+ from fractal_server.app.runner.task_files import TaskFiles
26
+ from fractal_server.app.runner.v2.db_tools import (
27
+ bulk_update_status_of_history_unit,
28
+ )
29
+ from fractal_server.app.runner.v2.db_tools import bulk_upsert_image_cache_fast
30
+ from fractal_server.app.runner.v2.task_interface import (
31
+ _cast_and_validate_InitTaskOutput,
32
+ )
33
+ from fractal_server.app.runner.v2.task_interface import (
34
+ _cast_and_validate_TaskOutput,
35
+ )
36
+ from fractal_server.app.schemas.v2 import HistoryUnitStatus
37
+ from fractal_server.logger import set_logger
24
38
 
25
39
  __all__ = [
26
- "run_v2_task_non_parallel",
27
40
  "run_v2_task_parallel",
41
+ "run_v2_task_non_parallel",
28
42
  "run_v2_task_compound",
29
43
  ]
30
44
 
31
- MAX_PARALLELIZATION_LIST_SIZE = 20_000
32
45
 
46
+ logger = set_logger(__name__)
33
47
 
34
- def _cast_and_validate_TaskOutput(
35
- task_output: dict[str, Any]
36
- ) -> Optional[TaskOutput]:
37
- try:
38
- validated_task_output = TaskOutput(**task_output)
39
- return validated_task_output
40
- except ValidationError as e:
41
- raise JobExecutionError(
42
- "Validation of task output failed.\n"
43
- f"Original error: {str(e)}\n"
44
- f"Original data: {task_output}."
45
- )
46
48
 
49
+ class SubmissionOutcome(BaseModel):
50
+ model_config = ConfigDict(arbitrary_types_allowed=True)
51
+ task_output: TaskOutput | None = None
52
+ exception: BaseException | None = None
53
+ invalid_output: bool = False
54
+
55
+
56
+ class InitSubmissionOutcome(BaseModel):
57
+ model_config = ConfigDict(arbitrary_types_allowed=True)
58
+ task_output: InitTaskOutput | None = None
59
+ exception: BaseException | None = None
47
60
 
48
- def _cast_and_validate_InitTaskOutput(
49
- init_task_output: dict[str, Any],
50
- ) -> Optional[InitTaskOutput]:
51
- try:
52
- validated_init_task_output = InitTaskOutput(**init_task_output)
53
- return validated_init_task_output
54
- except ValidationError as e:
55
- raise JobExecutionError(
56
- "Validation of init-task output failed.\n"
57
- f"Original error: {str(e)}\n"
58
- f"Original data: {init_task_output}."
59
- )
61
+
62
+ MAX_PARALLELIZATION_LIST_SIZE = 20_000
60
63
 
61
64
 
62
- def no_op_submit_setup_call(
65
+ def _process_task_output(
63
66
  *,
64
- wftask: WorkflowTaskV2,
65
- workflow_dir_local: Path,
66
- workflow_dir_remote: Path,
67
- which_type: Literal["non_parallel", "parallel"],
68
- ) -> dict:
69
- """
70
- Default (no-operation) interface of submit_setup_call in V2.
71
- """
72
- return {}
67
+ result: dict[str, Any] | None = None,
68
+ exception: BaseException | None = None,
69
+ ) -> SubmissionOutcome:
70
+ invalid_output = False
71
+ if exception is not None:
72
+ task_output = None
73
+ else:
74
+ if result is None:
75
+ task_output = TaskOutput()
76
+ else:
77
+ try:
78
+ task_output = _cast_and_validate_TaskOutput(result)
79
+ except TaskOutputValidationError as e:
80
+ task_output = None
81
+ exception = e
82
+ invalid_output = True
83
+ return SubmissionOutcome(
84
+ task_output=task_output,
85
+ exception=exception,
86
+ invalid_output=invalid_output,
87
+ )
73
88
 
74
89
 
75
- # Backend-specific configuration
76
- def _get_executor_options(
90
+ def _process_init_task_output(
77
91
  *,
78
- wftask: WorkflowTaskV2,
79
- workflow_dir_local: Path,
80
- workflow_dir_remote: Path,
81
- submit_setup_call: Callable,
82
- which_type: Literal["non_parallel", "parallel"],
83
- ) -> dict:
84
- try:
85
- options = submit_setup_call(
86
- wftask=wftask,
87
- workflow_dir_local=workflow_dir_local,
88
- workflow_dir_remote=workflow_dir_remote,
89
- which_type=which_type,
90
- )
91
- except Exception as e:
92
- tb = "".join(traceback.format_tb(e.__traceback__))
93
- raise RuntimeError(
94
- f"{type(e)} error in {submit_setup_call=}\n"
95
- f"Original traceback:\n{tb}"
96
- )
97
- return options
92
+ result: dict[str, Any] | None = None,
93
+ exception: BaseException | None = None,
94
+ ) -> SubmissionOutcome:
95
+ if exception is not None:
96
+ task_output = None
97
+ else:
98
+ if result is None:
99
+ task_output = InitTaskOutput()
100
+ else:
101
+ try:
102
+ task_output = _cast_and_validate_InitTaskOutput(result)
103
+ except TaskOutputValidationError as e:
104
+ task_output = None
105
+ exception = e
106
+ return InitSubmissionOutcome(
107
+ task_output=task_output,
108
+ exception=exception,
109
+ )
98
110
 
99
111
 
100
112
  def _check_parallelization_list_size(my_list):
@@ -113,51 +125,123 @@ def run_v2_task_non_parallel(
113
125
  task: TaskV2,
114
126
  wftask: WorkflowTaskV2,
115
127
  workflow_dir_local: Path,
116
- workflow_dir_remote: Optional[Path] = None,
117
- executor: Executor,
118
- submit_setup_call: Callable = no_op_submit_setup_call,
119
- ) -> tuple[TaskOutput, int]:
128
+ workflow_dir_remote: Path,
129
+ runner: BaseRunner,
130
+ get_runner_config: Callable[
131
+ [
132
+ WorkflowTaskV2,
133
+ Literal["non_parallel", "parallel"],
134
+ Optional[Path],
135
+ int,
136
+ ],
137
+ Any,
138
+ ],
139
+ dataset_id: int,
140
+ history_run_id: int,
141
+ task_type: Literal["non_parallel", "converter_non_parallel"],
142
+ user_id: int,
143
+ ) -> tuple[dict[int, SubmissionOutcome], int]:
120
144
  """
121
145
  This runs server-side (see `executor` argument)
122
146
  """
123
147
 
124
- if workflow_dir_remote is None:
125
- workflow_dir_remote = workflow_dir_local
126
- logging.warning(
127
- "In `run_single_task`, workflow_dir_remote=None. Is this right?"
148
+ if task_type not in ["non_parallel", "converter_non_parallel"]:
149
+ raise ValueError(
150
+ f"Invalid {task_type=} for `run_v2_task_non_parallel`."
128
151
  )
129
- workflow_dir_remote = workflow_dir_local
130
152
 
131
- executor_options = _get_executor_options(
153
+ # Get TaskFiles object
154
+ task_files = TaskFiles(
155
+ root_dir_local=workflow_dir_local,
156
+ root_dir_remote=workflow_dir_remote,
157
+ task_order=wftask.order,
158
+ task_name=wftask.task.name,
159
+ component="",
160
+ prefix=SUBMIT_PREFIX,
161
+ )
162
+
163
+ runner_config = get_runner_config(
132
164
  wftask=wftask,
133
- workflow_dir_local=workflow_dir_local,
134
- workflow_dir_remote=workflow_dir_remote,
135
- submit_setup_call=submit_setup_call,
136
165
  which_type="non_parallel",
137
166
  )
138
167
 
139
- function_kwargs = dict(
140
- zarr_urls=[image["zarr_url"] for image in images],
141
- zarr_dir=zarr_dir,
168
+ function_kwargs = {
169
+ "zarr_dir": zarr_dir,
142
170
  **(wftask.args_non_parallel or {}),
143
- )
144
- future = executor.submit(
171
+ }
172
+ if task_type == "non_parallel":
173
+ function_kwargs["zarr_urls"] = [img["zarr_url"] for img in images]
174
+
175
+ # Database History operations
176
+ with next(get_sync_db()) as db:
177
+ if task_type == "non_parallel":
178
+ zarr_urls = function_kwargs["zarr_urls"]
179
+ elif task_type == "converter_non_parallel":
180
+ zarr_urls = []
181
+
182
+ history_unit = HistoryUnit(
183
+ history_run_id=history_run_id,
184
+ status=HistoryUnitStatus.SUBMITTED,
185
+ logfile=task_files.log_file_local,
186
+ zarr_urls=zarr_urls,
187
+ )
188
+ db.add(history_unit)
189
+ db.commit()
190
+ db.refresh(history_unit)
191
+ logger.debug(
192
+ "[run_v2_task_non_parallel] Created `HistoryUnit` with "
193
+ f"{history_run_id=}."
194
+ )
195
+ history_unit_id = history_unit.id
196
+ bulk_upsert_image_cache_fast(
197
+ db=db,
198
+ list_upsert_objects=[
199
+ dict(
200
+ workflowtask_id=wftask.id,
201
+ dataset_id=dataset_id,
202
+ zarr_url=zarr_url,
203
+ latest_history_unit_id=history_unit_id,
204
+ )
205
+ for zarr_url in history_unit.zarr_urls
206
+ ],
207
+ )
208
+
209
+ result, exception = runner.submit(
145
210
  functools.partial(
146
211
  run_single_task,
147
- wftask=wftask,
148
212
  command=task.command_non_parallel,
149
- workflow_dir_local=workflow_dir_local,
150
- workflow_dir_remote=workflow_dir_remote,
213
+ workflow_task_order=wftask.order,
214
+ workflow_task_id=wftask.task_id,
215
+ task_name=wftask.task.name,
151
216
  ),
152
- function_kwargs,
153
- **executor_options,
217
+ parameters=function_kwargs,
218
+ task_type=task_type,
219
+ task_files=task_files,
220
+ history_unit_id=history_unit_id,
221
+ config=runner_config,
222
+ user_id=user_id,
154
223
  )
155
- output = future.result()
224
+
225
+ positional_index = 0
156
226
  num_tasks = 1
157
- if output is None:
158
- return (TaskOutput(), num_tasks)
159
- else:
160
- return (_cast_and_validate_TaskOutput(output), num_tasks)
227
+
228
+ outcome = {
229
+ positional_index: _process_task_output(
230
+ result=result,
231
+ exception=exception,
232
+ )
233
+ }
234
+ # NOTE: Here we don't have to handle the
235
+ # `outcome[0].exception is not None` branch, since for non_parallel
236
+ # tasks it was already handled within submit
237
+ if outcome[0].invalid_output:
238
+ with next(get_sync_db()) as db:
239
+ update_status_of_history_unit(
240
+ history_unit_id=history_unit_id,
241
+ status=HistoryUnitStatus.FAILED,
242
+ db_sync=db,
243
+ )
244
+ return outcome, num_tasks
161
245
 
162
246
 
163
247
  def run_v2_task_parallel(
@@ -165,59 +249,132 @@ def run_v2_task_parallel(
165
249
  images: list[dict[str, Any]],
166
250
  task: TaskV2,
167
251
  wftask: WorkflowTaskV2,
168
- executor: Executor,
252
+ runner: BaseRunner,
169
253
  workflow_dir_local: Path,
170
- workflow_dir_remote: Optional[Path] = None,
171
- submit_setup_call: Callable = no_op_submit_setup_call,
172
- ) -> tuple[TaskOutput, int]:
173
-
254
+ workflow_dir_remote: Path,
255
+ get_runner_config: Callable[
256
+ [
257
+ WorkflowTaskV2,
258
+ Literal["non_parallel", "parallel"],
259
+ Optional[Path],
260
+ int,
261
+ ],
262
+ Any,
263
+ ],
264
+ dataset_id: int,
265
+ history_run_id: int,
266
+ user_id: int,
267
+ ) -> tuple[dict[int, SubmissionOutcome], int]:
174
268
  if len(images) == 0:
175
- return (TaskOutput(), 0)
269
+ return {}, 0
176
270
 
177
271
  _check_parallelization_list_size(images)
178
272
 
179
- executor_options = _get_executor_options(
273
+ # Get TaskFiles object
274
+ task_files = TaskFiles(
275
+ root_dir_local=workflow_dir_local,
276
+ root_dir_remote=workflow_dir_remote,
277
+ task_order=wftask.order,
278
+ task_name=wftask.task.name,
279
+ )
280
+
281
+ runner_config = get_runner_config(
180
282
  wftask=wftask,
181
- workflow_dir_local=workflow_dir_local,
182
- workflow_dir_remote=workflow_dir_remote,
183
- submit_setup_call=submit_setup_call,
184
283
  which_type="parallel",
284
+ tot_tasks=len(images),
185
285
  )
186
286
 
187
- list_function_kwargs = []
188
- for ind, image in enumerate(images):
189
- list_function_kwargs.append(
287
+ list_function_kwargs = [
288
+ {
289
+ "zarr_url": image["zarr_url"],
290
+ **(wftask.args_parallel or {}),
291
+ }
292
+ for image in images
293
+ ]
294
+
295
+ list_task_files = enrich_task_files_multisubmit(
296
+ base_task_files=task_files,
297
+ tot_tasks=len(images),
298
+ batch_size=runner_config.batch_size,
299
+ )
300
+
301
+ history_units = [
302
+ HistoryUnit(
303
+ history_run_id=history_run_id,
304
+ status=HistoryUnitStatus.SUBMITTED,
305
+ logfile=list_task_files[ind].log_file_local,
306
+ zarr_urls=[image["zarr_url"]],
307
+ )
308
+ for ind, image in enumerate(images)
309
+ ]
310
+
311
+ with next(get_sync_db()) as db:
312
+ db.add_all(history_units)
313
+ db.commit()
314
+ logger.debug(
315
+ f"[run_v2_task_non_parallel] Created {len(history_units)} "
316
+ "`HistoryUnit`s."
317
+ )
318
+
319
+ for history_unit in history_units:
320
+ db.refresh(history_unit)
321
+ history_unit_ids = [history_unit.id for history_unit in history_units]
322
+
323
+ history_image_caches = [
190
324
  dict(
191
- zarr_url=image["zarr_url"],
192
- **(wftask.args_parallel or {}),
193
- ),
325
+ workflowtask_id=wftask.id,
326
+ dataset_id=dataset_id,
327
+ zarr_url=history_unit.zarr_urls[0],
328
+ latest_history_unit_id=history_unit.id,
329
+ )
330
+ for history_unit in history_units
331
+ ]
332
+
333
+ bulk_upsert_image_cache_fast(
334
+ db=db, list_upsert_objects=history_image_caches
194
335
  )
195
- list_function_kwargs[-1][_COMPONENT_KEY_] = _index_to_component(ind)
196
336
 
197
- results_iterator = executor.map(
337
+ results, exceptions = runner.multisubmit(
198
338
  functools.partial(
199
339
  run_single_task,
200
- wftask=wftask,
201
340
  command=task.command_parallel,
202
- workflow_dir_local=workflow_dir_local,
203
- workflow_dir_remote=workflow_dir_remote,
341
+ workflow_task_order=wftask.order,
342
+ workflow_task_id=wftask.task_id,
343
+ task_name=wftask.task.name,
204
344
  ),
205
- list_function_kwargs,
206
- **executor_options,
345
+ list_parameters=list_function_kwargs,
346
+ task_type="parallel",
347
+ list_task_files=list_task_files,
348
+ history_unit_ids=history_unit_ids,
349
+ config=runner_config,
350
+ user_id=user_id,
207
351
  )
208
- # Explicitly iterate over the whole list, so that all futures are waited
209
- outputs = list(results_iterator)
210
-
211
- # Validate all non-None outputs
212
- for ind, output in enumerate(outputs):
213
- if output is None:
214
- outputs[ind] = TaskOutput()
215
- else:
216
- outputs[ind] = _cast_and_validate_TaskOutput(output)
217
352
 
353
+ outcome = {}
354
+ for ind in range(len(list_function_kwargs)):
355
+ if ind not in results.keys() and ind not in exceptions.keys():
356
+ error_msg = (
357
+ f"Invalid branch: {ind=} is not in `results.keys()` "
358
+ "nor in `exceptions.keys()`."
359
+ )
360
+ logger.error(error_msg)
361
+ raise RuntimeError(error_msg)
362
+ outcome[ind] = _process_task_output(
363
+ result=results.get(ind, None),
364
+ exception=exceptions.get(ind, None),
365
+ )
366
+ # NOTE: Here we don't have to handle the
367
+ # `outcome[ind].exception is not None` branch, since for parallel
368
+ # tasks it was already handled within multisubmit
369
+ if outcome[ind].invalid_output:
370
+ with next(get_sync_db()) as db:
371
+ update_status_of_history_unit(
372
+ history_unit_id=history_unit_ids[ind],
373
+ status=HistoryUnitStatus.FAILED,
374
+ db_sync=db,
375
+ )
218
376
  num_tasks = len(images)
219
- merged_output = merge_outputs(outputs)
220
- return (merged_output, num_tasks)
377
+ return outcome, num_tasks
221
378
 
222
379
 
223
380
  def run_v2_task_compound(
@@ -226,92 +383,236 @@ def run_v2_task_compound(
226
383
  zarr_dir: str,
227
384
  task: TaskV2,
228
385
  wftask: WorkflowTaskV2,
229
- executor: Executor,
386
+ runner: BaseRunner,
230
387
  workflow_dir_local: Path,
231
- workflow_dir_remote: Optional[Path] = None,
232
- submit_setup_call: Callable = no_op_submit_setup_call,
233
- ) -> TaskOutput:
388
+ workflow_dir_remote: Path,
389
+ get_runner_config: Callable[
390
+ [
391
+ WorkflowTaskV2,
392
+ Literal["non_parallel", "parallel"],
393
+ Optional[Path],
394
+ int,
395
+ ],
396
+ Any,
397
+ ],
398
+ dataset_id: int,
399
+ history_run_id: int,
400
+ task_type: Literal["compound", "converter_compound"],
401
+ user_id: int,
402
+ ) -> tuple[dict[int, SubmissionOutcome], int]:
403
+ # Get TaskFiles object
404
+ task_files_init = TaskFiles(
405
+ root_dir_local=workflow_dir_local,
406
+ root_dir_remote=workflow_dir_remote,
407
+ task_order=wftask.order,
408
+ task_name=wftask.task.name,
409
+ component="",
410
+ prefix=SUBMIT_PREFIX,
411
+ )
234
412
 
235
- executor_options_init = _get_executor_options(
413
+ runner_config_init = get_runner_config(
236
414
  wftask=wftask,
237
- workflow_dir_local=workflow_dir_local,
238
- workflow_dir_remote=workflow_dir_remote,
239
- submit_setup_call=submit_setup_call,
240
415
  which_type="non_parallel",
241
416
  )
242
- executor_options_compute = _get_executor_options(
243
- wftask=wftask,
244
- workflow_dir_local=workflow_dir_local,
245
- workflow_dir_remote=workflow_dir_remote,
246
- submit_setup_call=submit_setup_call,
247
- which_type="parallel",
248
- )
249
-
250
417
  # 3/A: non-parallel init task
251
- function_kwargs = dict(
252
- zarr_urls=[image["zarr_url"] for image in images],
253
- zarr_dir=zarr_dir,
418
+ function_kwargs = {
419
+ "zarr_dir": zarr_dir,
254
420
  **(wftask.args_non_parallel or {}),
255
- )
256
- future = executor.submit(
421
+ }
422
+ if task_type == "compound":
423
+ function_kwargs["zarr_urls"] = [img["zarr_url"] for img in images]
424
+ input_image_zarr_urls = function_kwargs["zarr_urls"]
425
+ elif task_type == "converter_compound":
426
+ input_image_zarr_urls = []
427
+
428
+ # Create database History entries
429
+ with next(get_sync_db()) as db:
430
+ # Create a single `HistoryUnit` for the whole compound task
431
+ history_unit = HistoryUnit(
432
+ history_run_id=history_run_id,
433
+ status=HistoryUnitStatus.SUBMITTED,
434
+ logfile=task_files_init.log_file_local,
435
+ zarr_urls=input_image_zarr_urls,
436
+ )
437
+ db.add(history_unit)
438
+ db.commit()
439
+ db.refresh(history_unit)
440
+ init_history_unit_id = history_unit.id
441
+ logger.debug(
442
+ "[run_v2_task_compound] Created `HistoryUnit` with "
443
+ f"{init_history_unit_id=}."
444
+ )
445
+ # Create one `HistoryImageCache` for each input image
446
+ bulk_upsert_image_cache_fast(
447
+ db=db,
448
+ list_upsert_objects=[
449
+ dict(
450
+ workflowtask_id=wftask.id,
451
+ dataset_id=dataset_id,
452
+ zarr_url=zarr_url,
453
+ latest_history_unit_id=init_history_unit_id,
454
+ )
455
+ for zarr_url in input_image_zarr_urls
456
+ ],
457
+ )
458
+ result, exception = runner.submit(
257
459
  functools.partial(
258
460
  run_single_task,
259
- wftask=wftask,
260
461
  command=task.command_non_parallel,
261
- workflow_dir_local=workflow_dir_local,
262
- workflow_dir_remote=workflow_dir_remote,
462
+ workflow_task_order=wftask.order,
463
+ workflow_task_id=wftask.task_id,
464
+ task_name=wftask.task.name,
263
465
  ),
264
- function_kwargs,
265
- **executor_options_init,
466
+ parameters=function_kwargs,
467
+ task_type=task_type,
468
+ task_files=task_files_init,
469
+ history_unit_id=init_history_unit_id,
470
+ config=runner_config_init,
471
+ user_id=user_id,
266
472
  )
267
- output = future.result()
268
- if output is None:
269
- init_task_output = InitTaskOutput()
270
- else:
271
- init_task_output = _cast_and_validate_InitTaskOutput(output)
272
- parallelization_list = init_task_output.parallelization_list
473
+
474
+ init_outcome = _process_init_task_output(
475
+ result=result,
476
+ exception=exception,
477
+ )
478
+ num_tasks = 1
479
+ if init_outcome.exception is not None:
480
+ positional_index = 0
481
+ return (
482
+ {
483
+ positional_index: SubmissionOutcome(
484
+ exception=init_outcome.exception
485
+ )
486
+ },
487
+ num_tasks,
488
+ )
489
+
490
+ parallelization_list = init_outcome.task_output.parallelization_list
273
491
  parallelization_list = deduplicate_list(parallelization_list)
274
492
 
275
- num_task = 1 + len(parallelization_list)
493
+ num_tasks = 1 + len(parallelization_list)
276
494
 
277
495
  # 3/B: parallel part of a compound task
278
496
  _check_parallelization_list_size(parallelization_list)
279
497
 
280
498
  if len(parallelization_list) == 0:
281
- return (TaskOutput(), 0)
499
+ with next(get_sync_db()) as db:
500
+ update_status_of_history_unit(
501
+ history_unit_id=init_history_unit_id,
502
+ status=HistoryUnitStatus.DONE,
503
+ db_sync=db,
504
+ )
505
+ positional_index = 0
506
+ init_outcome = {
507
+ positional_index: _process_task_output(
508
+ result=None,
509
+ exception=None,
510
+ )
511
+ }
512
+ return init_outcome, num_tasks
513
+
514
+ runner_config_compute = get_runner_config(
515
+ wftask=wftask,
516
+ which_type="parallel",
517
+ tot_tasks=len(parallelization_list),
518
+ )
282
519
 
283
- list_function_kwargs = []
284
- for ind, parallelization_item in enumerate(parallelization_list):
285
- list_function_kwargs.append(
286
- dict(
287
- zarr_url=parallelization_item.zarr_url,
288
- init_args=parallelization_item.init_args,
289
- **(wftask.args_parallel or {}),
290
- ),
520
+ list_task_files = enrich_task_files_multisubmit(
521
+ base_task_files=TaskFiles(
522
+ root_dir_local=workflow_dir_local,
523
+ root_dir_remote=workflow_dir_remote,
524
+ task_order=wftask.order,
525
+ task_name=wftask.task.name,
526
+ ),
527
+ tot_tasks=len(parallelization_list),
528
+ batch_size=runner_config_compute.batch_size,
529
+ )
530
+
531
+ list_function_kwargs = [
532
+ {
533
+ "zarr_url": parallelization_item.zarr_url,
534
+ "init_args": parallelization_item.init_args,
535
+ **(wftask.args_parallel or {}),
536
+ }
537
+ for parallelization_item in parallelization_list
538
+ ]
539
+
540
+ # Create one `HistoryUnit` per parallelization item
541
+ history_units = [
542
+ HistoryUnit(
543
+ history_run_id=history_run_id,
544
+ status=HistoryUnitStatus.SUBMITTED,
545
+ logfile=list_task_files[ind].log_file_local,
546
+ zarr_urls=[parallelization_item.zarr_url],
291
547
  )
292
- list_function_kwargs[-1][_COMPONENT_KEY_] = _index_to_component(ind)
548
+ for ind, parallelization_item in enumerate(parallelization_list)
549
+ ]
550
+ with next(get_sync_db()) as db:
551
+ db.add_all(history_units)
552
+ db.commit()
553
+ for history_unit in history_units:
554
+ db.refresh(history_unit)
555
+ logger.debug(
556
+ f"[run_v2_task_compound] Created {len(history_units)} "
557
+ "`HistoryUnit`s."
558
+ )
559
+ history_unit_ids = [history_unit.id for history_unit in history_units]
293
560
 
294
- results_iterator = executor.map(
561
+ results, exceptions = runner.multisubmit(
295
562
  functools.partial(
296
563
  run_single_task,
297
- wftask=wftask,
298
564
  command=task.command_parallel,
299
- workflow_dir_local=workflow_dir_local,
300
- workflow_dir_remote=workflow_dir_remote,
565
+ workflow_task_order=wftask.order,
566
+ workflow_task_id=wftask.task_id,
567
+ task_name=wftask.task.name,
301
568
  ),
302
- list_function_kwargs,
303
- **executor_options_compute,
569
+ list_parameters=list_function_kwargs,
570
+ task_type=task_type,
571
+ list_task_files=list_task_files,
572
+ history_unit_ids=history_unit_ids,
573
+ config=runner_config_compute,
574
+ user_id=user_id,
304
575
  )
305
- # Explicitly iterate over the whole list, so that all futures are waited
306
- outputs = list(results_iterator)
307
576
 
308
- # Validate all non-None outputs
309
- for ind, output in enumerate(outputs):
310
- if output is None:
311
- outputs[ind] = TaskOutput()
577
+ compute_outcomes: dict[int, SubmissionOutcome] = {}
578
+ failure = False
579
+ for ind in range(len(list_function_kwargs)):
580
+ if ind not in results.keys() and ind not in exceptions.keys():
581
+ # NOTE: see issue 2484
582
+ error_msg = (
583
+ f"Invalid branch: {ind=} is not in `results.keys()` "
584
+ "nor in `exceptions.keys()`."
585
+ )
586
+ logger.error(error_msg)
587
+ raise RuntimeError(error_msg)
588
+ compute_outcomes[ind] = _process_task_output(
589
+ result=results.get(ind, None),
590
+ exception=exceptions.get(ind, None),
591
+ )
592
+ # NOTE: For compound task, `multisubmit` did not handle the
593
+ # `exception is not None` branch, therefore we have to include it here.
594
+ if (
595
+ compute_outcomes[ind].exception is not None
596
+ or compute_outcomes[ind].invalid_output
597
+ ):
598
+ failure = True
599
+
600
+ # NOTE: For compound tasks, we update `HistoryUnit.status` from here,
601
+ # rather than within the submit/multisubmit runner methods. This is
602
+ # to enforce the fact that either all units succeed or they all fail -
603
+ # at a difference with the parallel-task case.
604
+ with next(get_sync_db()) as db:
605
+ if failure:
606
+ bulk_update_status_of_history_unit(
607
+ history_unit_ids=history_unit_ids + [init_history_unit_id],
608
+ status=HistoryUnitStatus.FAILED,
609
+ db_sync=db,
610
+ )
312
611
  else:
313
- validated_output = _cast_and_validate_TaskOutput(output)
314
- outputs[ind] = validated_output
612
+ bulk_update_status_of_history_unit(
613
+ history_unit_ids=history_unit_ids + [init_history_unit_id],
614
+ status=HistoryUnitStatus.DONE,
615
+ db_sync=db,
616
+ )
315
617
 
316
- merged_output = merge_outputs(outputs)
317
- return (merged_output, num_task)
618
+ return compute_outcomes, num_tasks