fractal-server 1.4.9__py3-none-any.whl → 2.0.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. fractal_server/__init__.py +1 -1
  2. fractal_server/app/models/__init__.py +4 -7
  3. fractal_server/app/models/linkuserproject.py +9 -0
  4. fractal_server/app/models/security.py +6 -0
  5. fractal_server/app/models/state.py +1 -1
  6. fractal_server/app/models/v1/__init__.py +10 -0
  7. fractal_server/app/models/{dataset.py → v1/dataset.py} +5 -5
  8. fractal_server/app/models/{job.py → v1/job.py} +5 -5
  9. fractal_server/app/models/{project.py → v1/project.py} +5 -5
  10. fractal_server/app/models/{task.py → v1/task.py} +7 -2
  11. fractal_server/app/models/{workflow.py → v1/workflow.py} +5 -5
  12. fractal_server/app/models/v2/__init__.py +20 -0
  13. fractal_server/app/models/v2/dataset.py +55 -0
  14. fractal_server/app/models/v2/job.py +51 -0
  15. fractal_server/app/models/v2/project.py +31 -0
  16. fractal_server/app/models/v2/task.py +93 -0
  17. fractal_server/app/models/v2/workflow.py +43 -0
  18. fractal_server/app/models/v2/workflowtask.py +90 -0
  19. fractal_server/app/routes/{admin.py → admin/v1.py} +42 -42
  20. fractal_server/app/routes/admin/v2.py +275 -0
  21. fractal_server/app/routes/api/v1/__init__.py +7 -7
  22. fractal_server/app/routes/api/v1/_aux_functions.py +2 -2
  23. fractal_server/app/routes/api/v1/dataset.py +44 -37
  24. fractal_server/app/routes/api/v1/job.py +12 -12
  25. fractal_server/app/routes/api/v1/project.py +23 -21
  26. fractal_server/app/routes/api/v1/task.py +24 -14
  27. fractal_server/app/routes/api/v1/task_collection.py +16 -14
  28. fractal_server/app/routes/api/v1/workflow.py +24 -24
  29. fractal_server/app/routes/api/v1/workflowtask.py +10 -10
  30. fractal_server/app/routes/api/v2/__init__.py +28 -0
  31. fractal_server/app/routes/api/v2/_aux_functions.py +497 -0
  32. fractal_server/app/routes/api/v2/apply.py +220 -0
  33. fractal_server/app/routes/api/v2/dataset.py +310 -0
  34. fractal_server/app/routes/api/v2/images.py +212 -0
  35. fractal_server/app/routes/api/v2/job.py +200 -0
  36. fractal_server/app/routes/api/v2/project.py +205 -0
  37. fractal_server/app/routes/api/v2/task.py +222 -0
  38. fractal_server/app/routes/api/v2/task_collection.py +229 -0
  39. fractal_server/app/routes/api/v2/workflow.py +398 -0
  40. fractal_server/app/routes/api/v2/workflowtask.py +269 -0
  41. fractal_server/app/routes/aux/_job.py +1 -1
  42. fractal_server/app/runner/async_wrap.py +27 -0
  43. fractal_server/app/runner/exceptions.py +129 -0
  44. fractal_server/app/runner/executors/local/__init__.py +3 -0
  45. fractal_server/app/runner/{_local → executors/local}/executor.py +2 -2
  46. fractal_server/app/runner/executors/slurm/__init__.py +3 -0
  47. fractal_server/app/runner/{_slurm → executors/slurm}/_batching.py +1 -1
  48. fractal_server/app/runner/executors/slurm/_check_jobs_status.py +72 -0
  49. fractal_server/app/runner/{_slurm → executors/slurm}/_executor_wait_thread.py +3 -4
  50. fractal_server/app/runner/{_slurm → executors/slurm}/_slurm_config.py +3 -152
  51. fractal_server/app/runner/{_slurm → executors/slurm}/_subprocess_run_as_user.py +1 -1
  52. fractal_server/app/runner/{_slurm → executors/slurm}/executor.py +9 -9
  53. fractal_server/app/runner/filenames.py +6 -0
  54. fractal_server/app/runner/set_start_and_last_task_index.py +39 -0
  55. fractal_server/app/runner/task_files.py +105 -0
  56. fractal_server/app/runner/{__init__.py → v1/__init__.py} +36 -49
  57. fractal_server/app/runner/{_common.py → v1/_common.py} +13 -120
  58. fractal_server/app/runner/{_local → v1/_local}/__init__.py +6 -6
  59. fractal_server/app/runner/{_local → v1/_local}/_local_config.py +6 -7
  60. fractal_server/app/runner/{_local → v1/_local}/_submit_setup.py +1 -5
  61. fractal_server/app/runner/v1/_slurm/__init__.py +310 -0
  62. fractal_server/app/runner/{_slurm → v1/_slurm}/_submit_setup.py +3 -9
  63. fractal_server/app/runner/v1/_slurm/get_slurm_config.py +163 -0
  64. fractal_server/app/runner/v1/common.py +117 -0
  65. fractal_server/app/runner/{handle_failed_job.py → v1/handle_failed_job.py} +8 -8
  66. fractal_server/app/runner/v2/__init__.py +337 -0
  67. fractal_server/app/runner/v2/_local/__init__.py +169 -0
  68. fractal_server/app/runner/v2/_local/_local_config.py +118 -0
  69. fractal_server/app/runner/v2/_local/_submit_setup.py +52 -0
  70. fractal_server/app/runner/v2/_slurm/__init__.py +157 -0
  71. fractal_server/app/runner/v2/_slurm/_submit_setup.py +83 -0
  72. fractal_server/app/runner/v2/_slurm/get_slurm_config.py +179 -0
  73. fractal_server/app/runner/v2/components.py +5 -0
  74. fractal_server/app/runner/v2/deduplicate_list.py +24 -0
  75. fractal_server/app/runner/v2/handle_failed_job.py +156 -0
  76. fractal_server/app/runner/v2/merge_outputs.py +41 -0
  77. fractal_server/app/runner/v2/runner.py +264 -0
  78. fractal_server/app/runner/v2/runner_functions.py +339 -0
  79. fractal_server/app/runner/v2/runner_functions_low_level.py +134 -0
  80. fractal_server/app/runner/v2/task_interface.py +43 -0
  81. fractal_server/app/runner/v2/v1_compat.py +21 -0
  82. fractal_server/app/schemas/__init__.py +4 -42
  83. fractal_server/app/schemas/v1/__init__.py +42 -0
  84. fractal_server/app/schemas/{applyworkflow.py → v1/applyworkflow.py} +18 -18
  85. fractal_server/app/schemas/{dataset.py → v1/dataset.py} +30 -30
  86. fractal_server/app/schemas/{dumps.py → v1/dumps.py} +8 -8
  87. fractal_server/app/schemas/{manifest.py → v1/manifest.py} +5 -5
  88. fractal_server/app/schemas/{project.py → v1/project.py} +9 -9
  89. fractal_server/app/schemas/{task.py → v1/task.py} +12 -12
  90. fractal_server/app/schemas/{task_collection.py → v1/task_collection.py} +7 -7
  91. fractal_server/app/schemas/{workflow.py → v1/workflow.py} +38 -38
  92. fractal_server/app/schemas/v2/__init__.py +34 -0
  93. fractal_server/app/schemas/v2/dataset.py +88 -0
  94. fractal_server/app/schemas/v2/dumps.py +87 -0
  95. fractal_server/app/schemas/v2/job.py +113 -0
  96. fractal_server/app/schemas/v2/manifest.py +109 -0
  97. fractal_server/app/schemas/v2/project.py +36 -0
  98. fractal_server/app/schemas/v2/task.py +121 -0
  99. fractal_server/app/schemas/v2/task_collection.py +105 -0
  100. fractal_server/app/schemas/v2/workflow.py +78 -0
  101. fractal_server/app/schemas/v2/workflowtask.py +118 -0
  102. fractal_server/config.py +5 -10
  103. fractal_server/images/__init__.py +50 -0
  104. fractal_server/images/tools.py +86 -0
  105. fractal_server/main.py +11 -3
  106. fractal_server/migrations/versions/4b35c5cefbe3_tmp_is_v2_compatible.py +39 -0
  107. fractal_server/migrations/versions/56af171b0159_v2.py +217 -0
  108. fractal_server/migrations/versions/876f28db9d4e_tmp_split_task_and_wftask_meta.py +68 -0
  109. fractal_server/migrations/versions/974c802f0dd0_tmp_workflowtaskv2_type_in_db.py +37 -0
  110. fractal_server/migrations/versions/9cd305cd6023_tmp_workflowtaskv2.py +40 -0
  111. fractal_server/migrations/versions/a6231ed6273c_tmp_args_schemas_in_taskv2.py +42 -0
  112. fractal_server/migrations/versions/b9e9eed9d442_tmp_taskv2_type.py +37 -0
  113. fractal_server/migrations/versions/e3e639454d4b_tmp_make_task_meta_non_optional.py +50 -0
  114. fractal_server/tasks/__init__.py +0 -5
  115. fractal_server/tasks/endpoint_operations.py +13 -19
  116. fractal_server/tasks/utils.py +35 -0
  117. fractal_server/tasks/{_TaskCollectPip.py → v1/_TaskCollectPip.py} +3 -3
  118. fractal_server/tasks/{background_operations.py → v1/background_operations.py} +18 -50
  119. fractal_server/tasks/v1/get_collection_data.py +14 -0
  120. fractal_server/tasks/v2/_TaskCollectPip.py +103 -0
  121. fractal_server/tasks/v2/background_operations.py +382 -0
  122. fractal_server/tasks/v2/get_collection_data.py +14 -0
  123. {fractal_server-1.4.9.dist-info → fractal_server-2.0.0a0.dist-info}/METADATA +3 -4
  124. fractal_server-2.0.0a0.dist-info/RECORD +166 -0
  125. fractal_server/app/runner/_slurm/.gitignore +0 -2
  126. fractal_server/app/runner/_slurm/__init__.py +0 -150
  127. fractal_server/app/runner/common.py +0 -311
  128. fractal_server-1.4.9.dist-info/RECORD +0 -97
  129. /fractal_server/app/runner/{_slurm → executors/slurm}/remote.py +0 -0
  130. {fractal_server-1.4.9.dist-info → fractal_server-2.0.0a0.dist-info}/LICENSE +0 -0
  131. {fractal_server-1.4.9.dist-info → fractal_server-2.0.0a0.dist-info}/WHEEL +0 -0
  132. {fractal_server-1.4.9.dist-info → fractal_server-2.0.0a0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,200 @@
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ from fastapi import APIRouter
5
+ from fastapi import Depends
6
+ from fastapi import Response
7
+ from fastapi import status
8
+ from fastapi.responses import StreamingResponse
9
+ from sqlmodel import select
10
+
11
+ from ....db import AsyncSession
12
+ from ....db import get_async_db
13
+ from ....models.v2 import JobV2
14
+ from ....models.v2 import ProjectV2
15
+ from ....runner.filenames import WORKFLOW_LOG_FILENAME # FIXME
16
+ from ....schemas.v2 import JobReadV2
17
+ from ....schemas.v2 import JobStatusTypeV2
18
+ from ....security import current_active_user
19
+ from ....security import User
20
+ from ...aux._job import _write_shutdown_file
21
+ from ...aux._job import _zip_folder_to_byte_stream
22
+ from ...aux._runner import _check_backend_is_slurm
23
+ from ._aux_functions import _get_job_check_owner
24
+ from ._aux_functions import _get_project_check_owner
25
+ from ._aux_functions import _get_workflow_check_owner
26
+
27
+ router = APIRouter()
28
+
29
+
30
+ @router.get("/job/", response_model=list[JobReadV2])
31
+ async def get_user_jobs(
32
+ user: User = Depends(current_active_user),
33
+ log: bool = True,
34
+ db: AsyncSession = Depends(get_async_db),
35
+ ) -> list[JobReadV2]:
36
+ """
37
+ Returns all the jobs of the current user
38
+ """
39
+ stm = (
40
+ select(JobV2)
41
+ .join(ProjectV2)
42
+ .where(ProjectV2.user_list.any(User.id == user.id))
43
+ )
44
+ res = await db.execute(stm)
45
+ job_list = res.scalars().all()
46
+ await db.close()
47
+ if not log:
48
+ for job in job_list:
49
+ setattr(job, "log", None)
50
+
51
+ return job_list
52
+
53
+
54
+ @router.get(
55
+ "/project/{project_id}/workflow/{workflow_id}/job/",
56
+ response_model=list[JobReadV2],
57
+ )
58
+ async def get_workflow_jobs(
59
+ project_id: int,
60
+ workflow_id: int,
61
+ user: User = Depends(current_active_user),
62
+ db: AsyncSession = Depends(get_async_db),
63
+ ) -> Optional[list[JobReadV2]]:
64
+ """
65
+ Returns all the jobs related to a specific workflow
66
+ """
67
+ await _get_workflow_check_owner(
68
+ project_id=project_id, workflow_id=workflow_id, user_id=user.id, db=db
69
+ )
70
+ stm = select(JobV2).where(JobV2.workflow_id == workflow_id)
71
+ res = await db.execute(stm)
72
+ job_list = res.scalars().all()
73
+ return job_list
74
+
75
+
76
+ @router.get(
77
+ "/project/{project_id}/job/{job_id}/",
78
+ response_model=JobReadV2,
79
+ )
80
+ async def read_job(
81
+ project_id: int,
82
+ job_id: int,
83
+ show_tmp_logs: bool = False,
84
+ user: User = Depends(current_active_user),
85
+ db: AsyncSession = Depends(get_async_db),
86
+ ) -> Optional[JobReadV2]:
87
+ """
88
+ Return info on an existing job
89
+ """
90
+
91
+ output = await _get_job_check_owner(
92
+ project_id=project_id,
93
+ job_id=job_id,
94
+ user_id=user.id,
95
+ db=db,
96
+ )
97
+ job = output["job"]
98
+ await db.close()
99
+
100
+ if show_tmp_logs and (job.status == JobStatusTypeV2.SUBMITTED):
101
+ try:
102
+ with open(f"{job.working_dir}/{WORKFLOW_LOG_FILENAME}", "r") as f:
103
+ job.log = f.read()
104
+ except FileNotFoundError:
105
+ pass
106
+
107
+ return job
108
+
109
+
110
+ @router.get(
111
+ "/project/{project_id}/job/{job_id}/download/",
112
+ response_class=StreamingResponse,
113
+ )
114
+ async def download_job_logs(
115
+ project_id: int,
116
+ job_id: int,
117
+ user: User = Depends(current_active_user),
118
+ db: AsyncSession = Depends(get_async_db),
119
+ ) -> StreamingResponse:
120
+ """
121
+ Download job folder
122
+ """
123
+ output = await _get_job_check_owner(
124
+ project_id=project_id,
125
+ job_id=job_id,
126
+ user_id=user.id,
127
+ db=db,
128
+ )
129
+ job = output["job"]
130
+
131
+ # Create and return byte stream for zipped log folder
132
+ PREFIX_ZIP = Path(job.working_dir).name
133
+ zip_filename = f"{PREFIX_ZIP}_archive.zip"
134
+ byte_stream = _zip_folder_to_byte_stream(
135
+ folder=job.working_dir, zip_filename=zip_filename
136
+ )
137
+ return StreamingResponse(
138
+ iter([byte_stream.getvalue()]),
139
+ media_type="application/x-zip-compressed",
140
+ headers={"Content-Disposition": f"attachment;filename={zip_filename}"},
141
+ )
142
+
143
+
144
+ @router.get(
145
+ "/project/{project_id}/job/",
146
+ response_model=list[JobReadV2],
147
+ )
148
+ async def get_job_list(
149
+ project_id: int,
150
+ user: User = Depends(current_active_user),
151
+ log: bool = True,
152
+ db: AsyncSession = Depends(get_async_db),
153
+ ) -> Optional[list[JobReadV2]]:
154
+ """
155
+ Get job list for given project
156
+ """
157
+ project = await _get_project_check_owner(
158
+ project_id=project_id, user_id=user.id, db=db
159
+ )
160
+
161
+ stm = select(JobV2).where(JobV2.project_id == project.id)
162
+ res = await db.execute(stm)
163
+ job_list = res.scalars().all()
164
+ await db.close()
165
+ if not log:
166
+ for job in job_list:
167
+ setattr(job, "log", None)
168
+
169
+ return job_list
170
+
171
+
172
+ @router.get(
173
+ "/project/{project_id}/job/{job_id}/stop/",
174
+ status_code=202,
175
+ )
176
+ async def stop_job(
177
+ project_id: int,
178
+ job_id: int,
179
+ user: User = Depends(current_active_user),
180
+ db: AsyncSession = Depends(get_async_db),
181
+ ) -> Response:
182
+ """
183
+ Stop execution of a workflow job (only available for slurm backend)
184
+ """
185
+
186
+ # This endpoint is only implemented for SLURM backend
187
+ _check_backend_is_slurm()
188
+
189
+ # Get job from DB
190
+ output = await _get_job_check_owner(
191
+ project_id=project_id,
192
+ job_id=job_id,
193
+ user_id=user.id,
194
+ db=db,
195
+ )
196
+ job = output["job"]
197
+
198
+ _write_shutdown_file(job=job)
199
+
200
+ return Response(status_code=status.HTTP_202_ACCEPTED)
@@ -0,0 +1,205 @@
1
+ from typing import Optional
2
+
3
+ from fastapi import APIRouter
4
+ from fastapi import Depends
5
+ from fastapi import HTTPException
6
+ from fastapi import Response
7
+ from fastapi import status
8
+ from sqlalchemy.exc import IntegrityError
9
+ from sqlmodel import select
10
+
11
+ from .....logger import close_logger
12
+ from .....logger import set_logger
13
+ from ....db import AsyncSession
14
+ from ....db import get_async_db
15
+ from ....models.v2 import DatasetV2
16
+ from ....models.v2 import JobV2
17
+ from ....models.v2 import LinkUserProjectV2
18
+ from ....models.v2 import ProjectV2
19
+ from ....models.v2 import WorkflowV2
20
+ from ....schemas.v2 import ProjectCreateV2
21
+ from ....schemas.v2 import ProjectReadV2
22
+ from ....schemas.v2 import ProjectUpdateV2
23
+ from ....security import current_active_user
24
+ from ....security import User
25
+ from ._aux_functions import _check_project_exists
26
+ from ._aux_functions import _get_project_check_owner
27
+ from ._aux_functions import _get_submitted_jobs_statement
28
+
29
+ router = APIRouter()
30
+
31
+
32
+ @router.get("/project/", response_model=list[ProjectReadV2])
33
+ async def get_list_project(
34
+ user: User = Depends(current_active_user),
35
+ db: AsyncSession = Depends(get_async_db),
36
+ ) -> list[ProjectV2]:
37
+ """
38
+ Return list of projects user is member of
39
+ """
40
+ stm = (
41
+ select(ProjectV2)
42
+ .join(LinkUserProjectV2)
43
+ .where(LinkUserProjectV2.user_id == user.id)
44
+ )
45
+ res = await db.execute(stm)
46
+ project_list = res.scalars().all()
47
+ await db.close()
48
+ return project_list
49
+
50
+
51
+ @router.post("/project/", response_model=ProjectReadV2, status_code=201)
52
+ async def create_project(
53
+ project: ProjectCreateV2,
54
+ user: User = Depends(current_active_user),
55
+ db: AsyncSession = Depends(get_async_db),
56
+ ) -> Optional[ProjectReadV2]:
57
+ """
58
+ Create new poject
59
+ """
60
+
61
+ # Check that there is no project with the same user and name
62
+ await _check_project_exists(
63
+ project_name=project.name, user_id=user.id, db=db
64
+ )
65
+
66
+ db_project = ProjectV2(**project.dict())
67
+ db_project.user_list.append(user)
68
+ try:
69
+ db.add(db_project)
70
+ await db.commit()
71
+ await db.refresh(db_project)
72
+ await db.close()
73
+ except IntegrityError as e:
74
+ await db.rollback()
75
+ logger = set_logger("create_project")
76
+ logger.error(str(e))
77
+ close_logger(logger)
78
+ raise HTTPException(
79
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
80
+ detail=str(e),
81
+ )
82
+
83
+ return db_project
84
+
85
+
86
+ @router.get("/project/{project_id}/", response_model=ProjectReadV2)
87
+ async def read_project(
88
+ project_id: int,
89
+ user: User = Depends(current_active_user),
90
+ db: AsyncSession = Depends(get_async_db),
91
+ ) -> Optional[ProjectReadV2]:
92
+ """
93
+ Return info on an existing project
94
+ """
95
+ project = await _get_project_check_owner(
96
+ project_id=project_id, user_id=user.id, db=db
97
+ )
98
+ await db.close()
99
+ return project
100
+
101
+
102
+ @router.patch("/project/{project_id}/", response_model=ProjectReadV2)
103
+ async def update_project(
104
+ project_id: int,
105
+ project_update: ProjectUpdateV2,
106
+ user: User = Depends(current_active_user),
107
+ db: AsyncSession = Depends(get_async_db),
108
+ ):
109
+ project = await _get_project_check_owner(
110
+ project_id=project_id, user_id=user.id, db=db
111
+ )
112
+
113
+ # Check that there is no project with the same user and name
114
+ if project_update.name is not None:
115
+ await _check_project_exists(
116
+ project_name=project_update.name, user_id=user.id, db=db
117
+ )
118
+
119
+ for key, value in project_update.dict(exclude_unset=True).items():
120
+ setattr(project, key, value)
121
+
122
+ await db.commit()
123
+ await db.refresh(project)
124
+ await db.close()
125
+ return project
126
+
127
+
128
+ @router.delete("/project/{project_id}/", status_code=204)
129
+ async def delete_project(
130
+ project_id: int,
131
+ user: User = Depends(current_active_user),
132
+ db: AsyncSession = Depends(get_async_db),
133
+ ) -> Response:
134
+ """
135
+ Delete project
136
+ """
137
+ project = await _get_project_check_owner(
138
+ project_id=project_id, user_id=user.id, db=db
139
+ )
140
+
141
+ # Fail if there exist jobs that are submitted and in relation with the
142
+ # current project.
143
+ stm = _get_submitted_jobs_statement().where(JobV2.project_id == project_id)
144
+ res = await db.execute(stm)
145
+ jobs = res.scalars().all()
146
+ if jobs:
147
+ string_ids = str([job.id for job in jobs])[1:-1]
148
+ raise HTTPException(
149
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
150
+ detail=(
151
+ f"Cannot delete project {project.id} because it "
152
+ f"is linked to active job(s) {string_ids}."
153
+ ),
154
+ )
155
+
156
+ # Cascade operations
157
+
158
+ # Workflows
159
+ stm = select(WorkflowV2).where(WorkflowV2.project_id == project_id)
160
+ res = await db.execute(stm)
161
+ workflows = res.scalars().all()
162
+ for wf in workflows:
163
+ # Cascade operations: set foreign-keys to null for jobs which are in
164
+ # relationship with the current workflow
165
+ stm = select(JobV2).where(JobV2.workflow_id == wf.id)
166
+ res = await db.execute(stm)
167
+ jobs = res.scalars().all()
168
+ for job in jobs:
169
+ job.workflow_id = None
170
+ await db.merge(job)
171
+ # Delete workflow
172
+ await db.delete(wf)
173
+ await db.commit()
174
+
175
+ # Dataset
176
+ stm = select(DatasetV2).where(DatasetV2.project_id == project_id)
177
+ res = await db.execute(stm)
178
+ datasets = res.scalars().all()
179
+ for ds in datasets:
180
+ # Cascade operations: set foreign-keys to null for jobs which are in
181
+ # relationship with the current dataset
182
+ stm = select(JobV2).where(JobV2.dataset_id == ds.id)
183
+ res = await db.execute(stm)
184
+ jobs = res.scalars().all()
185
+ for job in jobs:
186
+ job.dataset_id = None
187
+ await db.merge(job)
188
+ # Delete dataset
189
+ await db.delete(ds)
190
+ await db.commit()
191
+
192
+ # Job
193
+ stm = select(JobV2).where(JobV2.project_id == project_id)
194
+ res = await db.execute(stm)
195
+ jobs = res.scalars().all()
196
+ for job in jobs:
197
+ job.project_id = None
198
+ await db.merge(job)
199
+
200
+ await db.commit()
201
+
202
+ await db.delete(project)
203
+ await db.commit()
204
+
205
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
@@ -0,0 +1,222 @@
1
+ from copy import deepcopy # noqa
2
+ from typing import Optional
3
+
4
+ from fastapi import APIRouter
5
+ from fastapi import Depends
6
+ from fastapi import HTTPException
7
+ from fastapi import Response
8
+ from fastapi import status
9
+ from sqlmodel import select
10
+
11
+ from .....logger import set_logger
12
+ from ....db import AsyncSession
13
+ from ....db import get_async_db
14
+ from ....models.v1 import Task as TaskV1
15
+ from ....models.v2 import TaskV2
16
+ from ....models.v2 import WorkflowTaskV2
17
+ from ....schemas.v2 import TaskCreateV2
18
+ from ....schemas.v2 import TaskReadV2
19
+ from ....schemas.v2 import TaskUpdateV2
20
+ from ....security import current_active_user
21
+ from ....security import current_active_verified_user
22
+ from ....security import User
23
+ from ._aux_functions import _get_task_check_owner
24
+
25
+ router = APIRouter()
26
+
27
+ logger = set_logger(__name__)
28
+
29
+
30
+ @router.get("/", response_model=list[TaskReadV2])
31
+ async def get_list_task(
32
+ args_schema_parallel: bool = True,
33
+ args_schema_non_parallel: bool = True,
34
+ user: User = Depends(current_active_user),
35
+ db: AsyncSession = Depends(get_async_db),
36
+ ) -> list[TaskReadV2]:
37
+ """
38
+ Get list of available tasks
39
+ """
40
+ stm = select(TaskV2)
41
+ res = await db.execute(stm)
42
+ task_list = res.scalars().all()
43
+ await db.close()
44
+ if args_schema_parallel is False:
45
+ for task in task_list:
46
+ setattr(task, "args_schema_parallel", None)
47
+ if args_schema_non_parallel is False:
48
+ for task in task_list:
49
+ setattr(task, "args_schema_non_parallel", None)
50
+
51
+ return task_list
52
+
53
+
54
+ @router.get("/{task_id}/", response_model=TaskReadV2)
55
+ async def get_task(
56
+ task_id: int,
57
+ user: User = Depends(current_active_user),
58
+ db: AsyncSession = Depends(get_async_db),
59
+ ) -> TaskReadV2:
60
+ """
61
+ Get info on a specific task
62
+ """
63
+ task = await db.get(TaskV2, task_id)
64
+ await db.close()
65
+ if not task:
66
+ raise HTTPException(
67
+ status_code=status.HTTP_404_NOT_FOUND, detail="TaskV2 not found"
68
+ )
69
+ return task
70
+
71
+
72
+ @router.patch("/{task_id}/", response_model=TaskReadV2)
73
+ async def patch_task(
74
+ task_id: int,
75
+ task_update: TaskUpdateV2,
76
+ user: User = Depends(current_active_verified_user),
77
+ db: AsyncSession = Depends(get_async_db),
78
+ ) -> Optional[TaskReadV2]:
79
+ """
80
+ Edit a specific task (restricted to superusers and task owner)
81
+ """
82
+
83
+ # Retrieve task from database
84
+ db_task = await _get_task_check_owner(task_id=task_id, user=user, db=db)
85
+ update = task_update.dict(exclude_unset=True)
86
+
87
+ # Forbid changes that set a previously unset command
88
+ if db_task.type == "non_parallel" and "command_parallel" in update:
89
+ raise HTTPException(
90
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
91
+ detail="Cannot set an unset `command_parallel`.",
92
+ )
93
+ if db_task.type == "parallel" and "command_non_parallel" in update:
94
+ raise HTTPException(
95
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
96
+ detail="Cannot set an unset `command_non_parallel`.",
97
+ )
98
+
99
+ for key, value in update.items():
100
+ setattr(db_task, key, value)
101
+
102
+ await db.commit()
103
+ await db.refresh(db_task)
104
+ await db.close()
105
+ return db_task
106
+
107
+
108
+ @router.post(
109
+ "/", response_model=TaskReadV2, status_code=status.HTTP_201_CREATED
110
+ )
111
+ async def create_task(
112
+ task: TaskCreateV2,
113
+ user: User = Depends(current_active_verified_user),
114
+ db: AsyncSession = Depends(get_async_db),
115
+ ) -> Optional[TaskReadV2]:
116
+ """
117
+ Create a new task
118
+ """
119
+
120
+ if task.command_non_parallel is None:
121
+ task_type = "parallel"
122
+ elif task.command_parallel is None:
123
+ task_type = "non_parallel"
124
+ else:
125
+ task_type = "compound"
126
+
127
+ if task_type == "parallel" and (
128
+ task.args_schema_non_parallel is not None
129
+ or task.meta_non_parallel is not None
130
+ ):
131
+ raise HTTPException(
132
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
133
+ detail=(
134
+ "Cannot set `TaskV2.args_schema_non_parallel` or "
135
+ "`TaskV2.args_schema_non_parallel` if TaskV2 is parallel"
136
+ ),
137
+ )
138
+ elif task_type == "non_parallel" and (
139
+ task.args_schema_parallel is not None or task.meta_parallel is not None
140
+ ):
141
+ raise HTTPException(
142
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
143
+ detail=(
144
+ "Cannot set `TaskV2.args_schema_parallel` or "
145
+ "`TaskV2.args_schema_parallel` if TaskV2 is non_parallel"
146
+ ),
147
+ )
148
+
149
+ # Set task.owner attribute
150
+ if user.username:
151
+ owner = user.username
152
+ elif user.slurm_user:
153
+ owner = user.slurm_user
154
+ else:
155
+ raise HTTPException(
156
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
157
+ detail=(
158
+ "Cannot add a new task because current user does not "
159
+ "have `username` or `slurm_user` attributes."
160
+ ),
161
+ )
162
+
163
+ # Prepend owner to task.source
164
+ task.source = f"{owner}:{task.source}"
165
+
166
+ # Verify that source is not already in use (note: this check is only useful
167
+ # to provide a user-friendly error message, but `task.source` uniqueness is
168
+ # already guaranteed by a constraint in the table definition).
169
+ stm = select(TaskV2).where(TaskV2.source == task.source)
170
+ res = await db.execute(stm)
171
+ if res.scalars().all():
172
+ raise HTTPException(
173
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
174
+ detail=f"Source '{task.source}' already used by some TaskV2",
175
+ )
176
+ stm = select(TaskV1).where(TaskV1.source == task.source)
177
+ res = await db.execute(stm)
178
+ if res.scalars().all():
179
+ raise HTTPException(
180
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
181
+ detail=f"Source '{task.source}' already used by some TaskV1",
182
+ )
183
+ # Add task
184
+ db_task = TaskV2(**task.dict(), owner=owner, type=task_type)
185
+ db.add(db_task)
186
+ await db.commit()
187
+ await db.refresh(db_task)
188
+ await db.close()
189
+ return db_task
190
+
191
+
192
+ @router.delete("/{task_id}/", status_code=204)
193
+ async def delete_task(
194
+ task_id: int,
195
+ user: User = Depends(current_active_user),
196
+ db: AsyncSession = Depends(get_async_db),
197
+ ) -> Response:
198
+ """
199
+ Delete a task
200
+ """
201
+
202
+ db_task = await _get_task_check_owner(task_id=task_id, user=user, db=db)
203
+
204
+ # Check that the TaskV2 is not in relationship with some WorkflowTaskV2
205
+ stm = select(WorkflowTaskV2).filter(WorkflowTaskV2.task_id == task_id)
206
+ res = await db.execute(stm)
207
+ workflowtask_list = res.scalars().all()
208
+ if workflowtask_list:
209
+ raise HTTPException(
210
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
211
+ detail=(
212
+ f"Cannot remove TaskV2 {task_id} because it is currently "
213
+ "imported in WorkflowsV2 "
214
+ f"{[x.workflow_id for x in workflowtask_list]}. "
215
+ "If you want to remove this task, then you should first remove"
216
+ " the workflows.",
217
+ ),
218
+ )
219
+
220
+ await db.delete(db_task)
221
+ await db.commit()
222
+ return Response(status_code=status.HTTP_204_NO_CONTENT)