fractal-server 2.14.0a9__py3-none-any.whl → 2.14.0a11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. fractal_server/__init__.py +1 -1
  2. fractal_server/app/models/v2/dataset.py +0 -10
  3. fractal_server/app/models/v2/job.py +3 -0
  4. fractal_server/app/routes/api/v2/__init__.py +2 -0
  5. fractal_server/app/routes/api/v2/history.py +14 -9
  6. fractal_server/app/routes/api/v2/images.py +5 -2
  7. fractal_server/app/routes/api/v2/submit.py +16 -14
  8. fractal_server/app/routes/api/v2/verify_image_types.py +64 -0
  9. fractal_server/app/routes/api/v2/workflow.py +11 -7
  10. fractal_server/app/runner/components.py +0 -3
  11. fractal_server/app/runner/exceptions.py +4 -0
  12. fractal_server/app/runner/executors/base_runner.py +16 -17
  13. fractal_server/app/runner/executors/local/{_local_config.py → get_local_config.py} +0 -7
  14. fractal_server/app/runner/executors/local/runner.py +117 -58
  15. fractal_server/app/runner/executors/{slurm_sudo → slurm_common}/_check_jobs_status.py +4 -0
  16. fractal_server/app/runner/executors/slurm_ssh/_check_job_status_ssh.py +67 -0
  17. fractal_server/app/runner/executors/slurm_ssh/executor.py +7 -5
  18. fractal_server/app/runner/executors/slurm_ssh/runner.py +707 -0
  19. fractal_server/app/runner/executors/slurm_sudo/runner.py +265 -114
  20. fractal_server/app/runner/task_files.py +8 -0
  21. fractal_server/app/runner/v2/__init__.py +0 -365
  22. fractal_server/app/runner/v2/_local.py +4 -2
  23. fractal_server/app/runner/v2/_slurm_ssh.py +4 -2
  24. fractal_server/app/runner/v2/_slurm_sudo.py +4 -2
  25. fractal_server/app/runner/v2/db_tools.py +87 -0
  26. fractal_server/app/runner/v2/runner.py +83 -89
  27. fractal_server/app/runner/v2/runner_functions.py +279 -436
  28. fractal_server/app/runner/v2/runner_functions_low_level.py +37 -39
  29. fractal_server/app/runner/v2/submit_workflow.py +366 -0
  30. fractal_server/app/runner/v2/task_interface.py +31 -0
  31. fractal_server/app/schemas/v2/dataset.py +4 -71
  32. fractal_server/app/schemas/v2/dumps.py +6 -5
  33. fractal_server/app/schemas/v2/job.py +6 -3
  34. fractal_server/migrations/versions/47351f8c7ebc_drop_dataset_filters.py +50 -0
  35. fractal_server/migrations/versions/e81103413827_add_job_type_filters.py +36 -0
  36. {fractal_server-2.14.0a9.dist-info → fractal_server-2.14.0a11.dist-info}/METADATA +1 -1
  37. {fractal_server-2.14.0a9.dist-info → fractal_server-2.14.0a11.dist-info}/RECORD +40 -36
  38. fractal_server/app/runner/executors/local/_submit_setup.py +0 -46
  39. fractal_server/app/runner/executors/slurm_common/_submit_setup.py +0 -84
  40. fractal_server/app/runner/v2/_db_tools.py +0 -48
  41. {fractal_server-2.14.0a9.dist-info → fractal_server-2.14.0a11.dist-info}/LICENSE +0 -0
  42. {fractal_server-2.14.0a9.dist-info → fractal_server-2.14.0a11.dist-info}/WHEEL +0 -0
  43. {fractal_server-2.14.0a9.dist-info → fractal_server-2.14.0a11.dist-info}/entry_points.txt +0 -0
@@ -1,16 +1,17 @@
1
1
  import functools
2
- import logging
3
2
  from pathlib import Path
4
3
  from typing import Any
4
+ from typing import Callable
5
5
  from typing import Literal
6
6
  from typing import Optional
7
7
 
8
- from pydantic import ValidationError
9
- from sqlmodel import update
8
+ from pydantic import BaseModel
9
+ from pydantic import ConfigDict
10
10
 
11
11
  from ..exceptions import JobExecutionError
12
+ from ..exceptions import TaskOutputValidationError
13
+ from .db_tools import update_status_of_history_unit
12
14
  from .deduplicate_list import deduplicate_list
13
- from .merge_outputs import merge_outputs
14
15
  from .runner_functions_low_level import run_single_task
15
16
  from .task_interface import InitTaskOutput
16
17
  from .task_interface import TaskOutput
@@ -18,62 +19,90 @@ from fractal_server.app.db import get_sync_db
18
19
  from fractal_server.app.models.v2 import HistoryUnit
19
20
  from fractal_server.app.models.v2 import TaskV2
20
21
  from fractal_server.app.models.v2 import WorkflowTaskV2
21
- from fractal_server.app.runner.components import _COMPONENT_KEY_
22
22
  from fractal_server.app.runner.components import _index_to_component
23
23
  from fractal_server.app.runner.executors.base_runner import BaseRunner
24
- from fractal_server.app.runner.v2._db_tools import bulk_upsert_image_cache_fast
24
+ from fractal_server.app.runner.task_files import TaskFiles
25
+ from fractal_server.app.runner.v2.db_tools import bulk_upsert_image_cache_fast
26
+ from fractal_server.app.runner.v2.task_interface import (
27
+ _cast_and_validate_InitTaskOutput,
28
+ )
29
+ from fractal_server.app.runner.v2.task_interface import (
30
+ _cast_and_validate_TaskOutput,
31
+ )
25
32
  from fractal_server.app.schemas.v2 import HistoryUnitStatus
26
-
33
+ from fractal_server.logger import set_logger
27
34
 
28
35
  __all__ = [
29
36
  "run_v2_task_parallel",
30
37
  "run_v2_task_non_parallel",
31
38
  "run_v2_task_compound",
32
- "run_v2_task_converter_non_parallel",
33
- "run_v2_task_converter_compound",
34
39
  ]
35
40
 
36
- MAX_PARALLELIZATION_LIST_SIZE = 20_000
37
41
 
42
+ logger = set_logger(__name__)
38
43
 
39
- def _cast_and_validate_TaskOutput(
40
- task_output: dict[str, Any]
41
- ) -> Optional[TaskOutput]:
42
- try:
43
- validated_task_output = TaskOutput(**task_output)
44
- return validated_task_output
45
- except ValidationError as e:
46
- raise JobExecutionError(
47
- "Validation of task output failed.\n"
48
- f"Original error: {str(e)}\n"
49
- f"Original data: {task_output}."
50
- )
51
44
 
45
+ class SubmissionOutcome(BaseModel):
46
+ model_config = ConfigDict(arbitrary_types_allowed=True)
47
+ task_output: TaskOutput | None = None
48
+ exception: BaseException | None = None
52
49
 
53
- def _cast_and_validate_InitTaskOutput(
54
- init_task_output: dict[str, Any],
55
- ) -> Optional[InitTaskOutput]:
56
- try:
57
- validated_init_task_output = InitTaskOutput(**init_task_output)
58
- return validated_init_task_output
59
- except ValidationError as e:
60
- raise JobExecutionError(
61
- "Validation of init-task output failed.\n"
62
- f"Original error: {str(e)}\n"
63
- f"Original data: {init_task_output}."
64
- )
50
+
51
+ class InitSubmissionOutcome(BaseModel):
52
+ model_config = ConfigDict(arbitrary_types_allowed=True)
53
+ task_output: InitTaskOutput | None = None
54
+ exception: BaseException | None = None
65
55
 
66
56
 
67
- def no_op_submit_setup_call(
57
+ MAX_PARALLELIZATION_LIST_SIZE = 20_000
58
+
59
+
60
+ def _process_task_output(
68
61
  *,
69
- wftask: WorkflowTaskV2,
70
- root_dir_local: Path,
71
- which_type: Literal["non_parallel", "parallel"],
72
- ) -> dict[str, Any]:
73
- """
74
- Default (no-operation) interface of submit_setup_call in V2.
75
- """
76
- return {}
62
+ result: dict[str, Any] | None = None,
63
+ exception: BaseException | None = None,
64
+ ) -> SubmissionOutcome:
65
+ if exception is not None:
66
+ task_output = None
67
+ else:
68
+ if result is None:
69
+ task_output = TaskOutput()
70
+ else:
71
+ try:
72
+ task_output = _cast_and_validate_TaskOutput(result)
73
+ except TaskOutputValidationError as e:
74
+ # FIXME: This should correspond to some status="failed",
75
+ # but it does not
76
+ task_output = None
77
+ exception = e
78
+ return SubmissionOutcome(
79
+ task_output=task_output,
80
+ exception=exception,
81
+ )
82
+
83
+
84
+ def _process_init_task_output(
85
+ *,
86
+ result: dict[str, Any] | None = None,
87
+ exception: BaseException | None = None,
88
+ ) -> SubmissionOutcome:
89
+ if exception is not None:
90
+ task_output = None
91
+ else:
92
+ if result is None:
93
+ task_output = InitTaskOutput()
94
+ else:
95
+ try:
96
+ task_output = _cast_and_validate_InitTaskOutput(result)
97
+ except TaskOutputValidationError as e:
98
+ # FIXME: This should correspond to some status="failed",
99
+ # but it does not
100
+ task_output = None
101
+ exception = e
102
+ return InitSubmissionOutcome(
103
+ task_output=task_output,
104
+ exception=exception,
105
+ )
77
106
 
78
107
 
79
108
  def _check_parallelization_list_size(my_list):
@@ -92,44 +121,59 @@ def run_v2_task_non_parallel(
92
121
  task: TaskV2,
93
122
  wftask: WorkflowTaskV2,
94
123
  workflow_dir_local: Path,
95
- workflow_dir_remote: Optional[Path] = None,
96
- executor: BaseRunner,
97
- submit_setup_call: callable = no_op_submit_setup_call,
124
+ workflow_dir_remote: Path,
125
+ runner: BaseRunner,
126
+ get_runner_config: Callable[
127
+ [
128
+ WorkflowTaskV2,
129
+ Literal["non_parallel", "parallel"],
130
+ Optional[Path],
131
+ ],
132
+ Any,
133
+ ],
98
134
  dataset_id: int,
99
135
  history_run_id: int,
100
- ) -> tuple[TaskOutput, int, dict[int, BaseException]]:
136
+ task_type: Literal["non_parallel", "converter_non_parallel"],
137
+ ) -> tuple[dict[int, SubmissionOutcome], int]:
101
138
  """
102
139
  This runs server-side (see `executor` argument)
103
140
  """
104
141
 
105
- if workflow_dir_remote is None:
106
- workflow_dir_remote = workflow_dir_local
107
- logging.warning(
108
- "In `run_single_task`, workflow_dir_remote=None. Is this right?"
142
+ if task_type not in ["non_parallel", "converter_non_parallel"]:
143
+ raise ValueError(
144
+ f"Invalid {task_type=} for `run_v2_task_non_parallel`."
109
145
  )
110
- workflow_dir_remote = workflow_dir_local
111
146
 
112
- executor_options = submit_setup_call(
113
- wftask=wftask,
147
+ # Get TaskFiles object
148
+ task_files = TaskFiles(
114
149
  root_dir_local=workflow_dir_local,
115
150
  root_dir_remote=workflow_dir_remote,
116
- which_type="non_parallel",
151
+ task_order=wftask.order,
152
+ task_name=wftask.task.name,
153
+ component=_index_to_component(0),
117
154
  )
118
155
 
156
+ runner_config = get_runner_config(wftask=wftask, which_type="non_parallel")
157
+
119
158
  function_kwargs = {
120
- "zarr_urls": [image["zarr_url"] for image in images],
121
159
  "zarr_dir": zarr_dir,
122
- _COMPONENT_KEY_: _index_to_component(0),
123
160
  **(wftask.args_non_parallel or {}),
124
161
  }
162
+ if task_type == "non_parallel":
163
+ function_kwargs["zarr_urls"] = [img["zarr_url"] for img in images]
125
164
 
126
165
  # Database History operations
127
166
  with next(get_sync_db()) as db:
167
+ if task_type == "non_parallel":
168
+ zarr_urls = function_kwargs["zarr_urls"]
169
+ elif task_type == "converter_non_parallel":
170
+ zarr_urls = []
171
+
128
172
  history_unit = HistoryUnit(
129
173
  history_run_id=history_run_id,
130
174
  status=HistoryUnitStatus.SUBMITTED,
131
- logfile=None, # FIXME
132
- zarr_urls=function_kwargs["zarr_urls"],
175
+ logfile=task_files.log_file_local,
176
+ zarr_urls=zarr_urls,
133
177
  )
134
178
  db.add(history_unit)
135
179
  db.commit()
@@ -148,125 +192,31 @@ def run_v2_task_non_parallel(
148
192
  ],
149
193
  )
150
194
 
151
- result, exception = executor.submit(
195
+ result, exception = runner.submit(
152
196
  functools.partial(
153
197
  run_single_task,
154
- wftask=wftask,
155
198
  command=task.command_non_parallel,
156
- root_dir_local=workflow_dir_local,
157
- root_dir_remote=workflow_dir_remote,
199
+ workflow_task_order=wftask.order,
200
+ workflow_task_id=wftask.task_id,
201
+ task_name=wftask.task.name,
158
202
  ),
159
203
  parameters=function_kwargs,
160
- task_type="non_parallel",
161
- **executor_options,
204
+ task_type=task_type,
205
+ task_files=task_files,
206
+ history_unit_id=history_unit_id,
207
+ config=runner_config,
162
208
  )
163
209
 
210
+ positional_index = 0
164
211
  num_tasks = 1
165
- with next(get_sync_db()) as db:
166
- if exception is None:
167
- db.execute(
168
- update(HistoryUnit)
169
- .where(HistoryUnit.id == history_unit_id)
170
- .values(status=HistoryUnitStatus.DONE)
171
- )
172
- db.commit()
173
- if result is None:
174
- return (TaskOutput(), num_tasks, {})
175
- else:
176
- return (_cast_and_validate_TaskOutput(result), num_tasks, {})
177
- else:
178
- db.execute(
179
- update(HistoryUnit)
180
- .where(HistoryUnit.id == history_unit_id)
181
- .values(status=HistoryUnitStatus.FAILED)
182
- )
183
- db.commit()
184
- return (TaskOutput(), num_tasks, {0: exception})
185
-
186
212
 
187
- def run_v2_task_converter_non_parallel(
188
- *,
189
- zarr_dir: str,
190
- task: TaskV2,
191
- wftask: WorkflowTaskV2,
192
- workflow_dir_local: Path,
193
- workflow_dir_remote: Optional[Path] = None,
194
- executor: BaseRunner,
195
- submit_setup_call: callable = no_op_submit_setup_call,
196
- dataset_id: int,
197
- history_run_id: int,
198
- ) -> tuple[TaskOutput, int, dict[int, BaseException]]:
199
- """
200
- This runs server-side (see `executor` argument)
201
- """
202
-
203
- if workflow_dir_remote is None:
204
- workflow_dir_remote = workflow_dir_local
205
- logging.warning(
206
- "In `run_single_task`, workflow_dir_remote=None. Is this right?"
213
+ outcome = {
214
+ positional_index: _process_task_output(
215
+ result=result,
216
+ exception=exception,
207
217
  )
208
- workflow_dir_remote = workflow_dir_local
209
-
210
- executor_options = submit_setup_call(
211
- wftask=wftask,
212
- root_dir_local=workflow_dir_local,
213
- root_dir_remote=workflow_dir_remote,
214
- which_type="non_parallel",
215
- )
216
-
217
- function_kwargs = {
218
- "zarr_dir": zarr_dir,
219
- _COMPONENT_KEY_: _index_to_component(0),
220
- **(wftask.args_non_parallel or {}),
221
218
  }
222
-
223
- # Database History operations
224
- with next(get_sync_db()) as db:
225
- history_unit = HistoryUnit(
226
- history_run_id=history_run_id,
227
- status=HistoryUnitStatus.SUBMITTED,
228
- logfile=None, # FIXME
229
- zarr_urls=[],
230
- )
231
- db.add(history_unit)
232
- db.commit()
233
- db.refresh(history_unit)
234
- history_unit_id = history_unit.id
235
-
236
- result, exception = executor.submit(
237
- functools.partial(
238
- run_single_task,
239
- wftask=wftask,
240
- command=task.command_non_parallel,
241
- root_dir_local=workflow_dir_local,
242
- root_dir_remote=workflow_dir_remote,
243
- ),
244
- task_type="converter_non_parallel",
245
- parameters=function_kwargs,
246
- **executor_options,
247
- )
248
-
249
- num_tasks = 1
250
- with next(get_sync_db()) as db:
251
- if exception is None:
252
- db.execute(
253
- update(HistoryUnit)
254
- .where(HistoryUnit.id == history_unit_id)
255
- .values(status=HistoryUnitStatus.DONE)
256
- )
257
- db.commit()
258
- if result is None:
259
- return (TaskOutput(), num_tasks, {})
260
- else:
261
- return (_cast_and_validate_TaskOutput(result), num_tasks, {})
262
- else:
263
- db.execute(
264
- update(HistoryUnit)
265
- .where(HistoryUnit.id == history_unit_id)
266
- .values(status=HistoryUnitStatus.FAILED)
267
- )
268
- db.commit()
269
- return (TaskOutput(), num_tasks, {0: exception})
219
+ return outcome, num_tasks
270
220
 
271
221
 
272
222
  def run_v2_task_parallel(
@@ -274,42 +224,61 @@ def run_v2_task_parallel(
274
224
  images: list[dict[str, Any]],
275
225
  task: TaskV2,
276
226
  wftask: WorkflowTaskV2,
277
- executor: BaseRunner,
227
+ runner: BaseRunner,
278
228
  workflow_dir_local: Path,
279
- workflow_dir_remote: Optional[Path] = None,
280
- submit_setup_call: callable = no_op_submit_setup_call,
229
+ workflow_dir_remote: Path,
230
+ get_runner_config: Callable[
231
+ [
232
+ WorkflowTaskV2,
233
+ Literal["non_parallel", "parallel"],
234
+ Optional[Path],
235
+ ],
236
+ Any,
237
+ ],
281
238
  dataset_id: int,
282
239
  history_run_id: int,
283
- ) -> tuple[TaskOutput, int, dict[int, BaseException]]:
240
+ ) -> tuple[dict[int, SubmissionOutcome], int]:
284
241
  if len(images) == 0:
285
- # FIXME: Do something with history units/images?
286
- return (TaskOutput(), 0, {})
242
+ return {}, 0
287
243
 
288
244
  _check_parallelization_list_size(images)
289
245
 
290
- executor_options = submit_setup_call(
291
- wftask=wftask,
246
+ # Get TaskFiles object
247
+ task_files = TaskFiles(
292
248
  root_dir_local=workflow_dir_local,
293
249
  root_dir_remote=workflow_dir_remote,
250
+ task_order=wftask.order,
251
+ task_name=wftask.task.name,
252
+ )
253
+
254
+ runner_config = get_runner_config(
255
+ wftask=wftask,
294
256
  which_type="parallel",
295
257
  )
296
258
 
297
259
  list_function_kwargs = [
298
260
  {
299
261
  "zarr_url": image["zarr_url"],
300
- _COMPONENT_KEY_: _index_to_component(ind),
301
262
  **(wftask.args_parallel or {}),
302
263
  }
303
- for ind, image in enumerate(images)
264
+ for image in images
265
+ ]
266
+ list_task_files = [
267
+ TaskFiles(
268
+ **task_files.model_dump(exclude={"component"}),
269
+ component=_index_to_component(ind),
270
+ )
271
+ for ind in range(len(images))
304
272
  ]
273
+
305
274
  history_units = [
306
275
  HistoryUnit(
307
276
  history_run_id=history_run_id,
308
277
  status=HistoryUnitStatus.SUBMITTED,
309
- logfile=None, # FIXME
278
+ logfile=list_task_files[ind].log_file_local,
310
279
  zarr_urls=[image["zarr_url"]],
311
280
  )
312
- for image in images
281
+ for ind, image in enumerate(images)
313
282
  ]
314
283
 
315
284
  with next(get_sync_db()) as db:
@@ -334,53 +303,46 @@ def run_v2_task_parallel(
334
303
  db=db, list_upsert_objects=history_image_caches
335
304
  )
336
305
 
337
- results, exceptions = executor.multisubmit(
306
+ results, exceptions = runner.multisubmit(
338
307
  functools.partial(
339
308
  run_single_task,
340
- wftask=wftask,
341
309
  command=task.command_parallel,
342
- root_dir_local=workflow_dir_local,
343
- root_dir_remote=workflow_dir_remote,
310
+ workflow_task_order=wftask.order,
311
+ workflow_task_id=wftask.task_id,
312
+ task_name=wftask.task.name,
344
313
  ),
345
314
  list_parameters=list_function_kwargs,
346
315
  task_type="parallel",
347
- **executor_options,
316
+ list_task_files=list_task_files,
317
+ history_unit_ids=history_unit_ids,
318
+ config=runner_config,
348
319
  )
349
320
 
350
- outputs = []
351
- history_unit_ids_done: list[int] = []
352
- history_unit_ids_failed: list[int] = []
321
+ outcome = {}
353
322
  for ind in range(len(list_function_kwargs)):
354
- if ind in results.keys():
355
- result = results[ind]
356
- if result is None:
357
- output = TaskOutput()
358
- else:
359
- output = _cast_and_validate_TaskOutput(result)
360
- outputs.append(output)
361
- history_unit_ids_done.append(history_unit_ids[ind])
362
- elif ind in exceptions.keys():
363
- print(f"Bad: {exceptions[ind]}")
364
- history_unit_ids_failed.append(history_unit_ids[ind])
365
- else:
366
- print("VERY BAD - should have not reached this point")
367
-
368
- with next(get_sync_db()) as db:
369
- db.execute(
370
- update(HistoryUnit)
371
- .where(HistoryUnit.id.in_(history_unit_ids_done))
372
- .values(status=HistoryUnitStatus.DONE)
373
- )
374
- db.execute(
375
- update(HistoryUnit)
376
- .where(HistoryUnit.id.in_(history_unit_ids_failed))
377
- .values(status=HistoryUnitStatus.FAILED)
323
+ if ind not in results.keys() and ind not in exceptions.keys():
324
+ # FIXME: Could we avoid this branch?
325
+ error_msg = (
326
+ f"Invalid branch: {ind=} is not in `results.keys()` "
327
+ "nor in `exceptions.keys()`."
328
+ )
329
+ logger.error(error_msg)
330
+ raise RuntimeError(error_msg)
331
+ outcome[ind] = _process_task_output(
332
+ result=results.get(ind, None),
333
+ exception=exceptions.get(ind, None),
378
334
  )
379
- db.commit()
380
335
 
381
336
  num_tasks = len(images)
382
- merged_output = merge_outputs(outputs)
383
- return (merged_output, num_tasks, exceptions)
337
+ return outcome, num_tasks
338
+
339
+
340
+ # FIXME: THIS FOR CONVERTERS:
341
+ # if task_type in ["converter_non_parallel"]:
342
+ # run = db.get(HistoryRun, history_run_id)
343
+ # run.status = HistoryUnitStatus.DONE
344
+ # db.merge(run)
345
+ # db.commit()
384
346
 
385
347
 
386
348
  def run_v2_task_compound(
@@ -389,42 +351,58 @@ def run_v2_task_compound(
389
351
  zarr_dir: str,
390
352
  task: TaskV2,
391
353
  wftask: WorkflowTaskV2,
392
- executor: BaseRunner,
354
+ runner: BaseRunner,
393
355
  workflow_dir_local: Path,
394
- workflow_dir_remote: Optional[Path] = None,
395
- submit_setup_call: callable = no_op_submit_setup_call,
356
+ workflow_dir_remote: Path,
357
+ get_runner_config: Callable[
358
+ [
359
+ WorkflowTaskV2,
360
+ Literal["non_parallel", "parallel"],
361
+ Optional[Path],
362
+ ],
363
+ Any,
364
+ ],
396
365
  dataset_id: int,
397
366
  history_run_id: int,
398
- ) -> tuple[TaskOutput, int, dict[int, BaseException]]:
399
- executor_options_init = submit_setup_call(
400
- wftask=wftask,
367
+ task_type: Literal["compound", "converter_compound"],
368
+ ) -> tuple[dict[int, SubmissionOutcome], int]:
369
+ # Get TaskFiles object
370
+ task_files_init = TaskFiles(
401
371
  root_dir_local=workflow_dir_local,
402
372
  root_dir_remote=workflow_dir_remote,
373
+ task_order=wftask.order,
374
+ task_name=wftask.task.name,
375
+ component=f"init_{_index_to_component(0)}",
376
+ )
377
+
378
+ runner_config_init = get_runner_config(
379
+ wftask=wftask,
403
380
  which_type="non_parallel",
404
381
  )
405
- executor_options_compute = submit_setup_call(
382
+ runner_config_compute = get_runner_config(
406
383
  wftask=wftask,
407
- root_dir_local=workflow_dir_local,
408
- root_dir_remote=workflow_dir_remote,
409
384
  which_type="parallel",
410
385
  )
411
386
 
412
387
  # 3/A: non-parallel init task
413
388
  function_kwargs = {
414
- "zarr_urls": [image["zarr_url"] for image in images],
415
389
  "zarr_dir": zarr_dir,
416
- _COMPONENT_KEY_: f"init_{_index_to_component(0)}",
417
390
  **(wftask.args_non_parallel or {}),
418
391
  }
392
+ if task_type == "compound":
393
+ function_kwargs["zarr_urls"] = [img["zarr_url"] for img in images]
394
+ input_image_zarr_urls = function_kwargs["zarr_urls"]
395
+ elif task_type == "converter_compound":
396
+ input_image_zarr_urls = []
419
397
 
420
398
  # Create database History entries
421
- input_image_zarr_urls = function_kwargs["zarr_urls"]
422
399
  with next(get_sync_db()) as db:
423
400
  # Create a single `HistoryUnit` for the whole compound task
424
401
  history_unit = HistoryUnit(
425
402
  history_run_id=history_run_id,
426
403
  status=HistoryUnitStatus.SUBMITTED,
427
- logfile=None, # FIXME
404
+ # FIXME: What about compute-task logs?
405
+ logfile=task_files_init.log_file_local,
428
406
  zarr_urls=input_image_zarr_urls,
429
407
  )
430
408
  db.add(history_unit)
@@ -444,37 +422,38 @@ def run_v2_task_compound(
444
422
  for zarr_url in input_image_zarr_urls
445
423
  ],
446
424
  )
447
-
448
- result, exception = executor.submit(
425
+ result, exception = runner.submit(
449
426
  functools.partial(
450
427
  run_single_task,
451
- wftask=wftask,
452
428
  command=task.command_non_parallel,
453
- root_dir_local=workflow_dir_local,
454
- root_dir_remote=workflow_dir_remote,
429
+ workflow_task_order=wftask.order,
430
+ workflow_task_id=wftask.task_id,
431
+ task_name=wftask.task.name,
455
432
  ),
456
433
  parameters=function_kwargs,
457
- task_type="compound",
458
- **executor_options_init,
434
+ task_type=task_type,
435
+ task_files=task_files_init,
436
+ history_unit_id=history_unit_id,
437
+ config=runner_config_init,
459
438
  )
460
439
 
440
+ init_outcome = _process_init_task_output(
441
+ result=result,
442
+ exception=exception,
443
+ )
461
444
  num_tasks = 1
462
- if exception is None:
463
- if result is None:
464
- init_task_output = InitTaskOutput()
465
- else:
466
- init_task_output = _cast_and_validate_InitTaskOutput(result)
467
- else:
468
- with next(get_sync_db()) as db:
469
- db.execute(
470
- update(HistoryUnit)
471
- .where(HistoryUnit.id == history_unit_id)
472
- .values(status=HistoryUnitStatus.FAILED)
473
- )
474
- db.commit()
475
- return (TaskOutput(), num_tasks, {0: exception})
445
+ if init_outcome.exception is not None:
446
+ positional_index = 0
447
+ return (
448
+ {
449
+ positional_index: SubmissionOutcome(
450
+ exception=init_outcome.exception
451
+ )
452
+ },
453
+ num_tasks,
454
+ )
476
455
 
477
- parallelization_list = init_task_output.parallelization_list
456
+ parallelization_list = init_outcome.task_output.parallelization_list
478
457
  parallelization_list = deduplicate_list(parallelization_list)
479
458
 
480
459
  num_tasks = 1 + len(parallelization_list)
@@ -484,220 +463,84 @@ def run_v2_task_compound(
484
463
 
485
464
  if len(parallelization_list) == 0:
486
465
  with next(get_sync_db()) as db:
487
- db.execute(
488
- update(HistoryUnit)
489
- .where(HistoryUnit.id == history_unit_id)
490
- .values(status=HistoryUnitStatus.DONE)
466
+ update_status_of_history_unit(
467
+ history_unit_id=history_unit_id,
468
+ status=HistoryUnitStatus.DONE,
469
+ db_sync=db,
470
+ )
471
+ positional_index = 0
472
+ init_outcome = {
473
+ positional_index: _process_task_output(
474
+ result=None,
475
+ exception=None,
491
476
  )
492
- db.commit()
493
- return (TaskOutput(), 0, {})
494
-
495
- list_function_kwargs = [
496
- {
497
- "zarr_url": parallelization_item.zarr_url,
498
- "init_args": parallelization_item.init_args,
499
- _COMPONENT_KEY_: f"compute_{_index_to_component(ind)}",
500
- **(wftask.args_parallel or {}),
501
477
  }
502
- for ind, parallelization_item in enumerate(parallelization_list)
503
- ]
478
+ return init_outcome, num_tasks
504
479
 
505
- results, exceptions = executor.multisubmit(
506
- functools.partial(
507
- run_single_task,
508
- wftask=wftask,
509
- command=task.command_parallel,
480
+ list_task_files = [
481
+ TaskFiles(
510
482
  root_dir_local=workflow_dir_local,
511
483
  root_dir_remote=workflow_dir_remote,
512
- ),
513
- list_parameters=list_function_kwargs,
514
- task_type="compound",
515
- **executor_options_compute,
516
- )
517
-
518
- outputs = []
519
- failure = False
520
- for ind in range(len(list_function_kwargs)):
521
- if ind in results.keys():
522
- result = results[ind]
523
- if result is None:
524
- output = TaskOutput()
525
- else:
526
- output = _cast_and_validate_TaskOutput(result)
527
- outputs.append(output)
528
-
529
- elif ind in exceptions.keys():
530
- print(f"Bad: {exceptions[ind]}")
531
- failure = True
532
- else:
533
- print("VERY BAD - should have not reached this point")
534
-
535
- with next(get_sync_db()) as db:
536
- if failure:
537
- db.execute(
538
- update(HistoryUnit)
539
- .where(HistoryUnit.id == history_unit_id)
540
- .values(status=HistoryUnitStatus.FAILED)
541
- )
542
- else:
543
- db.execute(
544
- update(HistoryUnit)
545
- .where(HistoryUnit.id == history_unit_id)
546
- .values(status=HistoryUnitStatus.DONE)
547
- )
548
- db.commit()
549
-
550
- merged_output = merge_outputs(outputs)
551
- return (merged_output, num_tasks, exceptions)
552
-
553
-
554
- def run_v2_task_converter_compound(
555
- *,
556
- zarr_dir: str,
557
- task: TaskV2,
558
- wftask: WorkflowTaskV2,
559
- executor: BaseRunner,
560
- workflow_dir_local: Path,
561
- workflow_dir_remote: Optional[Path] = None,
562
- submit_setup_call: callable = no_op_submit_setup_call,
563
- dataset_id: int,
564
- history_run_id: int,
565
- ) -> tuple[TaskOutput, int, dict[int, BaseException]]:
566
- executor_options_init = submit_setup_call(
567
- wftask=wftask,
568
- root_dir_local=workflow_dir_local,
569
- root_dir_remote=workflow_dir_remote,
570
- which_type="non_parallel",
571
- )
572
- executor_options_compute = submit_setup_call(
573
- wftask=wftask,
574
- root_dir_local=workflow_dir_local,
575
- root_dir_remote=workflow_dir_remote,
576
- which_type="parallel",
577
- )
578
-
579
- # 3/A: non-parallel init task
580
- function_kwargs = {
581
- "zarr_dir": zarr_dir,
582
- _COMPONENT_KEY_: f"init_{_index_to_component(0)}",
583
- **(wftask.args_non_parallel or {}),
584
- }
585
-
586
- # Create database History entries
587
- with next(get_sync_db()) as db:
588
- # Create a single `HistoryUnit` for the whole compound task
589
- history_unit = HistoryUnit(
590
- history_run_id=history_run_id,
591
- status=HistoryUnitStatus.SUBMITTED,
592
- logfile=None, # FIXME
593
- zarr_urls=[],
484
+ task_order=wftask.order,
485
+ task_name=wftask.task.name,
486
+ component=f"compute_{_index_to_component(ind)}",
594
487
  )
595
- db.add(history_unit)
596
- db.commit()
597
- db.refresh(history_unit)
598
- history_unit_id = history_unit.id
599
-
600
- result, exception = executor.submit(
601
- functools.partial(
602
- run_single_task,
603
- wftask=wftask,
604
- command=task.command_non_parallel,
605
- root_dir_local=workflow_dir_local,
606
- root_dir_remote=workflow_dir_remote,
607
- ),
608
- parameters=function_kwargs,
609
- task_type="converter_compound",
610
- **executor_options_init,
611
- )
612
-
613
- num_tasks = 1
614
- if exception is None:
615
- if result is None:
616
- init_task_output = InitTaskOutput()
617
- else:
618
- init_task_output = _cast_and_validate_InitTaskOutput(result)
619
- else:
620
- with next(get_sync_db()) as db:
621
- db.execute(
622
- update(HistoryUnit)
623
- .where(HistoryUnit.id == history_unit_id)
624
- .values(status=HistoryUnitStatus.FAILED)
625
- )
626
- db.commit()
627
- return (TaskOutput(), num_tasks, {0: exception})
628
-
629
- parallelization_list = init_task_output.parallelization_list
630
- parallelization_list = deduplicate_list(parallelization_list)
631
-
632
- num_tasks = 1 + len(parallelization_list)
633
-
634
- # 3/B: parallel part of a compound task
635
- _check_parallelization_list_size(parallelization_list)
636
-
637
- if len(parallelization_list) == 0:
638
- with next(get_sync_db()) as db:
639
- db.execute(
640
- update(HistoryUnit)
641
- .where(HistoryUnit.id == history_unit_id)
642
- .values(status=HistoryUnitStatus.DONE)
643
- )
644
- db.commit()
645
- return (TaskOutput(), 0, {})
646
-
488
+ for ind in range(len(parallelization_list))
489
+ ]
647
490
  list_function_kwargs = [
648
491
  {
649
492
  "zarr_url": parallelization_item.zarr_url,
650
493
  "init_args": parallelization_item.init_args,
651
- _COMPONENT_KEY_: f"compute_{_index_to_component(ind)}",
652
494
  **(wftask.args_parallel or {}),
653
495
  }
654
- for ind, parallelization_item in enumerate(parallelization_list)
496
+ for parallelization_item in parallelization_list
655
497
  ]
656
498
 
657
- results, exceptions = executor.multisubmit(
499
+ results, exceptions = runner.multisubmit(
658
500
  functools.partial(
659
501
  run_single_task,
660
- wftask=wftask,
661
502
  command=task.command_parallel,
662
- root_dir_local=workflow_dir_local,
663
- root_dir_remote=workflow_dir_remote,
503
+ workflow_task_order=wftask.order,
504
+ workflow_task_id=wftask.task_id,
505
+ task_name=wftask.task.name,
664
506
  ),
665
507
  list_parameters=list_function_kwargs,
666
- task_type="converter_compound",
667
- **executor_options_compute,
508
+ task_type=task_type,
509
+ list_task_files=list_task_files,
510
+ history_unit_ids=[history_unit_id],
511
+ config=runner_config_compute,
668
512
  )
669
513
 
670
- outputs = []
514
+ init_outcome = {}
671
515
  failure = False
672
516
  for ind in range(len(list_function_kwargs)):
673
- if ind in results.keys():
674
- result = results[ind]
675
- if result is None:
676
- output = TaskOutput()
677
- else:
678
- output = _cast_and_validate_TaskOutput(result)
679
- outputs.append(output)
680
-
681
- elif ind in exceptions.keys():
682
- print(f"Bad: {exceptions[ind]}")
683
- failure = True
684
- else:
685
- print("VERY BAD - should have not reached this point")
517
+ if ind not in results.keys() and ind not in exceptions.keys():
518
+ # FIXME: Could we avoid this branch?
519
+ error_msg = (
520
+ f"Invalid branch: {ind=} is not in `results.keys()` "
521
+ "nor in `exceptions.keys()`."
522
+ )
523
+ logger.error(error_msg)
524
+ raise RuntimeError(error_msg)
525
+ init_outcome[ind] = _process_task_output(
526
+ result=results.get(ind, None),
527
+ exception=exceptions.get(ind, None),
528
+ )
686
529
 
530
+ # FIXME: In this case, we are performing db updates from here, rather
531
+ # than at lower level.
687
532
  with next(get_sync_db()) as db:
688
533
  if failure:
689
- db.execute(
690
- update(HistoryUnit)
691
- .where(HistoryUnit.id == history_unit_id)
692
- .values(status=HistoryUnitStatus.FAILED)
534
+ update_status_of_history_unit(
535
+ history_unit_id=history_unit_id,
536
+ status=HistoryUnitStatus.FAILED,
537
+ db_sync=db,
693
538
  )
694
539
  else:
695
- db.execute(
696
- update(HistoryUnit)
697
- .where(HistoryUnit.id == history_unit_id)
698
- .values(status=HistoryUnitStatus.DONE)
540
+ update_status_of_history_unit(
541
+ history_unit_id=history_unit_id,
542
+ status=HistoryUnitStatus.DONE,
543
+ db_sync=db,
699
544
  )
700
- db.commit()
701
545
 
702
- merged_output = merge_outputs(outputs)
703
- return (merged_output, num_tasks, exceptions)
546
+ return init_outcome, num_tasks