fractal-server 2.3.10__py3-none-any.whl → 2.4.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fractal_server/__init__.py +1 -1
- fractal_server/__main__.py +25 -2
- fractal_server/app/models/__init__.py +11 -5
- fractal_server/app/models/linkusergroup.py +11 -0
- fractal_server/app/models/security.py +24 -3
- fractal_server/app/models/v1/project.py +1 -1
- fractal_server/app/models/v2/project.py +3 -3
- fractal_server/app/routes/admin/v1.py +14 -14
- fractal_server/app/routes/admin/v2.py +12 -12
- fractal_server/app/routes/api/__init__.py +2 -2
- fractal_server/app/routes/api/v1/_aux_functions.py +2 -2
- fractal_server/app/routes/api/v1/dataset.py +17 -15
- fractal_server/app/routes/api/v1/job.py +11 -9
- fractal_server/app/routes/api/v1/project.py +9 -9
- fractal_server/app/routes/api/v1/task.py +8 -8
- fractal_server/app/routes/api/v1/task_collection.py +5 -5
- fractal_server/app/routes/api/v1/workflow.py +13 -11
- fractal_server/app/routes/api/v1/workflowtask.py +6 -6
- fractal_server/app/routes/api/v2/_aux_functions.py +2 -2
- fractal_server/app/routes/api/v2/dataset.py +11 -11
- fractal_server/app/routes/api/v2/images.py +6 -6
- fractal_server/app/routes/api/v2/job.py +9 -9
- fractal_server/app/routes/api/v2/project.py +7 -7
- fractal_server/app/routes/api/v2/status.py +3 -3
- fractal_server/app/routes/api/v2/submit.py +3 -3
- fractal_server/app/routes/api/v2/task.py +8 -8
- fractal_server/app/routes/api/v2/task_collection.py +5 -5
- fractal_server/app/routes/api/v2/task_collection_custom.py +3 -3
- fractal_server/app/routes/api/v2/task_legacy.py +9 -9
- fractal_server/app/routes/api/v2/workflow.py +11 -11
- fractal_server/app/routes/api/v2/workflowtask.py +6 -6
- fractal_server/app/routes/auth/__init__.py +55 -0
- fractal_server/app/routes/auth/_aux_auth.py +107 -0
- fractal_server/app/routes/auth/current_user.py +60 -0
- fractal_server/app/routes/auth/group.py +173 -0
- fractal_server/app/routes/auth/group_names.py +34 -0
- fractal_server/app/routes/auth/login.py +25 -0
- fractal_server/app/routes/auth/oauth.py +63 -0
- fractal_server/app/routes/auth/register.py +23 -0
- fractal_server/app/routes/auth/router.py +19 -0
- fractal_server/app/routes/auth/users.py +103 -0
- fractal_server/app/runner/v2/__init__.py +1 -5
- fractal_server/app/runner/v2/_slurm_ssh/__init__.py +17 -0
- fractal_server/app/schemas/user.py +2 -0
- fractal_server/app/schemas/user_group.py +57 -0
- fractal_server/app/security/__init__.py +72 -68
- fractal_server/data_migrations/2_4_0.py +61 -0
- fractal_server/main.py +1 -9
- fractal_server/migrations/versions/091b01f51f88_add_usergroup_and_linkusergroup_table.py +53 -0
- {fractal_server-2.3.10.dist-info → fractal_server-2.4.0a0.dist-info}/METADATA +1 -1
- {fractal_server-2.3.10.dist-info → fractal_server-2.4.0a0.dist-info}/RECORD +54 -41
- fractal_server/app/routes/auth.py +0 -165
- {fractal_server-2.3.10.dist-info → fractal_server-2.4.0a0.dist-info}/LICENSE +0 -0
- {fractal_server-2.3.10.dist-info → fractal_server-2.4.0a0.dist-info}/WHEEL +0 -0
- {fractal_server-2.3.10.dist-info → fractal_server-2.4.0a0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,34 @@
|
|
1
|
+
"""
|
2
|
+
Definition `/auth/group-names/` endpoints
|
3
|
+
"""
|
4
|
+
from fastapi import APIRouter
|
5
|
+
from fastapi import Depends
|
6
|
+
from fastapi import status
|
7
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
8
|
+
from sqlmodel import select
|
9
|
+
|
10
|
+
from . import current_active_user
|
11
|
+
from ...db import get_async_db
|
12
|
+
from fractal_server.app.models import UserGroup
|
13
|
+
from fractal_server.app.models import UserOAuth
|
14
|
+
|
15
|
+
router_group_names = APIRouter()
|
16
|
+
|
17
|
+
|
18
|
+
@router_group_names.get(
|
19
|
+
"/group-names/", response_model=list[str], status_code=status.HTTP_200_OK
|
20
|
+
)
|
21
|
+
async def get_list_user_group_names(
|
22
|
+
user: UserOAuth = Depends(current_active_user),
|
23
|
+
db: AsyncSession = Depends(get_async_db),
|
24
|
+
) -> list[str]:
|
25
|
+
"""
|
26
|
+
Return the available group names.
|
27
|
+
|
28
|
+
This endpoint is not restricted to superusers.
|
29
|
+
"""
|
30
|
+
stm_all_groups = select(UserGroup)
|
31
|
+
res = await db.execute(stm_all_groups)
|
32
|
+
groups = res.scalars().all()
|
33
|
+
group_names = [group.name for group in groups]
|
34
|
+
return group_names
|
@@ -0,0 +1,25 @@
|
|
1
|
+
"""
|
2
|
+
Definition of `/auth/{login,logout}/`, `/auth/token/{login/logout}` routes.
|
3
|
+
"""
|
4
|
+
from fastapi import APIRouter
|
5
|
+
|
6
|
+
from . import cookie_backend
|
7
|
+
from . import fastapi_users
|
8
|
+
from . import token_backend
|
9
|
+
|
10
|
+
router_login = APIRouter()
|
11
|
+
|
12
|
+
|
13
|
+
router_login.include_router(
|
14
|
+
fastapi_users.get_auth_router(token_backend),
|
15
|
+
prefix="/token",
|
16
|
+
)
|
17
|
+
router_login.include_router(
|
18
|
+
fastapi_users.get_auth_router(cookie_backend),
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
# Add trailing slash to all routes paths
|
23
|
+
for route in router_login.routes:
|
24
|
+
if not route.path.endswith("/"):
|
25
|
+
route.path = f"{route.path}/"
|
@@ -0,0 +1,63 @@
|
|
1
|
+
from fastapi import APIRouter
|
2
|
+
|
3
|
+
from . import cookie_backend
|
4
|
+
from . import fastapi_users
|
5
|
+
from ....config import get_settings
|
6
|
+
from ....syringe import Inject
|
7
|
+
|
8
|
+
router_oauth = APIRouter()
|
9
|
+
|
10
|
+
|
11
|
+
# OAUTH CLIENTS
|
12
|
+
|
13
|
+
# NOTE: settings.OAUTH_CLIENTS are collected by
|
14
|
+
# Settings.collect_oauth_clients(). If no specific client is specified in the
|
15
|
+
# environment variables (e.g. by setting OAUTH_FOO_CLIENT_ID and
|
16
|
+
# OAUTH_FOO_CLIENT_SECRET), this list is empty
|
17
|
+
|
18
|
+
# FIXME:Dependency injection should be wrapped within a function call to make
|
19
|
+
# it truly lazy. This function could then be called on startup of the FastAPI
|
20
|
+
# app (cf. fractal_server.main)
|
21
|
+
settings = Inject(get_settings)
|
22
|
+
|
23
|
+
for client_config in settings.OAUTH_CLIENTS_CONFIG:
|
24
|
+
client_name = client_config.CLIENT_NAME.lower()
|
25
|
+
|
26
|
+
if client_name == "google":
|
27
|
+
from httpx_oauth.clients.google import GoogleOAuth2
|
28
|
+
|
29
|
+
client = GoogleOAuth2(
|
30
|
+
client_config.CLIENT_ID, client_config.CLIENT_SECRET
|
31
|
+
)
|
32
|
+
elif client_name == "github":
|
33
|
+
from httpx_oauth.clients.github import GitHubOAuth2
|
34
|
+
|
35
|
+
client = GitHubOAuth2(
|
36
|
+
client_config.CLIENT_ID, client_config.CLIENT_SECRET
|
37
|
+
)
|
38
|
+
else:
|
39
|
+
from httpx_oauth.clients.openid import OpenID
|
40
|
+
|
41
|
+
client = OpenID(
|
42
|
+
client_config.CLIENT_ID,
|
43
|
+
client_config.CLIENT_SECRET,
|
44
|
+
client_config.OIDC_CONFIGURATION_ENDPOINT,
|
45
|
+
)
|
46
|
+
|
47
|
+
router_oauth.include_router(
|
48
|
+
fastapi_users.get_oauth_router(
|
49
|
+
client,
|
50
|
+
cookie_backend,
|
51
|
+
settings.JWT_SECRET_KEY,
|
52
|
+
is_verified_by_default=False,
|
53
|
+
associate_by_email=True,
|
54
|
+
redirect_url=client_config.REDIRECT_URL,
|
55
|
+
),
|
56
|
+
prefix=f"/{client_name}",
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
# Add trailing slash to all routes' paths
|
61
|
+
for route in router_oauth.routes:
|
62
|
+
if not route.path.endswith("/"):
|
63
|
+
route.path = f"{route.path}/"
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""
|
2
|
+
Definition of `/auth/register/` routes.
|
3
|
+
"""
|
4
|
+
from fastapi import APIRouter
|
5
|
+
from fastapi import Depends
|
6
|
+
|
7
|
+
from . import current_active_superuser
|
8
|
+
from . import fastapi_users
|
9
|
+
from ...schemas.user import UserCreate
|
10
|
+
from ...schemas.user import UserRead
|
11
|
+
|
12
|
+
router_register = APIRouter()
|
13
|
+
|
14
|
+
router_register.include_router(
|
15
|
+
fastapi_users.get_register_router(UserRead, UserCreate),
|
16
|
+
dependencies=[Depends(current_active_superuser)],
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
# Add trailing slash to all routes' paths
|
21
|
+
for route in router_register.routes:
|
22
|
+
if not route.path.endswith("/"):
|
23
|
+
route.path = f"{route.path}/"
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from fastapi import APIRouter
|
2
|
+
|
3
|
+
from .current_user import router_current_user
|
4
|
+
from .group import router_group
|
5
|
+
from .group_names import router_group_names
|
6
|
+
from .login import router_login
|
7
|
+
from .oauth import router_oauth
|
8
|
+
from .register import router_register
|
9
|
+
from .users import router_users
|
10
|
+
|
11
|
+
router_auth = APIRouter()
|
12
|
+
|
13
|
+
router_auth.include_router(router_register)
|
14
|
+
router_auth.include_router(router_current_user)
|
15
|
+
router_auth.include_router(router_login)
|
16
|
+
router_auth.include_router(router_group_names)
|
17
|
+
router_auth.include_router(router_users)
|
18
|
+
router_auth.include_router(router_group)
|
19
|
+
router_auth.include_router(router_oauth)
|
@@ -0,0 +1,103 @@
|
|
1
|
+
"""
|
2
|
+
Definition of `/auth/users/` routes
|
3
|
+
"""
|
4
|
+
from fastapi import APIRouter
|
5
|
+
from fastapi import Depends
|
6
|
+
from fastapi import HTTPException
|
7
|
+
from fastapi import status
|
8
|
+
from fastapi_users import exceptions
|
9
|
+
from fastapi_users import schemas
|
10
|
+
from fastapi_users.router.common import ErrorCode
|
11
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
12
|
+
from sqlmodel import select
|
13
|
+
|
14
|
+
from . import current_active_superuser
|
15
|
+
from ...db import get_async_db
|
16
|
+
from ...schemas.user import UserRead
|
17
|
+
from ...schemas.user import UserUpdate
|
18
|
+
from ._aux_auth import _get_single_user_with_group_ids
|
19
|
+
from fractal_server.app.models import LinkUserGroup
|
20
|
+
from fractal_server.app.models import UserOAuth
|
21
|
+
from fractal_server.app.routes.auth._aux_auth import _user_or_404
|
22
|
+
from fractal_server.app.security import get_user_manager
|
23
|
+
from fractal_server.app.security import UserManager
|
24
|
+
|
25
|
+
router_users = APIRouter()
|
26
|
+
|
27
|
+
|
28
|
+
@router_users.get("/users/{user_id}/", response_model=UserRead)
|
29
|
+
async def get_user(
|
30
|
+
user_id: int,
|
31
|
+
group_ids: bool = True,
|
32
|
+
superuser: UserOAuth = Depends(current_active_superuser),
|
33
|
+
db: AsyncSession = Depends(get_async_db),
|
34
|
+
) -> UserRead:
|
35
|
+
user = await _user_or_404(user_id, db)
|
36
|
+
if group_ids:
|
37
|
+
user_with_group_ids = await _get_single_user_with_group_ids(user, db)
|
38
|
+
return user_with_group_ids
|
39
|
+
else:
|
40
|
+
return user
|
41
|
+
|
42
|
+
|
43
|
+
@router_users.patch("/users/{user_id}/", response_model=UserRead)
|
44
|
+
async def patch_user(
|
45
|
+
user_id: int,
|
46
|
+
user_update: UserUpdate,
|
47
|
+
current_superuser: UserOAuth = Depends(current_active_superuser),
|
48
|
+
user_manager: UserManager = Depends(get_user_manager),
|
49
|
+
db: AsyncSession = Depends(get_async_db),
|
50
|
+
):
|
51
|
+
"""
|
52
|
+
Custom version of the PATCH-user route from `fastapi-users`.
|
53
|
+
"""
|
54
|
+
|
55
|
+
user_to_patch = await _user_or_404(user_id, db)
|
56
|
+
|
57
|
+
try:
|
58
|
+
user = await user_manager.update(
|
59
|
+
user_update, user_to_patch, safe=False, request=None
|
60
|
+
)
|
61
|
+
patched_user = schemas.model_validate(UserOAuth, user)
|
62
|
+
except exceptions.InvalidPasswordException as e:
|
63
|
+
raise HTTPException(
|
64
|
+
status_code=status.HTTP_400_BAD_REQUEST,
|
65
|
+
detail={
|
66
|
+
"code": ErrorCode.UPDATE_USER_INVALID_PASSWORD,
|
67
|
+
"reason": e.reason,
|
68
|
+
},
|
69
|
+
)
|
70
|
+
|
71
|
+
patched_user_with_group_ids = await _get_single_user_with_group_ids(
|
72
|
+
patched_user, db
|
73
|
+
)
|
74
|
+
|
75
|
+
return patched_user_with_group_ids
|
76
|
+
|
77
|
+
|
78
|
+
@router_users.get("/users/", response_model=list[UserRead])
|
79
|
+
async def list_users(
|
80
|
+
user: UserOAuth = Depends(current_active_superuser),
|
81
|
+
db: AsyncSession = Depends(get_async_db),
|
82
|
+
):
|
83
|
+
"""
|
84
|
+
Return list of all users
|
85
|
+
"""
|
86
|
+
stm = select(UserOAuth)
|
87
|
+
res = await db.execute(stm)
|
88
|
+
user_list = res.scalars().unique().all()
|
89
|
+
|
90
|
+
# Get all user/group links
|
91
|
+
stm_all_links = select(LinkUserGroup)
|
92
|
+
res = await db.execute(stm_all_links)
|
93
|
+
links = res.scalars().all()
|
94
|
+
|
95
|
+
# FIXME GROUPS: this must be optimized
|
96
|
+
for ind, user in enumerate(user_list):
|
97
|
+
user_list[ind] = dict(
|
98
|
+
user.model_dump(),
|
99
|
+
group_ids=[
|
100
|
+
link.group_id for link in links if link.user_id == user.id
|
101
|
+
],
|
102
|
+
)
|
103
|
+
return user_list
|
@@ -194,15 +194,11 @@ async def submit_workflow(
|
|
194
194
|
folder=str(WORKFLOW_DIR_REMOTE), user=slurm_user
|
195
195
|
)
|
196
196
|
elif FRACTAL_RUNNER_BACKEND == "slurm_ssh":
|
197
|
+
# Folder creation is deferred to _process_workflow
|
197
198
|
WORKFLOW_DIR_REMOTE = (
|
198
199
|
Path(settings.FRACTAL_SLURM_SSH_WORKING_BASE_DIR)
|
199
200
|
/ WORKFLOW_DIR_LOCAL.name
|
200
201
|
)
|
201
|
-
# FIXME SSH: move mkdir to executor, likely within handshake
|
202
|
-
fractal_ssh.mkdir(
|
203
|
-
folder=str(WORKFLOW_DIR_REMOTE),
|
204
|
-
)
|
205
|
-
logger.info(f"Created {str(WORKFLOW_DIR_REMOTE)} via SSH.")
|
206
202
|
else:
|
207
203
|
logger.error(
|
208
204
|
"Invalid FRACTAL_RUNNER_BACKEND="
|
@@ -25,10 +25,15 @@ from .....ssh._fabric import FractalSSH
|
|
25
25
|
from ....models.v2 import DatasetV2
|
26
26
|
from ....models.v2 import WorkflowV2
|
27
27
|
from ...async_wrap import async_wrap
|
28
|
+
from ...exceptions import JobExecutionError
|
28
29
|
from ...executors.slurm.ssh.executor import FractalSlurmSSHExecutor
|
29
30
|
from ...set_start_and_last_task_index import set_start_and_last_task_index
|
30
31
|
from ..runner import execute_tasks_v2
|
31
32
|
from ._submit_setup import _slurm_submit_setup
|
33
|
+
from fractal_server.logger import set_logger
|
34
|
+
|
35
|
+
|
36
|
+
logger = set_logger(__name__)
|
32
37
|
|
33
38
|
|
34
39
|
def _process_workflow(
|
@@ -60,6 +65,18 @@ def _process_workflow(
|
|
60
65
|
if isinstance(worker_init, str):
|
61
66
|
worker_init = worker_init.split("\n")
|
62
67
|
|
68
|
+
# Create main remote folder
|
69
|
+
try:
|
70
|
+
fractal_ssh.mkdir(folder=str(workflow_dir_remote))
|
71
|
+
logger.info(f"Created {str(workflow_dir_remote)} via SSH.")
|
72
|
+
except Exception as e:
|
73
|
+
error_msg = (
|
74
|
+
f"Could not create {str(workflow_dir_remote)} via SSH.\n"
|
75
|
+
f"Original error: {str(e)}."
|
76
|
+
)
|
77
|
+
logger.error(error_msg)
|
78
|
+
raise JobExecutionError(info=error_msg)
|
79
|
+
|
63
80
|
with FractalSlurmSSHExecutor(
|
64
81
|
fractal_ssh=fractal_ssh,
|
65
82
|
workflow_dir_local=workflow_dir_local,
|
@@ -0,0 +1,57 @@
|
|
1
|
+
from datetime import datetime
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
from pydantic import BaseModel
|
5
|
+
from pydantic import Extra
|
6
|
+
from pydantic import Field
|
7
|
+
|
8
|
+
__all__ = (
|
9
|
+
"UserGroupRead",
|
10
|
+
"UserGroupUpdate",
|
11
|
+
"UserGroupCreate",
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
class UserGroupRead(BaseModel):
|
16
|
+
"""
|
17
|
+
Schema for `UserGroup` read
|
18
|
+
|
19
|
+
NOTE: `user_ids` does not correspond to a column of the `UserGroup` table,
|
20
|
+
but it is rather computed dynamically in relevant endpoints.
|
21
|
+
|
22
|
+
Attributes:
|
23
|
+
id: Group ID
|
24
|
+
name: Group name
|
25
|
+
timestamp_created: Creation timestamp
|
26
|
+
user_ids: IDs of users of this group
|
27
|
+
"""
|
28
|
+
|
29
|
+
id: int
|
30
|
+
name: str
|
31
|
+
timestamp_created: datetime
|
32
|
+
user_ids: Optional[list[int]] = None
|
33
|
+
|
34
|
+
|
35
|
+
class UserGroupCreate(BaseModel, extra=Extra.forbid):
|
36
|
+
"""
|
37
|
+
Schema for `UserGroup` creation
|
38
|
+
|
39
|
+
Attributes:
|
40
|
+
name: Group name
|
41
|
+
"""
|
42
|
+
|
43
|
+
name: str
|
44
|
+
|
45
|
+
|
46
|
+
class UserGroupUpdate(BaseModel, extra=Extra.forbid):
|
47
|
+
"""
|
48
|
+
Schema for `UserGroup` update
|
49
|
+
|
50
|
+
NOTE: `new_user_ids` does not correspond to a column of the `UserGroup`
|
51
|
+
table, but it is rather used to create new `LinkUserGroup` rows.
|
52
|
+
|
53
|
+
Attributes:
|
54
|
+
new_user_ids: IDs of groups to be associated to user.
|
55
|
+
"""
|
56
|
+
|
57
|
+
new_user_ids: list[int] = Field(default_factory=list)
|
@@ -29,19 +29,14 @@ All routes are registerd under the `auth/` prefix.
|
|
29
29
|
import contextlib
|
30
30
|
from typing import Any
|
31
31
|
from typing import AsyncGenerator
|
32
|
-
from typing import Dict
|
33
32
|
from typing import Generic
|
34
33
|
from typing import Optional
|
35
34
|
from typing import Type
|
36
35
|
|
37
36
|
from fastapi import Depends
|
37
|
+
from fastapi import Request
|
38
38
|
from fastapi_users import BaseUserManager
|
39
|
-
from fastapi_users import FastAPIUsers
|
40
39
|
from fastapi_users import IntegerIDMixin
|
41
|
-
from fastapi_users.authentication import AuthenticationBackend
|
42
|
-
from fastapi_users.authentication import BearerTransport
|
43
|
-
from fastapi_users.authentication import CookieTransport
|
44
|
-
from fastapi_users.authentication import JWTStrategy
|
45
40
|
from fastapi_users.db.base import BaseUserDatabase
|
46
41
|
from fastapi_users.exceptions import InvalidPasswordException
|
47
42
|
from fastapi_users.exceptions import UserAlreadyExists
|
@@ -54,16 +49,19 @@ from sqlalchemy.orm import selectinload
|
|
54
49
|
from sqlmodel import func
|
55
50
|
from sqlmodel import select
|
56
51
|
|
57
|
-
from ...config import get_settings
|
58
|
-
from ...syringe import Inject
|
59
52
|
from ..db import get_async_db
|
60
|
-
from fractal_server.app.
|
61
|
-
from fractal_server.app.models
|
53
|
+
from fractal_server.app.db import get_sync_db
|
54
|
+
from fractal_server.app.models import LinkUserGroup
|
55
|
+
from fractal_server.app.models import OAuthAccount
|
56
|
+
from fractal_server.app.models import UserGroup
|
57
|
+
from fractal_server.app.models import UserOAuth
|
62
58
|
from fractal_server.app.schemas.user import UserCreate
|
63
59
|
from fractal_server.logger import get_logger
|
64
60
|
|
65
61
|
logger = get_logger(__name__)
|
66
62
|
|
63
|
+
FRACTAL_DEFAULT_GROUP_NAME = "All"
|
64
|
+
|
67
65
|
|
68
66
|
class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
|
69
67
|
"""
|
@@ -125,7 +123,7 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
|
|
125
123
|
return user
|
126
124
|
return None
|
127
125
|
|
128
|
-
async def create(self, create_dict:
|
126
|
+
async def create(self, create_dict: dict[str, Any]) -> UP:
|
129
127
|
"""Create a user."""
|
130
128
|
user = self.user_model(**create_dict)
|
131
129
|
self.session.add(user)
|
@@ -133,7 +131,7 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
|
|
133
131
|
await self.session.refresh(user)
|
134
132
|
return user
|
135
133
|
|
136
|
-
async def update(self, user: UP, update_dict:
|
134
|
+
async def update(self, user: UP, update_dict: dict[str, Any]) -> UP:
|
137
135
|
for key, value in update_dict.items():
|
138
136
|
setattr(user, key, value)
|
139
137
|
self.session.add(user)
|
@@ -146,7 +144,7 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
|
|
146
144
|
await self.session.commit()
|
147
145
|
|
148
146
|
async def add_oauth_account(
|
149
|
-
self, user: UP, create_dict:
|
147
|
+
self, user: UP, create_dict: dict[str, Any]
|
150
148
|
) -> UP: # noqa
|
151
149
|
if self.oauth_account_model is None:
|
152
150
|
raise NotImplementedError()
|
@@ -160,7 +158,7 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
|
|
160
158
|
return user
|
161
159
|
|
162
160
|
async def update_oauth_account(
|
163
|
-
self, user: UP, oauth_account: OAP, update_dict:
|
161
|
+
self, user: UP, oauth_account: OAP, update_dict: dict[str, Any]
|
164
162
|
) -> UP:
|
165
163
|
if self.oauth_account_model is None:
|
166
164
|
raise NotImplementedError()
|
@@ -176,13 +174,14 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
|
|
176
174
|
async def get_user_db(
|
177
175
|
session: AsyncSession = Depends(get_async_db),
|
178
176
|
) -> AsyncGenerator[SQLModelUserDatabaseAsync, None]:
|
179
|
-
yield SQLModelUserDatabaseAsync(session,
|
177
|
+
yield SQLModelUserDatabaseAsync(session, UserOAuth, OAuthAccount)
|
180
178
|
|
181
179
|
|
182
|
-
class UserManager(IntegerIDMixin, BaseUserManager[
|
183
|
-
async def validate_password(self, password: str, user:
|
180
|
+
class UserManager(IntegerIDMixin, BaseUserManager[UserOAuth, int]):
|
181
|
+
async def validate_password(self, password: str, user: UserOAuth) -> None:
|
184
182
|
# check password length
|
185
|
-
min_length
|
183
|
+
min_length = 4
|
184
|
+
max_length = 100
|
186
185
|
if len(password) < min_length:
|
187
186
|
raise InvalidPasswordException(
|
188
187
|
f"The password is too short (minimum length: {min_length})."
|
@@ -192,6 +191,38 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
|
|
192
191
|
f"The password is too long (maximum length: {min_length})."
|
193
192
|
)
|
194
193
|
|
194
|
+
async def on_after_register(
|
195
|
+
self, user: UserOAuth, request: Optional[Request] = None
|
196
|
+
):
|
197
|
+
logger.info(
|
198
|
+
f"New-user registration completed ({user.id=}, {user.email=})."
|
199
|
+
)
|
200
|
+
async for db in get_async_db():
|
201
|
+
# Find default group
|
202
|
+
stm = select(UserGroup).where(
|
203
|
+
UserGroup.name == FRACTAL_DEFAULT_GROUP_NAME
|
204
|
+
)
|
205
|
+
res = await db.execute(stm)
|
206
|
+
default_group = res.scalar_one_or_none()
|
207
|
+
if default_group is None:
|
208
|
+
logger.error(
|
209
|
+
f"No group found with name {FRACTAL_DEFAULT_GROUP_NAME}"
|
210
|
+
)
|
211
|
+
else:
|
212
|
+
logger.warning(
|
213
|
+
f"START adding {user.email} user to group "
|
214
|
+
f"{default_group.id=}."
|
215
|
+
)
|
216
|
+
link = LinkUserGroup(
|
217
|
+
user_id=user.id, group_id=default_group.id
|
218
|
+
)
|
219
|
+
db.add(link)
|
220
|
+
await db.commit()
|
221
|
+
logger.warning(
|
222
|
+
f"END adding {user.email} user to group "
|
223
|
+
f"{default_group.id=}."
|
224
|
+
)
|
225
|
+
|
195
226
|
|
196
227
|
async def get_user_manager(
|
197
228
|
user_db: SQLModelUserDatabaseAsync = Depends(get_user_db),
|
@@ -199,53 +230,6 @@ async def get_user_manager(
|
|
199
230
|
yield UserManager(user_db)
|
200
231
|
|
201
232
|
|
202
|
-
bearer_transport = BearerTransport(tokenUrl="/auth/token/login")
|
203
|
-
cookie_transport = CookieTransport(cookie_samesite="none")
|
204
|
-
|
205
|
-
|
206
|
-
def get_jwt_strategy() -> JWTStrategy:
|
207
|
-
settings = Inject(get_settings)
|
208
|
-
return JWTStrategy(
|
209
|
-
secret=settings.JWT_SECRET_KEY, # type: ignore
|
210
|
-
lifetime_seconds=settings.JWT_EXPIRE_SECONDS,
|
211
|
-
)
|
212
|
-
|
213
|
-
|
214
|
-
def get_jwt_cookie_strategy() -> JWTStrategy:
|
215
|
-
settings = Inject(get_settings)
|
216
|
-
return JWTStrategy(
|
217
|
-
secret=settings.JWT_SECRET_KEY, # type: ignore
|
218
|
-
lifetime_seconds=settings.COOKIE_EXPIRE_SECONDS,
|
219
|
-
)
|
220
|
-
|
221
|
-
|
222
|
-
token_backend = AuthenticationBackend(
|
223
|
-
name="bearer-jwt",
|
224
|
-
transport=bearer_transport,
|
225
|
-
get_strategy=get_jwt_strategy,
|
226
|
-
)
|
227
|
-
cookie_backend = AuthenticationBackend(
|
228
|
-
name="cookie-jwt",
|
229
|
-
transport=cookie_transport,
|
230
|
-
get_strategy=get_jwt_cookie_strategy,
|
231
|
-
)
|
232
|
-
|
233
|
-
|
234
|
-
fastapi_users = FastAPIUsers[User, int](
|
235
|
-
get_user_manager,
|
236
|
-
[token_backend, cookie_backend],
|
237
|
-
)
|
238
|
-
|
239
|
-
|
240
|
-
# Create dependencies for users
|
241
|
-
current_active_user = fastapi_users.current_user(active=True)
|
242
|
-
current_active_verified_user = fastapi_users.current_user(
|
243
|
-
active=True, verified=True
|
244
|
-
)
|
245
|
-
current_active_superuser = fastapi_users.current_user(
|
246
|
-
active=True, superuser=True
|
247
|
-
)
|
248
|
-
|
249
233
|
get_async_session_context = contextlib.asynccontextmanager(get_async_db)
|
250
234
|
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
|
251
235
|
get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
|
@@ -286,9 +270,9 @@ async def _create_first_user(
|
|
286
270
|
|
287
271
|
if is_superuser is True:
|
288
272
|
# If a superuser already exists, exit
|
289
|
-
stm = select(
|
290
|
-
|
291
|
-
)
|
273
|
+
stm = select(UserOAuth).where( # noqa
|
274
|
+
UserOAuth.is_superuser == True # noqa
|
275
|
+
) # noqa
|
292
276
|
res = await session.execute(stm)
|
293
277
|
existing_superuser = res.scalars().first()
|
294
278
|
if existing_superuser is not None:
|
@@ -319,3 +303,23 @@ async def _create_first_user(
|
|
319
303
|
|
320
304
|
except UserAlreadyExists:
|
321
305
|
logger.warning(f"User {email} already exists")
|
306
|
+
|
307
|
+
|
308
|
+
def _create_first_group():
|
309
|
+
logger.info(
|
310
|
+
f"START _create_first_group, with name {FRACTAL_DEFAULT_GROUP_NAME}"
|
311
|
+
)
|
312
|
+
with next(get_sync_db()) as db:
|
313
|
+
group_all = db.execute(select(UserGroup))
|
314
|
+
if group_all.scalars().one_or_none() is None:
|
315
|
+
first_group = UserGroup(name=FRACTAL_DEFAULT_GROUP_NAME)
|
316
|
+
db.add(first_group)
|
317
|
+
db.commit()
|
318
|
+
logger.info(f"Created group {FRACTAL_DEFAULT_GROUP_NAME}")
|
319
|
+
else:
|
320
|
+
logger.info(
|
321
|
+
f"Group {FRACTAL_DEFAULT_GROUP_NAME} already exists, skip."
|
322
|
+
)
|
323
|
+
logger.info(
|
324
|
+
f"END _create_first_group, with name {FRACTAL_DEFAULT_GROUP_NAME}"
|
325
|
+
)
|
@@ -0,0 +1,61 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
from packaging.version import parse
|
4
|
+
from sqlalchemy import select
|
5
|
+
|
6
|
+
import fractal_server
|
7
|
+
from fractal_server.app.db import get_sync_db
|
8
|
+
from fractal_server.app.models import LinkUserGroup
|
9
|
+
from fractal_server.app.models import UserGroup
|
10
|
+
from fractal_server.app.models import UserOAuth
|
11
|
+
from fractal_server.app.security import FRACTAL_DEFAULT_GROUP_NAME
|
12
|
+
|
13
|
+
|
14
|
+
def _check_current_version(*, expected_version: str):
|
15
|
+
# Check that this module matches with the current version
|
16
|
+
module_version = parse(expected_version)
|
17
|
+
current_version = parse(fractal_server.__VERSION__)
|
18
|
+
if (
|
19
|
+
current_version.major != module_version.major
|
20
|
+
or current_version.minor != module_version.minor
|
21
|
+
or current_version.micro != module_version.micro
|
22
|
+
):
|
23
|
+
raise RuntimeError(
|
24
|
+
f"{fractal_server.__VERSION__=} not matching with {__file__=}"
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
def fix_db():
|
29
|
+
logger = logging.getLogger("fix_db")
|
30
|
+
logger.warning("START execution of fix_db function")
|
31
|
+
_check_current_version(expected_version="2.4.0")
|
32
|
+
|
33
|
+
with next(get_sync_db()) as db:
|
34
|
+
# Find default group
|
35
|
+
stm = select(UserGroup).where(
|
36
|
+
UserGroup.name == FRACTAL_DEFAULT_GROUP_NAME
|
37
|
+
)
|
38
|
+
res = db.execute(stm)
|
39
|
+
default_group = res.scalar_one_or_none()
|
40
|
+
if default_group is None:
|
41
|
+
raise RuntimeError("Default group not found, exit.")
|
42
|
+
logger.warning(
|
43
|
+
"Default user group exists: "
|
44
|
+
f"{default_group.id=}, {default_group.name=}."
|
45
|
+
)
|
46
|
+
|
47
|
+
# Find
|
48
|
+
stm = select(UserOAuth)
|
49
|
+
users = db.execute(stm).scalars().unique().all()
|
50
|
+
for user in sorted(users, key=lambda x: x.id):
|
51
|
+
logger.warning(
|
52
|
+
f"START adding {user.id=} ({user.email=}) to default group."
|
53
|
+
)
|
54
|
+
link = LinkUserGroup(user_id=user.id, group_id=default_group.id)
|
55
|
+
db.add(link)
|
56
|
+
db.commit()
|
57
|
+
logger.warning(
|
58
|
+
f"END adding {user.id=} ({user.email=}) to default group."
|
59
|
+
)
|
60
|
+
|
61
|
+
logger.warning("END of execution of fix_db function")
|