fractal-server 2.3.11__py3-none-any.whl → 2.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fractal_server/__init__.py +1 -1
- fractal_server/__main__.py +25 -2
- fractal_server/app/models/__init__.py +11 -5
- fractal_server/app/models/linkusergroup.py +11 -0
- fractal_server/app/models/security.py +24 -3
- fractal_server/app/models/v1/project.py +1 -1
- fractal_server/app/models/v2/project.py +3 -3
- fractal_server/app/routes/admin/v1.py +14 -14
- fractal_server/app/routes/admin/v2.py +12 -12
- fractal_server/app/routes/api/__init__.py +2 -2
- fractal_server/app/routes/api/v1/_aux_functions.py +2 -2
- fractal_server/app/routes/api/v1/dataset.py +17 -15
- fractal_server/app/routes/api/v1/job.py +11 -9
- fractal_server/app/routes/api/v1/project.py +9 -9
- fractal_server/app/routes/api/v1/task.py +8 -8
- fractal_server/app/routes/api/v1/task_collection.py +5 -5
- fractal_server/app/routes/api/v1/workflow.py +13 -11
- fractal_server/app/routes/api/v1/workflowtask.py +6 -6
- fractal_server/app/routes/api/v2/_aux_functions.py +2 -2
- fractal_server/app/routes/api/v2/dataset.py +11 -11
- fractal_server/app/routes/api/v2/images.py +6 -6
- fractal_server/app/routes/api/v2/job.py +9 -9
- fractal_server/app/routes/api/v2/project.py +7 -7
- fractal_server/app/routes/api/v2/status.py +3 -3
- fractal_server/app/routes/api/v2/submit.py +3 -3
- fractal_server/app/routes/api/v2/task.py +8 -8
- fractal_server/app/routes/api/v2/task_collection.py +5 -5
- fractal_server/app/routes/api/v2/task_collection_custom.py +3 -3
- fractal_server/app/routes/api/v2/task_legacy.py +9 -9
- fractal_server/app/routes/api/v2/workflow.py +11 -11
- fractal_server/app/routes/api/v2/workflowtask.py +6 -6
- fractal_server/app/routes/auth/__init__.py +55 -0
- fractal_server/app/routes/auth/_aux_auth.py +107 -0
- fractal_server/app/routes/auth/current_user.py +60 -0
- fractal_server/app/routes/auth/group.py +176 -0
- fractal_server/app/routes/auth/group_names.py +34 -0
- fractal_server/app/routes/auth/login.py +25 -0
- fractal_server/app/routes/auth/oauth.py +63 -0
- fractal_server/app/routes/auth/register.py +23 -0
- fractal_server/app/routes/auth/router.py +19 -0
- fractal_server/app/routes/auth/users.py +192 -0
- fractal_server/app/schemas/user.py +11 -0
- fractal_server/app/schemas/user_group.py +65 -0
- fractal_server/app/security/__init__.py +72 -75
- fractal_server/data_migrations/2_4_0.py +61 -0
- fractal_server/main.py +1 -9
- fractal_server/migrations/versions/091b01f51f88_add_usergroup_and_linkusergroup_table.py +53 -0
- {fractal_server-2.3.11.dist-info → fractal_server-2.4.0.dist-info}/METADATA +1 -1
- {fractal_server-2.3.11.dist-info → fractal_server-2.4.0.dist-info}/RECORD +52 -39
- fractal_server/app/routes/auth.py +0 -165
- {fractal_server-2.3.11.dist-info → fractal_server-2.4.0.dist-info}/LICENSE +0 -0
- {fractal_server-2.3.11.dist-info → fractal_server-2.4.0.dist-info}/WHEEL +0 -0
- {fractal_server-2.3.11.dist-info → fractal_server-2.4.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,34 @@
|
|
1
|
+
"""
|
2
|
+
Definition `/auth/group-names/` endpoints
|
3
|
+
"""
|
4
|
+
from fastapi import APIRouter
|
5
|
+
from fastapi import Depends
|
6
|
+
from fastapi import status
|
7
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
8
|
+
from sqlmodel import select
|
9
|
+
|
10
|
+
from . import current_active_user
|
11
|
+
from ...db import get_async_db
|
12
|
+
from fractal_server.app.models import UserGroup
|
13
|
+
from fractal_server.app.models import UserOAuth
|
14
|
+
|
15
|
+
router_group_names = APIRouter()
|
16
|
+
|
17
|
+
|
18
|
+
@router_group_names.get(
|
19
|
+
"/group-names/", response_model=list[str], status_code=status.HTTP_200_OK
|
20
|
+
)
|
21
|
+
async def get_list_user_group_names(
|
22
|
+
user: UserOAuth = Depends(current_active_user),
|
23
|
+
db: AsyncSession = Depends(get_async_db),
|
24
|
+
) -> list[str]:
|
25
|
+
"""
|
26
|
+
Return the available group names.
|
27
|
+
|
28
|
+
This endpoint is not restricted to superusers.
|
29
|
+
"""
|
30
|
+
stm_all_groups = select(UserGroup)
|
31
|
+
res = await db.execute(stm_all_groups)
|
32
|
+
groups = res.scalars().all()
|
33
|
+
group_names = [group.name for group in groups]
|
34
|
+
return group_names
|
@@ -0,0 +1,25 @@
|
|
1
|
+
"""
|
2
|
+
Definition of `/auth/{login,logout}/`, `/auth/token/{login/logout}` routes.
|
3
|
+
"""
|
4
|
+
from fastapi import APIRouter
|
5
|
+
|
6
|
+
from . import cookie_backend
|
7
|
+
from . import fastapi_users
|
8
|
+
from . import token_backend
|
9
|
+
|
10
|
+
router_login = APIRouter()
|
11
|
+
|
12
|
+
|
13
|
+
router_login.include_router(
|
14
|
+
fastapi_users.get_auth_router(token_backend),
|
15
|
+
prefix="/token",
|
16
|
+
)
|
17
|
+
router_login.include_router(
|
18
|
+
fastapi_users.get_auth_router(cookie_backend),
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
# Add trailing slash to all routes paths
|
23
|
+
for route in router_login.routes:
|
24
|
+
if not route.path.endswith("/"):
|
25
|
+
route.path = f"{route.path}/"
|
@@ -0,0 +1,63 @@
|
|
1
|
+
from fastapi import APIRouter
|
2
|
+
|
3
|
+
from . import cookie_backend
|
4
|
+
from . import fastapi_users
|
5
|
+
from ....config import get_settings
|
6
|
+
from ....syringe import Inject
|
7
|
+
|
8
|
+
router_oauth = APIRouter()
|
9
|
+
|
10
|
+
|
11
|
+
# OAUTH CLIENTS
|
12
|
+
|
13
|
+
# NOTE: settings.OAUTH_CLIENTS are collected by
|
14
|
+
# Settings.collect_oauth_clients(). If no specific client is specified in the
|
15
|
+
# environment variables (e.g. by setting OAUTH_FOO_CLIENT_ID and
|
16
|
+
# OAUTH_FOO_CLIENT_SECRET), this list is empty
|
17
|
+
|
18
|
+
# FIXME:Dependency injection should be wrapped within a function call to make
|
19
|
+
# it truly lazy. This function could then be called on startup of the FastAPI
|
20
|
+
# app (cf. fractal_server.main)
|
21
|
+
settings = Inject(get_settings)
|
22
|
+
|
23
|
+
for client_config in settings.OAUTH_CLIENTS_CONFIG:
|
24
|
+
client_name = client_config.CLIENT_NAME.lower()
|
25
|
+
|
26
|
+
if client_name == "google":
|
27
|
+
from httpx_oauth.clients.google import GoogleOAuth2
|
28
|
+
|
29
|
+
client = GoogleOAuth2(
|
30
|
+
client_config.CLIENT_ID, client_config.CLIENT_SECRET
|
31
|
+
)
|
32
|
+
elif client_name == "github":
|
33
|
+
from httpx_oauth.clients.github import GitHubOAuth2
|
34
|
+
|
35
|
+
client = GitHubOAuth2(
|
36
|
+
client_config.CLIENT_ID, client_config.CLIENT_SECRET
|
37
|
+
)
|
38
|
+
else:
|
39
|
+
from httpx_oauth.clients.openid import OpenID
|
40
|
+
|
41
|
+
client = OpenID(
|
42
|
+
client_config.CLIENT_ID,
|
43
|
+
client_config.CLIENT_SECRET,
|
44
|
+
client_config.OIDC_CONFIGURATION_ENDPOINT,
|
45
|
+
)
|
46
|
+
|
47
|
+
router_oauth.include_router(
|
48
|
+
fastapi_users.get_oauth_router(
|
49
|
+
client,
|
50
|
+
cookie_backend,
|
51
|
+
settings.JWT_SECRET_KEY,
|
52
|
+
is_verified_by_default=False,
|
53
|
+
associate_by_email=True,
|
54
|
+
redirect_url=client_config.REDIRECT_URL,
|
55
|
+
),
|
56
|
+
prefix=f"/{client_name}",
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
# Add trailing slash to all routes' paths
|
61
|
+
for route in router_oauth.routes:
|
62
|
+
if not route.path.endswith("/"):
|
63
|
+
route.path = f"{route.path}/"
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""
|
2
|
+
Definition of `/auth/register/` routes.
|
3
|
+
"""
|
4
|
+
from fastapi import APIRouter
|
5
|
+
from fastapi import Depends
|
6
|
+
|
7
|
+
from . import current_active_superuser
|
8
|
+
from . import fastapi_users
|
9
|
+
from ...schemas.user import UserCreate
|
10
|
+
from ...schemas.user import UserRead
|
11
|
+
|
12
|
+
router_register = APIRouter()
|
13
|
+
|
14
|
+
router_register.include_router(
|
15
|
+
fastapi_users.get_register_router(UserRead, UserCreate),
|
16
|
+
dependencies=[Depends(current_active_superuser)],
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
# Add trailing slash to all routes' paths
|
21
|
+
for route in router_register.routes:
|
22
|
+
if not route.path.endswith("/"):
|
23
|
+
route.path = f"{route.path}/"
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from fastapi import APIRouter
|
2
|
+
|
3
|
+
from .current_user import router_current_user
|
4
|
+
from .group import router_group
|
5
|
+
from .group_names import router_group_names
|
6
|
+
from .login import router_login
|
7
|
+
from .oauth import router_oauth
|
8
|
+
from .register import router_register
|
9
|
+
from .users import router_users
|
10
|
+
|
11
|
+
router_auth = APIRouter()
|
12
|
+
|
13
|
+
router_auth.include_router(router_register)
|
14
|
+
router_auth.include_router(router_current_user)
|
15
|
+
router_auth.include_router(router_login)
|
16
|
+
router_auth.include_router(router_group_names)
|
17
|
+
router_auth.include_router(router_users)
|
18
|
+
router_auth.include_router(router_group)
|
19
|
+
router_auth.include_router(router_oauth)
|
@@ -0,0 +1,192 @@
|
|
1
|
+
"""
|
2
|
+
Definition of `/auth/users/` routes
|
3
|
+
"""
|
4
|
+
from fastapi import APIRouter
|
5
|
+
from fastapi import Depends
|
6
|
+
from fastapi import HTTPException
|
7
|
+
from fastapi import status
|
8
|
+
from fastapi_users import exceptions
|
9
|
+
from fastapi_users import schemas
|
10
|
+
from fastapi_users.router.common import ErrorCode
|
11
|
+
from sqlalchemy.exc import IntegrityError
|
12
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
13
|
+
from sqlmodel import col
|
14
|
+
from sqlmodel import func
|
15
|
+
from sqlmodel import select
|
16
|
+
|
17
|
+
from . import current_active_superuser
|
18
|
+
from ...db import get_async_db
|
19
|
+
from ...schemas.user import UserRead
|
20
|
+
from ...schemas.user import UserUpdate
|
21
|
+
from ...schemas.user import UserUpdateWithNewGroupIds
|
22
|
+
from ._aux_auth import _get_single_user_with_group_ids
|
23
|
+
from fractal_server.app.models import LinkUserGroup
|
24
|
+
from fractal_server.app.models import UserGroup
|
25
|
+
from fractal_server.app.models import UserOAuth
|
26
|
+
from fractal_server.app.routes.auth._aux_auth import _user_or_404
|
27
|
+
from fractal_server.app.security import get_user_manager
|
28
|
+
from fractal_server.app.security import UserManager
|
29
|
+
from fractal_server.logger import set_logger
|
30
|
+
|
31
|
+
router_users = APIRouter()
|
32
|
+
|
33
|
+
|
34
|
+
logger = set_logger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
@router_users.get("/users/{user_id}/", response_model=UserRead)
|
38
|
+
async def get_user(
|
39
|
+
user_id: int,
|
40
|
+
group_ids: bool = True,
|
41
|
+
superuser: UserOAuth = Depends(current_active_superuser),
|
42
|
+
db: AsyncSession = Depends(get_async_db),
|
43
|
+
) -> UserRead:
|
44
|
+
user = await _user_or_404(user_id, db)
|
45
|
+
if group_ids:
|
46
|
+
user_with_group_ids = await _get_single_user_with_group_ids(user, db)
|
47
|
+
return user_with_group_ids
|
48
|
+
else:
|
49
|
+
return user
|
50
|
+
|
51
|
+
|
52
|
+
@router_users.patch("/users/{user_id}/", response_model=UserRead)
|
53
|
+
async def patch_user(
|
54
|
+
user_id: int,
|
55
|
+
user_update: UserUpdateWithNewGroupIds,
|
56
|
+
current_superuser: UserOAuth = Depends(current_active_superuser),
|
57
|
+
user_manager: UserManager = Depends(get_user_manager),
|
58
|
+
db: AsyncSession = Depends(get_async_db),
|
59
|
+
):
|
60
|
+
"""
|
61
|
+
Custom version of the PATCH-user route from `fastapi-users`.
|
62
|
+
|
63
|
+
In order to keep the fastapi-users logic in place (which is convenient to
|
64
|
+
update user attributes), we split the endpoint into two branches. We either
|
65
|
+
go through the fastapi-users-based attribute-update branch, or through the
|
66
|
+
branch where we establish new user/group relationships.
|
67
|
+
|
68
|
+
Note that we prevent making both changes at the same time, since it would
|
69
|
+
be more complex to guarantee that endpoint error would leave the database
|
70
|
+
in the same state as before the API call.
|
71
|
+
"""
|
72
|
+
|
73
|
+
# We prevent simultaneous editing of both user attributes and user/group
|
74
|
+
# associations
|
75
|
+
user_update_dict_without_groups = user_update.dict(
|
76
|
+
exclude_unset=True, exclude={"new_group_ids"}
|
77
|
+
)
|
78
|
+
edit_attributes = user_update_dict_without_groups != {}
|
79
|
+
edit_groups = user_update.new_group_ids is not None
|
80
|
+
if edit_attributes and edit_groups:
|
81
|
+
raise HTTPException(
|
82
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
83
|
+
detail=(
|
84
|
+
"Cannot modify both user attributes and group membership. "
|
85
|
+
"Please make two independent PATCH calls"
|
86
|
+
),
|
87
|
+
)
|
88
|
+
|
89
|
+
# Check that user exists
|
90
|
+
user_to_patch = await _user_or_404(user_id, db)
|
91
|
+
|
92
|
+
if edit_groups:
|
93
|
+
# Establish new user/group relationships
|
94
|
+
|
95
|
+
# Check that all required groups exist
|
96
|
+
# Note: The reason for introducing `col` is as in
|
97
|
+
# https://sqlmodel.tiangolo.com/tutorial/where/#type-annotations-and-errors,
|
98
|
+
stm = select(func.count()).where(
|
99
|
+
col(UserGroup.id).in_(user_update.new_group_ids)
|
100
|
+
)
|
101
|
+
res = await db.execute(stm)
|
102
|
+
number_matching_groups = res.scalar()
|
103
|
+
if number_matching_groups != len(user_update.new_group_ids):
|
104
|
+
raise HTTPException(
|
105
|
+
status_code=status.HTTP_404_NOT_FOUND,
|
106
|
+
detail=(
|
107
|
+
"Not all requested groups (IDs: "
|
108
|
+
f"{user_update.new_group_ids}) exist."
|
109
|
+
),
|
110
|
+
)
|
111
|
+
|
112
|
+
for new_group_id in user_update.new_group_ids:
|
113
|
+
link = LinkUserGroup(user_id=user_id, group_id=new_group_id)
|
114
|
+
db.add(link)
|
115
|
+
|
116
|
+
try:
|
117
|
+
await db.commit()
|
118
|
+
except IntegrityError as e:
|
119
|
+
error_msg = (
|
120
|
+
f"Cannot link groups with IDs {user_update.new_group_ids} "
|
121
|
+
f"to user {user_id}. "
|
122
|
+
"Likely reason: one of these links already exists.\n"
|
123
|
+
f"Original error: {str(e)}"
|
124
|
+
)
|
125
|
+
logger.info(error_msg)
|
126
|
+
raise HTTPException(
|
127
|
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
128
|
+
detail=error_msg,
|
129
|
+
)
|
130
|
+
|
131
|
+
patched_user = user_to_patch
|
132
|
+
|
133
|
+
elif edit_attributes:
|
134
|
+
# Modify user attributes
|
135
|
+
try:
|
136
|
+
user_update_without_groups = UserUpdate(
|
137
|
+
**user_update_dict_without_groups
|
138
|
+
)
|
139
|
+
user = await user_manager.update(
|
140
|
+
user_update_without_groups,
|
141
|
+
user_to_patch,
|
142
|
+
safe=False,
|
143
|
+
request=None,
|
144
|
+
)
|
145
|
+
patched_user = schemas.model_validate(UserOAuth, user)
|
146
|
+
except exceptions.InvalidPasswordException as e:
|
147
|
+
raise HTTPException(
|
148
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
149
|
+
detail={
|
150
|
+
"code": ErrorCode.UPDATE_USER_INVALID_PASSWORD,
|
151
|
+
"reason": e.reason,
|
152
|
+
},
|
153
|
+
)
|
154
|
+
else:
|
155
|
+
# Nothing to do, just continue
|
156
|
+
patched_user = user_to_patch
|
157
|
+
|
158
|
+
# Enrich user object with `group_ids` attribute
|
159
|
+
patched_user_with_group_ids = await _get_single_user_with_group_ids(
|
160
|
+
patched_user, db
|
161
|
+
)
|
162
|
+
|
163
|
+
return patched_user_with_group_ids
|
164
|
+
|
165
|
+
|
166
|
+
@router_users.get("/users/", response_model=list[UserRead])
|
167
|
+
async def list_users(
|
168
|
+
user: UserOAuth = Depends(current_active_superuser),
|
169
|
+
db: AsyncSession = Depends(get_async_db),
|
170
|
+
):
|
171
|
+
"""
|
172
|
+
Return list of all users
|
173
|
+
"""
|
174
|
+
stm = select(UserOAuth)
|
175
|
+
res = await db.execute(stm)
|
176
|
+
user_list = res.scalars().unique().all()
|
177
|
+
|
178
|
+
# Get all user/group links
|
179
|
+
stm_all_links = select(LinkUserGroup)
|
180
|
+
res = await db.execute(stm_all_links)
|
181
|
+
links = res.scalars().all()
|
182
|
+
|
183
|
+
# TODO: possible optimizations for this construction are listed in
|
184
|
+
# https://github.com/fractal-analytics-platform/fractal-server/issues/1742
|
185
|
+
for ind, user in enumerate(user_list):
|
186
|
+
user_list[ind] = dict(
|
187
|
+
user.model_dump(),
|
188
|
+
group_ids=[
|
189
|
+
link.group_id for link in links if link.user_id == user.id
|
190
|
+
],
|
191
|
+
)
|
192
|
+
return user_list
|
@@ -16,6 +16,7 @@ __all__ = (
|
|
16
16
|
"UserRead",
|
17
17
|
"UserUpdate",
|
18
18
|
"UserCreate",
|
19
|
+
"UserUpdateWithNewGroupIds",
|
19
20
|
)
|
20
21
|
|
21
22
|
|
@@ -34,6 +35,8 @@ class UserRead(schemas.BaseUser[int]):
|
|
34
35
|
cache_dir: Optional[str]
|
35
36
|
username: Optional[str]
|
36
37
|
slurm_accounts: list[str]
|
38
|
+
group_names: Optional[list[str]] = None
|
39
|
+
group_ids: Optional[list[int]] = None
|
37
40
|
|
38
41
|
|
39
42
|
class UserUpdate(schemas.BaseUserUpdate):
|
@@ -100,6 +103,14 @@ class UserUpdateStrict(BaseModel, extra=Extra.forbid):
|
|
100
103
|
)
|
101
104
|
|
102
105
|
|
106
|
+
class UserUpdateWithNewGroupIds(UserUpdate):
|
107
|
+
new_group_ids: Optional[list[int]] = None
|
108
|
+
|
109
|
+
_val_unique = validator("new_group_ids", allow_reuse=True)(
|
110
|
+
val_unique_list("new_group_ids")
|
111
|
+
)
|
112
|
+
|
113
|
+
|
103
114
|
class UserCreate(schemas.BaseUserCreate):
|
104
115
|
"""
|
105
116
|
Schema for `User` creation.
|
@@ -0,0 +1,65 @@
|
|
1
|
+
from datetime import datetime
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
from pydantic import BaseModel
|
5
|
+
from pydantic import Extra
|
6
|
+
from pydantic import Field
|
7
|
+
from pydantic import validator
|
8
|
+
|
9
|
+
from ._validators import val_unique_list
|
10
|
+
|
11
|
+
|
12
|
+
__all__ = (
|
13
|
+
"UserGroupRead",
|
14
|
+
"UserGroupUpdate",
|
15
|
+
"UserGroupCreate",
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
class UserGroupRead(BaseModel):
|
20
|
+
"""
|
21
|
+
Schema for `UserGroup` read
|
22
|
+
|
23
|
+
NOTE: `user_ids` does not correspond to a column of the `UserGroup` table,
|
24
|
+
but it is rather computed dynamically in relevant endpoints.
|
25
|
+
|
26
|
+
Attributes:
|
27
|
+
id: Group ID
|
28
|
+
name: Group name
|
29
|
+
timestamp_created: Creation timestamp
|
30
|
+
user_ids: IDs of users of this group
|
31
|
+
"""
|
32
|
+
|
33
|
+
id: int
|
34
|
+
name: str
|
35
|
+
timestamp_created: datetime
|
36
|
+
user_ids: Optional[list[int]] = None
|
37
|
+
|
38
|
+
|
39
|
+
class UserGroupCreate(BaseModel, extra=Extra.forbid):
|
40
|
+
"""
|
41
|
+
Schema for `UserGroup` creation
|
42
|
+
|
43
|
+
Attributes:
|
44
|
+
name: Group name
|
45
|
+
"""
|
46
|
+
|
47
|
+
name: str
|
48
|
+
|
49
|
+
|
50
|
+
class UserGroupUpdate(BaseModel, extra=Extra.forbid):
|
51
|
+
"""
|
52
|
+
Schema for `UserGroup` update
|
53
|
+
|
54
|
+
NOTE: `new_user_ids` does not correspond to a column of the `UserGroup`
|
55
|
+
table, but it is rather used to create new `LinkUserGroup` rows.
|
56
|
+
|
57
|
+
Attributes:
|
58
|
+
new_user_ids: IDs of groups to be associated to user.
|
59
|
+
"""
|
60
|
+
|
61
|
+
new_user_ids: list[int] = Field(default_factory=list)
|
62
|
+
|
63
|
+
_val_unique = validator("new_user_ids", allow_reuse=True)(
|
64
|
+
val_unique_list("new_user_ids")
|
65
|
+
)
|
@@ -29,41 +29,38 @@ All routes are registerd under the `auth/` prefix.
|
|
29
29
|
import contextlib
|
30
30
|
from typing import Any
|
31
31
|
from typing import AsyncGenerator
|
32
|
-
from typing import Dict
|
33
32
|
from typing import Generic
|
34
33
|
from typing import Optional
|
35
34
|
from typing import Type
|
36
35
|
|
37
36
|
from fastapi import Depends
|
37
|
+
from fastapi import Request
|
38
38
|
from fastapi_users import BaseUserManager
|
39
|
-
from fastapi_users import FastAPIUsers
|
40
39
|
from fastapi_users import IntegerIDMixin
|
41
|
-
from fastapi_users.authentication import AuthenticationBackend
|
42
|
-
from fastapi_users.authentication import BearerTransport
|
43
|
-
from fastapi_users.authentication import CookieTransport
|
44
|
-
from fastapi_users.authentication import JWTStrategy
|
45
40
|
from fastapi_users.db.base import BaseUserDatabase
|
46
41
|
from fastapi_users.exceptions import InvalidPasswordException
|
47
42
|
from fastapi_users.exceptions import UserAlreadyExists
|
48
43
|
from fastapi_users.models import ID
|
49
44
|
from fastapi_users.models import OAP
|
50
45
|
from fastapi_users.models import UP
|
51
|
-
from sqlalchemy.exc import IntegrityError
|
52
46
|
from sqlalchemy.ext.asyncio import AsyncSession
|
53
47
|
from sqlalchemy.orm import selectinload
|
54
48
|
from sqlmodel import func
|
55
49
|
from sqlmodel import select
|
56
50
|
|
57
|
-
from ...config import get_settings
|
58
|
-
from ...syringe import Inject
|
59
51
|
from ..db import get_async_db
|
60
|
-
from fractal_server.app.
|
61
|
-
from fractal_server.app.models
|
52
|
+
from fractal_server.app.db import get_sync_db
|
53
|
+
from fractal_server.app.models import LinkUserGroup
|
54
|
+
from fractal_server.app.models import OAuthAccount
|
55
|
+
from fractal_server.app.models import UserGroup
|
56
|
+
from fractal_server.app.models import UserOAuth
|
62
57
|
from fractal_server.app.schemas.user import UserCreate
|
63
58
|
from fractal_server.logger import get_logger
|
64
59
|
|
65
60
|
logger = get_logger(__name__)
|
66
61
|
|
62
|
+
FRACTAL_DEFAULT_GROUP_NAME = "All"
|
63
|
+
|
67
64
|
|
68
65
|
class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
|
69
66
|
"""
|
@@ -125,7 +122,7 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
|
|
125
122
|
return user
|
126
123
|
return None
|
127
124
|
|
128
|
-
async def create(self, create_dict:
|
125
|
+
async def create(self, create_dict: dict[str, Any]) -> UP:
|
129
126
|
"""Create a user."""
|
130
127
|
user = self.user_model(**create_dict)
|
131
128
|
self.session.add(user)
|
@@ -133,7 +130,7 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
|
|
133
130
|
await self.session.refresh(user)
|
134
131
|
return user
|
135
132
|
|
136
|
-
async def update(self, user: UP, update_dict:
|
133
|
+
async def update(self, user: UP, update_dict: dict[str, Any]) -> UP:
|
137
134
|
for key, value in update_dict.items():
|
138
135
|
setattr(user, key, value)
|
139
136
|
self.session.add(user)
|
@@ -146,7 +143,7 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
|
|
146
143
|
await self.session.commit()
|
147
144
|
|
148
145
|
async def add_oauth_account(
|
149
|
-
self, user: UP, create_dict:
|
146
|
+
self, user: UP, create_dict: dict[str, Any]
|
150
147
|
) -> UP: # noqa
|
151
148
|
if self.oauth_account_model is None:
|
152
149
|
raise NotImplementedError()
|
@@ -160,7 +157,7 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
|
|
160
157
|
return user
|
161
158
|
|
162
159
|
async def update_oauth_account(
|
163
|
-
self, user: UP, oauth_account: OAP, update_dict:
|
160
|
+
self, user: UP, oauth_account: OAP, update_dict: dict[str, Any]
|
164
161
|
) -> UP:
|
165
162
|
if self.oauth_account_model is None:
|
166
163
|
raise NotImplementedError()
|
@@ -176,13 +173,14 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
|
|
176
173
|
async def get_user_db(
|
177
174
|
session: AsyncSession = Depends(get_async_db),
|
178
175
|
) -> AsyncGenerator[SQLModelUserDatabaseAsync, None]:
|
179
|
-
yield SQLModelUserDatabaseAsync(session,
|
176
|
+
yield SQLModelUserDatabaseAsync(session, UserOAuth, OAuthAccount)
|
180
177
|
|
181
178
|
|
182
|
-
class UserManager(IntegerIDMixin, BaseUserManager[
|
183
|
-
async def validate_password(self, password: str, user:
|
179
|
+
class UserManager(IntegerIDMixin, BaseUserManager[UserOAuth, int]):
|
180
|
+
async def validate_password(self, password: str, user: UserOAuth) -> None:
|
184
181
|
# check password length
|
185
|
-
min_length
|
182
|
+
min_length = 4
|
183
|
+
max_length = 100
|
186
184
|
if len(password) < min_length:
|
187
185
|
raise InvalidPasswordException(
|
188
186
|
f"The password is too short (minimum length: {min_length})."
|
@@ -192,6 +190,38 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
|
|
192
190
|
f"The password is too long (maximum length: {min_length})."
|
193
191
|
)
|
194
192
|
|
193
|
+
async def on_after_register(
|
194
|
+
self, user: UserOAuth, request: Optional[Request] = None
|
195
|
+
):
|
196
|
+
logger.info(
|
197
|
+
f"New-user registration completed ({user.id=}, {user.email=})."
|
198
|
+
)
|
199
|
+
async for db in get_async_db():
|
200
|
+
# Find default group
|
201
|
+
stm = select(UserGroup).where(
|
202
|
+
UserGroup.name == FRACTAL_DEFAULT_GROUP_NAME
|
203
|
+
)
|
204
|
+
res = await db.execute(stm)
|
205
|
+
default_group = res.scalar_one_or_none()
|
206
|
+
if default_group is None:
|
207
|
+
logger.error(
|
208
|
+
f"No group found with name {FRACTAL_DEFAULT_GROUP_NAME}"
|
209
|
+
)
|
210
|
+
else:
|
211
|
+
logger.warning(
|
212
|
+
f"START adding {user.email} user to group "
|
213
|
+
f"{default_group.id=}."
|
214
|
+
)
|
215
|
+
link = LinkUserGroup(
|
216
|
+
user_id=user.id, group_id=default_group.id
|
217
|
+
)
|
218
|
+
db.add(link)
|
219
|
+
await db.commit()
|
220
|
+
logger.warning(
|
221
|
+
f"END adding {user.email} user to group "
|
222
|
+
f"{default_group.id=}."
|
223
|
+
)
|
224
|
+
|
195
225
|
|
196
226
|
async def get_user_manager(
|
197
227
|
user_db: SQLModelUserDatabaseAsync = Depends(get_user_db),
|
@@ -199,53 +229,6 @@ async def get_user_manager(
|
|
199
229
|
yield UserManager(user_db)
|
200
230
|
|
201
231
|
|
202
|
-
bearer_transport = BearerTransport(tokenUrl="/auth/token/login")
|
203
|
-
cookie_transport = CookieTransport(cookie_samesite="none")
|
204
|
-
|
205
|
-
|
206
|
-
def get_jwt_strategy() -> JWTStrategy:
|
207
|
-
settings = Inject(get_settings)
|
208
|
-
return JWTStrategy(
|
209
|
-
secret=settings.JWT_SECRET_KEY, # type: ignore
|
210
|
-
lifetime_seconds=settings.JWT_EXPIRE_SECONDS,
|
211
|
-
)
|
212
|
-
|
213
|
-
|
214
|
-
def get_jwt_cookie_strategy() -> JWTStrategy:
|
215
|
-
settings = Inject(get_settings)
|
216
|
-
return JWTStrategy(
|
217
|
-
secret=settings.JWT_SECRET_KEY, # type: ignore
|
218
|
-
lifetime_seconds=settings.COOKIE_EXPIRE_SECONDS,
|
219
|
-
)
|
220
|
-
|
221
|
-
|
222
|
-
token_backend = AuthenticationBackend(
|
223
|
-
name="bearer-jwt",
|
224
|
-
transport=bearer_transport,
|
225
|
-
get_strategy=get_jwt_strategy,
|
226
|
-
)
|
227
|
-
cookie_backend = AuthenticationBackend(
|
228
|
-
name="cookie-jwt",
|
229
|
-
transport=cookie_transport,
|
230
|
-
get_strategy=get_jwt_cookie_strategy,
|
231
|
-
)
|
232
|
-
|
233
|
-
|
234
|
-
fastapi_users = FastAPIUsers[User, int](
|
235
|
-
get_user_manager,
|
236
|
-
[token_backend, cookie_backend],
|
237
|
-
)
|
238
|
-
|
239
|
-
|
240
|
-
# Create dependencies for users
|
241
|
-
current_active_user = fastapi_users.current_user(active=True)
|
242
|
-
current_active_verified_user = fastapi_users.current_user(
|
243
|
-
active=True, verified=True
|
244
|
-
)
|
245
|
-
current_active_superuser = fastapi_users.current_user(
|
246
|
-
active=True, superuser=True
|
247
|
-
)
|
248
|
-
|
249
232
|
get_async_session_context = contextlib.asynccontextmanager(get_async_db)
|
250
233
|
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
|
251
234
|
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
|
@@ -286,9 +269,9 @@ async def _create_first_user(
|
|
286
269
|
|
287
270
|
if is_superuser is True:
|
288
271
|
# If a superuser already exists, exit
|
289
|
-
stm = select(
|
290
|
-
|
291
|
-
)
|
272
|
+
stm = select(UserOAuth).where( # noqa
|
273
|
+
UserOAuth.is_superuser == True # noqa
|
274
|
+
) # noqa
|
292
275
|
res = await session.execute(stm)
|
293
276
|
existing_superuser = res.scalars().first()
|
294
277
|
if existing_superuser is not None:
|
@@ -311,11 +294,25 @@ async def _create_first_user(
|
|
311
294
|
user = await user_manager.create(UserCreate(**kwargs))
|
312
295
|
logger.info(f"User {user.email} created")
|
313
296
|
|
314
|
-
except IntegrityError:
|
315
|
-
logger.warning(
|
316
|
-
f"Creation of user {email} failed with IntegrityError "
|
317
|
-
"(likely due to concurrent attempts from different workers)."
|
318
|
-
)
|
319
|
-
|
320
297
|
except UserAlreadyExists:
|
321
298
|
logger.warning(f"User {email} already exists")
|
299
|
+
|
300
|
+
|
301
|
+
def _create_first_group():
|
302
|
+
logger.info(
|
303
|
+
f"START _create_first_group, with name {FRACTAL_DEFAULT_GROUP_NAME}"
|
304
|
+
)
|
305
|
+
with next(get_sync_db()) as db:
|
306
|
+
group_all = db.execute(select(UserGroup))
|
307
|
+
if group_all.scalars().one_or_none() is None:
|
308
|
+
first_group = UserGroup(name=FRACTAL_DEFAULT_GROUP_NAME)
|
309
|
+
db.add(first_group)
|
310
|
+
db.commit()
|
311
|
+
logger.info(f"Created group {FRACTAL_DEFAULT_GROUP_NAME}")
|
312
|
+
else:
|
313
|
+
logger.info(
|
314
|
+
f"Group {FRACTAL_DEFAULT_GROUP_NAME} already exists, skip."
|
315
|
+
)
|
316
|
+
logger.info(
|
317
|
+
f"END _create_first_group, with name {FRACTAL_DEFAULT_GROUP_NAME}"
|
318
|
+
)
|