fractal-server 2.6.4__py3-none-any.whl → 2.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. fractal_server/__init__.py +1 -1
  2. fractal_server/__main__.py +1 -1
  3. fractal_server/app/models/linkusergroup.py +11 -0
  4. fractal_server/app/models/v2/__init__.py +2 -0
  5. fractal_server/app/models/v2/collection_state.py +1 -0
  6. fractal_server/app/models/v2/task.py +67 -2
  7. fractal_server/app/routes/admin/v2/__init__.py +16 -0
  8. fractal_server/app/routes/admin/{v2.py → v2/job.py} +20 -191
  9. fractal_server/app/routes/admin/v2/project.py +43 -0
  10. fractal_server/app/routes/admin/v2/task.py +133 -0
  11. fractal_server/app/routes/admin/v2/task_group.py +162 -0
  12. fractal_server/app/routes/api/v1/task_collection.py +4 -4
  13. fractal_server/app/routes/api/v2/__init__.py +8 -0
  14. fractal_server/app/routes/api/v2/_aux_functions.py +1 -68
  15. fractal_server/app/routes/api/v2/_aux_functions_tasks.py +343 -0
  16. fractal_server/app/routes/api/v2/submit.py +16 -35
  17. fractal_server/app/routes/api/v2/task.py +85 -110
  18. fractal_server/app/routes/api/v2/task_collection.py +184 -196
  19. fractal_server/app/routes/api/v2/task_collection_custom.py +70 -64
  20. fractal_server/app/routes/api/v2/task_group.py +173 -0
  21. fractal_server/app/routes/api/v2/workflow.py +39 -102
  22. fractal_server/app/routes/api/v2/workflow_import.py +360 -0
  23. fractal_server/app/routes/api/v2/workflowtask.py +4 -8
  24. fractal_server/app/routes/auth/_aux_auth.py +86 -40
  25. fractal_server/app/routes/auth/current_user.py +5 -5
  26. fractal_server/app/routes/auth/group.py +73 -23
  27. fractal_server/app/routes/auth/router.py +0 -2
  28. fractal_server/app/routes/auth/users.py +8 -7
  29. fractal_server/app/runner/executors/slurm/ssh/executor.py +82 -63
  30. fractal_server/app/runner/v2/__init__.py +13 -7
  31. fractal_server/app/runner/v2/task_interface.py +4 -9
  32. fractal_server/app/schemas/user.py +1 -2
  33. fractal_server/app/schemas/v2/__init__.py +7 -0
  34. fractal_server/app/schemas/v2/dataset.py +2 -7
  35. fractal_server/app/schemas/v2/dumps.py +1 -2
  36. fractal_server/app/schemas/v2/job.py +1 -1
  37. fractal_server/app/schemas/v2/manifest.py +25 -1
  38. fractal_server/app/schemas/v2/project.py +1 -1
  39. fractal_server/app/schemas/v2/task.py +95 -36
  40. fractal_server/app/schemas/v2/task_collection.py +8 -6
  41. fractal_server/app/schemas/v2/task_group.py +85 -0
  42. fractal_server/app/schemas/v2/workflow.py +7 -2
  43. fractal_server/app/schemas/v2/workflowtask.py +9 -6
  44. fractal_server/app/security/__init__.py +8 -1
  45. fractal_server/config.py +8 -28
  46. fractal_server/data_migrations/2_7_0.py +323 -0
  47. fractal_server/images/models.py +2 -4
  48. fractal_server/main.py +1 -1
  49. fractal_server/migrations/versions/034a469ec2eb_task_groups.py +184 -0
  50. fractal_server/ssh/_fabric.py +186 -73
  51. fractal_server/string_tools.py +6 -2
  52. fractal_server/tasks/utils.py +19 -5
  53. fractal_server/tasks/v1/_TaskCollectPip.py +1 -1
  54. fractal_server/tasks/v1/background_operations.py +5 -5
  55. fractal_server/tasks/v1/get_collection_data.py +2 -2
  56. fractal_server/tasks/v2/_venv_pip.py +67 -70
  57. fractal_server/tasks/v2/background_operations.py +180 -69
  58. fractal_server/tasks/v2/background_operations_ssh.py +57 -70
  59. fractal_server/tasks/v2/database_operations.py +44 -0
  60. fractal_server/tasks/v2/endpoint_operations.py +104 -116
  61. fractal_server/tasks/v2/templates/_1_create_venv.sh +9 -5
  62. fractal_server/tasks/v2/templates/{_2_upgrade_pip.sh → _2_preliminary_pip_operations.sh} +1 -0
  63. fractal_server/tasks/v2/utils.py +5 -0
  64. fractal_server/utils.py +3 -2
  65. {fractal_server-2.6.4.dist-info → fractal_server-2.7.0.dist-info}/METADATA +3 -7
  66. {fractal_server-2.6.4.dist-info → fractal_server-2.7.0.dist-info}/RECORD +69 -60
  67. fractal_server/app/routes/auth/group_names.py +0 -34
  68. fractal_server/tasks/v2/_TaskCollectPip.py +0 -132
  69. {fractal_server-2.6.4.dist-info → fractal_server-2.7.0.dist-info}/LICENSE +0 -0
  70. {fractal_server-2.6.4.dist-info → fractal_server-2.7.0.dist-info}/WHEEL +0 -0
  71. {fractal_server-2.6.4.dist-info → fractal_server-2.7.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,360 @@
1
+ from typing import Optional
2
+
3
+ from fastapi import APIRouter
4
+ from fastapi import Depends
5
+ from fastapi import HTTPException
6
+ from fastapi import status
7
+ from sqlmodel import or_
8
+ from sqlmodel import select
9
+
10
+ from ....db import AsyncSession
11
+ from ....db import get_async_db
12
+ from ....models.v2 import TaskV2
13
+ from ....models.v2 import WorkflowV2
14
+ from ....schemas.v2 import TaskImportV2Legacy
15
+ from ....schemas.v2 import WorkflowImportV2
16
+ from ....schemas.v2 import WorkflowReadV2WithWarnings
17
+ from ....schemas.v2 import WorkflowTaskCreateV2
18
+ from ._aux_functions import _check_workflow_exists
19
+ from ._aux_functions import _get_project_check_owner
20
+ from ._aux_functions import _workflow_insert_task
21
+ from ._aux_functions_tasks import _add_warnings_to_workflow_tasks
22
+ from fractal_server.app.models import LinkUserGroup
23
+ from fractal_server.app.models import UserOAuth
24
+ from fractal_server.app.models.v2.task import TaskGroupV2
25
+ from fractal_server.app.routes.auth import current_active_user
26
+ from fractal_server.app.routes.auth._aux_auth import _get_default_usergroup_id
27
+ from fractal_server.app.schemas.v2.task import TaskImportV2
28
+ from fractal_server.logger import set_logger
29
+
30
+ router = APIRouter()
31
+
32
+
33
+ logger = set_logger(__name__)
34
+
35
+
36
+ async def _get_user_accessible_taskgroups(
37
+ *,
38
+ user_id: int,
39
+ db: AsyncSession,
40
+ ) -> list[TaskGroupV2]:
41
+ """
42
+ Retrieve list of task groups that the user has access to.
43
+ """
44
+ stm = select(TaskGroupV2).where(
45
+ or_(
46
+ TaskGroupV2.user_id == user_id,
47
+ TaskGroupV2.user_group_id.in_(
48
+ select(LinkUserGroup.group_id).where(
49
+ LinkUserGroup.user_id == user_id
50
+ )
51
+ ),
52
+ )
53
+ )
54
+ res = await db.execute(stm)
55
+ accessible_task_groups = res.scalars().all()
56
+ logger.info(
57
+ f"Found {len(accessible_task_groups)} accessible "
58
+ f"task groups for {user_id=}."
59
+ )
60
+ return accessible_task_groups
61
+
62
+
63
+ async def _get_task_by_source(
64
+ source: str,
65
+ task_groups_list: list[TaskGroupV2],
66
+ ) -> Optional[int]:
67
+ """
68
+ Find task with a given source.
69
+
70
+ Args:
71
+ task_import: Info on task to be imported.
72
+ user_id: ID of current user.
73
+ default_group_id: ID of default user group.
74
+ task_group_list: Current list of valid task groups.
75
+ db: Asynchronous db session
76
+
77
+ Return:
78
+ `id` of the matching task, or `None`.
79
+ """
80
+ task_id = next(
81
+ iter(
82
+ task.id
83
+ for task_group in task_groups_list
84
+ for task in task_group.task_list
85
+ if task.source == source
86
+ ),
87
+ None,
88
+ )
89
+ return task_id
90
+
91
+
92
+ async def _disambiguate_task_groups(
93
+ *,
94
+ matching_task_groups: list[TaskGroupV2],
95
+ user_id: int,
96
+ db: AsyncSession,
97
+ default_group_id: int,
98
+ ) -> Optional[TaskV2]:
99
+ """
100
+ Disambiguate task groups based on ownership information.
101
+ """
102
+ # Highest priority: task groups created by user
103
+ for task_group in matching_task_groups:
104
+ if task_group.user_id == user_id:
105
+ logger.info(
106
+ "[_disambiguate_task_groups] "
107
+ f"Found task group {task_group.id} with {user_id=}, return."
108
+ )
109
+ return task_group
110
+ logger.info(
111
+ "[_disambiguate_task_groups] "
112
+ f"No task group found with {user_id=}, continue."
113
+ )
114
+
115
+ # Medium priority: task groups owned by default user group
116
+ for task_group in matching_task_groups:
117
+ if task_group.user_group_id == default_group_id:
118
+ logger.info(
119
+ "[_disambiguate_task_groups] "
120
+ f"Found task group {task_group.id} with user_group_id="
121
+ f"{default_group_id}, return."
122
+ )
123
+ return task_group
124
+ logger.info(
125
+ "[_disambiguate_task_groups] "
126
+ "No task group found with user_group_id="
127
+ f"{default_group_id}, continue."
128
+ )
129
+
130
+ # Lowest priority: task groups owned by other groups, sorted
131
+ # according to age of the user/usergroup link
132
+ logger.info(
133
+ "[_disambiguate_task_groups] "
134
+ "Now sorting remaining task groups by oldest-user-link."
135
+ )
136
+ user_group_ids = [
137
+ task_group.user_group_id for task_group in matching_task_groups
138
+ ]
139
+ stm = (
140
+ select(LinkUserGroup.group_id)
141
+ .where(LinkUserGroup.user_id == user_id)
142
+ .where(LinkUserGroup.group_id.in_(user_group_ids))
143
+ .order_by(LinkUserGroup.timestamp_created.asc())
144
+ )
145
+ res = await db.execute(stm)
146
+ oldest_user_group_id = res.scalars().first()
147
+ logger.info(
148
+ "[_disambiguate_task_groups] "
149
+ f"Result of sorting: {oldest_user_group_id=}."
150
+ )
151
+ task_group = next(
152
+ iter(
153
+ task_group
154
+ for task_group in matching_task_groups
155
+ if task_group.user_group_id == oldest_user_group_id
156
+ ),
157
+ None,
158
+ )
159
+ return task_group
160
+
161
+
162
+ async def _get_task_by_taskimport(
163
+ *,
164
+ task_import: TaskImportV2,
165
+ task_groups_list: list[TaskGroupV2],
166
+ user_id: int,
167
+ default_group_id: int,
168
+ db: AsyncSession,
169
+ ) -> Optional[int]:
170
+ """
171
+ Find a task based on `task_import`.
172
+
173
+ Args:
174
+ task_import: Info on task to be imported.
175
+ user_id: ID of current user.
176
+ default_group_id: ID of default user group.
177
+ task_group_list: Current list of valid task groups.
178
+ db: Asynchronous db session
179
+
180
+ Return:
181
+ `id` of the matching task, or `None`.
182
+ """
183
+
184
+ logger.info(f"[_get_task_by_taskimport] START, {task_import=}")
185
+
186
+ # Filter by `pkg_name` and by presence of a task with given `name`.
187
+ matching_task_groups = [
188
+ task_group
189
+ for task_group in task_groups_list
190
+ if (
191
+ task_group.pkg_name == task_import.pkg_name
192
+ and task_import.name
193
+ in [task.name for task in task_group.task_list]
194
+ )
195
+ ]
196
+ if len(matching_task_groups) < 1:
197
+ logger.info(
198
+ "[_get_task_by_taskimport] "
199
+ f"No task group with {task_import.pkg_name=} "
200
+ f"and a task with {task_import.name=}."
201
+ )
202
+ return None
203
+
204
+ # Determine target `version`
205
+ # Note that task_import.version cannot be "", due to a validator
206
+ if task_import.version is None:
207
+ logger.info(
208
+ "[_get_task_by_taskimport] "
209
+ "No version requested, looking for latest."
210
+ )
211
+ latest_task = max(
212
+ matching_task_groups, key=lambda tg: tg.version or ""
213
+ )
214
+ version = latest_task.version
215
+ logger.info(
216
+ f"[_get_task_by_taskimport] Latest version set to {version}."
217
+ )
218
+ else:
219
+ version = task_import.version
220
+
221
+ # Filter task groups by version
222
+ final_matching_task_groups = list(
223
+ filter(lambda tg: tg.version == version, task_groups_list)
224
+ )
225
+
226
+ if len(final_matching_task_groups) < 1:
227
+ logger.info(
228
+ "[_get_task_by_taskimport] "
229
+ "No task group left after filtering by version."
230
+ )
231
+ return None
232
+ elif len(final_matching_task_groups) == 1:
233
+ final_task_group = final_matching_task_groups[0]
234
+ logger.info(
235
+ "[_get_task_by_taskimport] "
236
+ "Found a single task group, after filtering by version."
237
+ )
238
+ else:
239
+ logger.info(
240
+ "[_get_task_by_taskimport] "
241
+ "Found many task groups, after filtering by version."
242
+ )
243
+ final_task_group = await _disambiguate_task_groups(
244
+ matching_task_groups=matching_task_groups,
245
+ user_id=user_id,
246
+ db=db,
247
+ default_group_id=default_group_id,
248
+ )
249
+ if final_task_group is None:
250
+ logger.info(
251
+ "[_get_task_by_taskimport] Disambiguation returned None."
252
+ )
253
+ return None
254
+
255
+ # Find task with given name
256
+ task_id = next(
257
+ iter(
258
+ task.id
259
+ for task in final_task_group.task_list
260
+ if task.name == task_import.name
261
+ ),
262
+ None,
263
+ )
264
+
265
+ logger.info(f"[_get_task_by_taskimport] END, {task_import=}, {task_id=}.")
266
+
267
+ return task_id
268
+
269
+
270
+ @router.post(
271
+ "/project/{project_id}/workflow/import/",
272
+ response_model=WorkflowReadV2WithWarnings,
273
+ status_code=status.HTTP_201_CREATED,
274
+ )
275
+ async def import_workflow(
276
+ project_id: int,
277
+ workflow_import: WorkflowImportV2,
278
+ user: UserOAuth = Depends(current_active_user),
279
+ db: AsyncSession = Depends(get_async_db),
280
+ ) -> WorkflowReadV2WithWarnings:
281
+ """
282
+ Import an existing workflow into a project and create required objects.
283
+ """
284
+
285
+ # Preliminary checks
286
+ await _get_project_check_owner(
287
+ project_id=project_id,
288
+ user_id=user.id,
289
+ db=db,
290
+ )
291
+ await _check_workflow_exists(
292
+ name=workflow_import.name,
293
+ project_id=project_id,
294
+ db=db,
295
+ )
296
+
297
+ task_group_list = await _get_user_accessible_taskgroups(
298
+ user_id=user.id,
299
+ db=db,
300
+ )
301
+ default_group_id = await _get_default_usergroup_id(db)
302
+
303
+ list_wf_tasks = []
304
+ list_task_ids = []
305
+ for wf_task in workflow_import.task_list:
306
+ task_import = wf_task.task
307
+ if isinstance(task_import, TaskImportV2Legacy):
308
+ task_id = await _get_task_by_source(
309
+ source=task_import.source,
310
+ task_groups_list=task_group_list,
311
+ )
312
+ else:
313
+ task_id = await _get_task_by_taskimport(
314
+ task_import=task_import,
315
+ user_id=user.id,
316
+ default_group_id=default_group_id,
317
+ task_groups_list=task_group_list,
318
+ db=db,
319
+ )
320
+ if task_id is None:
321
+ raise HTTPException(
322
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
323
+ detail=f"Could not find a task matching with {wf_task.task}.",
324
+ )
325
+ new_wf_task = WorkflowTaskCreateV2(
326
+ **wf_task.dict(exclude_none=True, exclude={"task"})
327
+ )
328
+ list_wf_tasks.append(new_wf_task)
329
+ list_task_ids.append(task_id)
330
+
331
+ # Create new Workflow
332
+ db_workflow = WorkflowV2(
333
+ project_id=project_id,
334
+ **workflow_import.dict(exclude_none=True, exclude={"task_list"}),
335
+ )
336
+ db.add(db_workflow)
337
+ await db.commit()
338
+ await db.refresh(db_workflow)
339
+
340
+ # Insert task into the workflow
341
+ for ind, new_wf_task in enumerate(list_wf_tasks):
342
+ await _workflow_insert_task(
343
+ **new_wf_task.dict(),
344
+ workflow_id=db_workflow.id,
345
+ task_id=list_task_ids[ind],
346
+ db=db,
347
+ )
348
+
349
+ # Add warnings for non-active tasks (or non-accessible tasks,
350
+ # although that should never happen)
351
+ wftask_list_with_warnings = await _add_warnings_to_workflow_tasks(
352
+ wftask_list=db_workflow.task_list, user_id=user.id, db=db
353
+ )
354
+ workflow_data = dict(
355
+ **db_workflow.model_dump(),
356
+ project=db_workflow.project,
357
+ task_list=wftask_list_with_warnings,
358
+ )
359
+
360
+ return workflow_data
@@ -9,13 +9,13 @@ from fastapi import status
9
9
 
10
10
  from ....db import AsyncSession
11
11
  from ....db import get_async_db
12
- from ....models.v2 import TaskV2
13
12
  from ....schemas.v2 import WorkflowTaskCreateV2
14
13
  from ....schemas.v2 import WorkflowTaskReadV2
15
14
  from ....schemas.v2 import WorkflowTaskUpdateV2
16
15
  from ._aux_functions import _get_workflow_check_owner
17
16
  from ._aux_functions import _get_workflow_task_check_owner
18
17
  from ._aux_functions import _workflow_insert_task
18
+ from ._aux_functions_tasks import _get_task_read_access
19
19
  from fractal_server.app.models import UserOAuth
20
20
  from fractal_server.app.routes.auth import current_active_user
21
21
 
@@ -43,12 +43,9 @@ async def create_workflowtask(
43
43
  project_id=project_id, workflow_id=workflow_id, user_id=user.id, db=db
44
44
  )
45
45
 
46
- task = await db.get(TaskV2, task_id)
47
- if not task:
48
- raise HTTPException(
49
- status_code=status.HTTP_404_NOT_FOUND,
50
- detail=f"TaskV2 {task_id} not found.",
51
- )
46
+ task = await _get_task_read_access(
47
+ task_id=task_id, user_id=user.id, db=db, require_active=True
48
+ )
52
49
 
53
50
  if task.type == "parallel":
54
51
  if (
@@ -80,7 +77,6 @@ async def create_workflowtask(
80
77
  workflow_task = await _workflow_insert_task(
81
78
  workflow_id=workflow.id,
82
79
  task_id=task_id,
83
- order=new_task.order,
84
80
  meta_non_parallel=new_task.meta_non_parallel,
85
81
  meta_parallel=new_task.meta_parallel,
86
82
  args_non_parallel=new_task.args_non_parallel,
@@ -1,6 +1,7 @@
1
1
  from fastapi import HTTPException
2
2
  from fastapi import status
3
3
  from sqlalchemy.ext.asyncio import AsyncSession
4
+ from sqlmodel import asc
4
5
  from sqlmodel import select
5
6
 
6
7
  from fractal_server.app.models.linkusergroup import LinkUserGroup
@@ -8,63 +9,64 @@ from fractal_server.app.models.security import UserGroup
8
9
  from fractal_server.app.models.security import UserOAuth
9
10
  from fractal_server.app.schemas.user import UserRead
10
11
  from fractal_server.app.schemas.user_group import UserGroupRead
12
+ from fractal_server.app.security import FRACTAL_DEFAULT_GROUP_NAME
13
+ from fractal_server.logger import set_logger
11
14
 
15
+ logger = set_logger(__name__)
12
16
 
13
- async def _get_single_user_with_group_names(
17
+
18
+ async def _get_single_user_with_groups(
14
19
  user: UserOAuth,
15
20
  db: AsyncSession,
16
21
  ) -> UserRead:
17
22
  """
18
- Enrich a user object by filling its `group_names` attribute.
23
+ Enrich a user object by filling its `group_ids_names` attribute.
19
24
 
20
25
  Arguments:
21
26
  user: The current `UserOAuth` object
22
27
  db: Async db session
23
28
 
24
29
  Returns:
25
- A `UserRead` object with `group_names` set
30
+ A `UserRead` object with `group_ids_names` dict
26
31
  """
27
32
  stm_groups = (
28
33
  select(UserGroup)
29
34
  .join(LinkUserGroup)
30
- .where(LinkUserGroup.user_id == UserOAuth.id)
35
+ .where(LinkUserGroup.user_id == user.id)
36
+ .order_by(asc(LinkUserGroup.timestamp_created))
31
37
  )
32
38
  res = await db.execute(stm_groups)
33
39
  groups = res.scalars().unique().all()
34
- group_names = [group.name for group in groups]
35
- return UserRead(
36
- **user.model_dump(),
37
- group_names=group_names,
38
- oauth_accounts=user.oauth_accounts,
40
+ group_ids_names = [(group.id, group.name) for group in groups]
41
+
42
+ # Check that Fractal Default Group is the first of the list. If not, fix.
43
+ index = next(
44
+ (
45
+ i
46
+ for i, group_tuple in enumerate(group_ids_names)
47
+ if group_tuple[1] == FRACTAL_DEFAULT_GROUP_NAME
48
+ ),
49
+ None,
39
50
  )
51
+ if index is None:
52
+ logger.warning(
53
+ f"User {user.id} not in "
54
+ f"default UserGroup '{FRACTAL_DEFAULT_GROUP_NAME}'"
55
+ )
56
+ elif index != 0:
57
+ default_group = group_ids_names.pop(index)
58
+ group_ids_names.insert(0, default_group)
59
+ else:
60
+ pass
40
61
 
41
-
42
- async def _get_single_user_with_group_ids(
43
- user: UserOAuth,
44
- db: AsyncSession,
45
- ) -> UserRead:
46
- """
47
- Enrich a user object by filling its `group_ids` attribute.
48
-
49
- Arguments:
50
- user: The current `UserOAuth` object
51
- db: Async db session
52
-
53
- Returns:
54
- A `UserRead` object with `group_ids` set
55
- """
56
- stm_links = select(LinkUserGroup).where(LinkUserGroup.user_id == user.id)
57
- res = await db.execute(stm_links)
58
- links = res.scalars().unique().all()
59
- group_ids = [link.group_id for link in links]
60
62
  return UserRead(
61
63
  **user.model_dump(),
62
- group_ids=group_ids,
64
+ group_ids_names=group_ids_names,
63
65
  oauth_accounts=user.oauth_accounts,
64
66
  )
65
67
 
66
68
 
67
- async def _get_single_group_with_user_ids(
69
+ async def _get_single_usergroup_with_user_ids(
68
70
  group_id: int, db: AsyncSession
69
71
  ) -> UserGroupRead:
70
72
  """
@@ -78,15 +80,7 @@ async def _get_single_group_with_user_ids(
78
80
  `UserGroupRead` object, with `user_ids` attribute populated
79
81
  from database.
80
82
  """
81
- # Get the UserGroup object from the database
82
- stm_group = select(UserGroup).where(UserGroup.id == group_id)
83
- res = await db.execute(stm_group)
84
- group = res.scalars().one_or_none()
85
- if group is None:
86
- raise HTTPException(
87
- status_code=status.HTTP_404_NOT_FOUND,
88
- detail=f"Group {group_id} not found.",
89
- )
83
+ group = await _usergroup_or_404(group_id, db)
90
84
 
91
85
  # Get all user/group links
92
86
  stm_links = select(LinkUserGroup).where(LinkUserGroup.group_id == group_id)
@@ -108,6 +102,58 @@ async def _user_or_404(user_id: int, db: AsyncSession) -> UserOAuth:
108
102
  user = await db.get(UserOAuth, user_id, populate_existing=True)
109
103
  if user is None:
110
104
  raise HTTPException(
111
- status_code=status.HTTP_404_NOT_FOUND, detail="User not found."
105
+ status_code=status.HTTP_404_NOT_FOUND,
106
+ detail=f"User {user_id} not found.",
107
+ )
108
+ return user
109
+
110
+
111
+ async def _usergroup_or_404(usergroup_id: int, db: AsyncSession) -> UserGroup:
112
+ user = await db.get(UserGroup, usergroup_id)
113
+ if user is None:
114
+ raise HTTPException(
115
+ status_code=status.HTTP_404_NOT_FOUND,
116
+ detail=f"UserGroup {usergroup_id} not found.",
112
117
  )
113
118
  return user
119
+
120
+
121
+ async def _get_default_usergroup_id(db: AsyncSession) -> int:
122
+ stm = select(UserGroup.id).where(
123
+ UserGroup.name == FRACTAL_DEFAULT_GROUP_NAME
124
+ )
125
+ res = await db.execute(stm)
126
+ user_group_id = res.scalars().one_or_none()
127
+ if user_group_id is None:
128
+ raise HTTPException(
129
+ status_code=status.HTTP_404_NOT_FOUND,
130
+ detail=f"User group '{FRACTAL_DEFAULT_GROUP_NAME}' not found.",
131
+ )
132
+ return user_group_id
133
+
134
+
135
+ async def _verify_user_belongs_to_group(
136
+ *, user_id: int, user_group_id: int, db: AsyncSession
137
+ ):
138
+ stm = (
139
+ select(LinkUserGroup)
140
+ .where(LinkUserGroup.user_id == user_id)
141
+ .where(LinkUserGroup.group_id == user_group_id)
142
+ )
143
+ res = await db.execute(stm)
144
+ link = res.scalars().one_or_none()
145
+ if link is None:
146
+ group = await db.get(UserGroup, user_group_id)
147
+ if group is None:
148
+ raise HTTPException(
149
+ status_code=status.HTTP_404_NOT_FOUND,
150
+ detail=f"UserGroup {user_group_id} not found",
151
+ )
152
+ else:
153
+ raise HTTPException(
154
+ status_code=status.HTTP_403_FORBIDDEN,
155
+ detail=(
156
+ f"User {user_id} does not belong "
157
+ f"to UserGroup {user_group_id}"
158
+ ),
159
+ )
@@ -13,7 +13,7 @@ from ...schemas.user import UserRead
13
13
  from ...schemas.user import UserUpdate
14
14
  from ...schemas.user import UserUpdateStrict
15
15
  from ..aux.validate_user_settings import verify_user_has_settings
16
- from ._aux_auth import _get_single_user_with_group_names
16
+ from ._aux_auth import _get_single_user_with_groups
17
17
  from fractal_server.app.models import LinkUserGroup
18
18
  from fractal_server.app.models import UserGroup
19
19
  from fractal_server.app.models import UserOAuth
@@ -28,15 +28,15 @@ router_current_user = APIRouter()
28
28
 
29
29
  @router_current_user.get("/current-user/", response_model=UserRead)
30
30
  async def get_current_user(
31
- group_names: bool = False,
31
+ group_ids_names: bool = False,
32
32
  user: UserOAuth = Depends(current_active_user),
33
33
  db: AsyncSession = Depends(get_async_db),
34
34
  ):
35
35
  """
36
36
  Return current user
37
37
  """
38
- if group_names is True:
39
- user_with_groups = await _get_single_user_with_group_names(user, db)
38
+ if group_ids_names is True:
39
+ user_with_groups = await _get_single_user_with_groups(user, db)
40
40
  return user_with_groups
41
41
  else:
42
42
  return user
@@ -65,7 +65,7 @@ async def patch_current_user(
65
65
  patched_user = await db.get(
66
66
  UserOAuth, validated_user.id, populate_existing=True
67
67
  )
68
- patched_user_with_groups = await _get_single_user_with_group_names(
68
+ patched_user_with_groups = await _get_single_user_with_groups(
69
69
  patched_user, db
70
70
  )
71
71
  return patched_user_with_groups