fractal-server 2.3.11__py3-none-any.whl → 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fractal_server/__init__.py +1 -1
  2. fractal_server/__main__.py +25 -2
  3. fractal_server/app/models/__init__.py +11 -5
  4. fractal_server/app/models/linkusergroup.py +11 -0
  5. fractal_server/app/models/security.py +24 -3
  6. fractal_server/app/models/v1/project.py +1 -1
  7. fractal_server/app/models/v2/project.py +3 -3
  8. fractal_server/app/routes/admin/v1.py +14 -14
  9. fractal_server/app/routes/admin/v2.py +12 -12
  10. fractal_server/app/routes/api/__init__.py +2 -2
  11. fractal_server/app/routes/api/v1/_aux_functions.py +2 -2
  12. fractal_server/app/routes/api/v1/dataset.py +17 -15
  13. fractal_server/app/routes/api/v1/job.py +11 -9
  14. fractal_server/app/routes/api/v1/project.py +9 -9
  15. fractal_server/app/routes/api/v1/task.py +8 -8
  16. fractal_server/app/routes/api/v1/task_collection.py +5 -5
  17. fractal_server/app/routes/api/v1/workflow.py +13 -11
  18. fractal_server/app/routes/api/v1/workflowtask.py +6 -6
  19. fractal_server/app/routes/api/v2/_aux_functions.py +2 -2
  20. fractal_server/app/routes/api/v2/dataset.py +11 -11
  21. fractal_server/app/routes/api/v2/images.py +6 -6
  22. fractal_server/app/routes/api/v2/job.py +9 -9
  23. fractal_server/app/routes/api/v2/project.py +7 -7
  24. fractal_server/app/routes/api/v2/status.py +3 -3
  25. fractal_server/app/routes/api/v2/submit.py +3 -3
  26. fractal_server/app/routes/api/v2/task.py +8 -8
  27. fractal_server/app/routes/api/v2/task_collection.py +5 -5
  28. fractal_server/app/routes/api/v2/task_collection_custom.py +3 -3
  29. fractal_server/app/routes/api/v2/task_legacy.py +9 -9
  30. fractal_server/app/routes/api/v2/workflow.py +11 -11
  31. fractal_server/app/routes/api/v2/workflowtask.py +6 -6
  32. fractal_server/app/routes/auth/__init__.py +55 -0
  33. fractal_server/app/routes/auth/_aux_auth.py +107 -0
  34. fractal_server/app/routes/auth/current_user.py +60 -0
  35. fractal_server/app/routes/auth/group.py +176 -0
  36. fractal_server/app/routes/auth/group_names.py +34 -0
  37. fractal_server/app/routes/auth/login.py +25 -0
  38. fractal_server/app/routes/auth/oauth.py +63 -0
  39. fractal_server/app/routes/auth/register.py +23 -0
  40. fractal_server/app/routes/auth/router.py +19 -0
  41. fractal_server/app/routes/auth/users.py +192 -0
  42. fractal_server/app/schemas/user.py +11 -0
  43. fractal_server/app/schemas/user_group.py +65 -0
  44. fractal_server/app/security/__init__.py +72 -75
  45. fractal_server/data_migrations/2_4_0.py +61 -0
  46. fractal_server/main.py +1 -9
  47. fractal_server/migrations/versions/091b01f51f88_add_usergroup_and_linkusergroup_table.py +53 -0
  48. {fractal_server-2.3.11.dist-info → fractal_server-2.4.0.dist-info}/METADATA +1 -1
  49. {fractal_server-2.3.11.dist-info → fractal_server-2.4.0.dist-info}/RECORD +52 -39
  50. fractal_server/app/routes/auth.py +0 -165
  51. {fractal_server-2.3.11.dist-info → fractal_server-2.4.0.dist-info}/LICENSE +0 -0
  52. {fractal_server-2.3.11.dist-info → fractal_server-2.4.0.dist-info}/WHEEL +0 -0
  53. {fractal_server-2.3.11.dist-info → fractal_server-2.4.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,34 @@
1
+ """
2
+ Definition `/auth/group-names/` endpoints
3
+ """
4
+ from fastapi import APIRouter
5
+ from fastapi import Depends
6
+ from fastapi import status
7
+ from sqlalchemy.ext.asyncio import AsyncSession
8
+ from sqlmodel import select
9
+
10
+ from . import current_active_user
11
+ from ...db import get_async_db
12
+ from fractal_server.app.models import UserGroup
13
+ from fractal_server.app.models import UserOAuth
14
+
15
+ router_group_names = APIRouter()
16
+
17
+
18
+ @router_group_names.get(
19
+ "/group-names/", response_model=list[str], status_code=status.HTTP_200_OK
20
+ )
21
+ async def get_list_user_group_names(
22
+ user: UserOAuth = Depends(current_active_user),
23
+ db: AsyncSession = Depends(get_async_db),
24
+ ) -> list[str]:
25
+ """
26
+ Return the available group names.
27
+
28
+ This endpoint is not restricted to superusers.
29
+ """
30
+ stm_all_groups = select(UserGroup)
31
+ res = await db.execute(stm_all_groups)
32
+ groups = res.scalars().all()
33
+ group_names = [group.name for group in groups]
34
+ return group_names
@@ -0,0 +1,25 @@
1
+ """
2
+ Definition of `/auth/{login,logout}/`, `/auth/token/{login/logout}` routes.
3
+ """
4
+ from fastapi import APIRouter
5
+
6
+ from . import cookie_backend
7
+ from . import fastapi_users
8
+ from . import token_backend
9
+
10
+ router_login = APIRouter()
11
+
12
+
13
+ router_login.include_router(
14
+ fastapi_users.get_auth_router(token_backend),
15
+ prefix="/token",
16
+ )
17
+ router_login.include_router(
18
+ fastapi_users.get_auth_router(cookie_backend),
19
+ )
20
+
21
+
22
+ # Add trailing slash to all routes paths
23
+ for route in router_login.routes:
24
+ if not route.path.endswith("/"):
25
+ route.path = f"{route.path}/"
@@ -0,0 +1,63 @@
1
+ from fastapi import APIRouter
2
+
3
+ from . import cookie_backend
4
+ from . import fastapi_users
5
+ from ....config import get_settings
6
+ from ....syringe import Inject
7
+
8
+ router_oauth = APIRouter()
9
+
10
+
11
+ # OAUTH CLIENTS
12
+
13
+ # NOTE: settings.OAUTH_CLIENTS are collected by
14
+ # Settings.collect_oauth_clients(). If no specific client is specified in the
15
+ # environment variables (e.g. by setting OAUTH_FOO_CLIENT_ID and
16
+ # OAUTH_FOO_CLIENT_SECRET), this list is empty
17
+
18
+ # FIXME:Dependency injection should be wrapped within a function call to make
19
+ # it truly lazy. This function could then be called on startup of the FastAPI
20
+ # app (cf. fractal_server.main)
21
+ settings = Inject(get_settings)
22
+
23
+ for client_config in settings.OAUTH_CLIENTS_CONFIG:
24
+ client_name = client_config.CLIENT_NAME.lower()
25
+
26
+ if client_name == "google":
27
+ from httpx_oauth.clients.google import GoogleOAuth2
28
+
29
+ client = GoogleOAuth2(
30
+ client_config.CLIENT_ID, client_config.CLIENT_SECRET
31
+ )
32
+ elif client_name == "github":
33
+ from httpx_oauth.clients.github import GitHubOAuth2
34
+
35
+ client = GitHubOAuth2(
36
+ client_config.CLIENT_ID, client_config.CLIENT_SECRET
37
+ )
38
+ else:
39
+ from httpx_oauth.clients.openid import OpenID
40
+
41
+ client = OpenID(
42
+ client_config.CLIENT_ID,
43
+ client_config.CLIENT_SECRET,
44
+ client_config.OIDC_CONFIGURATION_ENDPOINT,
45
+ )
46
+
47
+ router_oauth.include_router(
48
+ fastapi_users.get_oauth_router(
49
+ client,
50
+ cookie_backend,
51
+ settings.JWT_SECRET_KEY,
52
+ is_verified_by_default=False,
53
+ associate_by_email=True,
54
+ redirect_url=client_config.REDIRECT_URL,
55
+ ),
56
+ prefix=f"/{client_name}",
57
+ )
58
+
59
+
60
+ # Add trailing slash to all routes' paths
61
+ for route in router_oauth.routes:
62
+ if not route.path.endswith("/"):
63
+ route.path = f"{route.path}/"
@@ -0,0 +1,23 @@
1
+ """
2
+ Definition of `/auth/register/` routes.
3
+ """
4
+ from fastapi import APIRouter
5
+ from fastapi import Depends
6
+
7
+ from . import current_active_superuser
8
+ from . import fastapi_users
9
+ from ...schemas.user import UserCreate
10
+ from ...schemas.user import UserRead
11
+
12
+ router_register = APIRouter()
13
+
14
+ router_register.include_router(
15
+ fastapi_users.get_register_router(UserRead, UserCreate),
16
+ dependencies=[Depends(current_active_superuser)],
17
+ )
18
+
19
+
20
+ # Add trailing slash to all routes' paths
21
+ for route in router_register.routes:
22
+ if not route.path.endswith("/"):
23
+ route.path = f"{route.path}/"
@@ -0,0 +1,19 @@
1
+ from fastapi import APIRouter
2
+
3
+ from .current_user import router_current_user
4
+ from .group import router_group
5
+ from .group_names import router_group_names
6
+ from .login import router_login
7
+ from .oauth import router_oauth
8
+ from .register import router_register
9
+ from .users import router_users
10
+
11
+ router_auth = APIRouter()
12
+
13
+ router_auth.include_router(router_register)
14
+ router_auth.include_router(router_current_user)
15
+ router_auth.include_router(router_login)
16
+ router_auth.include_router(router_group_names)
17
+ router_auth.include_router(router_users)
18
+ router_auth.include_router(router_group)
19
+ router_auth.include_router(router_oauth)
@@ -0,0 +1,192 @@
1
+ """
2
+ Definition of `/auth/users/` routes
3
+ """
4
+ from fastapi import APIRouter
5
+ from fastapi import Depends
6
+ from fastapi import HTTPException
7
+ from fastapi import status
8
+ from fastapi_users import exceptions
9
+ from fastapi_users import schemas
10
+ from fastapi_users.router.common import ErrorCode
11
+ from sqlalchemy.exc import IntegrityError
12
+ from sqlalchemy.ext.asyncio import AsyncSession
13
+ from sqlmodel import col
14
+ from sqlmodel import func
15
+ from sqlmodel import select
16
+
17
+ from . import current_active_superuser
18
+ from ...db import get_async_db
19
+ from ...schemas.user import UserRead
20
+ from ...schemas.user import UserUpdate
21
+ from ...schemas.user import UserUpdateWithNewGroupIds
22
+ from ._aux_auth import _get_single_user_with_group_ids
23
+ from fractal_server.app.models import LinkUserGroup
24
+ from fractal_server.app.models import UserGroup
25
+ from fractal_server.app.models import UserOAuth
26
+ from fractal_server.app.routes.auth._aux_auth import _user_or_404
27
+ from fractal_server.app.security import get_user_manager
28
+ from fractal_server.app.security import UserManager
29
+ from fractal_server.logger import set_logger
30
+
31
+ router_users = APIRouter()
32
+
33
+
34
+ logger = set_logger(__name__)
35
+
36
+
37
+ @router_users.get("/users/{user_id}/", response_model=UserRead)
38
+ async def get_user(
39
+ user_id: int,
40
+ group_ids: bool = True,
41
+ superuser: UserOAuth = Depends(current_active_superuser),
42
+ db: AsyncSession = Depends(get_async_db),
43
+ ) -> UserRead:
44
+ user = await _user_or_404(user_id, db)
45
+ if group_ids:
46
+ user_with_group_ids = await _get_single_user_with_group_ids(user, db)
47
+ return user_with_group_ids
48
+ else:
49
+ return user
50
+
51
+
52
+ @router_users.patch("/users/{user_id}/", response_model=UserRead)
53
+ async def patch_user(
54
+ user_id: int,
55
+ user_update: UserUpdateWithNewGroupIds,
56
+ current_superuser: UserOAuth = Depends(current_active_superuser),
57
+ user_manager: UserManager = Depends(get_user_manager),
58
+ db: AsyncSession = Depends(get_async_db),
59
+ ):
60
+ """
61
+ Custom version of the PATCH-user route from `fastapi-users`.
62
+
63
+ In order to keep the fastapi-users logic in place (which is convenient to
64
+ update user attributes), we split the endpoint into two branches. We either
65
+ go through the fastapi-users-based attribute-update branch, or through the
66
+ branch where we establish new user/group relationships.
67
+
68
+ Note that we prevent making both changes at the same time, since it would
69
+ be more complex to guarantee that endpoint error would leave the database
70
+ in the same state as before the API call.
71
+ """
72
+
73
+ # We prevent simultaneous editing of both user attributes and user/group
74
+ # associations
75
+ user_update_dict_without_groups = user_update.dict(
76
+ exclude_unset=True, exclude={"new_group_ids"}
77
+ )
78
+ edit_attributes = user_update_dict_without_groups != {}
79
+ edit_groups = user_update.new_group_ids is not None
80
+ if edit_attributes and edit_groups:
81
+ raise HTTPException(
82
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
83
+ detail=(
84
+ "Cannot modify both user attributes and group membership. "
85
+ "Please make two independent PATCH calls"
86
+ ),
87
+ )
88
+
89
+ # Check that user exists
90
+ user_to_patch = await _user_or_404(user_id, db)
91
+
92
+ if edit_groups:
93
+ # Establish new user/group relationships
94
+
95
+ # Check that all required groups exist
96
+ # Note: The reason for introducing `col` is as in
97
+ # https://sqlmodel.tiangolo.com/tutorial/where/#type-annotations-and-errors,
98
+ stm = select(func.count()).where(
99
+ col(UserGroup.id).in_(user_update.new_group_ids)
100
+ )
101
+ res = await db.execute(stm)
102
+ number_matching_groups = res.scalar()
103
+ if number_matching_groups != len(user_update.new_group_ids):
104
+ raise HTTPException(
105
+ status_code=status.HTTP_404_NOT_FOUND,
106
+ detail=(
107
+ "Not all requested groups (IDs: "
108
+ f"{user_update.new_group_ids}) exist."
109
+ ),
110
+ )
111
+
112
+ for new_group_id in user_update.new_group_ids:
113
+ link = LinkUserGroup(user_id=user_id, group_id=new_group_id)
114
+ db.add(link)
115
+
116
+ try:
117
+ await db.commit()
118
+ except IntegrityError as e:
119
+ error_msg = (
120
+ f"Cannot link groups with IDs {user_update.new_group_ids} "
121
+ f"to user {user_id}. "
122
+ "Likely reason: one of these links already exists.\n"
123
+ f"Original error: {str(e)}"
124
+ )
125
+ logger.info(error_msg)
126
+ raise HTTPException(
127
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
128
+ detail=error_msg,
129
+ )
130
+
131
+ patched_user = user_to_patch
132
+
133
+ elif edit_attributes:
134
+ # Modify user attributes
135
+ try:
136
+ user_update_without_groups = UserUpdate(
137
+ **user_update_dict_without_groups
138
+ )
139
+ user = await user_manager.update(
140
+ user_update_without_groups,
141
+ user_to_patch,
142
+ safe=False,
143
+ request=None,
144
+ )
145
+ patched_user = schemas.model_validate(UserOAuth, user)
146
+ except exceptions.InvalidPasswordException as e:
147
+ raise HTTPException(
148
+ status_code=status.HTTP_400_BAD_REQUEST,
149
+ detail={
150
+ "code": ErrorCode.UPDATE_USER_INVALID_PASSWORD,
151
+ "reason": e.reason,
152
+ },
153
+ )
154
+ else:
155
+ # Nothing to do, just continue
156
+ patched_user = user_to_patch
157
+
158
+ # Enrich user object with `group_ids` attribute
159
+ patched_user_with_group_ids = await _get_single_user_with_group_ids(
160
+ patched_user, db
161
+ )
162
+
163
+ return patched_user_with_group_ids
164
+
165
+
166
+ @router_users.get("/users/", response_model=list[UserRead])
167
+ async def list_users(
168
+ user: UserOAuth = Depends(current_active_superuser),
169
+ db: AsyncSession = Depends(get_async_db),
170
+ ):
171
+ """
172
+ Return list of all users
173
+ """
174
+ stm = select(UserOAuth)
175
+ res = await db.execute(stm)
176
+ user_list = res.scalars().unique().all()
177
+
178
+ # Get all user/group links
179
+ stm_all_links = select(LinkUserGroup)
180
+ res = await db.execute(stm_all_links)
181
+ links = res.scalars().all()
182
+
183
+ # TODO: possible optimizations for this construction are listed in
184
+ # https://github.com/fractal-analytics-platform/fractal-server/issues/1742
185
+ for ind, user in enumerate(user_list):
186
+ user_list[ind] = dict(
187
+ user.model_dump(),
188
+ group_ids=[
189
+ link.group_id for link in links if link.user_id == user.id
190
+ ],
191
+ )
192
+ return user_list
@@ -16,6 +16,7 @@ __all__ = (
16
16
  "UserRead",
17
17
  "UserUpdate",
18
18
  "UserCreate",
19
+ "UserUpdateWithNewGroupIds",
19
20
  )
20
21
 
21
22
 
@@ -34,6 +35,8 @@ class UserRead(schemas.BaseUser[int]):
34
35
  cache_dir: Optional[str]
35
36
  username: Optional[str]
36
37
  slurm_accounts: list[str]
38
+ group_names: Optional[list[str]] = None
39
+ group_ids: Optional[list[int]] = None
37
40
 
38
41
 
39
42
  class UserUpdate(schemas.BaseUserUpdate):
@@ -100,6 +103,14 @@ class UserUpdateStrict(BaseModel, extra=Extra.forbid):
100
103
  )
101
104
 
102
105
 
106
+ class UserUpdateWithNewGroupIds(UserUpdate):
107
+ new_group_ids: Optional[list[int]] = None
108
+
109
+ _val_unique = validator("new_group_ids", allow_reuse=True)(
110
+ val_unique_list("new_group_ids")
111
+ )
112
+
113
+
103
114
  class UserCreate(schemas.BaseUserCreate):
104
115
  """
105
116
  Schema for `User` creation.
@@ -0,0 +1,65 @@
1
+ from datetime import datetime
2
+ from typing import Optional
3
+
4
+ from pydantic import BaseModel
5
+ from pydantic import Extra
6
+ from pydantic import Field
7
+ from pydantic import validator
8
+
9
+ from ._validators import val_unique_list
10
+
11
+
12
+ __all__ = (
13
+ "UserGroupRead",
14
+ "UserGroupUpdate",
15
+ "UserGroupCreate",
16
+ )
17
+
18
+
19
+ class UserGroupRead(BaseModel):
20
+ """
21
+ Schema for `UserGroup` read
22
+
23
+ NOTE: `user_ids` does not correspond to a column of the `UserGroup` table,
24
+ but it is rather computed dynamically in relevant endpoints.
25
+
26
+ Attributes:
27
+ id: Group ID
28
+ name: Group name
29
+ timestamp_created: Creation timestamp
30
+ user_ids: IDs of users of this group
31
+ """
32
+
33
+ id: int
34
+ name: str
35
+ timestamp_created: datetime
36
+ user_ids: Optional[list[int]] = None
37
+
38
+
39
+ class UserGroupCreate(BaseModel, extra=Extra.forbid):
40
+ """
41
+ Schema for `UserGroup` creation
42
+
43
+ Attributes:
44
+ name: Group name
45
+ """
46
+
47
+ name: str
48
+
49
+
50
+ class UserGroupUpdate(BaseModel, extra=Extra.forbid):
51
+ """
52
+ Schema for `UserGroup` update
53
+
54
+ NOTE: `new_user_ids` does not correspond to a column of the `UserGroup`
55
+ table, but it is rather used to create new `LinkUserGroup` rows.
56
+
57
+ Attributes:
58
+ new_user_ids: IDs of groups to be associated to user.
59
+ """
60
+
61
+ new_user_ids: list[int] = Field(default_factory=list)
62
+
63
+ _val_unique = validator("new_user_ids", allow_reuse=True)(
64
+ val_unique_list("new_user_ids")
65
+ )
@@ -29,41 +29,38 @@ All routes are registerd under the `auth/` prefix.
29
29
  import contextlib
30
30
  from typing import Any
31
31
  from typing import AsyncGenerator
32
- from typing import Dict
33
32
  from typing import Generic
34
33
  from typing import Optional
35
34
  from typing import Type
36
35
 
37
36
  from fastapi import Depends
37
+ from fastapi import Request
38
38
  from fastapi_users import BaseUserManager
39
- from fastapi_users import FastAPIUsers
40
39
  from fastapi_users import IntegerIDMixin
41
- from fastapi_users.authentication import AuthenticationBackend
42
- from fastapi_users.authentication import BearerTransport
43
- from fastapi_users.authentication import CookieTransport
44
- from fastapi_users.authentication import JWTStrategy
45
40
  from fastapi_users.db.base import BaseUserDatabase
46
41
  from fastapi_users.exceptions import InvalidPasswordException
47
42
  from fastapi_users.exceptions import UserAlreadyExists
48
43
  from fastapi_users.models import ID
49
44
  from fastapi_users.models import OAP
50
45
  from fastapi_users.models import UP
51
- from sqlalchemy.exc import IntegrityError
52
46
  from sqlalchemy.ext.asyncio import AsyncSession
53
47
  from sqlalchemy.orm import selectinload
54
48
  from sqlmodel import func
55
49
  from sqlmodel import select
56
50
 
57
- from ...config import get_settings
58
- from ...syringe import Inject
59
51
  from ..db import get_async_db
60
- from fractal_server.app.models.security import OAuthAccount
61
- from fractal_server.app.models.security import UserOAuth as User
52
+ from fractal_server.app.db import get_sync_db
53
+ from fractal_server.app.models import LinkUserGroup
54
+ from fractal_server.app.models import OAuthAccount
55
+ from fractal_server.app.models import UserGroup
56
+ from fractal_server.app.models import UserOAuth
62
57
  from fractal_server.app.schemas.user import UserCreate
63
58
  from fractal_server.logger import get_logger
64
59
 
65
60
  logger = get_logger(__name__)
66
61
 
62
+ FRACTAL_DEFAULT_GROUP_NAME = "All"
63
+
67
64
 
68
65
  class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
69
66
  """
@@ -125,7 +122,7 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
125
122
  return user
126
123
  return None
127
124
 
128
- async def create(self, create_dict: Dict[str, Any]) -> UP:
125
+ async def create(self, create_dict: dict[str, Any]) -> UP:
129
126
  """Create a user."""
130
127
  user = self.user_model(**create_dict)
131
128
  self.session.add(user)
@@ -133,7 +130,7 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
133
130
  await self.session.refresh(user)
134
131
  return user
135
132
 
136
- async def update(self, user: UP, update_dict: Dict[str, Any]) -> UP:
133
+ async def update(self, user: UP, update_dict: dict[str, Any]) -> UP:
137
134
  for key, value in update_dict.items():
138
135
  setattr(user, key, value)
139
136
  self.session.add(user)
@@ -146,7 +143,7 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
146
143
  await self.session.commit()
147
144
 
148
145
  async def add_oauth_account(
149
- self, user: UP, create_dict: Dict[str, Any]
146
+ self, user: UP, create_dict: dict[str, Any]
150
147
  ) -> UP: # noqa
151
148
  if self.oauth_account_model is None:
152
149
  raise NotImplementedError()
@@ -160,7 +157,7 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
160
157
  return user
161
158
 
162
159
  async def update_oauth_account(
163
- self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any]
160
+ self, user: UP, oauth_account: OAP, update_dict: dict[str, Any]
164
161
  ) -> UP:
165
162
  if self.oauth_account_model is None:
166
163
  raise NotImplementedError()
@@ -176,13 +173,14 @@ class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]):
176
173
  async def get_user_db(
177
174
  session: AsyncSession = Depends(get_async_db),
178
175
  ) -> AsyncGenerator[SQLModelUserDatabaseAsync, None]:
179
- yield SQLModelUserDatabaseAsync(session, User, OAuthAccount)
176
+ yield SQLModelUserDatabaseAsync(session, UserOAuth, OAuthAccount)
180
177
 
181
178
 
182
- class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
183
- async def validate_password(self, password: str, user: User) -> None:
179
+ class UserManager(IntegerIDMixin, BaseUserManager[UserOAuth, int]):
180
+ async def validate_password(self, password: str, user: UserOAuth) -> None:
184
181
  # check password length
185
- min_length, max_length = 4, 100
182
+ min_length = 4
183
+ max_length = 100
186
184
  if len(password) < min_length:
187
185
  raise InvalidPasswordException(
188
186
  f"The password is too short (minimum length: {min_length})."
@@ -192,6 +190,38 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
192
190
  f"The password is too long (maximum length: {min_length})."
193
191
  )
194
192
 
193
+ async def on_after_register(
194
+ self, user: UserOAuth, request: Optional[Request] = None
195
+ ):
196
+ logger.info(
197
+ f"New-user registration completed ({user.id=}, {user.email=})."
198
+ )
199
+ async for db in get_async_db():
200
+ # Find default group
201
+ stm = select(UserGroup).where(
202
+ UserGroup.name == FRACTAL_DEFAULT_GROUP_NAME
203
+ )
204
+ res = await db.execute(stm)
205
+ default_group = res.scalar_one_or_none()
206
+ if default_group is None:
207
+ logger.error(
208
+ f"No group found with name {FRACTAL_DEFAULT_GROUP_NAME}"
209
+ )
210
+ else:
211
+ logger.warning(
212
+ f"START adding {user.email} user to group "
213
+ f"{default_group.id=}."
214
+ )
215
+ link = LinkUserGroup(
216
+ user_id=user.id, group_id=default_group.id
217
+ )
218
+ db.add(link)
219
+ await db.commit()
220
+ logger.warning(
221
+ f"END adding {user.email} user to group "
222
+ f"{default_group.id=}."
223
+ )
224
+
195
225
 
196
226
  async def get_user_manager(
197
227
  user_db: SQLModelUserDatabaseAsync = Depends(get_user_db),
@@ -199,53 +229,6 @@ async def get_user_manager(
199
229
  yield UserManager(user_db)
200
230
 
201
231
 
202
- bearer_transport = BearerTransport(tokenUrl="/auth/token/login")
203
- cookie_transport = CookieTransport(cookie_samesite="none")
204
-
205
-
206
- def get_jwt_strategy() -> JWTStrategy:
207
- settings = Inject(get_settings)
208
- return JWTStrategy(
209
- secret=settings.JWT_SECRET_KEY, # type: ignore
210
- lifetime_seconds=settings.JWT_EXPIRE_SECONDS,
211
- )
212
-
213
-
214
- def get_jwt_cookie_strategy() -> JWTStrategy:
215
- settings = Inject(get_settings)
216
- return JWTStrategy(
217
- secret=settings.JWT_SECRET_KEY, # type: ignore
218
- lifetime_seconds=settings.COOKIE_EXPIRE_SECONDS,
219
- )
220
-
221
-
222
- token_backend = AuthenticationBackend(
223
- name="bearer-jwt",
224
- transport=bearer_transport,
225
- get_strategy=get_jwt_strategy,
226
- )
227
- cookie_backend = AuthenticationBackend(
228
- name="cookie-jwt",
229
- transport=cookie_transport,
230
- get_strategy=get_jwt_cookie_strategy,
231
- )
232
-
233
-
234
- fastapi_users = FastAPIUsers[User, int](
235
- get_user_manager,
236
- [token_backend, cookie_backend],
237
- )
238
-
239
-
240
- # Create dependencies for users
241
- current_active_user = fastapi_users.current_user(active=True)
242
- current_active_verified_user = fastapi_users.current_user(
243
- active=True, verified=True
244
- )
245
- current_active_superuser = fastapi_users.current_user(
246
- active=True, superuser=True
247
- )
248
-
249
232
  get_async_session_context = contextlib.asynccontextmanager(get_async_db)
250
233
  get_user_db_context = contextlib.asynccontextmanager(get_user_db)
251
234
  get_user_manager_context = contextlib.asynccontextmanager(get_user_manager)
@@ -286,9 +269,9 @@ async def _create_first_user(
286
269
 
287
270
  if is_superuser is True:
288
271
  # If a superuser already exists, exit
289
- stm = select(User).where(
290
- User.is_superuser == True # noqa E712
291
- )
272
+ stm = select(UserOAuth).where( # noqa
273
+ UserOAuth.is_superuser == True # noqa
274
+ ) # noqa
292
275
  res = await session.execute(stm)
293
276
  existing_superuser = res.scalars().first()
294
277
  if existing_superuser is not None:
@@ -311,11 +294,25 @@ async def _create_first_user(
311
294
  user = await user_manager.create(UserCreate(**kwargs))
312
295
  logger.info(f"User {user.email} created")
313
296
 
314
- except IntegrityError:
315
- logger.warning(
316
- f"Creation of user {email} failed with IntegrityError "
317
- "(likely due to concurrent attempts from different workers)."
318
- )
319
-
320
297
  except UserAlreadyExists:
321
298
  logger.warning(f"User {email} already exists")
299
+
300
+
301
+ def _create_first_group():
302
+ logger.info(
303
+ f"START _create_first_group, with name {FRACTAL_DEFAULT_GROUP_NAME}"
304
+ )
305
+ with next(get_sync_db()) as db:
306
+ group_all = db.execute(select(UserGroup))
307
+ if group_all.scalars().one_or_none() is None:
308
+ first_group = UserGroup(name=FRACTAL_DEFAULT_GROUP_NAME)
309
+ db.add(first_group)
310
+ db.commit()
311
+ logger.info(f"Created group {FRACTAL_DEFAULT_GROUP_NAME}")
312
+ else:
313
+ logger.info(
314
+ f"Group {FRACTAL_DEFAULT_GROUP_NAME} already exists, skip."
315
+ )
316
+ logger.info(
317
+ f"END _create_first_group, with name {FRACTAL_DEFAULT_GROUP_NAME}"
318
+ )