zrb 1.0.0b8__py3-none-any.whl → 1.0.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. zrb/__main__.py +3 -0
  2. zrb/builtin/project/add/fastapp/fastapp_task.py +1 -0
  3. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/.coveragerc +11 -0
  4. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/.gitignore +4 -0
  5. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/column/add_column_task.py +4 -4
  6. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/config.py +5 -0
  7. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_task.py +108 -1
  8. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +67 -4
  9. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/module/my_module/service/my_entity/my_entity_service.py +5 -5
  10. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/schema/my_entity.py +1 -0
  11. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/test/my_module/my_entity/test_create_my_entity.py +53 -0
  12. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/test/my_module/my_entity/test_delete_my_entity.py +62 -0
  13. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/test/my_module/my_entity/test_read_my_entity.py +65 -0
  14. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/test/my_module/my_entity/test_update_my_entity.py +61 -0
  15. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/gateway_subroute.py +57 -13
  16. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/input.py +8 -0
  17. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_util.py +2 -2
  18. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/gateway/subroute/my_module.py +6 -1
  19. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/module_task_definition.py +10 -6
  20. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/task.py +65 -14
  21. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/task_util.py +106 -0
  22. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/util.py +6 -86
  23. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_db_repository.py +27 -11
  24. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_service.py +140 -51
  25. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/error.py +15 -0
  26. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/util/parser.py +1 -1
  27. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/config.py +22 -4
  28. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/auth_client.py +21 -0
  29. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration/versions/3093c7336477_add_auth_tables.py +106 -61
  30. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration/versions/8ed025bcc845_create_permissions.py +69 -0
  31. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration_metadata.py +3 -4
  32. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/route.py +15 -14
  33. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/permission_service.py +4 -4
  34. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_db_repository.py +24 -5
  35. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/role_service.py +14 -12
  36. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_db_repository.py +134 -97
  37. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_repository.py +28 -11
  38. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service.py +215 -13
  39. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service_factory.py +30 -2
  40. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/subroute/auth.py +216 -41
  41. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/util/auth.py +57 -0
  42. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/requirements.txt +7 -1
  43. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/permission.py +2 -0
  44. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/role.py +13 -12
  45. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/user.py +64 -12
  46. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/_util/access_token.py +19 -0
  47. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/auth/permission/test_create_permission.py +59 -0
  48. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/auth/permission/test_delete_permission.py +68 -0
  49. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/auth/permission/test_read_permission.py +71 -0
  50. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/auth/permission/test_update_permission.py +66 -0
  51. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/auth/test_user_session.py +195 -0
  52. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/test_health_and_readiness.py +28 -0
  53. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/test_homepage.py +17 -0
  54. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test/test_not_found_error.py +16 -0
  55. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/test.sh +7 -0
  56. zrb/task/base_task.py +10 -10
  57. zrb/task/cmd_task.py +2 -5
  58. zrb/util/cmd/command.py +39 -48
  59. zrb/util/codemod/modification_mode.py +3 -0
  60. zrb/util/codemod/modify_class.py +58 -0
  61. zrb/util/codemod/modify_class_parent.py +68 -0
  62. zrb/util/codemod/modify_class_property.py +128 -0
  63. zrb/util/codemod/modify_dict.py +75 -0
  64. zrb/util/codemod/modify_function.py +65 -0
  65. zrb/util/codemod/modify_function_call.py +68 -0
  66. zrb/util/codemod/modify_method.py +88 -0
  67. zrb/util/codemod/{prepend_code_to_module.py → modify_module.py} +2 -3
  68. zrb/util/file.py +3 -2
  69. {zrb-1.0.0b8.dist-info → zrb-1.0.0b10.dist-info}/METADATA +2 -1
  70. {zrb-1.0.0b8.dist-info → zrb-1.0.0b10.dist-info}/RECORD +72 -55
  71. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/migrate.py +0 -3
  72. zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/session.py +0 -48
  73. zrb/util/codemod/append_code_to_class.py +0 -35
  74. zrb/util/codemod/append_code_to_function.py +0 -38
  75. zrb/util/codemod/append_code_to_method.py +0 -55
  76. zrb/util/codemod/append_key_to_dict.py +0 -51
  77. zrb/util/codemod/append_param_to_function_call.py +0 -39
  78. zrb/util/codemod/prepend_parent_to_class.py +0 -38
  79. zrb/util/codemod/prepend_property_to_class.py +0 -55
  80. {zrb-1.0.0b8.dist-info → zrb-1.0.0b10.dist-info}/WHEEL +0 -0
  81. {zrb-1.0.0b8.dist-info → zrb-1.0.0b10.dist-info}/entry_points.txt +0 -0
@@ -1,7 +1,6 @@
1
1
  from my_app_name.schema.permission import Permission
2
2
  from my_app_name.schema.role import Role, RolePermission
3
- from my_app_name.schema.session import Session
4
- from my_app_name.schema.user import User, UserRole
3
+ from my_app_name.schema.user import User, UserRole, UserSession
5
4
  from sqlalchemy import MetaData
6
5
 
7
6
  metadata = MetaData()
@@ -19,5 +18,5 @@ User.__table__.tometadata(metadata)
19
18
  UserRole.metadata = metadata
20
19
  UserRole.__table__.tometadata(metadata)
21
20
 
22
- Session.metadata = metadata
23
- Session.__table__.tometadata(metadata)
21
+ UserSession.metadata = metadata
22
+ UserSession.__table__.tometadata(metadata)
@@ -5,10 +5,23 @@ from my_app_name.config import APP_MAIN_MODULE, APP_MODE, APP_MODULES
5
5
  from my_app_name.module.auth.service.permission.permission_service_factory import (
6
6
  permission_service,
7
7
  )
8
+ from my_app_name.module.auth.service.role.role_service_factory import role_service
8
9
  from my_app_name.module.auth.service.user.user_service_factory import user_service
9
10
 
10
11
 
11
- def serve_health_check(app: FastAPI):
12
+ def serve_route(app: FastAPI):
13
+ if APP_MODE != "microservices" or "auth" not in APP_MODULES:
14
+ return
15
+ if APP_MAIN_MODULE == "auth":
16
+ _serve_health_check(app)
17
+ _serve_readiness_check(app)
18
+
19
+ permission_service.serve_route(app)
20
+ role_service.serve_route(app)
21
+ user_service.serve_route(app)
22
+
23
+
24
+ def _serve_health_check(app: FastAPI):
12
25
  @app.api_route("/health", methods=["GET", "HEAD"], response_model=BasicResponse)
13
26
  async def health():
14
27
  """
@@ -17,7 +30,7 @@ def serve_health_check(app: FastAPI):
17
30
  return BasicResponse(message="ok")
18
31
 
19
32
 
20
- def serve_readiness_check(app: FastAPI):
33
+ def _serve_readiness_check(app: FastAPI):
21
34
  @app.api_route("/readiness", methods=["GET", "HEAD"], response_model=BasicResponse)
22
35
  async def readiness():
23
36
  """
@@ -26,16 +39,4 @@ def serve_readiness_check(app: FastAPI):
26
39
  return BasicResponse(message="ok")
27
40
 
28
41
 
29
- def serve_route(app: FastAPI):
30
- if APP_MODE != "microservices" or "auth" not in APP_MODULES:
31
- return
32
- if APP_MAIN_MODULE == "auth":
33
- serve_health_check(app)
34
- serve_readiness_check(app)
35
-
36
- # Serve user endpoints for APIClient
37
- user_service.serve_route(app)
38
- permission_service.serve_route(app)
39
-
40
-
41
42
  serve_route(app)
@@ -71,11 +71,11 @@ class PermissionService(BaseService):
71
71
  @BaseService.route(
72
72
  "/api/v1/permissions/bulk",
73
73
  methods=["put"],
74
- response_model=PermissionResponse,
74
+ response_model=list[PermissionResponse],
75
75
  )
76
76
  async def update_permission_bulk(
77
77
  self, permission_ids: list[str], data: PermissionUpdateWithAudit
78
- ) -> PermissionResponse:
78
+ ) -> list[PermissionResponse]:
79
79
  await self.permission_repository.update_bulk(permission_ids, data)
80
80
  return await self.permission_repository.get_by_ids(permission_ids)
81
81
 
@@ -93,11 +93,11 @@ class PermissionService(BaseService):
93
93
  @BaseService.route(
94
94
  "/api/v1/permissions/bulk",
95
95
  methods=["delete"],
96
- response_model=PermissionResponse,
96
+ response_model=list[PermissionResponse],
97
97
  )
98
98
  async def delete_permission_bulk(
99
99
  self, permission_ids: list[str], deleted_by: str
100
- ) -> PermissionResponse:
100
+ ) -> list[PermissionResponse]:
101
101
  permissions = await self.permission_repository.get_by_ids(permission_ids)
102
102
  await self.permission_repository.delete_bulk(permission_ids)
103
103
  return permissions
@@ -54,28 +54,47 @@ class RoleDBRepository(
54
54
  and permission.id not in role_permission_map[role.id]
55
55
  ):
56
56
  role_permission_map[role.id].append(permission.id)
57
- role_map[role.id]["permissions"].append(permission.model_dump())
57
+ role_map[role.id]["permissions"].append(permission)
58
58
  return [
59
- RoleResponse(**data["role"].model_dump(), permissions=data["permissions"])
59
+ RoleResponse(
60
+ **data["role"].model_dump(),
61
+ permission_names=[
62
+ permission.name for permission in data["permissions"]
63
+ ],
64
+ )
60
65
  for data in role_map.values()
61
66
  ]
62
67
 
63
68
  async def add_permissions(self, data: dict[str, list[str]], created_by: str):
64
69
  now = datetime.datetime.now(datetime.timezone.utc)
70
+ # get mapping from perrmission names to permission ids
71
+ all_permission_names = {
72
+ name for permission_names in data.values() for name in permission_names
73
+ }
74
+ async with self._session_scope() as session:
75
+ result = await self._execute_statement(
76
+ session,
77
+ select(Permission.id, Permission.name).where(
78
+ Permission.name.in_(all_permission_names)
79
+ ),
80
+ )
81
+ permission_mapping = {row.name: row.id for row in result}
82
+ # Assemble data dict
65
83
  data_dict_list: list[dict[str, Any]] = []
66
- for role_id, permission_ids in data.items():
67
- for permission_id in permission_ids:
84
+ for role_id, permission_names in data.items():
85
+ for permission_name in permission_names:
68
86
  data_dict_list.append(
69
87
  self._model_to_data_dict(
70
88
  RolePermission(
71
89
  id=ulid.new().str,
72
90
  role_id=role_id,
73
- permission_id=permission_id,
91
+ permission_id=permission_mapping.get(permission_name),
74
92
  created_at=now,
75
93
  created_by=created_by,
76
94
  )
77
95
  )
78
96
  )
97
+ # Insert rolePermissions
79
98
  async with self._session_scope() as session:
80
99
  await self._execute_statement(
81
100
  session, insert(RolePermission).values(data_dict_list)
@@ -50,13 +50,13 @@ class RoleService(BaseService):
50
50
  async def create_role_bulk(
51
51
  self, data: list[RoleCreateWithPermissionsAndAudit]
52
52
  ) -> list[RoleResponse]:
53
- permission_ids = [row.get_permission_ids() for row in data]
53
+ permission_names = [row.get_permission_names() for row in data]
54
54
  data = [row.get_role_create_with_audit() for row in data]
55
55
  roles = await self.role_repository.create_bulk(data)
56
56
  if len(roles) > 0:
57
57
  created_by = roles[0].created_by
58
58
  await self.role_repository.add_permissions(
59
- data={role.id: permission_ids[i] for i, role in enumerate(roles)},
59
+ data={role.id: permission_names[i] for i, role in enumerate(roles)},
60
60
  created_by=created_by,
61
61
  )
62
62
  return await self.role_repository.get_by_ids([role.id for role in roles])
@@ -69,30 +69,32 @@ class RoleService(BaseService):
69
69
  async def create_role(
70
70
  self, data: RoleCreateWithPermissionsAndAudit
71
71
  ) -> RoleResponse:
72
- permission_ids = data.get_permission_ids()
72
+ permission_names = data.get_permission_names()
73
73
  data = data.get_role_create_with_audit()
74
74
  role = await self.role_repository.create(data)
75
75
  await self.role_repository.add_permissions(
76
- data={role.id: permission_ids}, created_by=role.created_by
76
+ data={role.id: permission_names}, created_by=role.created_by
77
77
  )
78
78
  return await self.role_repository.get_by_id(role.id)
79
79
 
80
80
  @BaseService.route(
81
81
  "/api/v1/roles/bulk",
82
82
  methods=["put"],
83
- response_model=RoleResponse,
83
+ response_model=list[RoleResponse],
84
84
  )
85
85
  async def update_role_bulk(
86
86
  self, role_ids: list[str], data: RoleUpdateWithPermissionsAndAudit
87
- ) -> RoleResponse:
88
- permission_ids = [row.get_permission_ids() for row in data]
87
+ ) -> list[RoleResponse]:
88
+ permission_names = [row.get_permission_names() for row in data]
89
89
  data = [row.get_role_update_with_audit() for row in data]
90
90
  await self.role_repository.update_bulk(role_ids, data)
91
91
  if len(role_ids) > 0:
92
92
  updated_by = data[0].updated_by
93
93
  await self.role_repository.remove_all_permissions(role_ids)
94
94
  await self.role_repository.add_permissions(
95
- data={role_id: permission_ids[i] for i, role_id in enumerate(role_ids)},
95
+ data={
96
+ role_id: permission_names[i] for i, role_id in enumerate(role_ids)
97
+ },
96
98
  created_by=updated_by,
97
99
  )
98
100
  return await self.role_repository.get_by_ids(role_ids)
@@ -105,23 +107,23 @@ class RoleService(BaseService):
105
107
  async def update_role(
106
108
  self, role_id: str, data: RoleUpdateWithPermissionsAndAudit
107
109
  ) -> RoleResponse:
108
- permission_ids = data.get_permission_ids()
110
+ permission_names = data.get_permission_names()
109
111
  role_data = data.get_role_update_with_audit()
110
112
  await self.role_repository.update(role_id, role_data)
111
113
  await self.role_repository.remove_all_permissions([role_id])
112
114
  await self.role_repository.add_permissions(
113
- data={role_id: permission_ids}, created_by=role_data.updated_by
115
+ data={role_id: permission_names}, created_by=role_data.updated_by
114
116
  )
115
117
  return await self.role_repository.get_by_id(role_id)
116
118
 
117
119
  @BaseService.route(
118
120
  "/api/v1/roles/bulk",
119
121
  methods=["delete"],
120
- response_model=RoleResponse,
122
+ response_model=list[RoleResponse],
121
123
  )
122
124
  async def delete_role_bulk(
123
125
  self, role_ids: list[str], deleted_by: str
124
- ) -> RoleResponse:
126
+ ) -> list[RoleResponse]:
125
127
  roles = await self.role_repository.get_by_ids(role_ids)
126
128
  await self.role_repository.delete_bulk(role_ids)
127
129
  await self.role_repository.remove_all_permissions(role_ids)
@@ -1,35 +1,27 @@
1
1
  import datetime
2
- from typing import Any, Callable
2
+ from typing import Any
3
3
 
4
4
  import ulid
5
5
  from my_app_name.common.base_db_repository import BaseDBRepository
6
- from my_app_name.common.error import NotFoundError
7
- from my_app_name.config import (
8
- APP_AUTH_GUEST_USER,
9
- APP_AUTH_GUEST_USER_PERMISSIONS,
10
- APP_AUTH_SUPER_USER,
11
- APP_AUTH_SUPER_USER_PASSWORD,
12
- APP_MAX_PARALLEL_SESSION,
13
- APP_SESSION_EXPIRE_MINUTES,
14
- )
6
+ from my_app_name.common.error import NotFoundError, UnauthorizedError
15
7
  from my_app_name.module.auth.service.user.repository.user_repository import (
16
8
  UserRepository,
17
9
  )
18
10
  from my_app_name.schema.permission import Permission
19
11
  from my_app_name.schema.role import Role, RolePermission
20
- from my_app_name.schema.session import Session, SessionResponse
21
12
  from my_app_name.schema.user import (
22
13
  User,
23
14
  UserCreateWithAudit,
24
15
  UserResponse,
25
16
  UserRole,
17
+ UserSession,
18
+ UserSessionResponse,
19
+ UserTokenData,
26
20
  UserUpdateWithAudit,
27
21
  )
28
22
  from passlib.context import CryptContext
29
- from sqlalchemy.engine import Engine
30
- from sqlalchemy.ext.asyncio import AsyncEngine
31
- from sqlalchemy.sql import ClauseElement, ColumnElement, Select
32
- from sqlmodel import SQLModel, delete, insert, select
23
+ from sqlalchemy.sql import Select
24
+ from sqlmodel import delete, insert, select, update
33
25
 
34
26
  # Password hashing context
35
27
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
@@ -39,6 +31,11 @@ def hash_password(password: str) -> str:
39
31
  return pwd_context.hash(password)
40
32
 
41
33
 
34
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
35
+ """Verifies if a password matches the stored hash."""
36
+ return pwd_context.verify(plain_password, hashed_password)
37
+
38
+
42
39
  class UserDBRepository(
43
40
  BaseDBRepository[User, UserResponse, UserCreateWithAudit, UserUpdateWithAudit],
44
41
  UserRepository,
@@ -50,86 +47,65 @@ class UserDBRepository(
50
47
  entity_name = "user"
51
48
  column_preprocessors = {"password": hash_password}
52
49
 
53
- def __init__(
54
- self,
55
- engine: Engine | AsyncEngine,
56
- super_user_username: str = APP_AUTH_SUPER_USER,
57
- super_user_password: str = APP_AUTH_SUPER_USER_PASSWORD,
58
- guest_user_username: str = APP_AUTH_GUEST_USER,
59
- guest_user_password: str = APP_AUTH_SUPER_USER_PASSWORD,
60
- guest_user_permission_names: list[str] = APP_AUTH_GUEST_USER_PERMISSIONS,
61
- max_parallel_session: int = APP_MAX_PARALLEL_SESSION,
62
- session_expire_minutes: int = APP_SESSION_EXPIRE_MINUTES,
63
- filter_param_parser: (
64
- Callable[[SQLModel, str], list[ClauseElement]] | None
65
- ) = None,
66
- sort_param_parser: Callable[[SQLModel, str], list[ColumnElement]] | None = None,
67
- ):
68
- super().__init__(
69
- engine=engine,
70
- filter_param_parser=filter_param_parser,
71
- sort_param_parser=sort_param_parser,
72
- )
73
- self._super_user_username = super_user_username
74
- self._super_user_passwored = super_user_password
75
- self._guest_user_username = guest_user_username
76
- self._guest_user_password = guest_user_password
77
- self._guest_user_permission_names = guest_user_permission_names
78
- self._max_parallel_session = max_parallel_session
79
- self._session_expire_minutes = session_expire_minutes
80
- self._super_user: User | None = None
81
- self._guest_user: User | None = None
82
-
83
50
  def _select(self) -> Select:
84
51
  return (
85
- select(User, Role, Permission, Session)
52
+ select(User, Role, Permission)
86
53
  .join(UserRole, UserRole.user_id == User.id, isouter=True)
87
54
  .join(Role, Role.id == UserRole.role_id, isouter=True)
88
55
  .join(RolePermission, RolePermission.role_id == Role.id, isouter=True)
89
56
  .join(
90
57
  Permission, Permission.id == RolePermission.permission_id, isouter=True
91
58
  )
92
- .join(Session, Session.user_id == User.id)
93
59
  )
94
60
 
95
61
  def _rows_to_responses(self, rows: list[tuple[Any, ...]]) -> list[UserResponse]:
96
62
  user_map: dict[str, dict[str, Any]] = {}
97
63
  user_role_map: dict[str, list[str]] = {}
98
64
  user_permission_map: dict[str, list[str]] = {}
99
- for user, role, permission, _ in rows:
65
+ for user, role, permission in rows:
100
66
  if user.id not in user_map:
101
67
  user_map[user.id] = {"user": user, "roles": [], "permissions": []}
102
68
  user_role_map[user.id] = []
103
69
  user_permission_map[user.id] = []
104
70
  if role is not None and role.id not in user_role_map[user.id]:
105
71
  user_role_map[user.id].append(role.id)
106
- user_map[user.id]["roles"].append(role.model_dump())
72
+ user_map[user.id]["roles"].append(role)
107
73
  if (
108
74
  permission is not None
109
75
  and permission.id not in user_permission_map[user.id]
110
76
  ):
111
77
  user_permission_map[user.id].append(permission.id)
112
- user_map[user.id]["permissions"].append(permission.model_dump())
78
+ user_map[user.id]["permissions"].append(permission)
113
79
  return [
114
80
  UserResponse(
115
81
  **data["user"].model_dump(),
116
- roles=list(data["roles"]),
117
- permissions=list(data["permissions"]),
82
+ role_names=[role.name for role in data["roles"]],
83
+ permission_names=[
84
+ permission.name for permission in data["permissions"]
85
+ ],
118
86
  )
119
87
  for data in user_map.values()
120
88
  ]
121
89
 
122
90
  async def add_roles(self, data: dict[str, list[str]], created_by: str):
123
91
  now = datetime.datetime.now(datetime.timezone.utc)
92
+ # get mapping from role names to role ids
93
+ all_role_names = {name for role_names in data.values() for name in role_names}
94
+ async with self._session_scope() as session:
95
+ result = await self._execute_statement(
96
+ session, select(Role.id, Role.name).where(Role.name.in_(all_role_names))
97
+ )
98
+ role_mapping = {row.name: row.id for row in result}
99
+ # Assemble data dict
124
100
  data_dict_list: list[dict[str, Any]] = []
125
- for user_id, role_ids in data.items():
126
- for role_id in role_ids:
101
+ for user_id, role_names in data.items():
102
+ for role_name in role_names:
127
103
  data_dict_list.append(
128
104
  self._model_to_data_dict(
129
105
  UserRole(
130
106
  id=ulid.new().str,
131
107
  user_id=user_id,
132
- role_id=role_id,
108
+ role_id=role_mapping.get(role_name),
133
109
  created_at=now,
134
110
  created_by=created_by,
135
111
  )
@@ -148,65 +124,126 @@ class UserDBRepository(
148
124
  )
149
125
 
150
126
  async def get_by_credentials(self, username: str, password: str) -> UserResponse:
151
- rows = await self._select_to_response(
152
- lambda q: q.where(
153
- User.username == username, User.password == hash_password(password)
127
+ async with self._session_scope() as session:
128
+ result = await self._execute_statement(
129
+ session, select(User).where(User.username == username, User.active)
154
130
  )
155
- )
156
- return self._ensure_one(rows)
131
+ user = result.scalar_one_or_none()
132
+ if user is None or not verify_password(password, user.password):
133
+ raise UnauthorizedError("Invalid username or password")
134
+ return await self.get_by_id(user.id)
157
135
 
158
- async def get_by_token(self, token: str) -> UserResponse:
159
- rows = await self._select_tor_response(
160
- lambda q: q.where(Session.token == token)
161
- )
162
- return self._ensure_one(rows)
163
-
164
- async def add_token(self, user_id: str, token: str):
136
+ async def delete_expired_user_sessions(self, user_id: str):
137
+ now = datetime.datetime.now(datetime.timezone.utc)
165
138
  async with self._session_scope() as session:
166
139
  await self._execute_statement(
167
140
  session,
168
- insert(Session).values(
169
- {
170
- "id": ulid.new().str,
171
- "user_id": user_id,
172
- "token": token,
173
- "created_by": "system",
174
- "created_at": datetime.datetime.now(datetime.timezone.utc),
175
- }
141
+ delete(UserSession).where(
142
+ UserSession.user_id == user_id,
143
+ UserSession.refresh_token_expired_at < now,
176
144
  ),
177
145
  )
178
146
 
179
- async def remove_token(self, user_id: str, token: str):
147
+ async def get_active_user_sessions(self, user_id: str) -> list[UserSessionResponse]:
148
+ now = datetime.datetime.now(datetime.timezone.utc)
180
149
  async with self._session_scope() as session:
181
- await self._execute_statement(
150
+ result = await self._execute_statement(
182
151
  session,
183
- delete(Session).where(
184
- Session.token == token, Session.user_id == user_id
152
+ select(UserSession).where(
153
+ UserSession.user_id == user_id,
154
+ UserSession.refresh_token_expired_at > now,
185
155
  ),
186
156
  )
157
+ return [self._user_session_to_response(row[0]) for row in result.all()]
187
158
 
188
- async def get_sessions(self, user_id: str) -> list[SessionResponse]:
159
+ async def get_user_session_by_access_token(
160
+ self, access_token: str
161
+ ) -> UserSessionResponse:
162
+ now = datetime.datetime.now(datetime.timezone.utc)
163
+ async with self._session_scope() as session:
164
+ result = await self._execute_statement(
165
+ session,
166
+ select(UserSession).where(
167
+ UserSession.access_token == access_token,
168
+ UserSession.access_token_expired_at > now,
169
+ ),
170
+ )
171
+ user_session = result.scalar_one_or_none()
172
+ if user_session is None:
173
+ raise NotFoundError("User session not found")
174
+ return self._user_session_to_response(user_session)
175
+
176
+ async def get_user_session_by_refresh_token(
177
+ self, refresh_token: str
178
+ ) -> UserSessionResponse:
179
+ now = datetime.datetime.now(datetime.timezone.utc)
189
180
  async with self._session_scope() as session:
190
- statement = select(Session).where(Session.user_id == user_id)
191
- result = await self._execute_statement(session, statement)
192
- return [
193
- SessionResponse(**session.model_dump())
194
- for session in result.scalars().all()
195
- ]
196
-
197
- async def remove_session(self, user_id: str, session_id: str) -> SessionResponse:
181
+ result = await self._execute_statement(
182
+ session,
183
+ select(UserSession).where(
184
+ UserSession.refresh_token == refresh_token,
185
+ UserSession.refresh_token_expired_at > now,
186
+ ),
187
+ )
188
+ user_session = result.scalar_one_or_none()
189
+ if user_session is None:
190
+ raise NotFoundError("User session not found")
191
+ return self._user_session_to_response(user_session)
192
+
193
+ async def create_user_session(
194
+ self, user_id: str, token_data: UserTokenData
195
+ ) -> UserSessionResponse:
196
+ data_dict = self._model_to_data_dict(
197
+ token_data, user_id=user_id, id=ulid.new().str
198
+ )
198
199
  async with self._session_scope() as session:
199
- statement = select(Session).where(
200
- Session.user_id == user_id, Session.id == session_id
200
+ await self._execute_statement(
201
+ session, insert(UserSession).values(**data_dict)
202
+ )
203
+ result = await self._execute_statement(
204
+ session, select(UserSession).where(UserSession.id == data_dict["id"])
201
205
  )
202
- result = await self._execute_statement(session, statement)
203
- session = result.scalar_one_or_none()
204
- if not session:
205
- raise NotFoundError(f"{self.entity_name} not found")
206
+ user_session = result.scalar_one_or_none()
207
+ if user_session is None:
208
+ raise NotFoundError("User session not found after created")
209
+ return self._user_session_to_response(user_session)
210
+
211
+ async def update_user_session(
212
+ self, user_id: str, session_id: str, token_data: UserTokenData
213
+ ) -> UserSessionResponse:
214
+ data_dict = self._model_to_data_dict(token_data, user_id=user_id)
215
+ async with self._session_scope() as session:
206
216
  await self._execute_statement(
207
217
  session,
208
- delete(Session).where(
209
- Session.id == session_id, Session.user_id == user_id
218
+ (
219
+ update(UserSession)
220
+ .where(UserSession.id == session_id)
221
+ .values(**data_dict)
210
222
  ),
211
223
  )
212
- return SessionResponse(**session.model_dump())
224
+ result = await self._execute_statement(
225
+ session, select(UserSession).where(UserSession.id == session_id)
226
+ )
227
+ user_session = result.scalar_one_or_none()
228
+ if user_session is None:
229
+ raise NotFoundError("User session not found after created")
230
+ return self._user_session_to_response(user_session)
231
+
232
+ async def delete_user_sessions(self, session_ids: list[str]):
233
+ async with self._session_scope() as session:
234
+ await self._execute_statement(
235
+ session, delete(UserSession).where(UserSession.id.in_(session_ids))
236
+ )
237
+
238
+ def _user_session_to_response(
239
+ self, user_session: UserSession
240
+ ) -> UserSessionResponse:
241
+ return UserSessionResponse(
242
+ id=user_session.id,
243
+ user_id=user_session.user_id,
244
+ access_token=user_session.access_token,
245
+ access_token_expired_at=user_session.access_token_expired_at,
246
+ refresh_token=user_session.refresh_token,
247
+ refresh_token_expired_at=user_session.refresh_token_expired_at,
248
+ token_type="bearer",
249
+ )
@@ -1,10 +1,11 @@
1
1
  from abc import ABC, abstractmethod
2
2
 
3
- from my_app_name.schema.session import SessionResponse
4
3
  from my_app_name.schema.user import (
5
4
  User,
6
5
  UserCreateWithAudit,
7
6
  UserResponse,
7
+ UserSessionResponse,
8
+ UserTokenData,
8
9
  UserUpdateWithAudit,
9
10
  )
10
11
 
@@ -72,21 +73,37 @@ class UserRepository(ABC):
72
73
  """Get user by credential"""
73
74
 
74
75
  @abstractmethod
75
- async def get_by_token(self, token: str) -> UserResponse:
76
- """Get user by token"""
76
+ async def get_active_user_sessions(self, user_id: str) -> list[UserSessionResponse]:
77
+ """Get user sessions"""
77
78
 
78
79
  @abstractmethod
79
- async def add_token(self, user_id: str, token: str):
80
- """Add token to user"""
80
+ async def get_user_session_by_access_token(
81
+ self, access_token: str
82
+ ) -> UserSessionResponse:
83
+ """Get user session by access token"""
81
84
 
82
85
  @abstractmethod
83
- async def remove_token(self, user_id: str, token: str):
84
- """Remove token from user"""
86
+ async def get_user_session_by_refresh_token(
87
+ self, refresh_token: str
88
+ ) -> UserSessionResponse:
89
+ """Get user session by refresh token"""
85
90
 
86
91
  @abstractmethod
87
- async def get_sessions(self, user_id: str) -> list[SessionResponse]:
88
- """Get sessions"""
92
+ async def create_user_session(
93
+ self, user_id: str, token_data: UserTokenData
94
+ ) -> UserSessionResponse:
95
+ """Create new user session"""
89
96
 
90
97
  @abstractmethod
91
- async def remove_session(self, user_id: str, session_id: str) -> SessionResponse:
92
- """Remove a session"""
98
+ async def update_user_session(
99
+ self, user_id: str, session_id: str, token_data: UserTokenData
100
+ ) -> UserSessionResponse:
101
+ """Update user session"""
102
+
103
+ @abstractmethod
104
+ async def delete_expired_user_sessions(self, user_id: str):
105
+ """Delete expired user sessions"""
106
+
107
+ @abstractmethod
108
+ async def delete_user_sessions(self, session_ids: list[str]):
109
+ """Delete user session"""