fractal-server 2.6.3__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. fractal_server/__init__.py +1 -1
  2. fractal_server/__main__.py +1 -1
  3. fractal_server/app/models/linkusergroup.py +11 -0
  4. fractal_server/app/models/v2/__init__.py +2 -0
  5. fractal_server/app/models/v2/collection_state.py +1 -0
  6. fractal_server/app/models/v2/task.py +67 -2
  7. fractal_server/app/routes/admin/v2/__init__.py +16 -0
  8. fractal_server/app/routes/admin/{v2.py → v2/job.py} +20 -191
  9. fractal_server/app/routes/admin/v2/project.py +43 -0
  10. fractal_server/app/routes/admin/v2/task.py +133 -0
  11. fractal_server/app/routes/admin/v2/task_group.py +162 -0
  12. fractal_server/app/routes/api/v1/task_collection.py +4 -4
  13. fractal_server/app/routes/api/v2/__init__.py +8 -0
  14. fractal_server/app/routes/api/v2/_aux_functions.py +1 -68
  15. fractal_server/app/routes/api/v2/_aux_functions_tasks.py +343 -0
  16. fractal_server/app/routes/api/v2/submit.py +16 -35
  17. fractal_server/app/routes/api/v2/task.py +85 -110
  18. fractal_server/app/routes/api/v2/task_collection.py +184 -196
  19. fractal_server/app/routes/api/v2/task_collection_custom.py +70 -64
  20. fractal_server/app/routes/api/v2/task_group.py +173 -0
  21. fractal_server/app/routes/api/v2/workflow.py +39 -102
  22. fractal_server/app/routes/api/v2/workflow_import.py +360 -0
  23. fractal_server/app/routes/api/v2/workflowtask.py +4 -8
  24. fractal_server/app/routes/auth/_aux_auth.py +86 -40
  25. fractal_server/app/routes/auth/current_user.py +5 -5
  26. fractal_server/app/routes/auth/group.py +73 -23
  27. fractal_server/app/routes/auth/router.py +0 -2
  28. fractal_server/app/routes/auth/users.py +8 -7
  29. fractal_server/app/runner/executors/slurm/ssh/executor.py +82 -63
  30. fractal_server/app/runner/v2/__init__.py +13 -7
  31. fractal_server/app/runner/v2/task_interface.py +4 -9
  32. fractal_server/app/schemas/user.py +1 -2
  33. fractal_server/app/schemas/v2/__init__.py +7 -0
  34. fractal_server/app/schemas/v2/dataset.py +2 -7
  35. fractal_server/app/schemas/v2/dumps.py +1 -2
  36. fractal_server/app/schemas/v2/job.py +1 -1
  37. fractal_server/app/schemas/v2/manifest.py +25 -1
  38. fractal_server/app/schemas/v2/project.py +1 -1
  39. fractal_server/app/schemas/v2/task.py +95 -36
  40. fractal_server/app/schemas/v2/task_collection.py +8 -6
  41. fractal_server/app/schemas/v2/task_group.py +85 -0
  42. fractal_server/app/schemas/v2/workflow.py +7 -2
  43. fractal_server/app/schemas/v2/workflowtask.py +9 -6
  44. fractal_server/app/security/__init__.py +8 -1
  45. fractal_server/config.py +8 -28
  46. fractal_server/data_migrations/2_7_0.py +323 -0
  47. fractal_server/images/models.py +2 -4
  48. fractal_server/main.py +1 -1
  49. fractal_server/migrations/env.py +4 -1
  50. fractal_server/migrations/versions/034a469ec2eb_task_groups.py +184 -0
  51. fractal_server/ssh/_fabric.py +186 -73
  52. fractal_server/string_tools.py +6 -2
  53. fractal_server/tasks/utils.py +19 -5
  54. fractal_server/tasks/v1/_TaskCollectPip.py +1 -1
  55. fractal_server/tasks/v1/background_operations.py +5 -5
  56. fractal_server/tasks/v1/get_collection_data.py +2 -2
  57. fractal_server/tasks/v2/_venv_pip.py +67 -70
  58. fractal_server/tasks/v2/background_operations.py +180 -69
  59. fractal_server/tasks/v2/background_operations_ssh.py +57 -70
  60. fractal_server/tasks/v2/database_operations.py +44 -0
  61. fractal_server/tasks/v2/endpoint_operations.py +104 -116
  62. fractal_server/tasks/v2/templates/_1_create_venv.sh +9 -5
  63. fractal_server/tasks/v2/templates/{_2_upgrade_pip.sh → _2_preliminary_pip_operations.sh} +1 -0
  64. fractal_server/tasks/v2/utils.py +5 -0
  65. fractal_server/utils.py +3 -2
  66. {fractal_server-2.6.3.dist-info → fractal_server-2.7.0.dist-info}/METADATA +3 -7
  67. {fractal_server-2.6.3.dist-info → fractal_server-2.7.0.dist-info}/RECORD +70 -61
  68. fractal_server/app/routes/auth/group_names.py +0 -34
  69. fractal_server/tasks/v2/_TaskCollectPip.py +0 -132
  70. {fractal_server-2.6.3.dist-info → fractal_server-2.7.0.dist-info}/LICENSE +0 -0
  71. {fractal_server-2.6.3.dist-info → fractal_server-2.7.0.dist-info}/WHEEL +0 -0
  72. {fractal_server-2.6.3.dist-info → fractal_server-2.7.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,162 @@
1
+ from typing import Optional
2
+
3
+ from fastapi import APIRouter
4
+ from fastapi import Depends
5
+ from fastapi import HTTPException
6
+ from fastapi import Response
7
+ from fastapi import status
8
+ from sqlalchemy.sql.operators import is_
9
+ from sqlalchemy.sql.operators import is_not
10
+ from sqlmodel import select
11
+
12
+ from fractal_server.app.db import AsyncSession
13
+ from fractal_server.app.db import get_async_db
14
+ from fractal_server.app.models import UserOAuth
15
+ from fractal_server.app.models.v2 import CollectionStateV2
16
+ from fractal_server.app.models.v2 import TaskGroupV2
17
+ from fractal_server.app.models.v2 import WorkflowTaskV2
18
+ from fractal_server.app.routes.auth import current_active_superuser
19
+ from fractal_server.app.routes.auth._aux_auth import (
20
+ _verify_user_belongs_to_group,
21
+ )
22
+ from fractal_server.app.schemas.v2 import TaskGroupReadV2
23
+ from fractal_server.app.schemas.v2 import TaskGroupUpdateV2
24
+ from fractal_server.app.schemas.v2 import TaskGroupV2OriginEnum
25
+ from fractal_server.logger import set_logger
26
+
27
+ router = APIRouter()
28
+
29
+ logger = set_logger(__name__)
30
+
31
+
32
+ @router.get("/{task_group_id}/", response_model=TaskGroupReadV2)
33
+ async def query_task_group(
34
+ task_group_id: int,
35
+ user: UserOAuth = Depends(current_active_superuser),
36
+ db: AsyncSession = Depends(get_async_db),
37
+ ) -> TaskGroupReadV2:
38
+
39
+ task_group = await db.get(TaskGroupV2, task_group_id)
40
+ if task_group is None:
41
+ raise HTTPException(
42
+ status_code=status.HTTP_404_NOT_FOUND,
43
+ detail=f"TaskGroup {task_group_id} not found",
44
+ )
45
+ return task_group
46
+
47
+
48
+ @router.get("/", response_model=list[TaskGroupReadV2])
49
+ async def query_task_group_list(
50
+ user_id: Optional[int] = None,
51
+ user_group_id: Optional[int] = None,
52
+ private: Optional[bool] = None,
53
+ active: Optional[bool] = None,
54
+ pkg_name: Optional[str] = None,
55
+ origin: Optional[TaskGroupV2OriginEnum] = None,
56
+ user: UserOAuth = Depends(current_active_superuser),
57
+ db: AsyncSession = Depends(get_async_db),
58
+ ) -> list[TaskGroupReadV2]:
59
+
60
+ stm = select(TaskGroupV2)
61
+
62
+ if user_group_id is not None and private is True:
63
+ raise HTTPException(
64
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
65
+ detail=f"Cannot set `user_group_id` with {private=}",
66
+ )
67
+ if user_id is not None:
68
+ stm = stm.where(TaskGroupV2.user_id == user_id)
69
+ if user_group_id is not None:
70
+ stm = stm.where(TaskGroupV2.user_group_id == user_group_id)
71
+ if private is not None:
72
+ if private is True:
73
+ stm = stm.where(is_(TaskGroupV2.user_group_id, None))
74
+ else:
75
+ stm = stm.where(is_not(TaskGroupV2.user_group_id, None))
76
+ if active is not None:
77
+ if active is True:
78
+ stm = stm.where(is_(TaskGroupV2.active, True))
79
+ else:
80
+ stm = stm.where(is_(TaskGroupV2.active, False))
81
+ if origin is not None:
82
+ stm = stm.where(TaskGroupV2.origin == origin)
83
+ if pkg_name is not None:
84
+ stm = stm.where(TaskGroupV2.pkg_name.icontains(pkg_name))
85
+
86
+ res = await db.execute(stm)
87
+ task_groups_list = res.scalars().all()
88
+ return task_groups_list
89
+
90
+
91
+ @router.patch("/{task_group_id}/", response_model=TaskGroupReadV2)
92
+ async def patch_task_group(
93
+ task_group_id: int,
94
+ task_group_update: TaskGroupUpdateV2,
95
+ user: UserOAuth = Depends(current_active_superuser),
96
+ db: AsyncSession = Depends(get_async_db),
97
+ ) -> list[TaskGroupReadV2]:
98
+ task_group = await db.get(TaskGroupV2, task_group_id)
99
+ if task_group is None:
100
+ raise HTTPException(
101
+ status_code=status.HTTP_404_NOT_FOUND,
102
+ detail=f"TaskGroupV2 {task_group_id} not found",
103
+ )
104
+
105
+ for key, value in task_group_update.dict(exclude_unset=True).items():
106
+ if (key == "user_group_id") and (value is not None):
107
+ await _verify_user_belongs_to_group(
108
+ user_id=user.id, user_group_id=value, db=db
109
+ )
110
+ setattr(task_group, key, value)
111
+
112
+ db.add(task_group)
113
+ await db.commit()
114
+ await db.refresh(task_group)
115
+ return task_group
116
+
117
+
118
+ @router.delete("/{task_group_id}/", status_code=204)
119
+ async def delete_task_group(
120
+ task_group_id: int,
121
+ user: UserOAuth = Depends(current_active_superuser),
122
+ db: AsyncSession = Depends(get_async_db),
123
+ ):
124
+ task_group = await db.get(TaskGroupV2, task_group_id)
125
+ if task_group is None:
126
+ raise HTTPException(
127
+ status_code=status.HTTP_404_NOT_FOUND,
128
+ detail=f"TaskGroupV2 {task_group_id} not found",
129
+ )
130
+
131
+ stm = select(WorkflowTaskV2).where(
132
+ WorkflowTaskV2.task_id.in_({task.id for task in task_group.task_list})
133
+ )
134
+ res = await db.execute(stm)
135
+ workflow_tasks = res.scalars().all()
136
+ if workflow_tasks != []:
137
+ raise HTTPException(
138
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
139
+ detail=f"TaskV2 {workflow_tasks[0].task_id} is still in use",
140
+ )
141
+
142
+ # Cascade operations: set foreign-keys to null for CollectionStateV2 which
143
+ # are in relationship with the current TaskGroupV2
144
+ logger.debug("Start of cascade operations on CollectionStateV2.")
145
+ stm = select(CollectionStateV2).where(
146
+ CollectionStateV2.taskgroupv2_id == task_group_id
147
+ )
148
+ res = await db.execute(stm)
149
+ collection_states = res.scalars().all()
150
+ for collection_state in collection_states:
151
+ logger.debug(
152
+ f"Setting CollectionStateV2[{collection_state.id}].taskgroupv2_id "
153
+ "to None."
154
+ )
155
+ collection_state.taskgroupv2_id = None
156
+ db.add(collection_state)
157
+ logger.debug("End of cascade operations on CollectionStateV2.")
158
+
159
+ await db.delete(task_group)
160
+ await db.commit()
161
+
162
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
@@ -26,8 +26,8 @@ from ._aux_functions import _raise_if_v1_is_read_only
26
26
  from fractal_server.app.models import UserOAuth
27
27
  from fractal_server.app.routes.auth import current_active_user
28
28
  from fractal_server.app.routes.auth import current_active_verified_user
29
- from fractal_server.string_tools import slugify_task_name_for_source
30
- from fractal_server.tasks.utils import get_collection_log
29
+ from fractal_server.string_tools import slugify_task_name_for_source_v1
30
+ from fractal_server.tasks.utils import get_collection_log_v1
31
31
  from fractal_server.tasks.v1._TaskCollectPip import _TaskCollectPip
32
32
  from fractal_server.tasks.v1.background_operations import (
33
33
  background_collect_pip,
@@ -160,7 +160,7 @@ async def collect_tasks_pip(
160
160
 
161
161
  # Check that tasks are not already in the DB
162
162
  for new_task in task_pkg.package_manifest.task_list:
163
- new_task_name_slug = slugify_task_name_for_source(new_task.name)
163
+ new_task_name_slug = slugify_task_name_for_source_v1(new_task.name)
164
164
  new_task_source = f"{task_pkg.package_source}:{new_task_name_slug}"
165
165
  stm = select(Task).where(Task.source == new_task_source)
166
166
  res = await db.execute(stm)
@@ -232,7 +232,7 @@ async def check_collection_status(
232
232
  # In some cases (i.e. a successful or ongoing task collection), data.log is
233
233
  # not set; if so, we collect the current logs
234
234
  if verbose and not data.log:
235
- data.log = get_collection_log(data.venv_path)
235
+ data.log = get_collection_log_v1(data.venv_path)
236
236
  state.data = data.sanitised_dict()
237
237
  close_logger(logger)
238
238
  await db.close()
@@ -12,7 +12,9 @@ from .submit import router as submit_job_router_v2
12
12
  from .task import router as task_router_v2
13
13
  from .task_collection import router as task_collection_router_v2
14
14
  from .task_collection_custom import router as task_collection_router_v2_custom
15
+ from .task_group import router as task_group_router_v2
15
16
  from .workflow import router as workflow_router_v2
17
+ from .workflow_import import router as workflow_import_router_v2
16
18
  from .workflowtask import router as workflowtask_router_v2
17
19
  from fractal_server.config import get_settings
18
20
  from fractal_server.syringe import Inject
@@ -37,6 +39,12 @@ router_api_v2.include_router(
37
39
  tags=["V2 Task Collection"],
38
40
  )
39
41
  router_api_v2.include_router(task_router_v2, prefix="/task", tags=["V2 Task"])
42
+ router_api_v2.include_router(
43
+ task_group_router_v2, prefix="/task-group", tags=["V2 TaskGroup"]
44
+ )
40
45
  router_api_v2.include_router(workflow_router_v2, tags=["V2 Workflow"])
46
+ router_api_v2.include_router(
47
+ workflow_import_router_v2, tags=["V2 Workflow Import"]
48
+ )
41
49
  router_api_v2.include_router(workflowtask_router_v2, tags=["V2 WorkflowTask"])
42
50
  router_api_v2.include_router(status_router_v2, tags=["V2 Status"])
@@ -21,8 +21,6 @@ from ....models.v2 import TaskV2
21
21
  from ....models.v2 import WorkflowTaskV2
22
22
  from ....models.v2 import WorkflowV2
23
23
  from ....schemas.v2 import JobStatusTypeV2
24
- from ...aux.validate_user_settings import verify_user_has_settings
25
- from fractal_server.app.models import UserOAuth
26
24
  from fractal_server.images import Filters
27
25
 
28
26
 
@@ -320,65 +318,6 @@ async def _get_job_check_owner(
320
318
  return dict(job=job, project=project)
321
319
 
322
320
 
323
- async def _get_task_check_owner(
324
- *,
325
- task_id: int,
326
- user: UserOAuth,
327
- db: AsyncSession,
328
- ) -> TaskV2:
329
- """
330
- Get a task, after access control.
331
-
332
- This check constitutes a preliminary version of access control:
333
- if the current user is not a superuser and differs from the task owner
334
- (including when `owner is None`), we raise an 403 HTTP Exception.
335
-
336
- Args:
337
- task_id:
338
- user:
339
- db:
340
-
341
- Returns:
342
- The task object.
343
-
344
- Raises:
345
- HTTPException(status_code=404_NOT_FOUND):
346
- If the task does not exist
347
- HTTPException(status_code=403_FORBIDDEN):
348
- If the user does not have rights to edit this task.
349
- """
350
- task = await db.get(TaskV2, task_id)
351
- if not task:
352
- raise HTTPException(
353
- status_code=status.HTTP_404_NOT_FOUND,
354
- detail=f"TaskV2 {task_id} not found.",
355
- )
356
-
357
- if not user.is_superuser:
358
- if task.owner is None:
359
- raise HTTPException(
360
- status_code=status.HTTP_403_FORBIDDEN,
361
- detail=(
362
- "Only a superuser can modify a TaskV2 with `owner=None`."
363
- ),
364
- )
365
- else:
366
- if user.username:
367
- owner = user.username
368
- else:
369
- verify_user_has_settings(user)
370
- owner = user.settings.slurm_user
371
- if owner != task.owner:
372
- raise HTTPException(
373
- status_code=status.HTTP_403_FORBIDDEN,
374
- detail=(
375
- f"Current user ({owner}) cannot modify TaskV2 "
376
- f"{task.id} with different owner ({task.owner})."
377
- ),
378
- )
379
- return task
380
-
381
-
382
321
  def _get_submitted_jobs_statement() -> SelectOfScalar:
383
322
  """
384
323
  Returns:
@@ -393,7 +332,6 @@ async def _workflow_insert_task(
393
332
  *,
394
333
  workflow_id: int,
395
334
  task_id: int,
396
- order: Optional[int] = None,
397
335
  meta_parallel: Optional[dict[str, Any]] = None,
398
336
  meta_non_parallel: Optional[dict[str, Any]] = None,
399
337
  args_non_parallel: Optional[dict[str, Any]] = None,
@@ -408,7 +346,6 @@ async def _workflow_insert_task(
408
346
  workflow_id:
409
347
  task_id:
410
348
 
411
- order:
412
349
  meta_parallel:
413
350
  meta_non_parallel:
414
351
  args_non_parallel:
@@ -420,9 +357,6 @@ async def _workflow_insert_task(
420
357
  if db_workflow is None:
421
358
  raise ValueError(f"Workflow {workflow_id} does not exist")
422
359
 
423
- if order is None:
424
- order = len(db_workflow.task_list)
425
-
426
360
  # Get task from db
427
361
  db_task = await db.get(TaskV2, task_id)
428
362
  if db_task is None:
@@ -458,8 +392,7 @@ async def _workflow_insert_task(
458
392
  meta_non_parallel=final_meta_non_parallel,
459
393
  **input_filters_kwarg,
460
394
  )
461
- db_workflow.task_list.insert(order, wf_task)
462
- db_workflow.task_list.reorder() # type: ignore
395
+ db_workflow.task_list.append(wf_task)
463
396
  flag_modified(db_workflow, "task_list")
464
397
  await db.commit()
465
398
 
@@ -0,0 +1,343 @@
1
+ """
2
+ Auxiliary functions to get task and task-group object from the database or
3
+ perform simple checks
4
+ """
5
+ from typing import Any
6
+ from typing import Optional
7
+
8
+ from fastapi import HTTPException
9
+ from fastapi import status
10
+ from sqlmodel import select
11
+
12
+ from fractal_server.app.db import AsyncSession
13
+ from fractal_server.app.models import LinkUserGroup
14
+ from fractal_server.app.models import UserGroup
15
+ from fractal_server.app.models import UserOAuth
16
+ from fractal_server.app.models.v2 import CollectionStateV2
17
+ from fractal_server.app.models.v2 import TaskGroupV2
18
+ from fractal_server.app.models.v2 import TaskV2
19
+ from fractal_server.app.models.v2 import WorkflowTaskV2
20
+ from fractal_server.app.routes.auth._aux_auth import _get_default_usergroup_id
21
+ from fractal_server.app.routes.auth._aux_auth import (
22
+ _verify_user_belongs_to_group,
23
+ )
24
+ from fractal_server.logger import set_logger
25
+
26
+ logger = set_logger(__name__)
27
+
28
+
29
+ async def _get_task_group_or_404(
30
+ *, task_group_id: int, db: AsyncSession
31
+ ) -> TaskGroupV2:
32
+ """
33
+ Get an existing task group or raise a 404.
34
+
35
+ Arguments:
36
+ task_group_id: The TaskGroupV2 id
37
+ db: An asynchronous db session
38
+ """
39
+ task_group = await db.get(TaskGroupV2, task_group_id)
40
+ if task_group is None:
41
+ raise HTTPException(
42
+ status_code=status.HTTP_404_NOT_FOUND,
43
+ detail=f"TaskGroupV2 {task_group_id} not found",
44
+ )
45
+ return task_group
46
+
47
+
48
+ async def _get_task_group_read_access(
49
+ *,
50
+ task_group_id: int,
51
+ user_id: int,
52
+ db: AsyncSession,
53
+ ) -> TaskGroupV2:
54
+ """
55
+ Get a task group or raise a 403 if user has no read access.
56
+
57
+ Arguments:
58
+ task_group_id: ID of the required task group.
59
+ user_id: ID of the current user.
60
+ db: An asynchronous db session.
61
+ """
62
+ task_group = await _get_task_group_or_404(
63
+ task_group_id=task_group_id, db=db
64
+ )
65
+
66
+ # Prepare exception to be used below
67
+ forbidden_exception = HTTPException(
68
+ status_code=status.HTTP_403_FORBIDDEN,
69
+ detail=(
70
+ "Current user has no read access to TaskGroupV2 "
71
+ f"{task_group_id}.",
72
+ ),
73
+ )
74
+
75
+ if task_group.user_id == user_id:
76
+ return task_group
77
+ elif task_group.user_group_id is None:
78
+ raise forbidden_exception
79
+ else:
80
+ stm = (
81
+ select(LinkUserGroup)
82
+ .where(LinkUserGroup.group_id == task_group.user_group_id)
83
+ .where(LinkUserGroup.user_id == user_id)
84
+ )
85
+ res = await db.execute(stm)
86
+ link = res.scalar_one_or_none()
87
+ if link is None:
88
+ raise forbidden_exception
89
+ else:
90
+ return task_group
91
+
92
+
93
+ async def _get_task_group_full_access(
94
+ *,
95
+ task_group_id: int,
96
+ user_id: int,
97
+ db: AsyncSession,
98
+ ) -> TaskGroupV2:
99
+ """
100
+ Get a task group or raise a 403 if user has no full access.
101
+
102
+ Arguments:
103
+ task_group_id: ID of the required task group.
104
+ user_id: ID of the current user.
105
+ db: An asynchronous db session
106
+ """
107
+ task_group = await _get_task_group_or_404(
108
+ task_group_id=task_group_id, db=db
109
+ )
110
+
111
+ if task_group.user_id == user_id:
112
+ return task_group
113
+ else:
114
+ raise HTTPException(
115
+ status_code=status.HTTP_403_FORBIDDEN,
116
+ detail=(
117
+ "Current user has no full access to "
118
+ f"TaskGroupV2 {task_group_id}.",
119
+ ),
120
+ )
121
+
122
+
123
+ async def _get_task_or_404(*, task_id: int, db: AsyncSession) -> TaskV2:
124
+ """
125
+ Get an existing task or raise a 404.
126
+
127
+ Arguments:
128
+ task_id: ID of the required task.
129
+ db: An asynchronous db session
130
+ """
131
+ task = await db.get(TaskV2, task_id)
132
+ if task is None:
133
+ raise HTTPException(
134
+ status_code=status.HTTP_404_NOT_FOUND,
135
+ detail=f"TaskV2 {task_id} not found",
136
+ )
137
+ return task
138
+
139
+
140
+ async def _get_task_full_access(
141
+ *,
142
+ task_id: int,
143
+ user_id: int,
144
+ db: AsyncSession,
145
+ ) -> TaskV2:
146
+ """
147
+ Get an existing task or raise a 404.
148
+
149
+ Arguments:
150
+ task_id: ID of the required task.
151
+ user_id: ID of the current user.
152
+ db: An asynchronous db session.
153
+ """
154
+ task = await _get_task_or_404(task_id=task_id, db=db)
155
+ await _get_task_group_full_access(
156
+ task_group_id=task.taskgroupv2_id, user_id=user_id, db=db
157
+ )
158
+ return task
159
+
160
+
161
+ async def _get_task_read_access(
162
+ *,
163
+ task_id: int,
164
+ user_id: int,
165
+ db: AsyncSession,
166
+ require_active: bool = False,
167
+ ) -> TaskV2:
168
+ """
169
+ Get an existing task or raise a 404.
170
+
171
+ Arguments:
172
+ task_id: ID of the required task.
173
+ user_id: ID of the current user.
174
+ db: An asynchronous db session.
175
+ require_active: If set, fail when the task group is not `active`
176
+ """
177
+ task = await _get_task_or_404(task_id=task_id, db=db)
178
+ task_group = await _get_task_group_read_access(
179
+ task_group_id=task.taskgroupv2_id, user_id=user_id, db=db
180
+ )
181
+ if require_active:
182
+ if not task_group.active:
183
+ raise HTTPException(
184
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
185
+ detail=f"Error: task {task_id} ({task.name}) is not active.",
186
+ )
187
+ return task
188
+
189
+
190
+ async def _get_valid_user_group_id(
191
+ *,
192
+ user_group_id: Optional[int] = None,
193
+ private: bool,
194
+ user_id: int,
195
+ db: AsyncSession,
196
+ ) -> Optional[int]:
197
+ """
198
+ Validate query parameters for endpoints that create some task(s).
199
+
200
+ Arguments:
201
+ user_group_id:
202
+ private:
203
+ user_id: ID of the current user
204
+ db: An asynchronous db session.
205
+ """
206
+ if (user_group_id is not None) and (private is True):
207
+ raise HTTPException(
208
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
209
+ detail=f"Cannot set both {user_group_id=} and {private=}",
210
+ )
211
+ elif private is True:
212
+ user_group_id = None
213
+ elif user_group_id is None:
214
+ user_group_id = await _get_default_usergroup_id(db=db)
215
+ else:
216
+ await _verify_user_belongs_to_group(
217
+ user_id=user_id, user_group_id=user_group_id, db=db
218
+ )
219
+ return user_group_id
220
+
221
+
222
+ async def _get_collection_status_message(
223
+ task_group: TaskGroupV2, db: AsyncSession
224
+ ) -> str:
225
+ res = await db.execute(
226
+ select(CollectionStateV2).where(
227
+ CollectionStateV2.taskgroupv2_id == task_group.id
228
+ )
229
+ )
230
+ states = res.scalars().all()
231
+ if len(states) > 1:
232
+ msg = (
233
+ "Expected one CollectionStateV2 associated to TaskGroup "
234
+ f"{task_group.id}, found {len(states)} "
235
+ f"(IDs: {[state.id for state in states]}).\n"
236
+ "Warning: this should have not happened, please contact an admin."
237
+ )
238
+ elif len(states) == 1:
239
+ msg = (
240
+ f"\nThere exists a task-collection state (ID={states[0].id}) for "
241
+ f"such task group (ID={task_group.id}), with status "
242
+ f"'{states[0].data.get('status')}'."
243
+ )
244
+ else:
245
+ msg = ""
246
+ return msg
247
+
248
+
249
+ async def _verify_non_duplication_user_constraint(
250
+ db: AsyncSession,
251
+ user_id: int,
252
+ pkg_name: str,
253
+ version: Optional[str],
254
+ ):
255
+ stm = (
256
+ select(TaskGroupV2)
257
+ .where(TaskGroupV2.user_id == user_id)
258
+ .where(TaskGroupV2.pkg_name == pkg_name)
259
+ .where(TaskGroupV2.version == version)
260
+ )
261
+ res = await db.execute(stm)
262
+ duplicate = res.scalars().all()
263
+ if duplicate:
264
+ user = await db.get(UserOAuth, user_id)
265
+ if len(duplicate) > 1:
266
+ raise HTTPException(
267
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
268
+ detail=(
269
+ "Invalid state:\n"
270
+ f"User '{user.email}' already owns {len(duplicate)} task "
271
+ f"groups with name='{pkg_name}' and {version=} "
272
+ f"(IDs: {[group.id for group in duplicate]}).\n"
273
+ "This should have not happened: please contact an admin."
274
+ ),
275
+ )
276
+ state_msg = await _get_collection_status_message(duplicate[0], db)
277
+ raise HTTPException(
278
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
279
+ detail=(
280
+ f"User '{user.email}' already owns a task group "
281
+ f"with name='{pkg_name}' and {version=}.{state_msg}"
282
+ ),
283
+ )
284
+
285
+
286
+ async def _verify_non_duplication_group_constraint(
287
+ db: AsyncSession,
288
+ user_group_id: Optional[int],
289
+ pkg_name: str,
290
+ version: Optional[str],
291
+ ):
292
+ if user_group_id is None:
293
+ return
294
+
295
+ stm = (
296
+ select(TaskGroupV2)
297
+ .where(TaskGroupV2.user_group_id == user_group_id)
298
+ .where(TaskGroupV2.pkg_name == pkg_name)
299
+ .where(TaskGroupV2.version == version)
300
+ )
301
+ res = await db.execute(stm)
302
+ duplicate = res.scalars().all()
303
+ if duplicate:
304
+ user_group = await db.get(UserGroup, user_group_id)
305
+ if len(duplicate) > 1:
306
+ raise HTTPException(
307
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
308
+ detail=(
309
+ "Invalid state:\n"
310
+ f"UserGroup '{user_group.name}' already owns "
311
+ f"{len(duplicate)} task groups with name='{pkg_name}' and "
312
+ f"{version=} (IDs: {[group.id for group in duplicate]}).\n"
313
+ "This should have not happened: please contact an admin."
314
+ ),
315
+ )
316
+ state_msg = await _get_collection_status_message(duplicate[0], db)
317
+ raise HTTPException(
318
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
319
+ detail=(
320
+ f"UserGroup {user_group.name} already owns a task group "
321
+ f"with {pkg_name=} and {version=}.{state_msg}"
322
+ ),
323
+ )
324
+
325
+
326
+ async def _add_warnings_to_workflow_tasks(
327
+ wftask_list: list[WorkflowTaskV2], user_id: int, db: AsyncSession
328
+ ) -> list[dict[str, Any]]:
329
+ wftask_list_with_warnings = []
330
+ for wftask in wftask_list:
331
+ wftask_data = dict(wftask.model_dump(), task=wftask.task)
332
+ try:
333
+ task_group = await _get_task_group_read_access(
334
+ task_group_id=wftask.task.taskgroupv2_id,
335
+ user_id=user_id,
336
+ db=db,
337
+ )
338
+ if not task_group.active:
339
+ wftask_data["warning"] = "Task is not active."
340
+ except HTTPException:
341
+ wftask_data["warning"] = "Current user has no access to this task."
342
+ wftask_list_with_warnings.append(wftask_data)
343
+ return wftask_list_with_warnings