trovesuite 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trovesuite/__init__.py +16 -0
- trovesuite/auth/__init__.py +16 -0
- trovesuite/auth/auth_base.py +4 -0
- trovesuite/auth/auth_controller.py +10 -0
- trovesuite/auth/auth_read_dto.py +18 -0
- trovesuite/auth/auth_service.py +334 -0
- trovesuite/auth/auth_write_dto.py +10 -0
- trovesuite/configs/__init__.py +16 -0
- trovesuite/configs/database.py +221 -0
- trovesuite/configs/logging.py +261 -0
- trovesuite/configs/settings.py +153 -0
- trovesuite/entities/__init__.py +11 -0
- trovesuite/entities/health.py +84 -0
- trovesuite/entities/sh_response.py +61 -0
- trovesuite/utils/__init__.py +11 -0
- trovesuite/utils/helper.py +36 -0
- trovesuite-1.0.0.dist-info/METADATA +572 -0
- trovesuite-1.0.0.dist-info/RECORD +21 -0
- trovesuite-1.0.0.dist-info/WHEEL +5 -0
- trovesuite-1.0.0.dist-info/licenses/LICENSE +21 -0
- trovesuite-1.0.0.dist-info/top_level.txt +1 -0
trovesuite/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TroveSuite Auth Package
|
|
3
|
+
|
|
4
|
+
A comprehensive authentication and authorization service for ERP systems.
|
|
5
|
+
Provides JWT token validation, user authorization, and permission checking.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .auth import AuthService
|
|
9
|
+
|
|
10
|
+
__version__ = "1.0.8"
|
|
11
|
+
__author__ = "Bright Debrah Owusu"
|
|
12
|
+
__email__ = "owusu.debrah@deladetech.com"
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"AuthService"
|
|
16
|
+
]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TroveSuite Auth Module
|
|
3
|
+
|
|
4
|
+
Authentication and authorization services for ERP systems.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .auth_service import AuthService
|
|
8
|
+
from .auth_base import AuthBase
|
|
9
|
+
from .auth_read_dto import AuthServiceReadDto, AuthControllerReadDto
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"AuthService",
|
|
13
|
+
"AuthBase",
|
|
14
|
+
"AuthServiceReadDto",
|
|
15
|
+
"AuthControllerReadDto"
|
|
16
|
+
]
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from fastapi import APIRouter
|
|
2
|
+
from src.trovesuite.auth.auth_write_dto import AuthControllerWriteDto
|
|
3
|
+
from src.trovesuite.auth.auth_read_dto import AuthControllerReadDto
|
|
4
|
+
from src.trovesuite.auth.auth_service import AuthService
|
|
5
|
+
|
|
6
|
+
auth_router = APIRouter()
|
|
7
|
+
|
|
8
|
+
@auth_router.post("/auth", response_model=AuthControllerReadDto)
|
|
9
|
+
async def authorize(data: AuthControllerWriteDto):
|
|
10
|
+
return AuthService.authorize(data=data)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from typing import Optional, List
|
|
2
|
+
from pydantic import BaseModel
|
|
3
|
+
|
|
4
|
+
class AuthControllerReadDto(BaseModel):
|
|
5
|
+
org_id: Optional[str] = None
|
|
6
|
+
bus_id: Optional[str] = None
|
|
7
|
+
app_id: Optional[str] = None
|
|
8
|
+
shared_resource_id: Optional[str] = None
|
|
9
|
+
user_id: Optional[str] = None
|
|
10
|
+
group_id: Optional[str] = None
|
|
11
|
+
role_id: Optional[str] = None
|
|
12
|
+
tenant_id: Optional[str] = None
|
|
13
|
+
permissions: Optional[List[str]] = None
|
|
14
|
+
shared_resource_id: Optional[str] = None
|
|
15
|
+
resource_id: Optional[str] = None
|
|
16
|
+
|
|
17
|
+
class AuthServiceReadDto(AuthControllerReadDto):
|
|
18
|
+
pass
|
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
"""Auth Service - Business logic for user authentication and authorization"""
|
|
2
|
+
from datetime import datetime, timezone
|
|
3
|
+
from typing import Annotated
|
|
4
|
+
from fastapi import HTTPException
|
|
5
|
+
from fastapi.params import Depends
|
|
6
|
+
from fastapi.security import OAuth2PasswordBearer
|
|
7
|
+
from ..utils.helper import Helper
|
|
8
|
+
from ..configs.settings import db_settings
|
|
9
|
+
from ..configs.database import DatabaseManager
|
|
10
|
+
from ..configs.logging import get_logger
|
|
11
|
+
from ..entities.sh_response import Respons
|
|
12
|
+
from .auth_read_dto import AuthServiceReadDto
|
|
13
|
+
from .auth_write_dto import AuthServiceWriteDto
|
|
14
|
+
import jwt
|
|
15
|
+
|
|
16
|
+
logger = get_logger("auth_service")
|
|
17
|
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
18
|
+
|
|
19
|
+
class AuthService:
|
|
20
|
+
"""Service class for authentication and authorization operations"""
|
|
21
|
+
|
|
22
|
+
def __init__(self) -> None:
|
|
23
|
+
"""Initialize the service"""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def decode_token(token: Annotated[str, Depends(oauth2_scheme)]) -> dict:
|
|
28
|
+
credentials_exception = HTTPException(
|
|
29
|
+
status_code=401,
|
|
30
|
+
detail="Could not validate credentials",
|
|
31
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
32
|
+
)
|
|
33
|
+
try:
|
|
34
|
+
payload = jwt.decode(token, db_settings.SECRET_KEY, algorithms=[db_settings.ALGORITHM])
|
|
35
|
+
user_id = payload.get("user_id")
|
|
36
|
+
tenant_id = payload.get("tenant_id")
|
|
37
|
+
|
|
38
|
+
if user_id is None or tenant_id is None:
|
|
39
|
+
raise credentials_exception
|
|
40
|
+
|
|
41
|
+
return {"user_id": user_id, "tenant_id": tenant_id}
|
|
42
|
+
|
|
43
|
+
except jwt.InvalidTokenError as exc:
|
|
44
|
+
raise credentials_exception from exc
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def authorize(data: AuthServiceWriteDto) -> Respons[AuthServiceReadDto]:
|
|
48
|
+
|
|
49
|
+
user_id: str = data.user_id
|
|
50
|
+
tenant_id: str = data.tenant_id
|
|
51
|
+
|
|
52
|
+
"""Check if a user is authorized based on login settings and roles"""
|
|
53
|
+
# Input validation
|
|
54
|
+
if not user_id or not isinstance(user_id, str):
|
|
55
|
+
return Respons[AuthServiceReadDto](
|
|
56
|
+
detail="Invalid user_id: must be a non-empty string",
|
|
57
|
+
data=[],
|
|
58
|
+
success=False,
|
|
59
|
+
status_code=400,
|
|
60
|
+
error="INVALID_USER_ID"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if not tenant_id or not isinstance(tenant_id, str):
|
|
64
|
+
return Respons[AuthServiceReadDto](
|
|
65
|
+
detail="Invalid tenant_id: must be a non-empty string",
|
|
66
|
+
data=[],
|
|
67
|
+
success=False,
|
|
68
|
+
status_code=400,
|
|
69
|
+
error="INVALID_TENANT_ID"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
|
|
74
|
+
is_tenant_verified = DatabaseManager.execute_query(
|
|
75
|
+
f"SELECT is_verified FROM {db_settings.MAIN_TENANTS_TABLE} WHERE delete_status = 'NOT_DELETED' AND id = %s",
|
|
76
|
+
(tenant_id,),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if not is_tenant_verified or len(is_tenant_verified) == 0:
|
|
80
|
+
logger.warning("Login failed - tenant not found: %s", tenant_id)
|
|
81
|
+
return Respons[AuthServiceReadDto](
|
|
82
|
+
detail=f"Tenant '{tenant_id}' not found or has been deleted",
|
|
83
|
+
data=[],
|
|
84
|
+
success=False,
|
|
85
|
+
status_code=404,
|
|
86
|
+
error="TENANT_NOT_FOUND"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if not is_tenant_verified[0]['is_verified']:
|
|
90
|
+
logger.warning("Login failed - tenant not verified for user: %s, tenant: %s", user_id, tenant_id)
|
|
91
|
+
return Respons[AuthServiceReadDto](
|
|
92
|
+
detail=f"Tenant '{tenant_id}' is not verified. Please contact your administrator.",
|
|
93
|
+
data=[],
|
|
94
|
+
success=False,
|
|
95
|
+
status_code=403,
|
|
96
|
+
error="TENANT_NOT_VERIFIED"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
login_settings_details = DatabaseManager.execute_query(
|
|
100
|
+
f"""SELECT user_id, group_id, is_suspended, can_always_login,
|
|
101
|
+
is_multi_factor_enabled, is_login_before, working_days,
|
|
102
|
+
login_on, logout_on FROM "{tenant_id}".{db_settings.TENANT_LOGIN_SETTINGS_TABLE}
|
|
103
|
+
WHERE (delete_status = 'NOT_DELETED' AND is_active = true ) AND user_id = %s""",
|
|
104
|
+
(user_id,),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if not login_settings_details or len(login_settings_details) == 0:
|
|
108
|
+
logger.warning("Authorization failed - user not found: %s in tenant: %s", user_id, tenant_id)
|
|
109
|
+
return Respons[AuthServiceReadDto](
|
|
110
|
+
detail=f"User '{user_id}' not found in tenant '{tenant_id}' or account is inactive",
|
|
111
|
+
data=[],
|
|
112
|
+
success=False,
|
|
113
|
+
status_code=404,
|
|
114
|
+
error="USER_NOT_FOUND"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
if login_settings_details[0]['is_suspended']:
|
|
118
|
+
logger.warning("Authorization failed - user suspended: %s", user_id)
|
|
119
|
+
return Respons[AuthServiceReadDto](
|
|
120
|
+
detail="Your account has been suspended. Please contact your administrator.",
|
|
121
|
+
data=[],
|
|
122
|
+
success=False,
|
|
123
|
+
status_code=403,
|
|
124
|
+
error="USER_SUSPENDED"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if not login_settings_details[0]['can_always_login']:
|
|
128
|
+
current_day = datetime.now().strftime("%A").upper()
|
|
129
|
+
|
|
130
|
+
if current_day not in login_settings_details[0]['working_days']:
|
|
131
|
+
logger.warning("Authorization failed - outside working days for user: %s checking custom login period", user_id)
|
|
132
|
+
|
|
133
|
+
# Get current datetime (full date and time) with timezone
|
|
134
|
+
current_datetime = datetime.now(timezone.utc).replace(microsecond=0, second=0)
|
|
135
|
+
|
|
136
|
+
# Get from database (should already be datetime objects)
|
|
137
|
+
login_on = login_settings_details[0]['login_on']
|
|
138
|
+
logout_on = login_settings_details[0]['logout_on']
|
|
139
|
+
|
|
140
|
+
# Set defaults if None (with timezone awareness)
|
|
141
|
+
if not login_on:
|
|
142
|
+
login_on = datetime.min.replace(tzinfo=timezone.utc)
|
|
143
|
+
if not logout_on:
|
|
144
|
+
logout_on = datetime.max.replace(tzinfo=timezone.utc)
|
|
145
|
+
|
|
146
|
+
# Compare full datetime objects (both date and time)
|
|
147
|
+
if not (login_on <= current_datetime <= logout_on):
|
|
148
|
+
logger.warning("Authorization failed - outside allowed period for user: %s", user_id)
|
|
149
|
+
return Respons[AuthServiceReadDto](
|
|
150
|
+
detail="Login is not allowed at this time. Please check your access schedule.",
|
|
151
|
+
data=[],
|
|
152
|
+
success=False,
|
|
153
|
+
status_code=403,
|
|
154
|
+
error="LOGIN_TIME_RESTRICTED"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# 1️⃣ Get all groups the user belongs to
|
|
158
|
+
user_groups = DatabaseManager.execute_query(
|
|
159
|
+
f"""SELECT group_id FROM "{tenant_id}".{db_settings.USER_GROUPS_TABLE}
|
|
160
|
+
WHERE delete_status = 'NOT_DELETED' AND is_active = true AND user_id = %s""",(user_id,),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# 2️⃣ Prepare list of group_ids
|
|
164
|
+
group_ids = [g["group_id"] for g in user_groups] if user_groups else []
|
|
165
|
+
|
|
166
|
+
# 3️⃣ Build query dynamically to include groups (if any) + user
|
|
167
|
+
if group_ids:
|
|
168
|
+
get_user_roles = DatabaseManager.execute_query(
|
|
169
|
+
f"""
|
|
170
|
+
SELECT DISTINCT ON (org_id, group_id, bus_id, app_id, shared_resource_id, resource_id, user_id, role_id)
|
|
171
|
+
org_id, group_id, bus_id, app_id, shared_resource_id, resource_id, user_id, role_id
|
|
172
|
+
FROM "{tenant_id}".{db_settings.ASSIGN_ROLES_TABLE}
|
|
173
|
+
WHERE delete_status = 'NOT_DELETED'
|
|
174
|
+
AND is_active = true
|
|
175
|
+
AND (user_id = %s OR group_id = ANY(%s))
|
|
176
|
+
ORDER BY org_id, group_id, bus_id, app_id, shared_resource_id, resource_id, user_id, role_id;
|
|
177
|
+
""",
|
|
178
|
+
(user_id, group_ids),
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
# No groups, just check roles for user
|
|
182
|
+
get_user_roles = DatabaseManager.execute_query(
|
|
183
|
+
f"""
|
|
184
|
+
SELECT DISTINCT ON (org_id, bus_id, app_id, shared_resource_id, resource_id, user_id, role_id)
|
|
185
|
+
org_id, bus_id, app_id, shared_resource_id, resource_id, user_id, role_id
|
|
186
|
+
FROM "{tenant_id}".{db_settings.ASSIGN_ROLES_TABLE}
|
|
187
|
+
WHERE delete_status = 'NOT_DELETED'
|
|
188
|
+
AND is_active = true
|
|
189
|
+
AND user_id = %s
|
|
190
|
+
ORDER BY org_id, bus_id, app_id, shared_resource_id, resource_id, user_id, role_id;
|
|
191
|
+
""",
|
|
192
|
+
(user_id,),
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# GET permissions and Append to Role
|
|
196
|
+
get_user_roles_with_tenant_and_permissions = []
|
|
197
|
+
for role in get_user_roles:
|
|
198
|
+
permissions = DatabaseManager.execute_query(
|
|
199
|
+
f"""SELECT permission_id FROM {db_settings.ROLE_PERMISSIONS_TABLE} WHERE role_id = %s""",
|
|
200
|
+
params=(role["role_id"],),)
|
|
201
|
+
|
|
202
|
+
role_dict = {**role, "tenant_id": tenant_id, "permissions": [p['permission_id'] for p in permissions]}
|
|
203
|
+
get_user_roles_with_tenant_and_permissions.append(role_dict)
|
|
204
|
+
|
|
205
|
+
roles_dto = Helper.map_to_dto(get_user_roles_with_tenant_and_permissions, AuthServiceReadDto)
|
|
206
|
+
|
|
207
|
+
return Respons[AuthServiceReadDto](
|
|
208
|
+
detail="Authorized",
|
|
209
|
+
data=roles_dto,
|
|
210
|
+
success=True,
|
|
211
|
+
status_code=200,
|
|
212
|
+
error=None,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
except HTTPException as http_ex:
|
|
216
|
+
raise http_ex
|
|
217
|
+
|
|
218
|
+
except Exception as e:
|
|
219
|
+
logger.error("Authorization check failed for user: %s", str(e))
|
|
220
|
+
return Respons[AuthServiceReadDto](
|
|
221
|
+
detail=None,
|
|
222
|
+
data=[],
|
|
223
|
+
success=False,
|
|
224
|
+
status_code=500,
|
|
225
|
+
error="Authorization check failed due to an internal error"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
@staticmethod
|
|
229
|
+
def check_permission(users_data: list, action=None, org_id=None, bus_id=None, app_id=None,
|
|
230
|
+
resource_id=None, shared_resource_id=None) -> bool:
|
|
231
|
+
"""
|
|
232
|
+
Check if user has a given permission (action) for a specific target.
|
|
233
|
+
|
|
234
|
+
Hierarchy: organization > business > app > location > resource/shared_resource
|
|
235
|
+
If a field in role is None, it applies to all under that level.
|
|
236
|
+
"""
|
|
237
|
+
for user_data in users_data:
|
|
238
|
+
# Check hierarchy: None means "all"
|
|
239
|
+
if user_data.org_id not in (None, org_id):
|
|
240
|
+
continue
|
|
241
|
+
if user_data.bus_id not in (None, bus_id):
|
|
242
|
+
continue
|
|
243
|
+
if user_data.app_id not in (None, app_id):
|
|
244
|
+
continue
|
|
245
|
+
if user_data.resource_id not in (None, resource_id):
|
|
246
|
+
continue
|
|
247
|
+
if user_data.shared_resource_id not in (None, shared_resource_id):
|
|
248
|
+
continue
|
|
249
|
+
|
|
250
|
+
# Check if the permission exists
|
|
251
|
+
if action in user_data.permissions:
|
|
252
|
+
return True
|
|
253
|
+
|
|
254
|
+
return False
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
def get_user_info_from_token(token: str) -> dict:
|
|
258
|
+
"""
|
|
259
|
+
Convenience method to get user information from a JWT token.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
token: JWT token string
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
dict: User information including user_id and tenant_id
|
|
266
|
+
|
|
267
|
+
Raises:
|
|
268
|
+
HTTPException: If token is invalid
|
|
269
|
+
"""
|
|
270
|
+
return AuthService.decode_token(token)
|
|
271
|
+
|
|
272
|
+
@staticmethod
|
|
273
|
+
def authorize_user_from_token(token: str) -> Respons[AuthServiceReadDto]:
|
|
274
|
+
"""
|
|
275
|
+
Convenience method to authorize a user directly from a JWT token.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
token: JWT token string
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Respons[AuthServiceReadDto]: Authorization result with user roles and permissions
|
|
282
|
+
|
|
283
|
+
Raises:
|
|
284
|
+
HTTPException: If token is invalid
|
|
285
|
+
"""
|
|
286
|
+
user_info = AuthService.decode_token(token)
|
|
287
|
+
return AuthService.authorize(user_info["user_id"], user_info["tenant_id"])
|
|
288
|
+
|
|
289
|
+
@staticmethod
|
|
290
|
+
def get_user_permissions(user_roles: list) -> list:
|
|
291
|
+
"""
|
|
292
|
+
Get all unique permissions for a user across all their roles.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
user_roles: List of user roles from authorization
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
list: Unique list of permissions
|
|
299
|
+
"""
|
|
300
|
+
permissions = set()
|
|
301
|
+
for role in user_roles:
|
|
302
|
+
if role.permissions:
|
|
303
|
+
permissions.update(role.permissions)
|
|
304
|
+
return list(permissions)
|
|
305
|
+
|
|
306
|
+
@staticmethod
|
|
307
|
+
def has_any_permission(user_roles: list, required_permissions: list) -> bool:
|
|
308
|
+
"""
|
|
309
|
+
Check if user has any of the required permissions.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
user_roles: List of user roles from authorization
|
|
313
|
+
required_permissions: List of permissions to check for
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
bool: True if user has any of the required permissions
|
|
317
|
+
"""
|
|
318
|
+
user_permissions = AuthService.get_user_permissions(user_roles)
|
|
319
|
+
return any(perm in user_permissions for perm in required_permissions)
|
|
320
|
+
|
|
321
|
+
@staticmethod
|
|
322
|
+
def has_all_permissions(user_roles: list, required_permissions: list) -> bool:
|
|
323
|
+
"""
|
|
324
|
+
Check if user has all of the required permissions.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
user_roles: List of user roles from authorization
|
|
328
|
+
required_permissions: List of permissions to check for
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
bool: True if user has all of the required permissions
|
|
332
|
+
"""
|
|
333
|
+
user_permissions = AuthService.get_user_permissions(user_roles)
|
|
334
|
+
return all(perm in user_permissions for perm in required_permissions)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TroveSuite Configuration Module
|
|
3
|
+
|
|
4
|
+
Configuration settings and database management for TroveSuite services.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .settings import db_settings
|
|
8
|
+
from .database import DatabaseManager
|
|
9
|
+
from .logging import setup_logging, get_logger
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"db_settings",
|
|
13
|
+
"DatabaseManager",
|
|
14
|
+
"setup_logging",
|
|
15
|
+
"get_logger"
|
|
16
|
+
]
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Database configuration and connection management
|
|
3
|
+
"""
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from typing import Generator, Optional
|
|
6
|
+
import psycopg2
|
|
7
|
+
import psycopg2.pool
|
|
8
|
+
from psycopg2.extras import RealDictCursor
|
|
9
|
+
from .settings import db_settings
|
|
10
|
+
from .logging import get_logger
|
|
11
|
+
|
|
12
|
+
logger = get_logger("database")
|
|
13
|
+
|
|
14
|
+
# Database connection pool
|
|
15
|
+
_connection_pool: Optional[psycopg2.pool.ThreadedConnectionPool] = None
|
|
16
|
+
_sqlmodel_engine = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DatabaseConfig:
|
|
20
|
+
"""Database configuration and connection management"""
|
|
21
|
+
|
|
22
|
+
def __init__(self):
|
|
23
|
+
self.settings = db_settings
|
|
24
|
+
self.pool_size = 5
|
|
25
|
+
self.max_overflow = 10
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def database_url(self):
|
|
29
|
+
"""Get database URL (lazy evaluation)"""
|
|
30
|
+
return self.settings.database_url
|
|
31
|
+
|
|
32
|
+
def get_connection_params(self) -> dict:
|
|
33
|
+
"""Get database connection parameters"""
|
|
34
|
+
return {
|
|
35
|
+
"host": self.settings.DB_HOST,
|
|
36
|
+
"port": self.settings.DB_PORT,
|
|
37
|
+
"database": self.settings.DB_NAME,
|
|
38
|
+
"user": self.settings.DB_USER,
|
|
39
|
+
"password": self.settings.DB_PASSWORD,
|
|
40
|
+
"cursor_factory": RealDictCursor,
|
|
41
|
+
"application_name": f"{self.settings.APP_NAME}_{self.settings.ENVIRONMENT}"
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
def create_connection_pool(self) -> psycopg2.pool.ThreadedConnectionPool:
|
|
45
|
+
"""Create a connection pool for psycopg2"""
|
|
46
|
+
try:
|
|
47
|
+
pool = psycopg2.pool.ThreadedConnectionPool(
|
|
48
|
+
minconn=1,
|
|
49
|
+
maxconn=self.pool_size,
|
|
50
|
+
**self.get_connection_params()
|
|
51
|
+
)
|
|
52
|
+
logger.info(f"Database connection pool created with {self.pool_size} connections")
|
|
53
|
+
return pool
|
|
54
|
+
except Exception as e:
|
|
55
|
+
logger.error(f"Failed to create database connection pool: {str(e)}")
|
|
56
|
+
raise
|
|
57
|
+
|
|
58
|
+
def test_connection(self) -> bool:
|
|
59
|
+
"""Test database connection"""
|
|
60
|
+
try:
|
|
61
|
+
with psycopg2.connect(**self.get_connection_params()) as conn:
|
|
62
|
+
with conn.cursor() as cursor:
|
|
63
|
+
cursor.execute("SELECT 1")
|
|
64
|
+
result = cursor.fetchone()
|
|
65
|
+
if result:
|
|
66
|
+
logger.info("Database connection test successful")
|
|
67
|
+
return True
|
|
68
|
+
except Exception as e:
|
|
69
|
+
logger.error(f"Database connection test failed: {str(e)}")
|
|
70
|
+
return False
|
|
71
|
+
return False
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# Global database configuration
|
|
75
|
+
db_config = DatabaseConfig()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def initialize_database():
|
|
79
|
+
"""Initialize database connections and pool"""
|
|
80
|
+
global _connection_pool
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
# Test connection first
|
|
84
|
+
if not db_config.test_connection():
|
|
85
|
+
raise Exception("Database connection test failed")
|
|
86
|
+
|
|
87
|
+
# Create connection pool
|
|
88
|
+
_connection_pool = db_config.create_connection_pool()
|
|
89
|
+
|
|
90
|
+
logger.info("Database initialization completed successfully")
|
|
91
|
+
|
|
92
|
+
except Exception as e:
|
|
93
|
+
logger.error(f"Database initialization failed: {str(e)}")
|
|
94
|
+
raise
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_connection_pool() -> psycopg2.pool.ThreadedConnectionPool:
|
|
98
|
+
"""Get the database connection pool"""
|
|
99
|
+
global _connection_pool
|
|
100
|
+
if _connection_pool is None:
|
|
101
|
+
raise Exception("Database not initialized. Call initialize_database() first.")
|
|
102
|
+
return _connection_pool
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_sqlmodel_engine():
|
|
106
|
+
"""Get the SQLModel engine"""
|
|
107
|
+
global _sqlmodel_engine
|
|
108
|
+
if _sqlmodel_engine is None:
|
|
109
|
+
raise Exception("Database not initialized. Call initialize_database() first.")
|
|
110
|
+
return _sqlmodel_engine
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@contextmanager
|
|
114
|
+
def get_db_connection():
|
|
115
|
+
"""Get a database connection from the pool (context manager)"""
|
|
116
|
+
pool = get_connection_pool()
|
|
117
|
+
conn = None
|
|
118
|
+
try:
|
|
119
|
+
conn = pool.getconn()
|
|
120
|
+
logger.debug("Database connection acquired from pool")
|
|
121
|
+
yield conn
|
|
122
|
+
except Exception as e:
|
|
123
|
+
logger.error(f"Database connection error: {str(e)}")
|
|
124
|
+
if conn:
|
|
125
|
+
conn.rollback()
|
|
126
|
+
raise
|
|
127
|
+
finally:
|
|
128
|
+
if conn:
|
|
129
|
+
pool.putconn(conn)
|
|
130
|
+
logger.debug("Database connection returned to pool")
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@contextmanager
|
|
134
|
+
def get_db_cursor():
|
|
135
|
+
"""Get a database cursor (context manager)"""
|
|
136
|
+
with get_db_connection() as conn:
|
|
137
|
+
cursor = conn.cursor()
|
|
138
|
+
try:
|
|
139
|
+
yield cursor
|
|
140
|
+
conn.commit()
|
|
141
|
+
except Exception as e:
|
|
142
|
+
conn.rollback()
|
|
143
|
+
logger.error(f"Database cursor error: {str(e)}")
|
|
144
|
+
raise
|
|
145
|
+
finally:
|
|
146
|
+
cursor.close()
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class DatabaseManager:
|
|
150
|
+
"""Database manager for common operations"""
|
|
151
|
+
|
|
152
|
+
@staticmethod
|
|
153
|
+
def execute_query(query: str, params: tuple = None) -> list:
|
|
154
|
+
"""Execute a SELECT query and return results"""
|
|
155
|
+
with get_db_cursor() as cursor:
|
|
156
|
+
cursor.execute(query, params)
|
|
157
|
+
return cursor.fetchall()
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def execute_update(query: str, params: tuple = None) -> int:
|
|
161
|
+
"""Execute an INSERT/UPDATE/DELETE query and return affected rows"""
|
|
162
|
+
with get_db_cursor() as cursor:
|
|
163
|
+
cursor.execute(query, params)
|
|
164
|
+
return cursor.rowcount
|
|
165
|
+
|
|
166
|
+
@staticmethod
|
|
167
|
+
def execute_scalar(query: str, params: tuple = None):
|
|
168
|
+
"""Execute a query and return a single value"""
|
|
169
|
+
with get_db_cursor() as cursor:
|
|
170
|
+
cursor.execute(query, params)
|
|
171
|
+
result = cursor.fetchone()
|
|
172
|
+
if result:
|
|
173
|
+
# Handle RealDictRow (dictionary-like) result
|
|
174
|
+
if hasattr(result, 'get'):
|
|
175
|
+
# For RealDictRow, get the first value
|
|
176
|
+
return list(result.values())[0] if result else None
|
|
177
|
+
else:
|
|
178
|
+
# Handle tuple result
|
|
179
|
+
return result[0] if len(result) > 0 else None
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
@staticmethod
|
|
183
|
+
def health_check() -> dict:
|
|
184
|
+
"""Perform database health check"""
|
|
185
|
+
try:
|
|
186
|
+
with get_db_cursor() as cursor:
|
|
187
|
+
cursor.execute("SELECT version(), current_database(), current_user")
|
|
188
|
+
result = cursor.fetchone()
|
|
189
|
+
|
|
190
|
+
if result:
|
|
191
|
+
# Handle RealDictRow (dictionary-like) result
|
|
192
|
+
if hasattr(result, 'get'):
|
|
193
|
+
return {
|
|
194
|
+
"status": "healthy",
|
|
195
|
+
"database": result.get('current_database', 'unknown'),
|
|
196
|
+
"user": result.get('current_user', 'unknown'),
|
|
197
|
+
"version": result.get('version', 'unknown')
|
|
198
|
+
}
|
|
199
|
+
else:
|
|
200
|
+
# Handle tuple result
|
|
201
|
+
return {
|
|
202
|
+
"status": "healthy",
|
|
203
|
+
"database": result[1] if len(result) > 1 else "unknown",
|
|
204
|
+
"user": result[2] if len(result) > 2 else "unknown",
|
|
205
|
+
"version": result[0] if len(result) > 0 else "unknown"
|
|
206
|
+
}
|
|
207
|
+
else:
|
|
208
|
+
return {
|
|
209
|
+
"status": "unhealthy",
|
|
210
|
+
"error": "No result from database query"
|
|
211
|
+
}
|
|
212
|
+
except Exception as e:
|
|
213
|
+
logger.error(f"Database health check failed: {str(e)}")
|
|
214
|
+
return {
|
|
215
|
+
"status": "unhealthy",
|
|
216
|
+
"error": str(e)
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
# Database initialization is done lazily when needed
|
|
221
|
+
# Call initialize_database() explicitly when you need to use the database
|