fractal-server 2.14.12__py3-none-any.whl → 2.14.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. fractal_server/__init__.py +1 -1
  2. fractal_server/app/routes/api/v2/_aux_task_group_disambiguation.py +163 -0
  3. fractal_server/app/routes/api/v2/pre_submission_checks.py +3 -2
  4. fractal_server/app/routes/api/v2/task.py +5 -4
  5. fractal_server/app/routes/api/v2/task_group.py +52 -4
  6. fractal_server/app/routes/api/v2/task_version_update.py +18 -10
  7. fractal_server/app/routes/api/v2/workflow_import.py +3 -70
  8. fractal_server/app/routes/api/v2/workflowtask.py +6 -5
  9. fractal_server/app/runner/executors/base_runner.py +38 -17
  10. fractal_server/app/runner/executors/local/runner.py +14 -14
  11. fractal_server/app/runner/executors/slurm_common/base_slurm_runner.py +12 -14
  12. fractal_server/app/runner/v2/runner.py +19 -8
  13. fractal_server/app/runner/v2/runner_functions.py +12 -8
  14. fractal_server/app/schemas/v2/__init__.py +1 -0
  15. fractal_server/app/schemas/v2/dumps.py +2 -2
  16. fractal_server/app/schemas/v2/manifest.py +2 -9
  17. fractal_server/app/schemas/v2/task.py +18 -14
  18. fractal_server/app/schemas/v2/workflowtask.py +2 -2
  19. fractal_server/exceptions.py +2 -0
  20. fractal_server/utils.py +0 -49
  21. {fractal_server-2.14.12.dist-info → fractal_server-2.14.14.dist-info}/METADATA +1 -1
  22. {fractal_server-2.14.12.dist-info → fractal_server-2.14.14.dist-info}/RECORD +25 -23
  23. {fractal_server-2.14.12.dist-info → fractal_server-2.14.14.dist-info}/LICENSE +0 -0
  24. {fractal_server-2.14.12.dist-info → fractal_server-2.14.14.dist-info}/WHEEL +0 -0
  25. {fractal_server-2.14.12.dist-info → fractal_server-2.14.14.dist-info}/entry_points.txt +0 -0
@@ -1 +1 @@
1
- __VERSION__ = "2.14.12"
1
+ __VERSION__ = "2.14.14"
@@ -0,0 +1,163 @@
1
+ import itertools
2
+
3
+ from sqlmodel import select
4
+
5
+ from fractal_server.app.db import AsyncSession
6
+ from fractal_server.app.models import LinkUserGroup
7
+ from fractal_server.app.models.v2 import TaskGroupV2
8
+ from fractal_server.exceptions import UnreachableBranchError
9
+ from fractal_server.logger import set_logger
10
+
11
+
12
+ logger = set_logger(__name__)
13
+
14
+
15
+ async def _disambiguate_task_groups(
16
+ *,
17
+ matching_task_groups: list[TaskGroupV2],
18
+ user_id: int,
19
+ default_group_id: int,
20
+ db: AsyncSession,
21
+ ) -> TaskGroupV2 | None:
22
+ """
23
+ Find ownership-based top-priority task group, if any.
24
+
25
+ Args:
26
+ matching_task_groups:
27
+ user_id:
28
+ default_group_id:
29
+ db:
30
+
31
+ Returns:
32
+ The task group or `None`.
33
+ """
34
+
35
+ # Highest priority: task groups created by user
36
+ list_user_ids = [tg.user_id for tg in matching_task_groups]
37
+ try:
38
+ ind_user_id = list_user_ids.index(user_id)
39
+ task_group = matching_task_groups[ind_user_id]
40
+ logger.debug(
41
+ "[_disambiguate_task_groups] "
42
+ f"Found task group {task_group.id} with {user_id=}, return."
43
+ )
44
+ return task_group
45
+ except ValueError:
46
+ logger.debug(
47
+ "[_disambiguate_task_groups] "
48
+ f"No task group with {user_id=}, continue."
49
+ )
50
+
51
+ # Medium priority: task groups owned by default user group
52
+ list_user_group_ids = [tg.user_group_id for tg in matching_task_groups]
53
+ try:
54
+ ind_user_group_id = list_user_group_ids.index(default_group_id)
55
+ task_group = matching_task_groups[ind_user_group_id]
56
+ logger.debug(
57
+ "[_disambiguate_task_groups] "
58
+ f"Found task group {task_group.id} with {user_id=}, return."
59
+ )
60
+ return task_group
61
+ except ValueError:
62
+ logger.debug(
63
+ "[_disambiguate_task_groups] "
64
+ "No task group with user_group_id="
65
+ f"{default_group_id}, continue."
66
+ )
67
+
68
+ # Lowest priority: task groups owned by other groups, sorted
69
+ # according to age of the user/usergroup link
70
+ logger.debug(
71
+ "[_disambiguate_task_groups] "
72
+ "Sort remaining task groups by oldest-user-link."
73
+ )
74
+ stm = (
75
+ select(LinkUserGroup.group_id)
76
+ .where(LinkUserGroup.user_id == user_id)
77
+ .where(LinkUserGroup.group_id.in_(list_user_group_ids))
78
+ .order_by(LinkUserGroup.timestamp_created.asc())
79
+ )
80
+ res = await db.execute(stm)
81
+ oldest_user_group_id = res.scalars().first()
82
+ logger.debug(
83
+ "[_disambiguate_task_groups] " f"Result: {oldest_user_group_id=}."
84
+ )
85
+ task_group = next(
86
+ iter(
87
+ task_group
88
+ for task_group in matching_task_groups
89
+ if task_group.user_group_id == oldest_user_group_id
90
+ ),
91
+ None,
92
+ )
93
+ return task_group
94
+
95
+
96
+ async def _disambiguate_task_groups_not_none(
97
+ *,
98
+ matching_task_groups: list[TaskGroupV2],
99
+ user_id: int,
100
+ default_group_id: int,
101
+ db: AsyncSession,
102
+ ) -> TaskGroupV2:
103
+ """
104
+ Find ownership-based top-priority task group, and fail otherwise.
105
+
106
+ Args:
107
+ matching_task_groups:
108
+ user_id:
109
+ default_group_id:
110
+ db:
111
+
112
+ Returns:
113
+ The top-priority task group.
114
+ """
115
+ task_group = await _disambiguate_task_groups(
116
+ matching_task_groups=matching_task_groups,
117
+ user_id=user_id,
118
+ default_group_id=default_group_id,
119
+ db=db,
120
+ )
121
+ if task_group is None:
122
+ error_msg = (
123
+ "[_disambiguate_task_groups_not_none] Could not find a task "
124
+ f"group ({user_id=}, {default_group_id=})."
125
+ )
126
+ logger.error(f"UnreachableBranchError {error_msg}")
127
+ raise UnreachableBranchError(error_msg)
128
+ else:
129
+ return task_group
130
+
131
+
132
+ async def remove_duplicate_task_groups(
133
+ *,
134
+ task_groups: list[TaskGroupV2],
135
+ user_id: int,
136
+ default_group_id: int,
137
+ db: AsyncSession,
138
+ ) -> list[TaskGroupV2]:
139
+ """
140
+ Extract an item for each `version` from a *sorted* task-group list.
141
+
142
+ Args:
143
+ task_groups: A list of task groups with identical `pkg_name`
144
+ user_id: User ID
145
+
146
+ Returns:
147
+ New list of task groups with no duplicate `(pkg_name,version)` entries
148
+ """
149
+
150
+ new_task_groups = [
151
+ (
152
+ await _disambiguate_task_groups_not_none(
153
+ matching_task_groups=list(groups),
154
+ user_id=user_id,
155
+ default_group_id=default_group_id,
156
+ db=db,
157
+ )
158
+ )
159
+ for version, groups in itertools.groupby(
160
+ task_groups, key=lambda tg: tg.version
161
+ )
162
+ ]
163
+ return new_task_groups
@@ -16,6 +16,7 @@ from fractal_server.app.models.v2 import HistoryImageCache
16
16
  from fractal_server.app.models.v2 import HistoryUnit
17
17
  from fractal_server.app.routes.auth import current_active_user
18
18
  from fractal_server.app.schemas.v2 import HistoryUnitStatus
19
+ from fractal_server.app.schemas.v2 import TaskType
19
20
  from fractal_server.images.tools import aggregate_types
20
21
  from fractal_server.images.tools import filter_image_list
21
22
  from fractal_server.types import AttributeFilters
@@ -105,8 +106,8 @@ async def check_workflowtask(
105
106
  # Skip check if previous task has non-trivial `output_types`
106
107
  return JSONResponse(status_code=200, content=[])
107
108
  elif previous_wft.task.type in [
108
- "converter_compound",
109
- "converter_non_parallel",
109
+ TaskType.CONVERTER_COMPOUND,
110
+ TaskType.CONVERTER_NON_PARALLEL,
110
111
  ]:
111
112
  # Skip check if previous task is converter
112
113
  return JSONResponse(status_code=200, content=[])
@@ -25,6 +25,7 @@ from fractal_server.app.routes.auth import current_active_verified_user
25
25
  from fractal_server.app.schemas.v2 import TaskCreateV2
26
26
  from fractal_server.app.schemas.v2 import TaskGroupV2OriginEnum
27
27
  from fractal_server.app.schemas.v2 import TaskReadV2
28
+ from fractal_server.app.schemas.v2 import TaskType
28
29
  from fractal_server.app.schemas.v2 import TaskUpdateV2
29
30
  from fractal_server.logger import set_logger
30
31
 
@@ -109,12 +110,12 @@ async def patch_task(
109
110
  update = task_update.model_dump(exclude_unset=True)
110
111
 
111
112
  # Forbid changes that set a previously unset command
112
- if db_task.type == "non_parallel" and "command_parallel" in update:
113
+ if db_task.type == TaskType.NON_PARALLEL and "command_parallel" in update:
113
114
  raise HTTPException(
114
115
  status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
115
116
  detail="Cannot set an unset `command_parallel`.",
116
117
  )
117
- if db_task.type == "parallel" and "command_non_parallel" in update:
118
+ if db_task.type == TaskType.PARALLEL and "command_non_parallel" in update:
118
119
  raise HTTPException(
119
120
  status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
120
121
  detail="Cannot set an unset `command_non_parallel`.",
@@ -151,7 +152,7 @@ async def create_task(
151
152
  db=db,
152
153
  )
153
154
 
154
- if task.type == "parallel" and (
155
+ if task.type == TaskType.PARALLEL and (
155
156
  task.args_schema_non_parallel is not None
156
157
  or task.meta_non_parallel is not None
157
158
  ):
@@ -162,7 +163,7 @@ async def create_task(
162
163
  "`TaskV2.args_schema_non_parallel` if TaskV2 is parallel"
163
164
  ),
164
165
  )
165
- elif task.type == "non_parallel" and (
166
+ elif task.type == TaskType.NON_PARALLEL and (
166
167
  task.args_schema_parallel is not None or task.meta_parallel is not None
167
168
  ):
168
169
  raise HTTPException(
@@ -1,8 +1,13 @@
1
+ import itertools
2
+
1
3
  from fastapi import APIRouter
2
4
  from fastapi import Depends
3
5
  from fastapi import HTTPException
4
6
  from fastapi import Response
5
7
  from fastapi import status
8
+ from packaging.version import InvalidVersion
9
+ from packaging.version import parse
10
+ from packaging.version import Version
6
11
  from pydantic.types import AwareDatetime
7
12
  from sqlmodel import or_
8
13
  from sqlmodel import select
@@ -10,6 +15,7 @@ from sqlmodel import select
10
15
  from ._aux_functions_tasks import _get_task_group_full_access
11
16
  from ._aux_functions_tasks import _get_task_group_read_access
12
17
  from ._aux_functions_tasks import _verify_non_duplication_group_constraint
18
+ from ._aux_task_group_disambiguation import remove_duplicate_task_groups
13
19
  from fractal_server.app.db import AsyncSession
14
20
  from fractal_server.app.db import get_async_db
15
21
  from fractal_server.app.models import LinkUserGroup
@@ -18,6 +24,7 @@ from fractal_server.app.models.v2 import TaskGroupActivityV2
18
24
  from fractal_server.app.models.v2 import TaskGroupV2
19
25
  from fractal_server.app.models.v2 import WorkflowTaskV2
20
26
  from fractal_server.app.routes.auth import current_active_user
27
+ from fractal_server.app.routes.auth._aux_auth import _get_default_usergroup_id
21
28
  from fractal_server.app.routes.auth._aux_auth import (
22
29
  _verify_user_belongs_to_group,
23
30
  )
@@ -33,6 +40,26 @@ router = APIRouter()
33
40
  logger = set_logger(__name__)
34
41
 
35
42
 
43
+ def _version_sort_key(
44
+ task_group: TaskGroupV2,
45
+ ) -> tuple[int, Version | str | None]:
46
+ """
47
+ Returns a tuple used as (reverse) ordering key for TaskGroups in
48
+ `get_task_group_list`.
49
+ The TaskGroups with a parsable versions are the first in order,
50
+ sorted according to the sorting rules of packaging.version.Version.
51
+ Next in order we have the TaskGroups with non-null non-parsable versions,
52
+ sorted alphabetically.
53
+ Last we have the TaskGroups with null version.
54
+ """
55
+ if task_group.version is None:
56
+ return (0, task_group.version)
57
+ try:
58
+ return (2, parse(task_group.version))
59
+ except InvalidVersion:
60
+ return (1, task_group.version)
61
+
62
+
36
63
  @router.get("/activity/", response_model=list[TaskGroupActivityV2Read])
37
64
  async def get_task_group_activity_list(
38
65
  task_group_activity_id: int | None = None,
@@ -97,14 +124,14 @@ async def get_task_group_activity(
97
124
  return activity
98
125
 
99
126
 
100
- @router.get("/", response_model=list[TaskGroupReadV2])
127
+ @router.get("/", response_model=list[tuple[str, list[TaskGroupReadV2]]])
101
128
  async def get_task_group_list(
102
129
  user: UserOAuth = Depends(current_active_user),
103
130
  db: AsyncSession = Depends(get_async_db),
104
131
  only_active: bool = False,
105
132
  only_owner: bool = False,
106
133
  args_schema: bool = True,
107
- ) -> list[TaskGroupReadV2]:
134
+ ) -> list[tuple[str, list[TaskGroupReadV2]]]:
108
135
  """
109
136
  Get all accessible TaskGroups
110
137
  """
@@ -119,7 +146,7 @@ async def get_task_group_list(
119
146
  )
120
147
  ),
121
148
  )
122
- stm = select(TaskGroupV2).where(condition)
149
+ stm = select(TaskGroupV2).where(condition).order_by(TaskGroupV2.pkg_name)
123
150
  if only_active:
124
151
  stm = stm.where(TaskGroupV2.active)
125
152
 
@@ -132,7 +159,28 @@ async def get_task_group_list(
132
159
  setattr(task, "args_schema_non_parallel", None)
133
160
  setattr(task, "args_schema_parallel", None)
134
161
 
135
- return task_groups
162
+ default_group_id = await _get_default_usergroup_id(db)
163
+ grouped_result = [
164
+ (
165
+ pkg_name,
166
+ (
167
+ await remove_duplicate_task_groups(
168
+ task_groups=sorted(
169
+ list(groups),
170
+ key=_version_sort_key,
171
+ reverse=True,
172
+ ),
173
+ user_id=user.id,
174
+ default_group_id=default_group_id,
175
+ db=db,
176
+ )
177
+ ),
178
+ )
179
+ for pkg_name, groups in itertools.groupby(
180
+ task_groups, key=lambda tg: tg.pkg_name
181
+ )
182
+ ]
183
+ return grouped_result
136
184
 
137
185
 
138
186
  @router.get("/{task_group_id}/", response_model=TaskGroupReadV2)
@@ -24,21 +24,23 @@ from ._aux_functions_tasks import _get_task_read_access
24
24
  from fractal_server.app.models import UserOAuth
25
25
  from fractal_server.app.models.v2 import TaskGroupV2
26
26
  from fractal_server.app.routes.auth import current_active_user
27
+ from fractal_server.app.schemas.v2 import TaskType
27
28
  from fractal_server.app.schemas.v2 import WorkflowTaskReadV2
28
29
  from fractal_server.app.schemas.v2 import WorkflowTaskReplaceV2
29
30
 
30
-
31
31
  router = APIRouter()
32
32
 
33
33
 
34
34
  VALID_TYPE_UPDATES = {
35
- ("non_parallel", "converter_non_parallel"),
36
- ("compound", "converter_compound"),
37
- ("converter_non_parallel", "converter_non_parallel"),
38
- ("converter_compound", "converter_compound"),
39
- ("non_parallel", "non_parallel"),
40
- ("compound", "compound"),
41
- ("parallel", "parallel"),
35
+ # Transform into converter
36
+ (TaskType.NON_PARALLEL, TaskType.CONVERTER_NON_PARALLEL),
37
+ (TaskType.COMPOUND, TaskType.CONVERTER_COMPOUND),
38
+ # Keep the same
39
+ (TaskType.CONVERTER_NON_PARALLEL, TaskType.CONVERTER_NON_PARALLEL),
40
+ (TaskType.CONVERTER_COMPOUND, TaskType.CONVERTER_COMPOUND),
41
+ (TaskType.NON_PARALLEL, TaskType.NON_PARALLEL),
42
+ (TaskType.COMPOUND, TaskType.COMPOUND),
43
+ (TaskType.PARALLEL, TaskType.PARALLEL),
42
44
  }
43
45
 
44
46
 
@@ -208,12 +210,18 @@ async def replace_workflowtask(
208
210
  ),
209
211
  )
210
212
 
211
- if replace.args_non_parallel is not None and new_task.type == "parallel":
213
+ if (
214
+ replace.args_non_parallel is not None
215
+ and new_task.type == TaskType.PARALLEL
216
+ ):
212
217
  raise HTTPException(
213
218
  status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
214
219
  detail="Cannot set 'args_non_parallel' for parallel task.",
215
220
  )
216
- if replace.args_parallel is not None and new_task.type == "non_parallel":
221
+ if (
222
+ replace.args_parallel is not None
223
+ and new_task.type == TaskType.NON_PARALLEL
224
+ ):
217
225
  raise HTTPException(
218
226
  status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
219
227
  detail="Cannot set 'args_parallel' for non-parallel task.",
@@ -21,6 +21,9 @@ from ._aux_functions_tasks import _check_type_filters_compatibility
21
21
  from fractal_server.app.models import LinkUserGroup
22
22
  from fractal_server.app.models import UserOAuth
23
23
  from fractal_server.app.models.v2 import TaskGroupV2
24
+ from fractal_server.app.routes.api.v2._aux_task_group_disambiguation import (
25
+ _disambiguate_task_groups,
26
+ )
24
27
  from fractal_server.app.routes.auth import current_active_user
25
28
  from fractal_server.app.routes.auth._aux_auth import _get_default_usergroup_id
26
29
  from fractal_server.app.schemas.v2 import TaskImportV2
@@ -85,76 +88,6 @@ async def _get_task_by_source(
85
88
  return task_id
86
89
 
87
90
 
88
- async def _disambiguate_task_groups(
89
- *,
90
- matching_task_groups: list[TaskGroupV2],
91
- user_id: int,
92
- db: AsyncSession,
93
- default_group_id: int,
94
- ) -> TaskV2 | None:
95
- """
96
- Disambiguate task groups based on ownership information.
97
- """
98
- # Highest priority: task groups created by user
99
- for task_group in matching_task_groups:
100
- if task_group.user_id == user_id:
101
- logger.info(
102
- "[_disambiguate_task_groups] "
103
- f"Found task group {task_group.id} with {user_id=}, return."
104
- )
105
- return task_group
106
- logger.info(
107
- "[_disambiguate_task_groups] "
108
- f"No task group found with {user_id=}, continue."
109
- )
110
-
111
- # Medium priority: task groups owned by default user group
112
- for task_group in matching_task_groups:
113
- if task_group.user_group_id == default_group_id:
114
- logger.info(
115
- "[_disambiguate_task_groups] "
116
- f"Found task group {task_group.id} with user_group_id="
117
- f"{default_group_id}, return."
118
- )
119
- return task_group
120
- logger.info(
121
- "[_disambiguate_task_groups] "
122
- "No task group found with user_group_id="
123
- f"{default_group_id}, continue."
124
- )
125
-
126
- # Lowest priority: task groups owned by other groups, sorted
127
- # according to age of the user/usergroup link
128
- logger.info(
129
- "[_disambiguate_task_groups] "
130
- "Now sorting remaining task groups by oldest-user-link."
131
- )
132
- user_group_ids = [
133
- task_group.user_group_id for task_group in matching_task_groups
134
- ]
135
- stm = (
136
- select(LinkUserGroup.group_id)
137
- .where(LinkUserGroup.user_id == user_id)
138
- .where(LinkUserGroup.group_id.in_(user_group_ids))
139
- .order_by(LinkUserGroup.timestamp_created.asc())
140
- )
141
- res = await db.execute(stm)
142
- oldest_user_group_id = res.scalars().first()
143
- logger.info(
144
- "[_disambiguate_task_groups] "
145
- f"Result of sorting: {oldest_user_group_id=}."
146
- )
147
- task_group = next(
148
- iter(
149
- task_group
150
- for task_group in matching_task_groups
151
- if task_group.user_group_id == oldest_user_group_id
152
- ),
153
- None,
154
- )
155
- return task_group
156
-
157
-
158
91
  async def _get_task_by_taskimport(
159
92
  *,
160
93
  task_import: TaskImportV2,
@@ -15,6 +15,7 @@ from ._aux_functions_tasks import _check_type_filters_compatibility
15
15
  from ._aux_functions_tasks import _get_task_read_access
16
16
  from fractal_server.app.models import UserOAuth
17
17
  from fractal_server.app.routes.auth import current_active_user
18
+ from fractal_server.app.schemas.v2 import TaskType
18
19
  from fractal_server.app.schemas.v2 import WorkflowTaskCreateV2
19
20
  from fractal_server.app.schemas.v2 import WorkflowTaskReadV2
20
21
  from fractal_server.app.schemas.v2 import WorkflowTaskUpdateV2
@@ -47,7 +48,7 @@ async def create_workflowtask(
47
48
  task_id=task_id, user_id=user.id, db=db, require_active=True
48
49
  )
49
50
 
50
- if task.type == "parallel":
51
+ if task.type == TaskType.PARALLEL:
51
52
  if (
52
53
  wftask.meta_non_parallel is not None
53
54
  or wftask.args_non_parallel is not None
@@ -60,7 +61,7 @@ async def create_workflowtask(
60
61
  "is `parallel`."
61
62
  ),
62
63
  )
63
- elif task.type == "non_parallel":
64
+ elif task.type == TaskType.NON_PARALLEL:
64
65
  if (
65
66
  wftask.meta_parallel is not None
66
67
  or wftask.args_parallel is not None
@@ -143,7 +144,7 @@ async def update_workflowtask(
143
144
  wftask_type_filters=workflow_task_update.type_filters,
144
145
  )
145
146
 
146
- if db_wf_task.task_type == "parallel" and (
147
+ if db_wf_task.task_type == TaskType.PARALLEL and (
147
148
  workflow_task_update.args_non_parallel is not None
148
149
  or workflow_task_update.meta_non_parallel is not None
149
150
  ):
@@ -156,8 +157,8 @@ async def update_workflowtask(
156
157
  ),
157
158
  )
158
159
  elif db_wf_task.task_type in [
159
- "non_parallel",
160
- "converter_non_parallel",
160
+ TaskType.NON_PARALLEL,
161
+ TaskType.CONVERTER_NON_PARALLEL,
161
162
  ] and (
162
163
  workflow_task_update.args_parallel is not None
163
164
  or workflow_task_update.meta_parallel is not None
@@ -1,19 +1,34 @@
1
+ from enum import StrEnum
1
2
  from typing import Any
2
3
 
3
4
  from fractal_server.app.runner.task_files import TaskFiles
4
- from fractal_server.app.schemas.v2.task import TaskTypeType
5
+ from fractal_server.app.schemas.v2.task import TaskType
5
6
  from fractal_server.logger import set_logger
6
7
 
7
- TASK_TYPES_SUBMIT: list[TaskTypeType] = [
8
- "compound",
9
- "converter_compound",
10
- "non_parallel",
11
- "converter_non_parallel",
8
+
9
+ class SubmitTaskType(StrEnum):
10
+ COMPOUND = TaskType.COMPOUND
11
+ NON_PARALLEL = TaskType.NON_PARALLEL
12
+ CONVERTER_NON_PARALLEL = TaskType.CONVERTER_NON_PARALLEL
13
+ CONVERTER_COMPOUND = TaskType.CONVERTER_COMPOUND
14
+
15
+
16
+ class MultisubmitTaskType(StrEnum):
17
+ PARALLEL = TaskType.PARALLEL
18
+ COMPOUND = TaskType.COMPOUND
19
+ CONVERTER_COMPOUND = TaskType.CONVERTER_COMPOUND
20
+
21
+
22
+ TASK_TYPES_SUBMIT: list[TaskType] = [
23
+ TaskType.COMPOUND,
24
+ TaskType.CONVERTER_COMPOUND,
25
+ TaskType.NON_PARALLEL,
26
+ TaskType.CONVERTER_NON_PARALLEL,
12
27
  ]
13
- TASK_TYPES_MULTISUBMIT: list[TaskTypeType] = [
14
- "compound",
15
- "converter_compound",
16
- "parallel",
28
+ TASK_TYPES_MULTISUBMIT: list[TaskType] = [
29
+ TaskType.COMPOUND,
30
+ TaskType.CONVERTER_COMPOUND,
31
+ TaskType.PARALLEL,
17
32
  ]
18
33
 
19
34
  logger = set_logger(__name__)
@@ -32,7 +47,7 @@ class BaseRunner:
32
47
  task_name: str,
33
48
  parameters: dict[str, Any],
34
49
  history_unit_id: int,
35
- task_type: TaskTypeType,
50
+ task_type: TaskType,
36
51
  task_files: TaskFiles,
37
52
  config: Any,
38
53
  user_id: int,
@@ -64,7 +79,7 @@ class BaseRunner:
64
79
  list_parameters: list[dict[str, Any]],
65
80
  history_unit_ids: list[int],
66
81
  list_task_files: list[TaskFiles],
67
- task_type: TaskTypeType,
82
+ task_type: TaskType,
68
83
  config: Any,
69
84
  user_id: int,
70
85
  ) -> tuple[dict[int, Any], dict[int, BaseException]]:
@@ -90,7 +105,7 @@ class BaseRunner:
90
105
  def validate_submit_parameters(
91
106
  self,
92
107
  parameters: dict[str, Any],
93
- task_type: TaskTypeType,
108
+ task_type: TaskType,
94
109
  ) -> None:
95
110
  """
96
111
  Validate parameters for `submit` method
@@ -104,12 +119,18 @@ class BaseRunner:
104
119
  raise ValueError(f"Invalid {task_type=} for `submit`.")
105
120
  if not isinstance(parameters, dict):
106
121
  raise ValueError("`parameters` must be a dictionary.")
107
- if task_type in ["non_parallel", "compound"]:
122
+ if task_type in [
123
+ TaskType.NON_PARALLEL,
124
+ TaskType.COMPOUND,
125
+ ]:
108
126
  if "zarr_urls" not in parameters.keys():
109
127
  raise ValueError(
110
128
  f"No 'zarr_urls' key in in {list(parameters.keys())}"
111
129
  )
112
- elif task_type in ["converter_non_parallel", "converter_compound"]:
130
+ elif task_type in [
131
+ TaskType.CONVERTER_NON_PARALLEL,
132
+ TaskType.CONVERTER_COMPOUND,
133
+ ]:
113
134
  if "zarr_urls" in parameters.keys():
114
135
  raise ValueError(
115
136
  f"Forbidden 'zarr_urls' key in {list(parameters.keys())}"
@@ -119,7 +140,7 @@ class BaseRunner:
119
140
  def validate_multisubmit_parameters(
120
141
  self,
121
142
  *,
122
- task_type: TaskTypeType,
143
+ task_type: TaskType,
123
144
  list_parameters: list[dict[str, Any]],
124
145
  list_task_files: list[TaskFiles],
125
146
  history_unit_ids: list[int],
@@ -163,7 +184,7 @@ class BaseRunner:
163
184
  raise ValueError(
164
185
  f"No 'zarr_url' key in in {list(single_kwargs.keys())}"
165
186
  )
166
- if task_type == "parallel":
187
+ if task_type == TaskType.PARALLEL:
167
188
  zarr_urls = [kwargs["zarr_url"] for kwargs in list_parameters]
168
189
  if len(zarr_urls) != len(set(zarr_urls)):
169
190
  raise ValueError("Non-unique zarr_urls")