fractal-server 2.13.0__py3-none-any.whl → 2.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. fractal_server/__init__.py +1 -1
  2. fractal_server/__main__.py +3 -1
  3. fractal_server/app/models/linkusergroup.py +6 -2
  4. fractal_server/app/models/v2/__init__.py +11 -1
  5. fractal_server/app/models/v2/accounting.py +35 -0
  6. fractal_server/app/models/v2/dataset.py +1 -11
  7. fractal_server/app/models/v2/history.py +78 -0
  8. fractal_server/app/models/v2/job.py +10 -3
  9. fractal_server/app/models/v2/task_group.py +2 -2
  10. fractal_server/app/models/v2/workflow.py +1 -1
  11. fractal_server/app/models/v2/workflowtask.py +1 -1
  12. fractal_server/app/routes/admin/v2/__init__.py +4 -0
  13. fractal_server/app/routes/admin/v2/accounting.py +98 -0
  14. fractal_server/app/routes/admin/v2/impersonate.py +35 -0
  15. fractal_server/app/routes/admin/v2/job.py +5 -13
  16. fractal_server/app/routes/admin/v2/task.py +1 -1
  17. fractal_server/app/routes/admin/v2/task_group.py +4 -29
  18. fractal_server/app/routes/api/__init__.py +1 -1
  19. fractal_server/app/routes/api/v2/__init__.py +8 -2
  20. fractal_server/app/routes/api/v2/_aux_functions.py +66 -0
  21. fractal_server/app/routes/api/v2/_aux_functions_history.py +166 -0
  22. fractal_server/app/routes/api/v2/_aux_functions_task_lifecycle.py +3 -3
  23. fractal_server/app/routes/api/v2/dataset.py +0 -17
  24. fractal_server/app/routes/api/v2/history.py +544 -0
  25. fractal_server/app/routes/api/v2/images.py +31 -43
  26. fractal_server/app/routes/api/v2/job.py +30 -0
  27. fractal_server/app/routes/api/v2/project.py +1 -53
  28. fractal_server/app/routes/api/v2/{status.py → status_legacy.py} +6 -6
  29. fractal_server/app/routes/api/v2/submit.py +17 -14
  30. fractal_server/app/routes/api/v2/task.py +3 -10
  31. fractal_server/app/routes/api/v2/task_collection_custom.py +4 -9
  32. fractal_server/app/routes/api/v2/task_group.py +2 -22
  33. fractal_server/app/routes/api/v2/verify_image_types.py +61 -0
  34. fractal_server/app/routes/api/v2/workflow.py +28 -69
  35. fractal_server/app/routes/api/v2/workflowtask.py +53 -50
  36. fractal_server/app/routes/auth/group.py +0 -16
  37. fractal_server/app/routes/auth/oauth.py +5 -3
  38. fractal_server/app/routes/aux/__init__.py +0 -20
  39. fractal_server/app/routes/pagination.py +47 -0
  40. fractal_server/app/runner/components.py +0 -3
  41. fractal_server/app/runner/compress_folder.py +57 -29
  42. fractal_server/app/runner/exceptions.py +4 -0
  43. fractal_server/app/runner/executors/base_runner.py +157 -0
  44. fractal_server/app/runner/{v2/_local/_local_config.py → executors/local/get_local_config.py} +7 -9
  45. fractal_server/app/runner/executors/local/runner.py +248 -0
  46. fractal_server/app/runner/executors/{slurm → slurm_common}/_batching.py +1 -1
  47. fractal_server/app/runner/executors/{slurm → slurm_common}/_slurm_config.py +9 -7
  48. fractal_server/app/runner/executors/slurm_common/base_slurm_runner.py +868 -0
  49. fractal_server/app/runner/{v2/_slurm_common → executors/slurm_common}/get_slurm_config.py +48 -17
  50. fractal_server/app/runner/executors/{slurm → slurm_common}/remote.py +36 -47
  51. fractal_server/app/runner/executors/slurm_common/slurm_job_task_models.py +134 -0
  52. fractal_server/app/runner/executors/slurm_ssh/runner.py +268 -0
  53. fractal_server/app/runner/executors/slurm_sudo/__init__.py +0 -0
  54. fractal_server/app/runner/executors/{slurm/sudo → slurm_sudo}/_subprocess_run_as_user.py +2 -83
  55. fractal_server/app/runner/executors/slurm_sudo/runner.py +193 -0
  56. fractal_server/app/runner/extract_archive.py +1 -3
  57. fractal_server/app/runner/task_files.py +134 -87
  58. fractal_server/app/runner/v2/__init__.py +0 -395
  59. fractal_server/app/runner/v2/_local.py +88 -0
  60. fractal_server/app/runner/v2/{_slurm_ssh/__init__.py → _slurm_ssh.py} +22 -19
  61. fractal_server/app/runner/v2/{_slurm_sudo/__init__.py → _slurm_sudo.py} +19 -15
  62. fractal_server/app/runner/v2/db_tools.py +119 -0
  63. fractal_server/app/runner/v2/runner.py +219 -98
  64. fractal_server/app/runner/v2/runner_functions.py +491 -189
  65. fractal_server/app/runner/v2/runner_functions_low_level.py +40 -43
  66. fractal_server/app/runner/v2/submit_workflow.py +358 -0
  67. fractal_server/app/runner/v2/task_interface.py +31 -0
  68. fractal_server/app/schemas/_validators.py +13 -24
  69. fractal_server/app/schemas/user.py +10 -7
  70. fractal_server/app/schemas/user_settings.py +9 -21
  71. fractal_server/app/schemas/v2/__init__.py +10 -1
  72. fractal_server/app/schemas/v2/accounting.py +18 -0
  73. fractal_server/app/schemas/v2/dataset.py +12 -94
  74. fractal_server/app/schemas/v2/dumps.py +26 -9
  75. fractal_server/app/schemas/v2/history.py +80 -0
  76. fractal_server/app/schemas/v2/job.py +15 -8
  77. fractal_server/app/schemas/v2/manifest.py +14 -7
  78. fractal_server/app/schemas/v2/project.py +9 -7
  79. fractal_server/app/schemas/v2/status_legacy.py +35 -0
  80. fractal_server/app/schemas/v2/task.py +72 -77
  81. fractal_server/app/schemas/v2/task_collection.py +14 -32
  82. fractal_server/app/schemas/v2/task_group.py +10 -9
  83. fractal_server/app/schemas/v2/workflow.py +10 -11
  84. fractal_server/app/schemas/v2/workflowtask.py +2 -21
  85. fractal_server/app/security/__init__.py +3 -3
  86. fractal_server/app/security/signup_email.py +2 -2
  87. fractal_server/config.py +91 -90
  88. fractal_server/images/tools.py +23 -0
  89. fractal_server/migrations/versions/47351f8c7ebc_drop_dataset_filters.py +50 -0
  90. fractal_server/migrations/versions/9db60297b8b2_set_ondelete.py +250 -0
  91. fractal_server/migrations/versions/af1ef1c83c9b_add_accounting_tables.py +57 -0
  92. fractal_server/migrations/versions/c90a7c76e996_job_id_in_history_run.py +41 -0
  93. fractal_server/migrations/versions/e81103413827_add_job_type_filters.py +36 -0
  94. fractal_server/migrations/versions/f37aceb45062_make_historyunit_logfile_required.py +39 -0
  95. fractal_server/migrations/versions/fbce16ff4e47_new_history_items.py +120 -0
  96. fractal_server/ssh/_fabric.py +28 -14
  97. fractal_server/tasks/v2/local/collect.py +2 -2
  98. fractal_server/tasks/v2/ssh/collect.py +2 -2
  99. fractal_server/tasks/v2/templates/2_pip_install.sh +1 -1
  100. fractal_server/tasks/v2/templates/4_pip_show.sh +1 -1
  101. fractal_server/tasks/v2/utils_background.py +1 -20
  102. fractal_server/tasks/v2/utils_database.py +30 -17
  103. fractal_server/tasks/v2/utils_templates.py +6 -0
  104. {fractal_server-2.13.0.dist-info → fractal_server-2.14.0.dist-info}/METADATA +4 -4
  105. {fractal_server-2.13.0.dist-info → fractal_server-2.14.0.dist-info}/RECORD +114 -99
  106. {fractal_server-2.13.0.dist-info → fractal_server-2.14.0.dist-info}/WHEEL +1 -1
  107. fractal_server/app/runner/executors/slurm/ssh/_executor_wait_thread.py +0 -126
  108. fractal_server/app/runner/executors/slurm/ssh/_slurm_job.py +0 -116
  109. fractal_server/app/runner/executors/slurm/ssh/executor.py +0 -1386
  110. fractal_server/app/runner/executors/slurm/sudo/_check_jobs_status.py +0 -71
  111. fractal_server/app/runner/executors/slurm/sudo/_executor_wait_thread.py +0 -130
  112. fractal_server/app/runner/executors/slurm/sudo/executor.py +0 -1281
  113. fractal_server/app/runner/v2/_local/__init__.py +0 -129
  114. fractal_server/app/runner/v2/_local/_submit_setup.py +0 -52
  115. fractal_server/app/runner/v2/_local/executor.py +0 -100
  116. fractal_server/app/runner/v2/_slurm_ssh/_submit_setup.py +0 -83
  117. fractal_server/app/runner/v2/_slurm_sudo/_submit_setup.py +0 -83
  118. fractal_server/app/runner/v2/handle_failed_job.py +0 -59
  119. fractal_server/app/schemas/v2/status.py +0 -16
  120. /fractal_server/app/{runner/executors/slurm → history}/__init__.py +0 -0
  121. /fractal_server/app/runner/executors/{slurm/ssh → local}/__init__.py +0 -0
  122. /fractal_server/app/runner/executors/{slurm/sudo → slurm_common}/__init__.py +0 -0
  123. /fractal_server/app/runner/executors/{_job_states.py → slurm_common/_job_states.py} +0 -0
  124. /fractal_server/app/runner/executors/{slurm → slurm_common}/utils_executors.py +0 -0
  125. /fractal_server/app/runner/{v2/_slurm_common → executors/slurm_ssh}/__init__.py +0 -0
  126. {fractal_server-2.13.0.dist-info → fractal_server-2.14.0.dist-info}/LICENSE +0 -0
  127. {fractal_server-2.13.0.dist-info → fractal_server-2.14.0.dist-info}/entry_points.txt +0 -0
@@ -1,100 +1,112 @@
1
1
  import functools
2
- import logging
3
- import traceback
4
- from concurrent.futures import Executor
5
2
  from pathlib import Path
6
3
  from typing import Any
7
4
  from typing import Callable
8
5
  from typing import Literal
9
6
  from typing import Optional
10
7
 
11
- from pydantic import ValidationError
8
+ from pydantic import BaseModel
9
+ from pydantic import ConfigDict
12
10
 
13
11
  from ..exceptions import JobExecutionError
12
+ from ..exceptions import TaskOutputValidationError
13
+ from .db_tools import update_status_of_history_unit
14
14
  from .deduplicate_list import deduplicate_list
15
- from .merge_outputs import merge_outputs
16
15
  from .runner_functions_low_level import run_single_task
17
16
  from .task_interface import InitTaskOutput
18
17
  from .task_interface import TaskOutput
18
+ from fractal_server.app.db import get_sync_db
19
+ from fractal_server.app.models.v2 import HistoryUnit
19
20
  from fractal_server.app.models.v2 import TaskV2
20
21
  from fractal_server.app.models.v2 import WorkflowTaskV2
21
- from fractal_server.app.runner.components import _COMPONENT_KEY_
22
- from fractal_server.app.runner.components import _index_to_component
23
-
22
+ from fractal_server.app.runner.executors.base_runner import BaseRunner
23
+ from fractal_server.app.runner.task_files import enrich_task_files_multisubmit
24
+ from fractal_server.app.runner.task_files import SUBMIT_PREFIX
25
+ from fractal_server.app.runner.task_files import TaskFiles
26
+ from fractal_server.app.runner.v2.db_tools import (
27
+ bulk_update_status_of_history_unit,
28
+ )
29
+ from fractal_server.app.runner.v2.db_tools import bulk_upsert_image_cache_fast
30
+ from fractal_server.app.runner.v2.task_interface import (
31
+ _cast_and_validate_InitTaskOutput,
32
+ )
33
+ from fractal_server.app.runner.v2.task_interface import (
34
+ _cast_and_validate_TaskOutput,
35
+ )
36
+ from fractal_server.app.schemas.v2 import HistoryUnitStatus
37
+ from fractal_server.logger import set_logger
24
38
 
25
39
  __all__ = [
26
- "run_v2_task_non_parallel",
27
40
  "run_v2_task_parallel",
41
+ "run_v2_task_non_parallel",
28
42
  "run_v2_task_compound",
29
43
  ]
30
44
 
31
- MAX_PARALLELIZATION_LIST_SIZE = 20_000
32
45
 
46
+ logger = set_logger(__name__)
33
47
 
34
- def _cast_and_validate_TaskOutput(
35
- task_output: dict[str, Any]
36
- ) -> Optional[TaskOutput]:
37
- try:
38
- validated_task_output = TaskOutput(**task_output)
39
- return validated_task_output
40
- except ValidationError as e:
41
- raise JobExecutionError(
42
- "Validation of task output failed.\n"
43
- f"Original error: {str(e)}\n"
44
- f"Original data: {task_output}."
45
- )
46
48
 
49
+ class SubmissionOutcome(BaseModel):
50
+ model_config = ConfigDict(arbitrary_types_allowed=True)
51
+ task_output: TaskOutput | None = None
52
+ exception: BaseException | None = None
53
+ invalid_output: bool = False
54
+
55
+
56
+ class InitSubmissionOutcome(BaseModel):
57
+ model_config = ConfigDict(arbitrary_types_allowed=True)
58
+ task_output: InitTaskOutput | None = None
59
+ exception: BaseException | None = None
47
60
 
48
- def _cast_and_validate_InitTaskOutput(
49
- init_task_output: dict[str, Any],
50
- ) -> Optional[InitTaskOutput]:
51
- try:
52
- validated_init_task_output = InitTaskOutput(**init_task_output)
53
- return validated_init_task_output
54
- except ValidationError as e:
55
- raise JobExecutionError(
56
- "Validation of init-task output failed.\n"
57
- f"Original error: {str(e)}\n"
58
- f"Original data: {init_task_output}."
59
- )
61
+
62
+ MAX_PARALLELIZATION_LIST_SIZE = 20_000
60
63
 
61
64
 
62
- def no_op_submit_setup_call(
65
+ def _process_task_output(
63
66
  *,
64
- wftask: WorkflowTaskV2,
65
- workflow_dir_local: Path,
66
- workflow_dir_remote: Path,
67
- which_type: Literal["non_parallel", "parallel"],
68
- ) -> dict:
69
- """
70
- Default (no-operation) interface of submit_setup_call in V2.
71
- """
72
- return {}
67
+ result: dict[str, Any] | None = None,
68
+ exception: BaseException | None = None,
69
+ ) -> SubmissionOutcome:
70
+ invalid_output = False
71
+ if exception is not None:
72
+ task_output = None
73
+ else:
74
+ if result is None:
75
+ task_output = TaskOutput()
76
+ else:
77
+ try:
78
+ task_output = _cast_and_validate_TaskOutput(result)
79
+ except TaskOutputValidationError as e:
80
+ task_output = None
81
+ exception = e
82
+ invalid_output = True
83
+ return SubmissionOutcome(
84
+ task_output=task_output,
85
+ exception=exception,
86
+ invalid_output=invalid_output,
87
+ )
73
88
 
74
89
 
75
- # Backend-specific configuration
76
- def _get_executor_options(
90
+ def _process_init_task_output(
77
91
  *,
78
- wftask: WorkflowTaskV2,
79
- workflow_dir_local: Path,
80
- workflow_dir_remote: Path,
81
- submit_setup_call: Callable,
82
- which_type: Literal["non_parallel", "parallel"],
83
- ) -> dict:
84
- try:
85
- options = submit_setup_call(
86
- wftask=wftask,
87
- workflow_dir_local=workflow_dir_local,
88
- workflow_dir_remote=workflow_dir_remote,
89
- which_type=which_type,
90
- )
91
- except Exception as e:
92
- tb = "".join(traceback.format_tb(e.__traceback__))
93
- raise RuntimeError(
94
- f"{type(e)} error in {submit_setup_call=}\n"
95
- f"Original traceback:\n{tb}"
96
- )
97
- return options
92
+ result: dict[str, Any] | None = None,
93
+ exception: BaseException | None = None,
94
+ ) -> SubmissionOutcome:
95
+ if exception is not None:
96
+ task_output = None
97
+ else:
98
+ if result is None:
99
+ task_output = InitTaskOutput()
100
+ else:
101
+ try:
102
+ task_output = _cast_and_validate_InitTaskOutput(result)
103
+ except TaskOutputValidationError as e:
104
+ task_output = None
105
+ exception = e
106
+ return InitSubmissionOutcome(
107
+ task_output=task_output,
108
+ exception=exception,
109
+ )
98
110
 
99
111
 
100
112
  def _check_parallelization_list_size(my_list):
@@ -113,51 +125,123 @@ def run_v2_task_non_parallel(
113
125
  task: TaskV2,
114
126
  wftask: WorkflowTaskV2,
115
127
  workflow_dir_local: Path,
116
- workflow_dir_remote: Optional[Path] = None,
117
- executor: Executor,
118
- logger_name: Optional[str] = None,
119
- submit_setup_call: Callable = no_op_submit_setup_call,
120
- ) -> TaskOutput:
128
+ workflow_dir_remote: Path,
129
+ runner: BaseRunner,
130
+ get_runner_config: Callable[
131
+ [
132
+ WorkflowTaskV2,
133
+ Literal["non_parallel", "parallel"],
134
+ Optional[Path],
135
+ int,
136
+ ],
137
+ Any,
138
+ ],
139
+ dataset_id: int,
140
+ history_run_id: int,
141
+ task_type: Literal["non_parallel", "converter_non_parallel"],
142
+ user_id: int,
143
+ ) -> tuple[dict[int, SubmissionOutcome], int]:
121
144
  """
122
145
  This runs server-side (see `executor` argument)
123
146
  """
124
147
 
125
- if workflow_dir_remote is None:
126
- workflow_dir_remote = workflow_dir_local
127
- logging.warning(
128
- "In `run_single_task`, workflow_dir_remote=None. Is this right?"
148
+ if task_type not in ["non_parallel", "converter_non_parallel"]:
149
+ raise ValueError(
150
+ f"Invalid {task_type=} for `run_v2_task_non_parallel`."
129
151
  )
130
- workflow_dir_remote = workflow_dir_local
131
152
 
132
- executor_options = _get_executor_options(
153
+ # Get TaskFiles object
154
+ task_files = TaskFiles(
155
+ root_dir_local=workflow_dir_local,
156
+ root_dir_remote=workflow_dir_remote,
157
+ task_order=wftask.order,
158
+ task_name=wftask.task.name,
159
+ component="",
160
+ prefix=SUBMIT_PREFIX,
161
+ )
162
+
163
+ runner_config = get_runner_config(
133
164
  wftask=wftask,
134
- workflow_dir_local=workflow_dir_local,
135
- workflow_dir_remote=workflow_dir_remote,
136
- submit_setup_call=submit_setup_call,
137
165
  which_type="non_parallel",
138
166
  )
139
167
 
140
- function_kwargs = dict(
141
- zarr_urls=[image["zarr_url"] for image in images],
142
- zarr_dir=zarr_dir,
168
+ function_kwargs = {
169
+ "zarr_dir": zarr_dir,
143
170
  **(wftask.args_non_parallel or {}),
144
- )
145
- future = executor.submit(
171
+ }
172
+ if task_type == "non_parallel":
173
+ function_kwargs["zarr_urls"] = [img["zarr_url"] for img in images]
174
+
175
+ # Database History operations
176
+ with next(get_sync_db()) as db:
177
+ if task_type == "non_parallel":
178
+ zarr_urls = function_kwargs["zarr_urls"]
179
+ elif task_type == "converter_non_parallel":
180
+ zarr_urls = []
181
+
182
+ history_unit = HistoryUnit(
183
+ history_run_id=history_run_id,
184
+ status=HistoryUnitStatus.SUBMITTED,
185
+ logfile=task_files.log_file_local,
186
+ zarr_urls=zarr_urls,
187
+ )
188
+ db.add(history_unit)
189
+ db.commit()
190
+ db.refresh(history_unit)
191
+ logger.debug(
192
+ "[run_v2_task_non_parallel] Created `HistoryUnit` with "
193
+ f"{history_run_id=}."
194
+ )
195
+ history_unit_id = history_unit.id
196
+ bulk_upsert_image_cache_fast(
197
+ db=db,
198
+ list_upsert_objects=[
199
+ dict(
200
+ workflowtask_id=wftask.id,
201
+ dataset_id=dataset_id,
202
+ zarr_url=zarr_url,
203
+ latest_history_unit_id=history_unit_id,
204
+ )
205
+ for zarr_url in history_unit.zarr_urls
206
+ ],
207
+ )
208
+
209
+ result, exception = runner.submit(
146
210
  functools.partial(
147
211
  run_single_task,
148
- wftask=wftask,
149
212
  command=task.command_non_parallel,
150
- workflow_dir_local=workflow_dir_local,
151
- workflow_dir_remote=workflow_dir_remote,
213
+ workflow_task_order=wftask.order,
214
+ workflow_task_id=wftask.task_id,
215
+ task_name=wftask.task.name,
152
216
  ),
153
- function_kwargs,
154
- **executor_options,
217
+ parameters=function_kwargs,
218
+ task_type=task_type,
219
+ task_files=task_files,
220
+ history_unit_id=history_unit_id,
221
+ config=runner_config,
222
+ user_id=user_id,
155
223
  )
156
- output = future.result()
157
- if output is None:
158
- return TaskOutput()
159
- else:
160
- return _cast_and_validate_TaskOutput(output)
224
+
225
+ positional_index = 0
226
+ num_tasks = 1
227
+
228
+ outcome = {
229
+ positional_index: _process_task_output(
230
+ result=result,
231
+ exception=exception,
232
+ )
233
+ }
234
+ # NOTE: Here we don't have to handle the
235
+ # `outcome[0].exception is not None` branch, since for non_parallel
236
+ # tasks it was already handled within submit
237
+ if outcome[0].invalid_output:
238
+ with next(get_sync_db()) as db:
239
+ update_status_of_history_unit(
240
+ history_unit_id=history_unit_id,
241
+ status=HistoryUnitStatus.FAILED,
242
+ db_sync=db,
243
+ )
244
+ return outcome, num_tasks
161
245
 
162
246
 
163
247
  def run_v2_task_parallel(
@@ -165,59 +249,132 @@ def run_v2_task_parallel(
165
249
  images: list[dict[str, Any]],
166
250
  task: TaskV2,
167
251
  wftask: WorkflowTaskV2,
168
- executor: Executor,
252
+ runner: BaseRunner,
169
253
  workflow_dir_local: Path,
170
- workflow_dir_remote: Optional[Path] = None,
171
- logger_name: Optional[str] = None,
172
- submit_setup_call: Callable = no_op_submit_setup_call,
173
- ) -> TaskOutput:
174
-
254
+ workflow_dir_remote: Path,
255
+ get_runner_config: Callable[
256
+ [
257
+ WorkflowTaskV2,
258
+ Literal["non_parallel", "parallel"],
259
+ Optional[Path],
260
+ int,
261
+ ],
262
+ Any,
263
+ ],
264
+ dataset_id: int,
265
+ history_run_id: int,
266
+ user_id: int,
267
+ ) -> tuple[dict[int, SubmissionOutcome], int]:
175
268
  if len(images) == 0:
176
- return TaskOutput()
269
+ return {}, 0
177
270
 
178
271
  _check_parallelization_list_size(images)
179
272
 
180
- executor_options = _get_executor_options(
273
+ # Get TaskFiles object
274
+ task_files = TaskFiles(
275
+ root_dir_local=workflow_dir_local,
276
+ root_dir_remote=workflow_dir_remote,
277
+ task_order=wftask.order,
278
+ task_name=wftask.task.name,
279
+ )
280
+
281
+ runner_config = get_runner_config(
181
282
  wftask=wftask,
182
- workflow_dir_local=workflow_dir_local,
183
- workflow_dir_remote=workflow_dir_remote,
184
- submit_setup_call=submit_setup_call,
185
283
  which_type="parallel",
284
+ tot_tasks=len(images),
186
285
  )
187
286
 
188
- list_function_kwargs = []
189
- for ind, image in enumerate(images):
190
- list_function_kwargs.append(
287
+ list_function_kwargs = [
288
+ {
289
+ "zarr_url": image["zarr_url"],
290
+ **(wftask.args_parallel or {}),
291
+ }
292
+ for image in images
293
+ ]
294
+
295
+ list_task_files = enrich_task_files_multisubmit(
296
+ base_task_files=task_files,
297
+ tot_tasks=len(images),
298
+ batch_size=runner_config.batch_size,
299
+ )
300
+
301
+ history_units = [
302
+ HistoryUnit(
303
+ history_run_id=history_run_id,
304
+ status=HistoryUnitStatus.SUBMITTED,
305
+ logfile=list_task_files[ind].log_file_local,
306
+ zarr_urls=[image["zarr_url"]],
307
+ )
308
+ for ind, image in enumerate(images)
309
+ ]
310
+
311
+ with next(get_sync_db()) as db:
312
+ db.add_all(history_units)
313
+ db.commit()
314
+ logger.debug(
315
+ f"[run_v2_task_non_parallel] Created {len(history_units)} "
316
+ "`HistoryUnit`s."
317
+ )
318
+
319
+ for history_unit in history_units:
320
+ db.refresh(history_unit)
321
+ history_unit_ids = [history_unit.id for history_unit in history_units]
322
+
323
+ history_image_caches = [
191
324
  dict(
192
- zarr_url=image["zarr_url"],
193
- **(wftask.args_parallel or {}),
194
- ),
325
+ workflowtask_id=wftask.id,
326
+ dataset_id=dataset_id,
327
+ zarr_url=history_unit.zarr_urls[0],
328
+ latest_history_unit_id=history_unit.id,
329
+ )
330
+ for history_unit in history_units
331
+ ]
332
+
333
+ bulk_upsert_image_cache_fast(
334
+ db=db, list_upsert_objects=history_image_caches
195
335
  )
196
- list_function_kwargs[-1][_COMPONENT_KEY_] = _index_to_component(ind)
197
336
 
198
- results_iterator = executor.map(
337
+ results, exceptions = runner.multisubmit(
199
338
  functools.partial(
200
339
  run_single_task,
201
- wftask=wftask,
202
340
  command=task.command_parallel,
203
- workflow_dir_local=workflow_dir_local,
204
- workflow_dir_remote=workflow_dir_remote,
341
+ workflow_task_order=wftask.order,
342
+ workflow_task_id=wftask.task_id,
343
+ task_name=wftask.task.name,
205
344
  ),
206
- list_function_kwargs,
207
- **executor_options,
345
+ list_parameters=list_function_kwargs,
346
+ task_type="parallel",
347
+ list_task_files=list_task_files,
348
+ history_unit_ids=history_unit_ids,
349
+ config=runner_config,
350
+ user_id=user_id,
208
351
  )
209
- # Explicitly iterate over the whole list, so that all futures are waited
210
- outputs = list(results_iterator)
211
-
212
- # Validate all non-None outputs
213
- for ind, output in enumerate(outputs):
214
- if output is None:
215
- outputs[ind] = TaskOutput()
216
- else:
217
- outputs[ind] = _cast_and_validate_TaskOutput(output)
218
352
 
219
- merged_output = merge_outputs(outputs)
220
- return merged_output
353
+ outcome = {}
354
+ for ind in range(len(list_function_kwargs)):
355
+ if ind not in results.keys() and ind not in exceptions.keys():
356
+ error_msg = (
357
+ f"Invalid branch: {ind=} is not in `results.keys()` "
358
+ "nor in `exceptions.keys()`."
359
+ )
360
+ logger.error(error_msg)
361
+ raise RuntimeError(error_msg)
362
+ outcome[ind] = _process_task_output(
363
+ result=results.get(ind, None),
364
+ exception=exceptions.get(ind, None),
365
+ )
366
+ # NOTE: Here we don't have to handle the
367
+ # `outcome[ind].exception is not None` branch, since for parallel
368
+ # tasks it was already handled within multisubmit
369
+ if outcome[ind].invalid_output:
370
+ with next(get_sync_db()) as db:
371
+ update_status_of_history_unit(
372
+ history_unit_id=history_unit_ids[ind],
373
+ status=HistoryUnitStatus.FAILED,
374
+ db_sync=db,
375
+ )
376
+ num_tasks = len(images)
377
+ return outcome, num_tasks
221
378
 
222
379
 
223
380
  def run_v2_task_compound(
@@ -226,91 +383,236 @@ def run_v2_task_compound(
226
383
  zarr_dir: str,
227
384
  task: TaskV2,
228
385
  wftask: WorkflowTaskV2,
229
- executor: Executor,
386
+ runner: BaseRunner,
230
387
  workflow_dir_local: Path,
231
- workflow_dir_remote: Optional[Path] = None,
232
- logger_name: Optional[str] = None,
233
- submit_setup_call: Callable = no_op_submit_setup_call,
234
- ) -> TaskOutput:
388
+ workflow_dir_remote: Path,
389
+ get_runner_config: Callable[
390
+ [
391
+ WorkflowTaskV2,
392
+ Literal["non_parallel", "parallel"],
393
+ Optional[Path],
394
+ int,
395
+ ],
396
+ Any,
397
+ ],
398
+ dataset_id: int,
399
+ history_run_id: int,
400
+ task_type: Literal["compound", "converter_compound"],
401
+ user_id: int,
402
+ ) -> tuple[dict[int, SubmissionOutcome], int]:
403
+ # Get TaskFiles object
404
+ task_files_init = TaskFiles(
405
+ root_dir_local=workflow_dir_local,
406
+ root_dir_remote=workflow_dir_remote,
407
+ task_order=wftask.order,
408
+ task_name=wftask.task.name,
409
+ component="",
410
+ prefix=SUBMIT_PREFIX,
411
+ )
235
412
 
236
- executor_options_init = _get_executor_options(
413
+ runner_config_init = get_runner_config(
237
414
  wftask=wftask,
238
- workflow_dir_local=workflow_dir_local,
239
- workflow_dir_remote=workflow_dir_remote,
240
- submit_setup_call=submit_setup_call,
241
415
  which_type="non_parallel",
242
416
  )
243
- executor_options_compute = _get_executor_options(
244
- wftask=wftask,
245
- workflow_dir_local=workflow_dir_local,
246
- workflow_dir_remote=workflow_dir_remote,
247
- submit_setup_call=submit_setup_call,
248
- which_type="parallel",
249
- )
250
-
251
417
  # 3/A: non-parallel init task
252
- function_kwargs = dict(
253
- zarr_urls=[image["zarr_url"] for image in images],
254
- zarr_dir=zarr_dir,
418
+ function_kwargs = {
419
+ "zarr_dir": zarr_dir,
255
420
  **(wftask.args_non_parallel or {}),
256
- )
257
- future = executor.submit(
421
+ }
422
+ if task_type == "compound":
423
+ function_kwargs["zarr_urls"] = [img["zarr_url"] for img in images]
424
+ input_image_zarr_urls = function_kwargs["zarr_urls"]
425
+ elif task_type == "converter_compound":
426
+ input_image_zarr_urls = []
427
+
428
+ # Create database History entries
429
+ with next(get_sync_db()) as db:
430
+ # Create a single `HistoryUnit` for the whole compound task
431
+ history_unit = HistoryUnit(
432
+ history_run_id=history_run_id,
433
+ status=HistoryUnitStatus.SUBMITTED,
434
+ logfile=task_files_init.log_file_local,
435
+ zarr_urls=input_image_zarr_urls,
436
+ )
437
+ db.add(history_unit)
438
+ db.commit()
439
+ db.refresh(history_unit)
440
+ init_history_unit_id = history_unit.id
441
+ logger.debug(
442
+ "[run_v2_task_compound] Created `HistoryUnit` with "
443
+ f"{init_history_unit_id=}."
444
+ )
445
+ # Create one `HistoryImageCache` for each input image
446
+ bulk_upsert_image_cache_fast(
447
+ db=db,
448
+ list_upsert_objects=[
449
+ dict(
450
+ workflowtask_id=wftask.id,
451
+ dataset_id=dataset_id,
452
+ zarr_url=zarr_url,
453
+ latest_history_unit_id=init_history_unit_id,
454
+ )
455
+ for zarr_url in input_image_zarr_urls
456
+ ],
457
+ )
458
+ result, exception = runner.submit(
258
459
  functools.partial(
259
460
  run_single_task,
260
- wftask=wftask,
261
461
  command=task.command_non_parallel,
262
- workflow_dir_local=workflow_dir_local,
263
- workflow_dir_remote=workflow_dir_remote,
462
+ workflow_task_order=wftask.order,
463
+ workflow_task_id=wftask.task_id,
464
+ task_name=wftask.task.name,
264
465
  ),
265
- function_kwargs,
266
- **executor_options_init,
466
+ parameters=function_kwargs,
467
+ task_type=task_type,
468
+ task_files=task_files_init,
469
+ history_unit_id=init_history_unit_id,
470
+ config=runner_config_init,
471
+ user_id=user_id,
267
472
  )
268
- output = future.result()
269
- if output is None:
270
- init_task_output = InitTaskOutput()
271
- else:
272
- init_task_output = _cast_and_validate_InitTaskOutput(output)
273
- parallelization_list = init_task_output.parallelization_list
473
+
474
+ init_outcome = _process_init_task_output(
475
+ result=result,
476
+ exception=exception,
477
+ )
478
+ num_tasks = 1
479
+ if init_outcome.exception is not None:
480
+ positional_index = 0
481
+ return (
482
+ {
483
+ positional_index: SubmissionOutcome(
484
+ exception=init_outcome.exception
485
+ )
486
+ },
487
+ num_tasks,
488
+ )
489
+
490
+ parallelization_list = init_outcome.task_output.parallelization_list
274
491
  parallelization_list = deduplicate_list(parallelization_list)
275
492
 
493
+ num_tasks = 1 + len(parallelization_list)
494
+
276
495
  # 3/B: parallel part of a compound task
277
496
  _check_parallelization_list_size(parallelization_list)
278
497
 
279
498
  if len(parallelization_list) == 0:
280
- return TaskOutput()
499
+ with next(get_sync_db()) as db:
500
+ update_status_of_history_unit(
501
+ history_unit_id=init_history_unit_id,
502
+ status=HistoryUnitStatus.DONE,
503
+ db_sync=db,
504
+ )
505
+ positional_index = 0
506
+ init_outcome = {
507
+ positional_index: _process_task_output(
508
+ result=None,
509
+ exception=None,
510
+ )
511
+ }
512
+ return init_outcome, num_tasks
513
+
514
+ runner_config_compute = get_runner_config(
515
+ wftask=wftask,
516
+ which_type="parallel",
517
+ tot_tasks=len(parallelization_list),
518
+ )
281
519
 
282
- list_function_kwargs = []
283
- for ind, parallelization_item in enumerate(parallelization_list):
284
- list_function_kwargs.append(
285
- dict(
286
- zarr_url=parallelization_item.zarr_url,
287
- init_args=parallelization_item.init_args,
288
- **(wftask.args_parallel or {}),
289
- ),
520
+ list_task_files = enrich_task_files_multisubmit(
521
+ base_task_files=TaskFiles(
522
+ root_dir_local=workflow_dir_local,
523
+ root_dir_remote=workflow_dir_remote,
524
+ task_order=wftask.order,
525
+ task_name=wftask.task.name,
526
+ ),
527
+ tot_tasks=len(parallelization_list),
528
+ batch_size=runner_config_compute.batch_size,
529
+ )
530
+
531
+ list_function_kwargs = [
532
+ {
533
+ "zarr_url": parallelization_item.zarr_url,
534
+ "init_args": parallelization_item.init_args,
535
+ **(wftask.args_parallel or {}),
536
+ }
537
+ for parallelization_item in parallelization_list
538
+ ]
539
+
540
+ # Create one `HistoryUnit` per parallelization item
541
+ history_units = [
542
+ HistoryUnit(
543
+ history_run_id=history_run_id,
544
+ status=HistoryUnitStatus.SUBMITTED,
545
+ logfile=list_task_files[ind].log_file_local,
546
+ zarr_urls=[parallelization_item.zarr_url],
290
547
  )
291
- list_function_kwargs[-1][_COMPONENT_KEY_] = _index_to_component(ind)
548
+ for ind, parallelization_item in enumerate(parallelization_list)
549
+ ]
550
+ with next(get_sync_db()) as db:
551
+ db.add_all(history_units)
552
+ db.commit()
553
+ for history_unit in history_units:
554
+ db.refresh(history_unit)
555
+ logger.debug(
556
+ f"[run_v2_task_compound] Created {len(history_units)} "
557
+ "`HistoryUnit`s."
558
+ )
559
+ history_unit_ids = [history_unit.id for history_unit in history_units]
292
560
 
293
- results_iterator = executor.map(
561
+ results, exceptions = runner.multisubmit(
294
562
  functools.partial(
295
563
  run_single_task,
296
- wftask=wftask,
297
564
  command=task.command_parallel,
298
- workflow_dir_local=workflow_dir_local,
299
- workflow_dir_remote=workflow_dir_remote,
565
+ workflow_task_order=wftask.order,
566
+ workflow_task_id=wftask.task_id,
567
+ task_name=wftask.task.name,
300
568
  ),
301
- list_function_kwargs,
302
- **executor_options_compute,
569
+ list_parameters=list_function_kwargs,
570
+ task_type=task_type,
571
+ list_task_files=list_task_files,
572
+ history_unit_ids=history_unit_ids,
573
+ config=runner_config_compute,
574
+ user_id=user_id,
303
575
  )
304
- # Explicitly iterate over the whole list, so that all futures are waited
305
- outputs = list(results_iterator)
306
576
 
307
- # Validate all non-None outputs
308
- for ind, output in enumerate(outputs):
309
- if output is None:
310
- outputs[ind] = TaskOutput()
577
+ compute_outcomes: dict[int, SubmissionOutcome] = {}
578
+ failure = False
579
+ for ind in range(len(list_function_kwargs)):
580
+ if ind not in results.keys() and ind not in exceptions.keys():
581
+ # NOTE: see issue 2484
582
+ error_msg = (
583
+ f"Invalid branch: {ind=} is not in `results.keys()` "
584
+ "nor in `exceptions.keys()`."
585
+ )
586
+ logger.error(error_msg)
587
+ raise RuntimeError(error_msg)
588
+ compute_outcomes[ind] = _process_task_output(
589
+ result=results.get(ind, None),
590
+ exception=exceptions.get(ind, None),
591
+ )
592
+ # NOTE: For compound task, `multisubmit` did not handle the
593
+ # `exception is not None` branch, therefore we have to include it here.
594
+ if (
595
+ compute_outcomes[ind].exception is not None
596
+ or compute_outcomes[ind].invalid_output
597
+ ):
598
+ failure = True
599
+
600
+ # NOTE: For compound tasks, we update `HistoryUnit.status` from here,
601
+ # rather than within the submit/multisubmit runner methods. This is
602
+ # to enforce the fact that either all units succeed or they all fail -
603
+ # at a difference with the parallel-task case.
604
+ with next(get_sync_db()) as db:
605
+ if failure:
606
+ bulk_update_status_of_history_unit(
607
+ history_unit_ids=history_unit_ids + [init_history_unit_id],
608
+ status=HistoryUnitStatus.FAILED,
609
+ db_sync=db,
610
+ )
311
611
  else:
312
- validated_output = _cast_and_validate_TaskOutput(output)
313
- outputs[ind] = validated_output
612
+ bulk_update_status_of_history_unit(
613
+ history_unit_ids=history_unit_ids + [init_history_unit_id],
614
+ status=HistoryUnitStatus.DONE,
615
+ db_sync=db,
616
+ )
314
617
 
315
- merged_output = merge_outputs(outputs)
316
- return merged_output
618
+ return compute_outcomes, num_tasks