fractal-server 2.14.0a10__py3-none-any.whl → 2.14.0a12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. fractal_server/__init__.py +1 -1
  2. fractal_server/app/routes/api/v2/submit.py +1 -1
  3. fractal_server/app/runner/components.py +0 -3
  4. fractal_server/app/runner/exceptions.py +4 -0
  5. fractal_server/app/runner/executors/base_runner.py +38 -17
  6. fractal_server/app/runner/executors/local/{_local_config.py → get_local_config.py} +0 -7
  7. fractal_server/app/runner/executors/local/runner.py +109 -59
  8. fractal_server/app/runner/executors/slurm_common/_check_jobs_status.py +4 -0
  9. fractal_server/app/runner/executors/slurm_ssh/executor.py +7 -5
  10. fractal_server/app/runner/executors/slurm_ssh/runner.py +6 -10
  11. fractal_server/app/runner/executors/slurm_sudo/runner.py +196 -99
  12. fractal_server/app/runner/task_files.py +8 -0
  13. fractal_server/app/runner/v2/__init__.py +0 -366
  14. fractal_server/app/runner/v2/_local.py +2 -2
  15. fractal_server/app/runner/v2/_slurm_ssh.py +2 -2
  16. fractal_server/app/runner/v2/_slurm_sudo.py +2 -2
  17. fractal_server/app/runner/v2/db_tools.py +87 -0
  18. fractal_server/app/runner/v2/runner.py +77 -81
  19. fractal_server/app/runner/v2/runner_functions.py +274 -436
  20. fractal_server/app/runner/v2/runner_functions_low_level.py +37 -39
  21. fractal_server/app/runner/v2/submit_workflow.py +366 -0
  22. fractal_server/app/runner/v2/task_interface.py +31 -0
  23. {fractal_server-2.14.0a10.dist-info → fractal_server-2.14.0a12.dist-info}/METADATA +1 -1
  24. {fractal_server-2.14.0a10.dist-info → fractal_server-2.14.0a12.dist-info}/RECORD +27 -28
  25. fractal_server/app/runner/executors/local/_submit_setup.py +0 -46
  26. fractal_server/app/runner/executors/slurm_common/_submit_setup.py +0 -84
  27. fractal_server/app/runner/v2/_db_tools.py +0 -48
  28. {fractal_server-2.14.0a10.dist-info → fractal_server-2.14.0a12.dist-info}/LICENSE +0 -0
  29. {fractal_server-2.14.0a10.dist-info → fractal_server-2.14.0a12.dist-info}/WHEEL +0 -0
  30. {fractal_server-2.14.0a10.dist-info → fractal_server-2.14.0a12.dist-info}/entry_points.txt +0 -0
@@ -1,16 +1,17 @@
1
1
  import functools
2
- import logging
3
2
  from pathlib import Path
4
3
  from typing import Any
4
+ from typing import Callable
5
5
  from typing import Literal
6
6
  from typing import Optional
7
7
 
8
- from pydantic import ValidationError
9
- from sqlmodel import update
8
+ from pydantic import BaseModel
9
+ from pydantic import ConfigDict
10
10
 
11
11
  from ..exceptions import JobExecutionError
12
+ from ..exceptions import TaskOutputValidationError
13
+ from .db_tools import update_status_of_history_unit
12
14
  from .deduplicate_list import deduplicate_list
13
- from .merge_outputs import merge_outputs
14
15
  from .runner_functions_low_level import run_single_task
15
16
  from .task_interface import InitTaskOutput
16
17
  from .task_interface import TaskOutput
@@ -18,64 +19,90 @@ from fractal_server.app.db import get_sync_db
18
19
  from fractal_server.app.models.v2 import HistoryUnit
19
20
  from fractal_server.app.models.v2 import TaskV2
20
21
  from fractal_server.app.models.v2 import WorkflowTaskV2
21
- from fractal_server.app.runner.components import _COMPONENT_KEY_
22
22
  from fractal_server.app.runner.components import _index_to_component
23
23
  from fractal_server.app.runner.executors.base_runner import BaseRunner
24
- from fractal_server.app.runner.v2._db_tools import bulk_upsert_image_cache_fast
24
+ from fractal_server.app.runner.task_files import TaskFiles
25
+ from fractal_server.app.runner.v2.db_tools import bulk_upsert_image_cache_fast
26
+ from fractal_server.app.runner.v2.task_interface import (
27
+ _cast_and_validate_InitTaskOutput,
28
+ )
29
+ from fractal_server.app.runner.v2.task_interface import (
30
+ _cast_and_validate_TaskOutput,
31
+ )
25
32
  from fractal_server.app.schemas.v2 import HistoryUnitStatus
26
-
33
+ from fractal_server.logger import set_logger
27
34
 
28
35
  __all__ = [
29
36
  "run_v2_task_parallel",
30
37
  "run_v2_task_non_parallel",
31
38
  "run_v2_task_compound",
32
- "run_v2_task_converter_non_parallel",
33
- "run_v2_task_converter_compound",
34
39
  ]
35
40
 
36
- # FIXME: Review whether we need 5 functions or 3 are enough
37
41
 
38
- MAX_PARALLELIZATION_LIST_SIZE = 20_000
42
+ logger = set_logger(__name__)
39
43
 
40
44
 
41
- def _cast_and_validate_TaskOutput(
42
- task_output: dict[str, Any]
43
- ) -> Optional[TaskOutput]:
44
- try:
45
- validated_task_output = TaskOutput(**task_output)
46
- return validated_task_output
47
- except ValidationError as e:
48
- raise JobExecutionError(
49
- "Validation of task output failed.\n"
50
- f"Original error: {str(e)}\n"
51
- f"Original data: {task_output}."
52
- )
45
+ class SubmissionOutcome(BaseModel):
46
+ model_config = ConfigDict(arbitrary_types_allowed=True)
47
+ task_output: TaskOutput | None = None
48
+ exception: BaseException | None = None
53
49
 
54
50
 
55
- def _cast_and_validate_InitTaskOutput(
56
- init_task_output: dict[str, Any],
57
- ) -> Optional[InitTaskOutput]:
58
- try:
59
- validated_init_task_output = InitTaskOutput(**init_task_output)
60
- return validated_init_task_output
61
- except ValidationError as e:
62
- raise JobExecutionError(
63
- "Validation of init-task output failed.\n"
64
- f"Original error: {str(e)}\n"
65
- f"Original data: {init_task_output}."
66
- )
51
+ class InitSubmissionOutcome(BaseModel):
52
+ model_config = ConfigDict(arbitrary_types_allowed=True)
53
+ task_output: InitTaskOutput | None = None
54
+ exception: BaseException | None = None
55
+
56
+
57
+ MAX_PARALLELIZATION_LIST_SIZE = 20_000
67
58
 
68
59
 
69
- def no_op_submit_setup_call(
60
+ def _process_task_output(
70
61
  *,
71
- wftask: WorkflowTaskV2,
72
- root_dir_local: Path,
73
- which_type: Literal["non_parallel", "parallel"],
74
- ) -> dict[str, Any]:
75
- """
76
- Default (no-operation) interface of submit_setup_call in V2.
77
- """
78
- return {}
62
+ result: dict[str, Any] | None = None,
63
+ exception: BaseException | None = None,
64
+ ) -> SubmissionOutcome:
65
+ if exception is not None:
66
+ task_output = None
67
+ else:
68
+ if result is None:
69
+ task_output = TaskOutput()
70
+ else:
71
+ try:
72
+ task_output = _cast_and_validate_TaskOutput(result)
73
+ except TaskOutputValidationError as e:
74
+ # FIXME: This should correspond to some status="failed",
75
+ # but it does not
76
+ task_output = None
77
+ exception = e
78
+ return SubmissionOutcome(
79
+ task_output=task_output,
80
+ exception=exception,
81
+ )
82
+
83
+
84
+ def _process_init_task_output(
85
+ *,
86
+ result: dict[str, Any] | None = None,
87
+ exception: BaseException | None = None,
88
+ ) -> SubmissionOutcome:
89
+ if exception is not None:
90
+ task_output = None
91
+ else:
92
+ if result is None:
93
+ task_output = InitTaskOutput()
94
+ else:
95
+ try:
96
+ task_output = _cast_and_validate_InitTaskOutput(result)
97
+ except TaskOutputValidationError as e:
98
+ # FIXME: This should correspond to some status="failed",
99
+ # but it does not
100
+ task_output = None
101
+ exception = e
102
+ return InitSubmissionOutcome(
103
+ task_output=task_output,
104
+ exception=exception,
105
+ )
79
106
 
80
107
 
81
108
  def _check_parallelization_list_size(my_list):
@@ -94,44 +121,59 @@ def run_v2_task_non_parallel(
94
121
  task: TaskV2,
95
122
  wftask: WorkflowTaskV2,
96
123
  workflow_dir_local: Path,
97
- workflow_dir_remote: Optional[Path] = None,
124
+ workflow_dir_remote: Path,
98
125
  runner: BaseRunner,
99
- submit_setup_call: callable = no_op_submit_setup_call,
126
+ get_runner_config: Callable[
127
+ [
128
+ WorkflowTaskV2,
129
+ Literal["non_parallel", "parallel"],
130
+ Optional[Path],
131
+ ],
132
+ Any,
133
+ ],
100
134
  dataset_id: int,
101
135
  history_run_id: int,
102
- ) -> tuple[TaskOutput, int, dict[int, BaseException]]:
136
+ task_type: Literal["non_parallel", "converter_non_parallel"],
137
+ ) -> tuple[dict[int, SubmissionOutcome], int]:
103
138
  """
104
139
  This runs server-side (see `executor` argument)
105
140
  """
106
141
 
107
- if workflow_dir_remote is None:
108
- workflow_dir_remote = workflow_dir_local
109
- logging.warning(
110
- "In `run_single_task`, workflow_dir_remote=None. Is this right?"
142
+ if task_type not in ["non_parallel", "converter_non_parallel"]:
143
+ raise ValueError(
144
+ f"Invalid {task_type=} for `run_v2_task_non_parallel`."
111
145
  )
112
- workflow_dir_remote = workflow_dir_local
113
146
 
114
- executor_options = submit_setup_call(
115
- wftask=wftask,
147
+ # Get TaskFiles object
148
+ task_files = TaskFiles(
116
149
  root_dir_local=workflow_dir_local,
117
150
  root_dir_remote=workflow_dir_remote,
118
- which_type="non_parallel",
151
+ task_order=wftask.order,
152
+ task_name=wftask.task.name,
153
+ component=_index_to_component(0),
119
154
  )
120
155
 
156
+ runner_config = get_runner_config(wftask=wftask, which_type="non_parallel")
157
+
121
158
  function_kwargs = {
122
- "zarr_urls": [image["zarr_url"] for image in images],
123
159
  "zarr_dir": zarr_dir,
124
- _COMPONENT_KEY_: _index_to_component(0),
125
160
  **(wftask.args_non_parallel or {}),
126
161
  }
162
+ if task_type == "non_parallel":
163
+ function_kwargs["zarr_urls"] = [img["zarr_url"] for img in images]
127
164
 
128
165
  # Database History operations
129
166
  with next(get_sync_db()) as db:
167
+ if task_type == "non_parallel":
168
+ zarr_urls = function_kwargs["zarr_urls"]
169
+ elif task_type == "converter_non_parallel":
170
+ zarr_urls = []
171
+
130
172
  history_unit = HistoryUnit(
131
173
  history_run_id=history_run_id,
132
174
  status=HistoryUnitStatus.SUBMITTED,
133
- logfile=None, # FIXME
134
- zarr_urls=function_kwargs["zarr_urls"],
175
+ logfile=task_files.log_file_local,
176
+ zarr_urls=zarr_urls,
135
177
  )
136
178
  db.add(history_unit)
137
179
  db.commit()
@@ -153,122 +195,28 @@ def run_v2_task_non_parallel(
153
195
  result, exception = runner.submit(
154
196
  functools.partial(
155
197
  run_single_task,
156
- wftask=wftask,
157
198
  command=task.command_non_parallel,
158
- root_dir_local=workflow_dir_local,
159
- root_dir_remote=workflow_dir_remote,
199
+ workflow_task_order=wftask.order,
200
+ workflow_task_id=wftask.task_id,
201
+ task_name=wftask.task.name,
160
202
  ),
161
203
  parameters=function_kwargs,
162
- task_type="non_parallel",
163
- **executor_options,
204
+ task_type=task_type,
205
+ task_files=task_files,
206
+ history_unit_id=history_unit_id,
207
+ config=runner_config,
164
208
  )
165
209
 
210
+ positional_index = 0
166
211
  num_tasks = 1
167
- with next(get_sync_db()) as db:
168
- if exception is None:
169
- db.execute(
170
- update(HistoryUnit)
171
- .where(HistoryUnit.id == history_unit_id)
172
- .values(status=HistoryUnitStatus.DONE)
173
- )
174
- db.commit()
175
- if result is None:
176
- return (TaskOutput(), num_tasks, {})
177
- else:
178
- return (_cast_and_validate_TaskOutput(result), num_tasks, {})
179
- else:
180
- db.execute(
181
- update(HistoryUnit)
182
- .where(HistoryUnit.id == history_unit_id)
183
- .values(status=HistoryUnitStatus.FAILED)
184
- )
185
- db.commit()
186
- return (TaskOutput(), num_tasks, {0: exception})
187
-
188
212
 
189
- def run_v2_task_converter_non_parallel(
190
- *,
191
- zarr_dir: str,
192
- task: TaskV2,
193
- wftask: WorkflowTaskV2,
194
- workflow_dir_local: Path,
195
- workflow_dir_remote: Optional[Path] = None,
196
- executor: BaseRunner,
197
- submit_setup_call: callable = no_op_submit_setup_call,
198
- dataset_id: int,
199
- history_run_id: int,
200
- ) -> tuple[TaskOutput, int, dict[int, BaseException]]:
201
- """
202
- This runs server-side (see `executor` argument)
203
- """
204
-
205
- if workflow_dir_remote is None:
206
- workflow_dir_remote = workflow_dir_local
207
- logging.warning(
208
- "In `run_single_task`, workflow_dir_remote=None. Is this right?"
213
+ outcome = {
214
+ positional_index: _process_task_output(
215
+ result=result,
216
+ exception=exception,
209
217
  )
210
- workflow_dir_remote = workflow_dir_local
211
-
212
- executor_options = submit_setup_call(
213
- wftask=wftask,
214
- root_dir_local=workflow_dir_local,
215
- root_dir_remote=workflow_dir_remote,
216
- which_type="non_parallel",
217
- )
218
-
219
- function_kwargs = {
220
- "zarr_dir": zarr_dir,
221
- _COMPONENT_KEY_: _index_to_component(0),
222
- **(wftask.args_non_parallel or {}),
223
218
  }
224
-
225
- # Database History operations
226
- with next(get_sync_db()) as db:
227
- history_unit = HistoryUnit(
228
- history_run_id=history_run_id,
229
- status=HistoryUnitStatus.SUBMITTED,
230
- logfile=None, # FIXME
231
- zarr_urls=[],
232
- )
233
- db.add(history_unit)
234
- db.commit()
235
- db.refresh(history_unit)
236
- history_unit_id = history_unit.id
237
-
238
- result, exception = executor.submit(
239
- functools.partial(
240
- run_single_task,
241
- wftask=wftask,
242
- command=task.command_non_parallel,
243
- root_dir_local=workflow_dir_local,
244
- root_dir_remote=workflow_dir_remote,
245
- ),
246
- task_type="converter_non_parallel",
247
- parameters=function_kwargs,
248
- **executor_options,
249
- )
250
-
251
- num_tasks = 1
252
- with next(get_sync_db()) as db:
253
- if exception is None:
254
- db.execute(
255
- update(HistoryUnit)
256
- .where(HistoryUnit.id == history_unit_id)
257
- .values(status=HistoryUnitStatus.DONE)
258
- )
259
- db.commit()
260
- if result is None:
261
- return (TaskOutput(), num_tasks, {})
262
- else:
263
- return (_cast_and_validate_TaskOutput(result), num_tasks, {})
264
- else:
265
- db.execute(
266
- update(HistoryUnit)
267
- .where(HistoryUnit.id == history_unit_id)
268
- .values(status=HistoryUnitStatus.FAILED)
269
- )
270
- db.commit()
271
- return (TaskOutput(), num_tasks, {0: exception})
219
+ return outcome, num_tasks
272
220
 
273
221
 
274
222
  def run_v2_task_parallel(
@@ -278,40 +226,59 @@ def run_v2_task_parallel(
278
226
  wftask: WorkflowTaskV2,
279
227
  runner: BaseRunner,
280
228
  workflow_dir_local: Path,
281
- workflow_dir_remote: Optional[Path] = None,
282
- submit_setup_call: callable = no_op_submit_setup_call,
229
+ workflow_dir_remote: Path,
230
+ get_runner_config: Callable[
231
+ [
232
+ WorkflowTaskV2,
233
+ Literal["non_parallel", "parallel"],
234
+ Optional[Path],
235
+ ],
236
+ Any,
237
+ ],
283
238
  dataset_id: int,
284
239
  history_run_id: int,
285
- ) -> tuple[TaskOutput, int, dict[int, BaseException]]:
240
+ ) -> tuple[dict[int, SubmissionOutcome], int]:
286
241
  if len(images) == 0:
287
- # FIXME: Do something with history units/images?
288
- return (TaskOutput(), 0, {})
242
+ return {}, 0
289
243
 
290
244
  _check_parallelization_list_size(images)
291
245
 
292
- executor_options = submit_setup_call(
293
- wftask=wftask,
246
+ # Get TaskFiles object
247
+ task_files = TaskFiles(
294
248
  root_dir_local=workflow_dir_local,
295
249
  root_dir_remote=workflow_dir_remote,
250
+ task_order=wftask.order,
251
+ task_name=wftask.task.name,
252
+ )
253
+
254
+ runner_config = get_runner_config(
255
+ wftask=wftask,
296
256
  which_type="parallel",
297
257
  )
298
258
 
299
259
  list_function_kwargs = [
300
260
  {
301
261
  "zarr_url": image["zarr_url"],
302
- _COMPONENT_KEY_: _index_to_component(ind),
303
262
  **(wftask.args_parallel or {}),
304
263
  }
305
- for ind, image in enumerate(images)
264
+ for image in images
306
265
  ]
266
+ list_task_files = [
267
+ TaskFiles(
268
+ **task_files.model_dump(exclude={"component"}),
269
+ component=_index_to_component(ind),
270
+ )
271
+ for ind in range(len(images))
272
+ ]
273
+
307
274
  history_units = [
308
275
  HistoryUnit(
309
276
  history_run_id=history_run_id,
310
277
  status=HistoryUnitStatus.SUBMITTED,
311
- logfile=None, # FIXME
278
+ logfile=list_task_files[ind].log_file_local,
312
279
  zarr_urls=[image["zarr_url"]],
313
280
  )
314
- for image in images
281
+ for ind, image in enumerate(images)
315
282
  ]
316
283
 
317
284
  with next(get_sync_db()) as db:
@@ -339,50 +306,43 @@ def run_v2_task_parallel(
339
306
  results, exceptions = runner.multisubmit(
340
307
  functools.partial(
341
308
  run_single_task,
342
- wftask=wftask,
343
309
  command=task.command_parallel,
344
- root_dir_local=workflow_dir_local,
345
- root_dir_remote=workflow_dir_remote,
310
+ workflow_task_order=wftask.order,
311
+ workflow_task_id=wftask.task_id,
312
+ task_name=wftask.task.name,
346
313
  ),
347
314
  list_parameters=list_function_kwargs,
348
315
  task_type="parallel",
349
- **executor_options,
316
+ list_task_files=list_task_files,
317
+ history_unit_ids=history_unit_ids,
318
+ config=runner_config,
350
319
  )
351
320
 
352
- outputs = []
353
- history_unit_ids_done: list[int] = []
354
- history_unit_ids_failed: list[int] = []
321
+ outcome = {}
355
322
  for ind in range(len(list_function_kwargs)):
356
- if ind in results.keys():
357
- result = results[ind]
358
- if result is None:
359
- output = TaskOutput()
360
- else:
361
- output = _cast_and_validate_TaskOutput(result)
362
- outputs.append(output)
363
- history_unit_ids_done.append(history_unit_ids[ind])
364
- elif ind in exceptions.keys():
365
- print(f"Bad: {exceptions[ind]}")
366
- history_unit_ids_failed.append(history_unit_ids[ind])
367
- else:
368
- print("VERY BAD - should have not reached this point")
369
-
370
- with next(get_sync_db()) as db:
371
- db.execute(
372
- update(HistoryUnit)
373
- .where(HistoryUnit.id.in_(history_unit_ids_done))
374
- .values(status=HistoryUnitStatus.DONE)
375
- )
376
- db.execute(
377
- update(HistoryUnit)
378
- .where(HistoryUnit.id.in_(history_unit_ids_failed))
379
- .values(status=HistoryUnitStatus.FAILED)
323
+ if ind not in results.keys() and ind not in exceptions.keys():
324
+ # FIXME: Could we avoid this branch?
325
+ error_msg = (
326
+ f"Invalid branch: {ind=} is not in `results.keys()` "
327
+ "nor in `exceptions.keys()`."
328
+ )
329
+ logger.error(error_msg)
330
+ raise RuntimeError(error_msg)
331
+ outcome[ind] = _process_task_output(
332
+ result=results.get(ind, None),
333
+ exception=exceptions.get(ind, None),
380
334
  )
381
- db.commit()
382
335
 
383
336
  num_tasks = len(images)
384
- merged_output = merge_outputs(outputs)
385
- return (merged_output, num_tasks, exceptions)
337
+ return outcome, num_tasks
338
+
339
+
340
+ # FIXME: THIS FOR CONVERTERS:
341
+ # if task_type in ["converter_non_parallel"]:
342
+ # run = db.get(HistoryRun, history_run_id)
343
+ # run.status = HistoryUnitStatus.DONE
344
+ # db.merge(run)
345
+ # db.commit()
386
346
 
387
347
 
388
348
  def run_v2_task_compound(
@@ -391,45 +351,58 @@ def run_v2_task_compound(
391
351
  zarr_dir: str,
392
352
  task: TaskV2,
393
353
  wftask: WorkflowTaskV2,
394
- executor: BaseRunner,
354
+ runner: BaseRunner,
395
355
  workflow_dir_local: Path,
396
- workflow_dir_remote: Optional[Path] = None,
397
- submit_setup_call: callable = no_op_submit_setup_call,
356
+ workflow_dir_remote: Path,
357
+ get_runner_config: Callable[
358
+ [
359
+ WorkflowTaskV2,
360
+ Literal["non_parallel", "parallel"],
361
+ Optional[Path],
362
+ ],
363
+ Any,
364
+ ],
398
365
  dataset_id: int,
399
366
  history_run_id: int,
400
- ) -> tuple[TaskOutput, int, dict[int, BaseException]]:
401
- # FIXME: Add task_files as a required argument, rather than a kwargs
402
- # through executor_options_init
403
-
404
- executor_options_init = submit_setup_call(
405
- wftask=wftask,
367
+ task_type: Literal["compound", "converter_compound"],
368
+ ) -> tuple[dict[int, SubmissionOutcome], int]:
369
+ # Get TaskFiles object
370
+ task_files_init = TaskFiles(
406
371
  root_dir_local=workflow_dir_local,
407
372
  root_dir_remote=workflow_dir_remote,
373
+ task_order=wftask.order,
374
+ task_name=wftask.task.name,
375
+ component=f"init_{_index_to_component(0)}",
376
+ )
377
+
378
+ runner_config_init = get_runner_config(
379
+ wftask=wftask,
408
380
  which_type="non_parallel",
409
381
  )
410
- executor_options_compute = submit_setup_call(
382
+ runner_config_compute = get_runner_config(
411
383
  wftask=wftask,
412
- root_dir_local=workflow_dir_local,
413
- root_dir_remote=workflow_dir_remote,
414
384
  which_type="parallel",
415
385
  )
416
386
 
417
387
  # 3/A: non-parallel init task
418
388
  function_kwargs = {
419
- "zarr_urls": [image["zarr_url"] for image in images],
420
389
  "zarr_dir": zarr_dir,
421
- _COMPONENT_KEY_: f"init_{_index_to_component(0)}",
422
390
  **(wftask.args_non_parallel or {}),
423
391
  }
392
+ if task_type == "compound":
393
+ function_kwargs["zarr_urls"] = [img["zarr_url"] for img in images]
394
+ input_image_zarr_urls = function_kwargs["zarr_urls"]
395
+ elif task_type == "converter_compound":
396
+ input_image_zarr_urls = []
424
397
 
425
398
  # Create database History entries
426
- input_image_zarr_urls = function_kwargs["zarr_urls"]
427
399
  with next(get_sync_db()) as db:
428
400
  # Create a single `HistoryUnit` for the whole compound task
429
401
  history_unit = HistoryUnit(
430
402
  history_run_id=history_run_id,
431
403
  status=HistoryUnitStatus.SUBMITTED,
432
- logfile=None, # FIXME
404
+ # FIXME: What about compute-task logs?
405
+ logfile=task_files_init.log_file_local,
433
406
  zarr_urls=input_image_zarr_urls,
434
407
  )
435
408
  db.add(history_unit)
@@ -449,37 +422,38 @@ def run_v2_task_compound(
449
422
  for zarr_url in input_image_zarr_urls
450
423
  ],
451
424
  )
452
-
453
- result, exception = executor.submit(
425
+ result, exception = runner.submit(
454
426
  functools.partial(
455
427
  run_single_task,
456
- wftask=wftask,
457
428
  command=task.command_non_parallel,
458
- root_dir_local=workflow_dir_local,
459
- root_dir_remote=workflow_dir_remote,
429
+ workflow_task_order=wftask.order,
430
+ workflow_task_id=wftask.task_id,
431
+ task_name=wftask.task.name,
460
432
  ),
461
433
  parameters=function_kwargs,
462
- task_type="compound",
463
- **executor_options_init,
434
+ task_type=task_type,
435
+ task_files=task_files_init,
436
+ history_unit_id=history_unit_id,
437
+ config=runner_config_init,
464
438
  )
465
439
 
440
+ init_outcome = _process_init_task_output(
441
+ result=result,
442
+ exception=exception,
443
+ )
466
444
  num_tasks = 1
467
- if exception is None:
468
- if result is None:
469
- init_task_output = InitTaskOutput()
470
- else:
471
- init_task_output = _cast_and_validate_InitTaskOutput(result)
472
- else:
473
- with next(get_sync_db()) as db:
474
- db.execute(
475
- update(HistoryUnit)
476
- .where(HistoryUnit.id == history_unit_id)
477
- .values(status=HistoryUnitStatus.FAILED)
478
- )
479
- db.commit()
480
- return (TaskOutput(), num_tasks, {0: exception})
445
+ if init_outcome.exception is not None:
446
+ positional_index = 0
447
+ return (
448
+ {
449
+ positional_index: SubmissionOutcome(
450
+ exception=init_outcome.exception
451
+ )
452
+ },
453
+ num_tasks,
454
+ )
481
455
 
482
- parallelization_list = init_task_output.parallelization_list
456
+ parallelization_list = init_outcome.task_output.parallelization_list
483
457
  parallelization_list = deduplicate_list(parallelization_list)
484
458
 
485
459
  num_tasks = 1 + len(parallelization_list)
@@ -489,220 +463,84 @@ def run_v2_task_compound(
489
463
 
490
464
  if len(parallelization_list) == 0:
491
465
  with next(get_sync_db()) as db:
492
- db.execute(
493
- update(HistoryUnit)
494
- .where(HistoryUnit.id == history_unit_id)
495
- .values(status=HistoryUnitStatus.DONE)
466
+ update_status_of_history_unit(
467
+ history_unit_id=history_unit_id,
468
+ status=HistoryUnitStatus.DONE,
469
+ db_sync=db,
470
+ )
471
+ positional_index = 0
472
+ init_outcome = {
473
+ positional_index: _process_task_output(
474
+ result=None,
475
+ exception=None,
496
476
  )
497
- db.commit()
498
- return (TaskOutput(), 0, {})
499
-
500
- list_function_kwargs = [
501
- {
502
- "zarr_url": parallelization_item.zarr_url,
503
- "init_args": parallelization_item.init_args,
504
- _COMPONENT_KEY_: f"compute_{_index_to_component(ind)}",
505
- **(wftask.args_parallel or {}),
506
477
  }
507
- for ind, parallelization_item in enumerate(parallelization_list)
508
- ]
478
+ return init_outcome, num_tasks
509
479
 
510
- results, exceptions = executor.multisubmit(
511
- functools.partial(
512
- run_single_task,
513
- wftask=wftask,
514
- command=task.command_parallel,
480
+ list_task_files = [
481
+ TaskFiles(
515
482
  root_dir_local=workflow_dir_local,
516
483
  root_dir_remote=workflow_dir_remote,
517
- ),
518
- list_parameters=list_function_kwargs,
519
- task_type="compound",
520
- **executor_options_compute,
521
- )
522
-
523
- outputs = []
524
- failure = False
525
- for ind in range(len(list_function_kwargs)):
526
- if ind in results.keys():
527
- result = results[ind]
528
- if result is None:
529
- output = TaskOutput()
530
- else:
531
- output = _cast_and_validate_TaskOutput(result)
532
- outputs.append(output)
533
-
534
- elif ind in exceptions.keys():
535
- print(f"Bad: {exceptions[ind]}")
536
- failure = True
537
- else:
538
- print("VERY BAD - should have not reached this point")
539
-
540
- with next(get_sync_db()) as db:
541
- if failure:
542
- db.execute(
543
- update(HistoryUnit)
544
- .where(HistoryUnit.id == history_unit_id)
545
- .values(status=HistoryUnitStatus.FAILED)
546
- )
547
- else:
548
- db.execute(
549
- update(HistoryUnit)
550
- .where(HistoryUnit.id == history_unit_id)
551
- .values(status=HistoryUnitStatus.DONE)
552
- )
553
- db.commit()
554
-
555
- merged_output = merge_outputs(outputs)
556
- return (merged_output, num_tasks, exceptions)
557
-
558
-
559
- def run_v2_task_converter_compound(
560
- *,
561
- zarr_dir: str,
562
- task: TaskV2,
563
- wftask: WorkflowTaskV2,
564
- executor: BaseRunner,
565
- workflow_dir_local: Path,
566
- workflow_dir_remote: Optional[Path] = None,
567
- submit_setup_call: callable = no_op_submit_setup_call,
568
- dataset_id: int,
569
- history_run_id: int,
570
- ) -> tuple[TaskOutput, int, dict[int, BaseException]]:
571
- executor_options_init = submit_setup_call(
572
- wftask=wftask,
573
- root_dir_local=workflow_dir_local,
574
- root_dir_remote=workflow_dir_remote,
575
- which_type="non_parallel",
576
- )
577
- executor_options_compute = submit_setup_call(
578
- wftask=wftask,
579
- root_dir_local=workflow_dir_local,
580
- root_dir_remote=workflow_dir_remote,
581
- which_type="parallel",
582
- )
583
-
584
- # 3/A: non-parallel init task
585
- function_kwargs = {
586
- "zarr_dir": zarr_dir,
587
- _COMPONENT_KEY_: f"init_{_index_to_component(0)}",
588
- **(wftask.args_non_parallel or {}),
589
- }
590
-
591
- # Create database History entries
592
- with next(get_sync_db()) as db:
593
- # Create a single `HistoryUnit` for the whole compound task
594
- history_unit = HistoryUnit(
595
- history_run_id=history_run_id,
596
- status=HistoryUnitStatus.SUBMITTED,
597
- logfile=None, # FIXME
598
- zarr_urls=[],
484
+ task_order=wftask.order,
485
+ task_name=wftask.task.name,
486
+ component=f"compute_{_index_to_component(ind)}",
599
487
  )
600
- db.add(history_unit)
601
- db.commit()
602
- db.refresh(history_unit)
603
- history_unit_id = history_unit.id
604
-
605
- result, exception = executor.submit(
606
- functools.partial(
607
- run_single_task,
608
- wftask=wftask,
609
- command=task.command_non_parallel,
610
- root_dir_local=workflow_dir_local,
611
- root_dir_remote=workflow_dir_remote,
612
- ),
613
- parameters=function_kwargs,
614
- task_type="converter_compound",
615
- **executor_options_init,
616
- )
617
-
618
- num_tasks = 1
619
- if exception is None:
620
- if result is None:
621
- init_task_output = InitTaskOutput()
622
- else:
623
- init_task_output = _cast_and_validate_InitTaskOutput(result)
624
- else:
625
- with next(get_sync_db()) as db:
626
- db.execute(
627
- update(HistoryUnit)
628
- .where(HistoryUnit.id == history_unit_id)
629
- .values(status=HistoryUnitStatus.FAILED)
630
- )
631
- db.commit()
632
- return (TaskOutput(), num_tasks, {0: exception})
633
-
634
- parallelization_list = init_task_output.parallelization_list
635
- parallelization_list = deduplicate_list(parallelization_list)
636
-
637
- num_tasks = 1 + len(parallelization_list)
638
-
639
- # 3/B: parallel part of a compound task
640
- _check_parallelization_list_size(parallelization_list)
641
-
642
- if len(parallelization_list) == 0:
643
- with next(get_sync_db()) as db:
644
- db.execute(
645
- update(HistoryUnit)
646
- .where(HistoryUnit.id == history_unit_id)
647
- .values(status=HistoryUnitStatus.DONE)
648
- )
649
- db.commit()
650
- return (TaskOutput(), 0, {})
651
-
488
+ for ind in range(len(parallelization_list))
489
+ ]
652
490
  list_function_kwargs = [
653
491
  {
654
492
  "zarr_url": parallelization_item.zarr_url,
655
493
  "init_args": parallelization_item.init_args,
656
- _COMPONENT_KEY_: f"compute_{_index_to_component(ind)}",
657
494
  **(wftask.args_parallel or {}),
658
495
  }
659
- for ind, parallelization_item in enumerate(parallelization_list)
496
+ for parallelization_item in parallelization_list
660
497
  ]
661
498
 
662
- results, exceptions = executor.multisubmit(
499
+ results, exceptions = runner.multisubmit(
663
500
  functools.partial(
664
501
  run_single_task,
665
- wftask=wftask,
666
502
  command=task.command_parallel,
667
- root_dir_local=workflow_dir_local,
668
- root_dir_remote=workflow_dir_remote,
503
+ workflow_task_order=wftask.order,
504
+ workflow_task_id=wftask.task_id,
505
+ task_name=wftask.task.name,
669
506
  ),
670
507
  list_parameters=list_function_kwargs,
671
- task_type="converter_compound",
672
- **executor_options_compute,
508
+ task_type=task_type,
509
+ list_task_files=list_task_files,
510
+ history_unit_ids=[history_unit_id],
511
+ config=runner_config_compute,
673
512
  )
674
513
 
675
- outputs = []
514
+ init_outcome = {}
676
515
  failure = False
677
516
  for ind in range(len(list_function_kwargs)):
678
- if ind in results.keys():
679
- result = results[ind]
680
- if result is None:
681
- output = TaskOutput()
682
- else:
683
- output = _cast_and_validate_TaskOutput(result)
684
- outputs.append(output)
685
-
686
- elif ind in exceptions.keys():
687
- print(f"Bad: {exceptions[ind]}")
688
- failure = True
689
- else:
690
- print("VERY BAD - should have not reached this point")
517
+ if ind not in results.keys() and ind not in exceptions.keys():
518
+ # FIXME: Could we avoid this branch?
519
+ error_msg = (
520
+ f"Invalid branch: {ind=} is not in `results.keys()` "
521
+ "nor in `exceptions.keys()`."
522
+ )
523
+ logger.error(error_msg)
524
+ raise RuntimeError(error_msg)
525
+ init_outcome[ind] = _process_task_output(
526
+ result=results.get(ind, None),
527
+ exception=exceptions.get(ind, None),
528
+ )
691
529
 
530
+ # FIXME: In this case, we are performing db updates from here, rather
531
+ # than at lower level.
692
532
  with next(get_sync_db()) as db:
693
533
  if failure:
694
- db.execute(
695
- update(HistoryUnit)
696
- .where(HistoryUnit.id == history_unit_id)
697
- .values(status=HistoryUnitStatus.FAILED)
534
+ update_status_of_history_unit(
535
+ history_unit_id=history_unit_id,
536
+ status=HistoryUnitStatus.FAILED,
537
+ db_sync=db,
698
538
  )
699
539
  else:
700
- db.execute(
701
- update(HistoryUnit)
702
- .where(HistoryUnit.id == history_unit_id)
703
- .values(status=HistoryUnitStatus.DONE)
540
+ update_status_of_history_unit(
541
+ history_unit_id=history_unit_id,
542
+ status=HistoryUnitStatus.DONE,
543
+ db_sync=db,
704
544
  )
705
- db.commit()
706
545
 
707
- merged_output = merge_outputs(outputs)
708
- return (merged_output, num_tasks, exceptions)
546
+ return init_outcome, num_tasks