fractal-server 2.6.4__py3-none-any.whl → 2.7.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. fractal_server/__init__.py +1 -1
  2. fractal_server/app/models/v2/__init__.py +2 -0
  3. fractal_server/app/models/v2/task.py +27 -0
  4. fractal_server/app/routes/api/v2/__init__.py +4 -0
  5. fractal_server/app/routes/api/v2/_aux_functions.py +0 -61
  6. fractal_server/app/routes/api/v2/_aux_functions_tasks.py +209 -0
  7. fractal_server/app/routes/api/v2/submit.py +16 -3
  8. fractal_server/app/routes/api/v2/task.py +59 -72
  9. fractal_server/app/routes/api/v2/task_collection.py +20 -4
  10. fractal_server/app/routes/api/v2/task_collection_custom.py +44 -18
  11. fractal_server/app/routes/api/v2/task_group.py +130 -0
  12. fractal_server/app/routes/api/v2/workflow.py +24 -3
  13. fractal_server/app/routes/api/v2/workflowtask.py +4 -7
  14. fractal_server/app/routes/auth/_aux_auth.py +42 -0
  15. fractal_server/app/schemas/v2/__init__.py +5 -0
  16. fractal_server/app/schemas/v2/task.py +2 -1
  17. fractal_server/app/schemas/v2/task_group.py +23 -0
  18. fractal_server/app/schemas/v2/workflow.py +5 -0
  19. fractal_server/app/schemas/v2/workflowtask.py +4 -0
  20. fractal_server/migrations/versions/7cf1baae8fb4_task_group_v2.py +66 -0
  21. fractal_server/tasks/v2/background_operations.py +16 -35
  22. fractal_server/tasks/v2/background_operations_ssh.py +15 -2
  23. fractal_server/tasks/v2/database_operations.py +54 -0
  24. {fractal_server-2.6.4.dist-info → fractal_server-2.7.0a0.dist-info}/METADATA +1 -1
  25. {fractal_server-2.6.4.dist-info → fractal_server-2.7.0a0.dist-info}/RECORD +28 -23
  26. {fractal_server-2.6.4.dist-info → fractal_server-2.7.0a0.dist-info}/LICENSE +0 -0
  27. {fractal_server-2.6.4.dist-info → fractal_server-2.7.0a0.dist-info}/WHEEL +0 -0
  28. {fractal_server-2.6.4.dist-info → fractal_server-2.7.0a0.dist-info}/entry_points.txt +0 -0
@@ -1,31 +1,40 @@
1
1
  import shlex
2
2
  import subprocess # nosec
3
3
  from pathlib import Path
4
+ from typing import Optional
4
5
 
5
6
  from fastapi import APIRouter
6
7
  from fastapi import Depends
7
8
  from fastapi import HTTPException
8
9
  from fastapi import status
10
+ from sqlalchemy.ext.asyncio import AsyncSession
9
11
  from sqlmodel import select
10
12
 
11
- from .....config import get_settings
12
- from .....logger import set_logger
13
- from .....syringe import Inject
14
- from ....db import DBSyncSession
15
- from ....db import get_sync_db
16
- from ....models.v1 import Task as TaskV1
17
- from ....models.v2 import TaskV2
18
- from ....schemas.v2 import TaskCollectCustomV2
19
- from ....schemas.v2 import TaskCreateV2
20
- from ....schemas.v2 import TaskReadV2
21
- from ...aux.validate_user_settings import verify_user_has_settings
13
+ from ._aux_functions_tasks import _get_valid_user_group_id
14
+ from fractal_server.app.db import DBSyncSession
15
+ from fractal_server.app.db import get_async_db
16
+ from fractal_server.app.db import get_sync_db
22
17
  from fractal_server.app.models import UserOAuth
18
+ from fractal_server.app.models.v1 import Task as TaskV1
19
+ from fractal_server.app.models.v2 import TaskV2
23
20
  from fractal_server.app.routes.auth import current_active_verified_user
21
+ from fractal_server.app.routes.aux.validate_user_settings import (
22
+ verify_user_has_settings,
23
+ )
24
+ from fractal_server.app.schemas.v2 import TaskCollectCustomV2
25
+ from fractal_server.app.schemas.v2 import TaskCreateV2
26
+ from fractal_server.app.schemas.v2 import TaskGroupCreateV2
27
+ from fractal_server.app.schemas.v2 import TaskReadV2
28
+ from fractal_server.config import get_settings
29
+ from fractal_server.logger import set_logger
24
30
  from fractal_server.string_tools import validate_cmd
25
- from fractal_server.tasks.v2.background_operations import _insert_tasks
31
+ from fractal_server.syringe import Inject
26
32
  from fractal_server.tasks.v2.background_operations import (
27
33
  _prepare_tasks_metadata,
28
34
  )
35
+ from fractal_server.tasks.v2.database_operations import (
36
+ create_db_task_group_and_tasks,
37
+ )
29
38
 
30
39
  router = APIRouter()
31
40
 
@@ -37,12 +46,25 @@ logger = set_logger(__name__)
37
46
  )
38
47
  async def collect_task_custom(
39
48
  task_collect: TaskCollectCustomV2,
49
+ private: bool = False,
50
+ user_group_id: Optional[int] = None,
40
51
  user: UserOAuth = Depends(current_active_verified_user),
41
- db: DBSyncSession = Depends(get_sync_db),
52
+ db: AsyncSession = Depends(get_async_db), # FIXME: using both sync/async
53
+ db_sync: DBSyncSession = Depends(
54
+ get_sync_db
55
+ ), # FIXME: using both sync/async
42
56
  ) -> list[TaskReadV2]:
43
57
 
44
58
  settings = Inject(get_settings)
45
59
 
60
+ # Validate query parameters related to user-group ownership
61
+ user_group_id = await _get_valid_user_group_id(
62
+ user_group_id=user_group_id,
63
+ private=private,
64
+ user_id=user.id,
65
+ db=db,
66
+ )
67
+
46
68
  if settings.FRACTAL_RUNNER_BACKEND == "slurm_ssh":
47
69
  if task_collect.package_root is None:
48
70
  raise HTTPException(
@@ -140,7 +162,7 @@ async def collect_task_custom(
140
162
  # already guaranteed by a constraint in the table definition).
141
163
  sources = [task.source for task in task_list]
142
164
  stm = select(TaskV2).where(TaskV2.source.in_(sources))
143
- res = db.execute(stm)
165
+ res = db_sync.execute(stm)
144
166
  overlapping_sources_v2 = res.scalars().all()
145
167
  if overlapping_sources_v2:
146
168
  overlapping_tasks_v2_source_and_id = [
@@ -152,7 +174,7 @@ async def collect_task_custom(
152
174
  detail="\n".join(overlapping_tasks_v2_source_and_id),
153
175
  )
154
176
  stm = select(TaskV1).where(TaskV1.source.in_(sources))
155
- res = db.execute(stm)
177
+ res = db_sync.execute(stm)
156
178
  overlapping_sources_v1 = res.scalars().all()
157
179
  if overlapping_sources_v1:
158
180
  overlapping_tasks_v1_source_and_id = [
@@ -164,8 +186,12 @@ async def collect_task_custom(
164
186
  detail="\n".join(overlapping_tasks_v1_source_and_id),
165
187
  )
166
188
 
167
- task_list_db: list[TaskV2] = _insert_tasks(
168
- task_list=task_list, owner=owner, db=db
189
+ task_group = create_db_task_group_and_tasks(
190
+ task_list=task_list,
191
+ task_group_obj=TaskGroupCreateV2(),
192
+ user_id=user.id,
193
+ user_group_id=user_group_id,
194
+ db=db_sync,
169
195
  )
170
196
 
171
197
  logger.debug(
@@ -173,4 +199,4 @@ async def collect_task_custom(
173
199
  f"for package with {source=}"
174
200
  )
175
201
 
176
- return task_list_db
202
+ return task_group.task_list
@@ -0,0 +1,130 @@
1
+ from fastapi import APIRouter
2
+ from fastapi import Depends
3
+ from fastapi import HTTPException
4
+ from fastapi import Response
5
+ from fastapi import status
6
+ from sqlmodel import or_
7
+ from sqlmodel import select
8
+
9
+ from ._aux_functions_tasks import _get_task_group_full_access
10
+ from ._aux_functions_tasks import _get_task_group_read_access
11
+ from fractal_server.app.db import AsyncSession
12
+ from fractal_server.app.db import get_async_db
13
+ from fractal_server.app.models import LinkUserGroup
14
+ from fractal_server.app.models import UserOAuth
15
+ from fractal_server.app.models.v2 import TaskGroupV2
16
+ from fractal_server.app.models.v2 import WorkflowTaskV2
17
+ from fractal_server.app.routes.auth import current_active_user
18
+ from fractal_server.app.routes.auth._aux_auth import (
19
+ _verify_user_belongs_to_group,
20
+ )
21
+ from fractal_server.app.schemas.v2 import TaskGroupReadV2
22
+ from fractal_server.app.schemas.v2 import TaskGroupUpdateV2
23
+ from fractal_server.logger import set_logger
24
+
25
+ router = APIRouter()
26
+
27
+ logger = set_logger(__name__)
28
+
29
+
30
+ @router.get("/", response_model=list[TaskGroupReadV2])
31
+ async def get_task_group_list(
32
+ user: UserOAuth = Depends(current_active_user),
33
+ db: AsyncSession = Depends(get_async_db),
34
+ ) -> list[TaskGroupReadV2]:
35
+ """
36
+ Get all accessible TaskGroups
37
+ """
38
+ stm = select(TaskGroupV2).where(
39
+ or_(
40
+ TaskGroupV2.user_id == user.id,
41
+ TaskGroupV2.user_group_id.in_(
42
+ select(LinkUserGroup.group_id).where(
43
+ LinkUserGroup.user_id == user.id
44
+ )
45
+ ),
46
+ )
47
+ )
48
+ res = await db.execute(stm)
49
+ task_groups = res.scalars().all()
50
+
51
+ return task_groups
52
+
53
+
54
+ @router.get("/{task_group_id}/", response_model=TaskGroupReadV2)
55
+ async def get_task_group(
56
+ task_group_id: int,
57
+ user: UserOAuth = Depends(current_active_user),
58
+ db: AsyncSession = Depends(get_async_db),
59
+ ) -> TaskGroupReadV2:
60
+ """
61
+ Get single TaskGroup
62
+ """
63
+ task_group = await _get_task_group_read_access(
64
+ task_group_id=task_group_id,
65
+ user_id=user.id,
66
+ db=db,
67
+ )
68
+ return task_group
69
+
70
+
71
+ @router.delete("/{task_group_id}/", status_code=204)
72
+ async def delete_task_group(
73
+ task_group_id: int,
74
+ user: UserOAuth = Depends(current_active_user),
75
+ db: AsyncSession = Depends(get_async_db),
76
+ ):
77
+ """
78
+ Delete single TaskGroup
79
+ """
80
+
81
+ task_group = await _get_task_group_full_access(
82
+ task_group_id=task_group_id,
83
+ user_id=user.id,
84
+ db=db,
85
+ )
86
+
87
+ stm = select(WorkflowTaskV2).where(
88
+ WorkflowTaskV2.task_id.in_({task.id for task in task_group.task_list})
89
+ )
90
+ res = await db.execute(stm)
91
+ workflow_tasks = res.scalars().all()
92
+ if workflow_tasks != []:
93
+ raise HTTPException(
94
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
95
+ detail=f"TaskV2 {workflow_tasks[0].task_id} is still in use",
96
+ )
97
+
98
+ await db.delete(task_group)
99
+ await db.commit()
100
+
101
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
102
+
103
+
104
+ @router.patch("/{task_group_id}/", response_model=TaskGroupReadV2)
105
+ async def patch_task_group(
106
+ task_group_id: int,
107
+ task_group_update: TaskGroupUpdateV2,
108
+ user: UserOAuth = Depends(current_active_user),
109
+ db: AsyncSession = Depends(get_async_db),
110
+ ) -> TaskGroupReadV2:
111
+ """
112
+ Patch single TaskGroup
113
+ """
114
+ task_group = await _get_task_group_full_access(
115
+ task_group_id=task_group_id,
116
+ user_id=user.id,
117
+ db=db,
118
+ )
119
+
120
+ for key, value in task_group_update.dict(exclude_unset=True).items():
121
+ if (key == "user_group_id") and (value is not None):
122
+ await _verify_user_belongs_to_group(
123
+ user_id=user.id, user_group_id=value, db=db
124
+ )
125
+ setattr(task_group, key, value)
126
+
127
+ db.add(task_group)
128
+ await db.commit()
129
+ await db.refresh(task_group)
130
+ return task_group
@@ -19,6 +19,7 @@ from ....schemas.v2 import WorkflowCreateV2
19
19
  from ....schemas.v2 import WorkflowExportV2
20
20
  from ....schemas.v2 import WorkflowImportV2
21
21
  from ....schemas.v2 import WorkflowReadV2
22
+ from ....schemas.v2 import WorkflowReadV2WithWarnings
22
23
  from ....schemas.v2 import WorkflowTaskCreateV2
23
24
  from ....schemas.v2 import WorkflowUpdateV2
24
25
  from ._aux_functions import _check_workflow_exists
@@ -26,6 +27,7 @@ from ._aux_functions import _get_project_check_owner
26
27
  from ._aux_functions import _get_submitted_jobs_statement
27
28
  from ._aux_functions import _get_workflow_check_owner
28
29
  from ._aux_functions import _workflow_insert_task
30
+ from ._aux_functions_tasks import _get_task_group_read_access
29
31
  from fractal_server.app.models import UserOAuth
30
32
  from fractal_server.app.routes.auth import current_active_user
31
33
 
@@ -89,14 +91,14 @@ async def create_workflow(
89
91
 
90
92
  @router.get(
91
93
  "/project/{project_id}/workflow/{workflow_id}/",
92
- response_model=WorkflowReadV2,
94
+ response_model=WorkflowReadV2WithWarnings,
93
95
  )
94
96
  async def read_workflow(
95
97
  project_id: int,
96
98
  workflow_id: int,
97
99
  user: UserOAuth = Depends(current_active_user),
98
100
  db: AsyncSession = Depends(get_async_db),
99
- ) -> Optional[WorkflowReadV2]:
101
+ ) -> Optional[WorkflowReadV2WithWarnings]:
100
102
  """
101
103
  Get info on an existing workflow
102
104
  """
@@ -108,7 +110,26 @@ async def read_workflow(
108
110
  db=db,
109
111
  )
110
112
 
111
- return workflow
113
+ workflow_data = dict(
114
+ **workflow.model_dump(),
115
+ project=workflow.project,
116
+ )
117
+ task_list_with_warnings = []
118
+ for wftask in workflow.task_list:
119
+ wftask_data = dict(wftask.model_dump(), task=wftask.task)
120
+ try:
121
+ task_group = await _get_task_group_read_access(
122
+ task_group_id=wftask.task.taskgroupv2_id,
123
+ user_id=user.id,
124
+ db=db,
125
+ )
126
+ if not task_group.active:
127
+ wftask_data["warning"] = "Task is not active."
128
+ except HTTPException:
129
+ wftask_data["warning"] = "Current user has no access to this task."
130
+ task_list_with_warnings.append(wftask_data)
131
+ workflow_data["task_list"] = task_list_with_warnings
132
+ return workflow_data
112
133
 
113
134
 
114
135
  @router.patch(
@@ -9,13 +9,13 @@ from fastapi import status
9
9
 
10
10
  from ....db import AsyncSession
11
11
  from ....db import get_async_db
12
- from ....models.v2 import TaskV2
13
12
  from ....schemas.v2 import WorkflowTaskCreateV2
14
13
  from ....schemas.v2 import WorkflowTaskReadV2
15
14
  from ....schemas.v2 import WorkflowTaskUpdateV2
16
15
  from ._aux_functions import _get_workflow_check_owner
17
16
  from ._aux_functions import _get_workflow_task_check_owner
18
17
  from ._aux_functions import _workflow_insert_task
18
+ from ._aux_functions_tasks import _get_task_read_access
19
19
  from fractal_server.app.models import UserOAuth
20
20
  from fractal_server.app.routes.auth import current_active_user
21
21
 
@@ -43,12 +43,9 @@ async def create_workflowtask(
43
43
  project_id=project_id, workflow_id=workflow_id, user_id=user.id, db=db
44
44
  )
45
45
 
46
- task = await db.get(TaskV2, task_id)
47
- if not task:
48
- raise HTTPException(
49
- status_code=status.HTTP_404_NOT_FOUND,
50
- detail=f"TaskV2 {task_id} not found.",
51
- )
46
+ task = await _get_task_read_access(
47
+ task_id=task_id, user_id=user.id, db=db, require_active=True
48
+ )
52
49
 
53
50
  if task.type == "parallel":
54
51
  if (
@@ -8,6 +8,7 @@ from fractal_server.app.models.security import UserGroup
8
8
  from fractal_server.app.models.security import UserOAuth
9
9
  from fractal_server.app.schemas.user import UserRead
10
10
  from fractal_server.app.schemas.user_group import UserGroupRead
11
+ from fractal_server.app.security import FRACTAL_DEFAULT_GROUP_NAME
11
12
 
12
13
 
13
14
  async def _get_single_user_with_group_names(
@@ -111,3 +112,44 @@ async def _user_or_404(user_id: int, db: AsyncSession) -> UserOAuth:
111
112
  status_code=status.HTTP_404_NOT_FOUND, detail="User not found."
112
113
  )
113
114
  return user
115
+
116
+
117
+ async def _get_default_user_group_id(db: AsyncSession) -> int:
118
+ stm = select(UserGroup.id).where(
119
+ UserGroup.name == FRACTAL_DEFAULT_GROUP_NAME
120
+ )
121
+ res = await db.execute(stm)
122
+ user_group_id = res.scalars().one_or_none()
123
+ if user_group_id is None:
124
+ raise HTTPException(
125
+ status_code=status.HTTP_404_NOT_FOUND,
126
+ detail=f"User group '{FRACTAL_DEFAULT_GROUP_NAME}' not found.",
127
+ )
128
+ return user_group_id
129
+
130
+
131
+ async def _verify_user_belongs_to_group(
132
+ *, user_id: int, user_group_id: int, db: AsyncSession
133
+ ):
134
+ stm = (
135
+ select(LinkUserGroup)
136
+ .where(LinkUserGroup.user_id == user_id)
137
+ .where(LinkUserGroup.group_id == user_group_id)
138
+ )
139
+ res = await db.execute(stm)
140
+ link = res.scalars().one_or_none()
141
+ if link is None:
142
+ group = await db.get(UserGroup, user_group_id)
143
+ if group is None:
144
+ raise HTTPException(
145
+ status_code=status.HTTP_404_NOT_FOUND,
146
+ detail=f"UserGroup {user_group_id} not found",
147
+ )
148
+ else:
149
+ raise HTTPException(
150
+ status_code=status.HTTP_403_FORBIDDEN,
151
+ detail=(
152
+ f"User {user_id} does not belong "
153
+ f"to UserGroup {user_group_id}"
154
+ ),
155
+ )
@@ -26,14 +26,19 @@ from .task_collection import CollectionStateReadV2 # noqa F401
26
26
  from .task_collection import CollectionStatusV2 # noqa F401
27
27
  from .task_collection import TaskCollectCustomV2 # noqa F401
28
28
  from .task_collection import TaskCollectPipV2 # noqa F401
29
+ from .task_group import TaskGroupCreateV2 # noqa F401
30
+ from .task_group import TaskGroupReadV2 # noqa F401
31
+ from .task_group import TaskGroupUpdateV2 # noqa F401
29
32
  from .workflow import WorkflowCreateV2 # noqa F401
30
33
  from .workflow import WorkflowExportV2 # noqa F401
31
34
  from .workflow import WorkflowImportV2 # noqa F401
32
35
  from .workflow import WorkflowReadV2 # noqa F401
36
+ from .workflow import WorkflowReadV2WithWarnings # noqa F401
33
37
  from .workflow import WorkflowUpdateV2 # noqa F401
34
38
  from .workflowtask import WorkflowTaskCreateV2 # noqa F401
35
39
  from .workflowtask import WorkflowTaskExportV2 # noqa F401
36
40
  from .workflowtask import WorkflowTaskImportV2 # noqa F401
37
41
  from .workflowtask import WorkflowTaskReadV2 # noqa F401
42
+ from .workflowtask import WorkflowTaskReadV2WithWarning # noqa F401
38
43
  from .workflowtask import WorkflowTaskStatusTypeV2 # noqa F401
39
44
  from .workflowtask import WorkflowTaskUpdateV2 # noqa F401
@@ -90,7 +90,6 @@ class TaskReadV2(BaseModel):
90
90
  name: str
91
91
  type: Literal["parallel", "non_parallel", "compound"]
92
92
  source: str
93
- owner: Optional[str]
94
93
  version: Optional[str]
95
94
 
96
95
  command_non_parallel: Optional[str]
@@ -105,6 +104,8 @@ class TaskReadV2(BaseModel):
105
104
  input_types: dict[str, bool]
106
105
  output_types: dict[str, bool]
107
106
 
107
+ taskgroupv2_id: Optional[int]
108
+
108
109
 
109
110
  class TaskUpdateV2(BaseModel):
110
111
 
@@ -0,0 +1,23 @@
1
+ from typing import Optional
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from .task import TaskReadV2
6
+
7
+
8
+ class TaskGroupCreateV2(BaseModel):
9
+ active: bool = True
10
+
11
+
12
+ class TaskGroupReadV2(BaseModel):
13
+
14
+ id: int
15
+ user_id: int
16
+ user_group_id: Optional[int] = None
17
+ active: bool
18
+ task_list: list[TaskReadV2]
19
+
20
+
21
+ class TaskGroupUpdateV2(BaseModel):
22
+ user_group_id: Optional[int] = None
23
+ active: Optional[bool] = None
@@ -11,6 +11,7 @@ from .project import ProjectReadV2
11
11
  from .workflowtask import WorkflowTaskExportV2
12
12
  from .workflowtask import WorkflowTaskImportV2
13
13
  from .workflowtask import WorkflowTaskReadV2
14
+ from .workflowtask import WorkflowTaskReadV2WithWarning
14
15
 
15
16
 
16
17
  class WorkflowCreateV2(BaseModel, extra=Extra.forbid):
@@ -35,6 +36,10 @@ class WorkflowReadV2(BaseModel):
35
36
  )
36
37
 
37
38
 
39
+ class WorkflowReadV2WithWarnings(WorkflowReadV2):
40
+ task_list: list[WorkflowTaskReadV2WithWarning]
41
+
42
+
38
43
  class WorkflowUpdateV2(BaseModel):
39
44
 
40
45
  name: Optional[str]
@@ -102,6 +102,10 @@ class WorkflowTaskReadV2(BaseModel):
102
102
  task: TaskReadV2
103
103
 
104
104
 
105
+ class WorkflowTaskReadV2WithWarning(WorkflowTaskReadV2):
106
+ warning: Optional[str] = None
107
+
108
+
105
109
  class WorkflowTaskUpdateV2(BaseModel):
106
110
 
107
111
  meta_non_parallel: Optional[dict[str, Any]]
@@ -0,0 +1,66 @@
1
+ """task group v2
2
+
3
+ Revision ID: 7cf1baae8fb4
4
+ Revises: da2cb2ac4255
5
+ Create Date: 2024-10-01 12:31:46.792037
6
+
7
+ """
8
+ import sqlalchemy as sa
9
+ from alembic import op
10
+
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = "7cf1baae8fb4"
14
+ down_revision = "da2cb2ac4255"
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ # ### commands auto generated by Alembic - please adjust! ###
21
+ op.create_table(
22
+ "taskgroupv2",
23
+ sa.Column("id", sa.Integer(), nullable=False),
24
+ sa.Column("user_id", sa.Integer(), nullable=False),
25
+ sa.Column("user_group_id", sa.Integer(), nullable=True),
26
+ sa.Column("active", sa.Boolean(), nullable=False),
27
+ sa.Column(
28
+ "timestamp_created", sa.DateTime(timezone=True), nullable=False
29
+ ),
30
+ sa.ForeignKeyConstraint(
31
+ ["user_group_id"],
32
+ ["usergroup.id"],
33
+ name=op.f("fk_taskgroupv2_user_group_id_usergroup"),
34
+ ),
35
+ sa.ForeignKeyConstraint(
36
+ ["user_id"],
37
+ ["user_oauth.id"],
38
+ name=op.f("fk_taskgroupv2_user_id_user_oauth"),
39
+ ),
40
+ sa.PrimaryKeyConstraint("id", name=op.f("pk_taskgroupv2")),
41
+ )
42
+ with op.batch_alter_table("taskv2", schema=None) as batch_op:
43
+ batch_op.add_column(
44
+ sa.Column("taskgroupv2_id", sa.Integer(), nullable=True)
45
+ )
46
+ batch_op.create_foreign_key(
47
+ batch_op.f("fk_taskv2_taskgroupv2_id_taskgroupv2"),
48
+ "taskgroupv2",
49
+ ["taskgroupv2_id"],
50
+ ["id"],
51
+ )
52
+
53
+ # ### end Alembic commands ###
54
+
55
+
56
+ def downgrade() -> None:
57
+ # ### commands auto generated by Alembic - please adjust! ###
58
+ with op.batch_alter_table("taskv2", schema=None) as batch_op:
59
+ batch_op.drop_constraint(
60
+ batch_op.f("fk_taskv2_taskgroupv2_id_taskgroupv2"),
61
+ type_="foreignkey",
62
+ )
63
+ batch_op.drop_column("taskgroupv2_id")
64
+
65
+ op.drop_table("taskgroupv2")
66
+ # ### end Alembic commands ###
@@ -17,11 +17,12 @@ from ..utils import get_collection_log
17
17
  from ..utils import get_collection_path
18
18
  from ..utils import get_log_path
19
19
  from ._TaskCollectPip import _TaskCollectPip
20
+ from .database_operations import create_db_task_group_and_tasks
20
21
  from fractal_server.app.db import get_sync_db
21
22
  from fractal_server.app.models.v2 import CollectionStateV2
22
- from fractal_server.app.models.v2 import TaskV2
23
23
  from fractal_server.app.schemas.v2 import CollectionStatusV2
24
24
  from fractal_server.app.schemas.v2 import TaskCreateV2
25
+ from fractal_server.app.schemas.v2 import TaskGroupCreateV2
25
26
  from fractal_server.app.schemas.v2 import TaskReadV2
26
27
  from fractal_server.app.schemas.v2.manifest import ManifestV2
27
28
  from fractal_server.logger import get_logger
@@ -30,38 +31,6 @@ from fractal_server.logger import set_logger
30
31
  from fractal_server.tasks.v2._venv_pip import _create_venv_install_package_pip
31
32
 
32
33
 
33
- def _get_task_type(task: TaskCreateV2) -> str:
34
- if task.command_non_parallel is None:
35
- return "parallel"
36
- elif task.command_parallel is None:
37
- return "non_parallel"
38
- else:
39
- return "compound"
40
-
41
-
42
- def _insert_tasks(
43
- task_list: list[TaskCreateV2],
44
- db: DBSyncSession,
45
- owner: Optional[str] = None,
46
- ) -> list[TaskV2]:
47
- """
48
- Insert tasks into database
49
- """
50
-
51
- owner_dict = dict(owner=owner) if owner is not None else dict()
52
-
53
- task_db_list = [
54
- TaskV2(**t.dict(), **owner_dict, type=_get_task_type(t))
55
- for t in task_list
56
- ]
57
- db.add_all(task_db_list)
58
- db.commit()
59
- for t in task_db_list:
60
- db.refresh(t)
61
- db.close()
62
- return task_db_list
63
-
64
-
65
34
  def _set_collection_state_data_status(
66
35
  *,
67
36
  state_id: int,
@@ -232,9 +201,12 @@ def _check_task_files_exist(task_list: list[TaskCreateV2]) -> None:
232
201
 
233
202
 
234
203
  async def background_collect_pip(
204
+ *,
235
205
  state_id: int,
236
206
  venv_path: Path,
237
207
  task_pkg: _TaskCollectPip,
208
+ user_id: int,
209
+ user_group_id: Optional[int],
238
210
  ) -> None:
239
211
  """
240
212
  Setup venv, install package, collect tasks.
@@ -301,7 +273,15 @@ async def background_collect_pip(
301
273
  python_bin=python_bin,
302
274
  )
303
275
  _check_task_files_exist(task_list=task_list)
304
- tasks = _insert_tasks(task_list=task_list, db=db)
276
+
277
+ task_group = create_db_task_group_and_tasks(
278
+ task_list=task_list,
279
+ task_group_obj=TaskGroupCreateV2(),
280
+ user_id=user_id,
281
+ user_group_id=user_group_id,
282
+ db=db,
283
+ )
284
+
305
285
  logger.debug("collecting - prepare tasks and update db " "- END")
306
286
  logger.debug("collecting - END")
307
287
 
@@ -310,7 +290,8 @@ async def background_collect_pip(
310
290
  collection_path = get_collection_path(venv_path)
311
291
  collection_state = db.get(CollectionStateV2, state_id)
312
292
  task_read_list = [
313
- TaskReadV2(**task.model_dump()).dict() for task in tasks
293
+ TaskReadV2(**task.model_dump()).dict()
294
+ for task in task_group.task_list
314
295
  ]
315
296
  collection_state.data["task_list"] = task_read_list
316
297
  collection_state.data["log"] = get_collection_log(venv_path)