fractal-server 2.6.3__py3-none-any.whl → 2.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fractal_server/__init__.py +1 -1
- fractal_server/__main__.py +1 -1
- fractal_server/app/models/linkusergroup.py +11 -0
- fractal_server/app/models/v2/__init__.py +2 -0
- fractal_server/app/models/v2/collection_state.py +1 -0
- fractal_server/app/models/v2/task.py +67 -2
- fractal_server/app/routes/admin/v2/__init__.py +16 -0
- fractal_server/app/routes/admin/{v2.py → v2/job.py} +20 -191
- fractal_server/app/routes/admin/v2/project.py +43 -0
- fractal_server/app/routes/admin/v2/task.py +133 -0
- fractal_server/app/routes/admin/v2/task_group.py +162 -0
- fractal_server/app/routes/api/v1/task_collection.py +4 -4
- fractal_server/app/routes/api/v2/__init__.py +8 -0
- fractal_server/app/routes/api/v2/_aux_functions.py +1 -68
- fractal_server/app/routes/api/v2/_aux_functions_tasks.py +343 -0
- fractal_server/app/routes/api/v2/submit.py +16 -35
- fractal_server/app/routes/api/v2/task.py +85 -110
- fractal_server/app/routes/api/v2/task_collection.py +184 -196
- fractal_server/app/routes/api/v2/task_collection_custom.py +70 -64
- fractal_server/app/routes/api/v2/task_group.py +173 -0
- fractal_server/app/routes/api/v2/workflow.py +39 -102
- fractal_server/app/routes/api/v2/workflow_import.py +360 -0
- fractal_server/app/routes/api/v2/workflowtask.py +4 -8
- fractal_server/app/routes/auth/_aux_auth.py +86 -40
- fractal_server/app/routes/auth/current_user.py +5 -5
- fractal_server/app/routes/auth/group.py +73 -23
- fractal_server/app/routes/auth/router.py +0 -2
- fractal_server/app/routes/auth/users.py +8 -7
- fractal_server/app/runner/executors/slurm/ssh/executor.py +82 -63
- fractal_server/app/runner/v2/__init__.py +13 -7
- fractal_server/app/runner/v2/task_interface.py +4 -9
- fractal_server/app/schemas/user.py +1 -2
- fractal_server/app/schemas/v2/__init__.py +7 -0
- fractal_server/app/schemas/v2/dataset.py +2 -7
- fractal_server/app/schemas/v2/dumps.py +1 -2
- fractal_server/app/schemas/v2/job.py +1 -1
- fractal_server/app/schemas/v2/manifest.py +25 -1
- fractal_server/app/schemas/v2/project.py +1 -1
- fractal_server/app/schemas/v2/task.py +95 -36
- fractal_server/app/schemas/v2/task_collection.py +8 -6
- fractal_server/app/schemas/v2/task_group.py +85 -0
- fractal_server/app/schemas/v2/workflow.py +7 -2
- fractal_server/app/schemas/v2/workflowtask.py +9 -6
- fractal_server/app/security/__init__.py +8 -1
- fractal_server/config.py +8 -28
- fractal_server/data_migrations/2_7_0.py +323 -0
- fractal_server/images/models.py +2 -4
- fractal_server/main.py +1 -1
- fractal_server/migrations/env.py +4 -1
- fractal_server/migrations/versions/034a469ec2eb_task_groups.py +184 -0
- fractal_server/ssh/_fabric.py +186 -73
- fractal_server/string_tools.py +6 -2
- fractal_server/tasks/utils.py +19 -5
- fractal_server/tasks/v1/_TaskCollectPip.py +1 -1
- fractal_server/tasks/v1/background_operations.py +5 -5
- fractal_server/tasks/v1/get_collection_data.py +2 -2
- fractal_server/tasks/v2/_venv_pip.py +67 -70
- fractal_server/tasks/v2/background_operations.py +180 -69
- fractal_server/tasks/v2/background_operations_ssh.py +57 -70
- fractal_server/tasks/v2/database_operations.py +44 -0
- fractal_server/tasks/v2/endpoint_operations.py +104 -116
- fractal_server/tasks/v2/templates/_1_create_venv.sh +9 -5
- fractal_server/tasks/v2/templates/{_2_upgrade_pip.sh → _2_preliminary_pip_operations.sh} +1 -0
- fractal_server/tasks/v2/utils.py +5 -0
- fractal_server/utils.py +3 -2
- {fractal_server-2.6.3.dist-info → fractal_server-2.7.0.dist-info}/METADATA +3 -7
- {fractal_server-2.6.3.dist-info → fractal_server-2.7.0.dist-info}/RECORD +70 -61
- fractal_server/app/routes/auth/group_names.py +0 -34
- fractal_server/tasks/v2/_TaskCollectPip.py +0 -132
- {fractal_server-2.6.3.dist-info → fractal_server-2.7.0.dist-info}/LICENSE +0 -0
- {fractal_server-2.6.3.dist-info → fractal_server-2.7.0.dist-info}/WHEEL +0 -0
- {fractal_server-2.6.3.dist-info → fractal_server-2.7.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,360 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from fastapi import APIRouter
|
4
|
+
from fastapi import Depends
|
5
|
+
from fastapi import HTTPException
|
6
|
+
from fastapi import status
|
7
|
+
from sqlmodel import or_
|
8
|
+
from sqlmodel import select
|
9
|
+
|
10
|
+
from ....db import AsyncSession
|
11
|
+
from ....db import get_async_db
|
12
|
+
from ....models.v2 import TaskV2
|
13
|
+
from ....models.v2 import WorkflowV2
|
14
|
+
from ....schemas.v2 import TaskImportV2Legacy
|
15
|
+
from ....schemas.v2 import WorkflowImportV2
|
16
|
+
from ....schemas.v2 import WorkflowReadV2WithWarnings
|
17
|
+
from ....schemas.v2 import WorkflowTaskCreateV2
|
18
|
+
from ._aux_functions import _check_workflow_exists
|
19
|
+
from ._aux_functions import _get_project_check_owner
|
20
|
+
from ._aux_functions import _workflow_insert_task
|
21
|
+
from ._aux_functions_tasks import _add_warnings_to_workflow_tasks
|
22
|
+
from fractal_server.app.models import LinkUserGroup
|
23
|
+
from fractal_server.app.models import UserOAuth
|
24
|
+
from fractal_server.app.models.v2.task import TaskGroupV2
|
25
|
+
from fractal_server.app.routes.auth import current_active_user
|
26
|
+
from fractal_server.app.routes.auth._aux_auth import _get_default_usergroup_id
|
27
|
+
from fractal_server.app.schemas.v2.task import TaskImportV2
|
28
|
+
from fractal_server.logger import set_logger
|
29
|
+
|
30
|
+
router = APIRouter()
|
31
|
+
|
32
|
+
|
33
|
+
logger = set_logger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
async def _get_user_accessible_taskgroups(
|
37
|
+
*,
|
38
|
+
user_id: int,
|
39
|
+
db: AsyncSession,
|
40
|
+
) -> list[TaskGroupV2]:
|
41
|
+
"""
|
42
|
+
Retrieve list of task groups that the user has access to.
|
43
|
+
"""
|
44
|
+
stm = select(TaskGroupV2).where(
|
45
|
+
or_(
|
46
|
+
TaskGroupV2.user_id == user_id,
|
47
|
+
TaskGroupV2.user_group_id.in_(
|
48
|
+
select(LinkUserGroup.group_id).where(
|
49
|
+
LinkUserGroup.user_id == user_id
|
50
|
+
)
|
51
|
+
),
|
52
|
+
)
|
53
|
+
)
|
54
|
+
res = await db.execute(stm)
|
55
|
+
accessible_task_groups = res.scalars().all()
|
56
|
+
logger.info(
|
57
|
+
f"Found {len(accessible_task_groups)} accessible "
|
58
|
+
f"task groups for {user_id=}."
|
59
|
+
)
|
60
|
+
return accessible_task_groups
|
61
|
+
|
62
|
+
|
63
|
+
async def _get_task_by_source(
|
64
|
+
source: str,
|
65
|
+
task_groups_list: list[TaskGroupV2],
|
66
|
+
) -> Optional[int]:
|
67
|
+
"""
|
68
|
+
Find task with a given source.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
task_import: Info on task to be imported.
|
72
|
+
user_id: ID of current user.
|
73
|
+
default_group_id: ID of default user group.
|
74
|
+
task_group_list: Current list of valid task groups.
|
75
|
+
db: Asynchronous db session
|
76
|
+
|
77
|
+
Return:
|
78
|
+
`id` of the matching task, or `None`.
|
79
|
+
"""
|
80
|
+
task_id = next(
|
81
|
+
iter(
|
82
|
+
task.id
|
83
|
+
for task_group in task_groups_list
|
84
|
+
for task in task_group.task_list
|
85
|
+
if task.source == source
|
86
|
+
),
|
87
|
+
None,
|
88
|
+
)
|
89
|
+
return task_id
|
90
|
+
|
91
|
+
|
92
|
+
async def _disambiguate_task_groups(
|
93
|
+
*,
|
94
|
+
matching_task_groups: list[TaskGroupV2],
|
95
|
+
user_id: int,
|
96
|
+
db: AsyncSession,
|
97
|
+
default_group_id: int,
|
98
|
+
) -> Optional[TaskV2]:
|
99
|
+
"""
|
100
|
+
Disambiguate task groups based on ownership information.
|
101
|
+
"""
|
102
|
+
# Highest priority: task groups created by user
|
103
|
+
for task_group in matching_task_groups:
|
104
|
+
if task_group.user_id == user_id:
|
105
|
+
logger.info(
|
106
|
+
"[_disambiguate_task_groups] "
|
107
|
+
f"Found task group {task_group.id} with {user_id=}, return."
|
108
|
+
)
|
109
|
+
return task_group
|
110
|
+
logger.info(
|
111
|
+
"[_disambiguate_task_groups] "
|
112
|
+
f"No task group found with {user_id=}, continue."
|
113
|
+
)
|
114
|
+
|
115
|
+
# Medium priority: task groups owned by default user group
|
116
|
+
for task_group in matching_task_groups:
|
117
|
+
if task_group.user_group_id == default_group_id:
|
118
|
+
logger.info(
|
119
|
+
"[_disambiguate_task_groups] "
|
120
|
+
f"Found task group {task_group.id} with user_group_id="
|
121
|
+
f"{default_group_id}, return."
|
122
|
+
)
|
123
|
+
return task_group
|
124
|
+
logger.info(
|
125
|
+
"[_disambiguate_task_groups] "
|
126
|
+
"No task group found with user_group_id="
|
127
|
+
f"{default_group_id}, continue."
|
128
|
+
)
|
129
|
+
|
130
|
+
# Lowest priority: task groups owned by other groups, sorted
|
131
|
+
# according to age of the user/usergroup link
|
132
|
+
logger.info(
|
133
|
+
"[_disambiguate_task_groups] "
|
134
|
+
"Now sorting remaining task groups by oldest-user-link."
|
135
|
+
)
|
136
|
+
user_group_ids = [
|
137
|
+
task_group.user_group_id for task_group in matching_task_groups
|
138
|
+
]
|
139
|
+
stm = (
|
140
|
+
select(LinkUserGroup.group_id)
|
141
|
+
.where(LinkUserGroup.user_id == user_id)
|
142
|
+
.where(LinkUserGroup.group_id.in_(user_group_ids))
|
143
|
+
.order_by(LinkUserGroup.timestamp_created.asc())
|
144
|
+
)
|
145
|
+
res = await db.execute(stm)
|
146
|
+
oldest_user_group_id = res.scalars().first()
|
147
|
+
logger.info(
|
148
|
+
"[_disambiguate_task_groups] "
|
149
|
+
f"Result of sorting: {oldest_user_group_id=}."
|
150
|
+
)
|
151
|
+
task_group = next(
|
152
|
+
iter(
|
153
|
+
task_group
|
154
|
+
for task_group in matching_task_groups
|
155
|
+
if task_group.user_group_id == oldest_user_group_id
|
156
|
+
),
|
157
|
+
None,
|
158
|
+
)
|
159
|
+
return task_group
|
160
|
+
|
161
|
+
|
162
|
+
async def _get_task_by_taskimport(
|
163
|
+
*,
|
164
|
+
task_import: TaskImportV2,
|
165
|
+
task_groups_list: list[TaskGroupV2],
|
166
|
+
user_id: int,
|
167
|
+
default_group_id: int,
|
168
|
+
db: AsyncSession,
|
169
|
+
) -> Optional[int]:
|
170
|
+
"""
|
171
|
+
Find a task based on `task_import`.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
task_import: Info on task to be imported.
|
175
|
+
user_id: ID of current user.
|
176
|
+
default_group_id: ID of default user group.
|
177
|
+
task_group_list: Current list of valid task groups.
|
178
|
+
db: Asynchronous db session
|
179
|
+
|
180
|
+
Return:
|
181
|
+
`id` of the matching task, or `None`.
|
182
|
+
"""
|
183
|
+
|
184
|
+
logger.info(f"[_get_task_by_taskimport] START, {task_import=}")
|
185
|
+
|
186
|
+
# Filter by `pkg_name` and by presence of a task with given `name`.
|
187
|
+
matching_task_groups = [
|
188
|
+
task_group
|
189
|
+
for task_group in task_groups_list
|
190
|
+
if (
|
191
|
+
task_group.pkg_name == task_import.pkg_name
|
192
|
+
and task_import.name
|
193
|
+
in [task.name for task in task_group.task_list]
|
194
|
+
)
|
195
|
+
]
|
196
|
+
if len(matching_task_groups) < 1:
|
197
|
+
logger.info(
|
198
|
+
"[_get_task_by_taskimport] "
|
199
|
+
f"No task group with {task_import.pkg_name=} "
|
200
|
+
f"and a task with {task_import.name=}."
|
201
|
+
)
|
202
|
+
return None
|
203
|
+
|
204
|
+
# Determine target `version`
|
205
|
+
# Note that task_import.version cannot be "", due to a validator
|
206
|
+
if task_import.version is None:
|
207
|
+
logger.info(
|
208
|
+
"[_get_task_by_taskimport] "
|
209
|
+
"No version requested, looking for latest."
|
210
|
+
)
|
211
|
+
latest_task = max(
|
212
|
+
matching_task_groups, key=lambda tg: tg.version or ""
|
213
|
+
)
|
214
|
+
version = latest_task.version
|
215
|
+
logger.info(
|
216
|
+
f"[_get_task_by_taskimport] Latest version set to {version}."
|
217
|
+
)
|
218
|
+
else:
|
219
|
+
version = task_import.version
|
220
|
+
|
221
|
+
# Filter task groups by version
|
222
|
+
final_matching_task_groups = list(
|
223
|
+
filter(lambda tg: tg.version == version, task_groups_list)
|
224
|
+
)
|
225
|
+
|
226
|
+
if len(final_matching_task_groups) < 1:
|
227
|
+
logger.info(
|
228
|
+
"[_get_task_by_taskimport] "
|
229
|
+
"No task group left after filtering by version."
|
230
|
+
)
|
231
|
+
return None
|
232
|
+
elif len(final_matching_task_groups) == 1:
|
233
|
+
final_task_group = final_matching_task_groups[0]
|
234
|
+
logger.info(
|
235
|
+
"[_get_task_by_taskimport] "
|
236
|
+
"Found a single task group, after filtering by version."
|
237
|
+
)
|
238
|
+
else:
|
239
|
+
logger.info(
|
240
|
+
"[_get_task_by_taskimport] "
|
241
|
+
"Found many task groups, after filtering by version."
|
242
|
+
)
|
243
|
+
final_task_group = await _disambiguate_task_groups(
|
244
|
+
matching_task_groups=matching_task_groups,
|
245
|
+
user_id=user_id,
|
246
|
+
db=db,
|
247
|
+
default_group_id=default_group_id,
|
248
|
+
)
|
249
|
+
if final_task_group is None:
|
250
|
+
logger.info(
|
251
|
+
"[_get_task_by_taskimport] Disambiguation returned None."
|
252
|
+
)
|
253
|
+
return None
|
254
|
+
|
255
|
+
# Find task with given name
|
256
|
+
task_id = next(
|
257
|
+
iter(
|
258
|
+
task.id
|
259
|
+
for task in final_task_group.task_list
|
260
|
+
if task.name == task_import.name
|
261
|
+
),
|
262
|
+
None,
|
263
|
+
)
|
264
|
+
|
265
|
+
logger.info(f"[_get_task_by_taskimport] END, {task_import=}, {task_id=}.")
|
266
|
+
|
267
|
+
return task_id
|
268
|
+
|
269
|
+
|
270
|
+
@router.post(
|
271
|
+
"/project/{project_id}/workflow/import/",
|
272
|
+
response_model=WorkflowReadV2WithWarnings,
|
273
|
+
status_code=status.HTTP_201_CREATED,
|
274
|
+
)
|
275
|
+
async def import_workflow(
|
276
|
+
project_id: int,
|
277
|
+
workflow_import: WorkflowImportV2,
|
278
|
+
user: UserOAuth = Depends(current_active_user),
|
279
|
+
db: AsyncSession = Depends(get_async_db),
|
280
|
+
) -> WorkflowReadV2WithWarnings:
|
281
|
+
"""
|
282
|
+
Import an existing workflow into a project and create required objects.
|
283
|
+
"""
|
284
|
+
|
285
|
+
# Preliminary checks
|
286
|
+
await _get_project_check_owner(
|
287
|
+
project_id=project_id,
|
288
|
+
user_id=user.id,
|
289
|
+
db=db,
|
290
|
+
)
|
291
|
+
await _check_workflow_exists(
|
292
|
+
name=workflow_import.name,
|
293
|
+
project_id=project_id,
|
294
|
+
db=db,
|
295
|
+
)
|
296
|
+
|
297
|
+
task_group_list = await _get_user_accessible_taskgroups(
|
298
|
+
user_id=user.id,
|
299
|
+
db=db,
|
300
|
+
)
|
301
|
+
default_group_id = await _get_default_usergroup_id(db)
|
302
|
+
|
303
|
+
list_wf_tasks = []
|
304
|
+
list_task_ids = []
|
305
|
+
for wf_task in workflow_import.task_list:
|
306
|
+
task_import = wf_task.task
|
307
|
+
if isinstance(task_import, TaskImportV2Legacy):
|
308
|
+
task_id = await _get_task_by_source(
|
309
|
+
source=task_import.source,
|
310
|
+
task_groups_list=task_group_list,
|
311
|
+
)
|
312
|
+
else:
|
313
|
+
task_id = await _get_task_by_taskimport(
|
314
|
+
task_import=task_import,
|
315
|
+
user_id=user.id,
|
316
|
+
default_group_id=default_group_id,
|
317
|
+
task_groups_list=task_group_list,
|
318
|
+
db=db,
|
319
|
+
)
|
320
|
+
if task_id is None:
|
321
|
+
raise HTTPException(
|
322
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
323
|
+
detail=f"Could not find a task matching with {wf_task.task}.",
|
324
|
+
)
|
325
|
+
new_wf_task = WorkflowTaskCreateV2(
|
326
|
+
**wf_task.dict(exclude_none=True, exclude={"task"})
|
327
|
+
)
|
328
|
+
list_wf_tasks.append(new_wf_task)
|
329
|
+
list_task_ids.append(task_id)
|
330
|
+
|
331
|
+
# Create new Workflow
|
332
|
+
db_workflow = WorkflowV2(
|
333
|
+
project_id=project_id,
|
334
|
+
**workflow_import.dict(exclude_none=True, exclude={"task_list"}),
|
335
|
+
)
|
336
|
+
db.add(db_workflow)
|
337
|
+
await db.commit()
|
338
|
+
await db.refresh(db_workflow)
|
339
|
+
|
340
|
+
# Insert task into the workflow
|
341
|
+
for ind, new_wf_task in enumerate(list_wf_tasks):
|
342
|
+
await _workflow_insert_task(
|
343
|
+
**new_wf_task.dict(),
|
344
|
+
workflow_id=db_workflow.id,
|
345
|
+
task_id=list_task_ids[ind],
|
346
|
+
db=db,
|
347
|
+
)
|
348
|
+
|
349
|
+
# Add warnings for non-active tasks (or non-accessible tasks,
|
350
|
+
# although that should never happen)
|
351
|
+
wftask_list_with_warnings = await _add_warnings_to_workflow_tasks(
|
352
|
+
wftask_list=db_workflow.task_list, user_id=user.id, db=db
|
353
|
+
)
|
354
|
+
workflow_data = dict(
|
355
|
+
**db_workflow.model_dump(),
|
356
|
+
project=db_workflow.project,
|
357
|
+
task_list=wftask_list_with_warnings,
|
358
|
+
)
|
359
|
+
|
360
|
+
return workflow_data
|
@@ -9,13 +9,13 @@ from fastapi import status
|
|
9
9
|
|
10
10
|
from ....db import AsyncSession
|
11
11
|
from ....db import get_async_db
|
12
|
-
from ....models.v2 import TaskV2
|
13
12
|
from ....schemas.v2 import WorkflowTaskCreateV2
|
14
13
|
from ....schemas.v2 import WorkflowTaskReadV2
|
15
14
|
from ....schemas.v2 import WorkflowTaskUpdateV2
|
16
15
|
from ._aux_functions import _get_workflow_check_owner
|
17
16
|
from ._aux_functions import _get_workflow_task_check_owner
|
18
17
|
from ._aux_functions import _workflow_insert_task
|
18
|
+
from ._aux_functions_tasks import _get_task_read_access
|
19
19
|
from fractal_server.app.models import UserOAuth
|
20
20
|
from fractal_server.app.routes.auth import current_active_user
|
21
21
|
|
@@ -43,12 +43,9 @@ async def create_workflowtask(
|
|
43
43
|
project_id=project_id, workflow_id=workflow_id, user_id=user.id, db=db
|
44
44
|
)
|
45
45
|
|
46
|
-
task = await
|
47
|
-
|
48
|
-
|
49
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
50
|
-
detail=f"TaskV2 {task_id} not found.",
|
51
|
-
)
|
46
|
+
task = await _get_task_read_access(
|
47
|
+
task_id=task_id, user_id=user.id, db=db, require_active=True
|
48
|
+
)
|
52
49
|
|
53
50
|
if task.type == "parallel":
|
54
51
|
if (
|
@@ -80,7 +77,6 @@ async def create_workflowtask(
|
|
80
77
|
workflow_task = await _workflow_insert_task(
|
81
78
|
workflow_id=workflow.id,
|
82
79
|
task_id=task_id,
|
83
|
-
order=new_task.order,
|
84
80
|
meta_non_parallel=new_task.meta_non_parallel,
|
85
81
|
meta_parallel=new_task.meta_parallel,
|
86
82
|
args_non_parallel=new_task.args_non_parallel,
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from fastapi import HTTPException
|
2
2
|
from fastapi import status
|
3
3
|
from sqlalchemy.ext.asyncio import AsyncSession
|
4
|
+
from sqlmodel import asc
|
4
5
|
from sqlmodel import select
|
5
6
|
|
6
7
|
from fractal_server.app.models.linkusergroup import LinkUserGroup
|
@@ -8,63 +9,64 @@ from fractal_server.app.models.security import UserGroup
|
|
8
9
|
from fractal_server.app.models.security import UserOAuth
|
9
10
|
from fractal_server.app.schemas.user import UserRead
|
10
11
|
from fractal_server.app.schemas.user_group import UserGroupRead
|
12
|
+
from fractal_server.app.security import FRACTAL_DEFAULT_GROUP_NAME
|
13
|
+
from fractal_server.logger import set_logger
|
11
14
|
|
15
|
+
logger = set_logger(__name__)
|
12
16
|
|
13
|
-
|
17
|
+
|
18
|
+
async def _get_single_user_with_groups(
|
14
19
|
user: UserOAuth,
|
15
20
|
db: AsyncSession,
|
16
21
|
) -> UserRead:
|
17
22
|
"""
|
18
|
-
Enrich a user object by filling its `
|
23
|
+
Enrich a user object by filling its `group_ids_names` attribute.
|
19
24
|
|
20
25
|
Arguments:
|
21
26
|
user: The current `UserOAuth` object
|
22
27
|
db: Async db session
|
23
28
|
|
24
29
|
Returns:
|
25
|
-
A `UserRead` object with `
|
30
|
+
A `UserRead` object with `group_ids_names` dict
|
26
31
|
"""
|
27
32
|
stm_groups = (
|
28
33
|
select(UserGroup)
|
29
34
|
.join(LinkUserGroup)
|
30
|
-
.where(LinkUserGroup.user_id ==
|
35
|
+
.where(LinkUserGroup.user_id == user.id)
|
36
|
+
.order_by(asc(LinkUserGroup.timestamp_created))
|
31
37
|
)
|
32
38
|
res = await db.execute(stm_groups)
|
33
39
|
groups = res.scalars().unique().all()
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
40
|
+
group_ids_names = [(group.id, group.name) for group in groups]
|
41
|
+
|
42
|
+
# Check that Fractal Default Group is the first of the list. If not, fix.
|
43
|
+
index = next(
|
44
|
+
(
|
45
|
+
i
|
46
|
+
for i, group_tuple in enumerate(group_ids_names)
|
47
|
+
if group_tuple[1] == FRACTAL_DEFAULT_GROUP_NAME
|
48
|
+
),
|
49
|
+
None,
|
39
50
|
)
|
51
|
+
if index is None:
|
52
|
+
logger.warning(
|
53
|
+
f"User {user.id} not in "
|
54
|
+
f"default UserGroup '{FRACTAL_DEFAULT_GROUP_NAME}'"
|
55
|
+
)
|
56
|
+
elif index != 0:
|
57
|
+
default_group = group_ids_names.pop(index)
|
58
|
+
group_ids_names.insert(0, default_group)
|
59
|
+
else:
|
60
|
+
pass
|
40
61
|
|
41
|
-
|
42
|
-
async def _get_single_user_with_group_ids(
|
43
|
-
user: UserOAuth,
|
44
|
-
db: AsyncSession,
|
45
|
-
) -> UserRead:
|
46
|
-
"""
|
47
|
-
Enrich a user object by filling its `group_ids` attribute.
|
48
|
-
|
49
|
-
Arguments:
|
50
|
-
user: The current `UserOAuth` object
|
51
|
-
db: Async db session
|
52
|
-
|
53
|
-
Returns:
|
54
|
-
A `UserRead` object with `group_ids` set
|
55
|
-
"""
|
56
|
-
stm_links = select(LinkUserGroup).where(LinkUserGroup.user_id == user.id)
|
57
|
-
res = await db.execute(stm_links)
|
58
|
-
links = res.scalars().unique().all()
|
59
|
-
group_ids = [link.group_id for link in links]
|
60
62
|
return UserRead(
|
61
63
|
**user.model_dump(),
|
62
|
-
|
64
|
+
group_ids_names=group_ids_names,
|
63
65
|
oauth_accounts=user.oauth_accounts,
|
64
66
|
)
|
65
67
|
|
66
68
|
|
67
|
-
async def
|
69
|
+
async def _get_single_usergroup_with_user_ids(
|
68
70
|
group_id: int, db: AsyncSession
|
69
71
|
) -> UserGroupRead:
|
70
72
|
"""
|
@@ -78,15 +80,7 @@ async def _get_single_group_with_user_ids(
|
|
78
80
|
`UserGroupRead` object, with `user_ids` attribute populated
|
79
81
|
from database.
|
80
82
|
"""
|
81
|
-
|
82
|
-
stm_group = select(UserGroup).where(UserGroup.id == group_id)
|
83
|
-
res = await db.execute(stm_group)
|
84
|
-
group = res.scalars().one_or_none()
|
85
|
-
if group is None:
|
86
|
-
raise HTTPException(
|
87
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
88
|
-
detail=f"Group {group_id} not found.",
|
89
|
-
)
|
83
|
+
group = await _usergroup_or_404(group_id, db)
|
90
84
|
|
91
85
|
# Get all user/group links
|
92
86
|
stm_links = select(LinkUserGroup).where(LinkUserGroup.group_id == group_id)
|
@@ -108,6 +102,58 @@ async def _user_or_404(user_id: int, db: AsyncSession) -> UserOAuth:
|
|
108
102
|
user = await db.get(UserOAuth, user_id, populate_existing=True)
|
109
103
|
if user is None:
|
110
104
|
raise HTTPException(
|
111
|
-
status_code=status.HTTP_404_NOT_FOUND,
|
105
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
106
|
+
detail=f"User {user_id} not found.",
|
107
|
+
)
|
108
|
+
return user
|
109
|
+
|
110
|
+
|
111
|
+
async def _usergroup_or_404(usergroup_id: int, db: AsyncSession) -> UserGroup:
|
112
|
+
user = await db.get(UserGroup, usergroup_id)
|
113
|
+
if user is None:
|
114
|
+
raise HTTPException(
|
115
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
116
|
+
detail=f"UserGroup {usergroup_id} not found.",
|
112
117
|
)
|
113
118
|
return user
|
119
|
+
|
120
|
+
|
121
|
+
async def _get_default_usergroup_id(db: AsyncSession) -> int:
|
122
|
+
stm = select(UserGroup.id).where(
|
123
|
+
UserGroup.name == FRACTAL_DEFAULT_GROUP_NAME
|
124
|
+
)
|
125
|
+
res = await db.execute(stm)
|
126
|
+
user_group_id = res.scalars().one_or_none()
|
127
|
+
if user_group_id is None:
|
128
|
+
raise HTTPException(
|
129
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
130
|
+
detail=f"User group '{FRACTAL_DEFAULT_GROUP_NAME}' not found.",
|
131
|
+
)
|
132
|
+
return user_group_id
|
133
|
+
|
134
|
+
|
135
|
+
async def _verify_user_belongs_to_group(
|
136
|
+
*, user_id: int, user_group_id: int, db: AsyncSession
|
137
|
+
):
|
138
|
+
stm = (
|
139
|
+
select(LinkUserGroup)
|
140
|
+
.where(LinkUserGroup.user_id == user_id)
|
141
|
+
.where(LinkUserGroup.group_id == user_group_id)
|
142
|
+
)
|
143
|
+
res = await db.execute(stm)
|
144
|
+
link = res.scalars().one_or_none()
|
145
|
+
if link is None:
|
146
|
+
group = await db.get(UserGroup, user_group_id)
|
147
|
+
if group is None:
|
148
|
+
raise HTTPException(
|
149
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
150
|
+
detail=f"UserGroup {user_group_id} not found",
|
151
|
+
)
|
152
|
+
else:
|
153
|
+
raise HTTPException(
|
154
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
155
|
+
detail=(
|
156
|
+
f"User {user_id} does not belong "
|
157
|
+
f"to UserGroup {user_group_id}"
|
158
|
+
),
|
159
|
+
)
|
@@ -13,7 +13,7 @@ from ...schemas.user import UserRead
|
|
13
13
|
from ...schemas.user import UserUpdate
|
14
14
|
from ...schemas.user import UserUpdateStrict
|
15
15
|
from ..aux.validate_user_settings import verify_user_has_settings
|
16
|
-
from ._aux_auth import
|
16
|
+
from ._aux_auth import _get_single_user_with_groups
|
17
17
|
from fractal_server.app.models import LinkUserGroup
|
18
18
|
from fractal_server.app.models import UserGroup
|
19
19
|
from fractal_server.app.models import UserOAuth
|
@@ -28,15 +28,15 @@ router_current_user = APIRouter()
|
|
28
28
|
|
29
29
|
@router_current_user.get("/current-user/", response_model=UserRead)
|
30
30
|
async def get_current_user(
|
31
|
-
|
31
|
+
group_ids_names: bool = False,
|
32
32
|
user: UserOAuth = Depends(current_active_user),
|
33
33
|
db: AsyncSession = Depends(get_async_db),
|
34
34
|
):
|
35
35
|
"""
|
36
36
|
Return current user
|
37
37
|
"""
|
38
|
-
if
|
39
|
-
user_with_groups = await
|
38
|
+
if group_ids_names is True:
|
39
|
+
user_with_groups = await _get_single_user_with_groups(user, db)
|
40
40
|
return user_with_groups
|
41
41
|
else:
|
42
42
|
return user
|
@@ -65,7 +65,7 @@ async def patch_current_user(
|
|
65
65
|
patched_user = await db.get(
|
66
66
|
UserOAuth, validated_user.id, populate_existing=True
|
67
67
|
)
|
68
|
-
patched_user_with_groups = await
|
68
|
+
patched_user_with_groups = await _get_single_user_with_groups(
|
69
69
|
patched_user, db
|
70
70
|
)
|
71
71
|
return patched_user_with_groups
|