mdb-engine 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. mdb_engine/__init__.py +38 -6
  2. mdb_engine/auth/README.md +534 -11
  3. mdb_engine/auth/__init__.py +129 -28
  4. mdb_engine/auth/audit.py +592 -0
  5. mdb_engine/auth/casbin_factory.py +10 -14
  6. mdb_engine/auth/config_helpers.py +7 -6
  7. mdb_engine/auth/cookie_utils.py +3 -7
  8. mdb_engine/auth/csrf.py +373 -0
  9. mdb_engine/auth/decorators.py +3 -10
  10. mdb_engine/auth/dependencies.py +37 -45
  11. mdb_engine/auth/helpers.py +3 -3
  12. mdb_engine/auth/integration.py +30 -73
  13. mdb_engine/auth/jwt.py +2 -6
  14. mdb_engine/auth/middleware.py +77 -34
  15. mdb_engine/auth/oso_factory.py +16 -36
  16. mdb_engine/auth/provider.py +17 -38
  17. mdb_engine/auth/rate_limiter.py +504 -0
  18. mdb_engine/auth/restrictions.py +8 -24
  19. mdb_engine/auth/session_manager.py +14 -29
  20. mdb_engine/auth/shared_middleware.py +600 -0
  21. mdb_engine/auth/shared_users.py +759 -0
  22. mdb_engine/auth/token_store.py +14 -28
  23. mdb_engine/auth/users.py +54 -113
  24. mdb_engine/auth/utils.py +213 -15
  25. mdb_engine/cli/commands/generate.py +545 -9
  26. mdb_engine/cli/commands/validate.py +3 -7
  27. mdb_engine/cli/utils.py +3 -3
  28. mdb_engine/config.py +7 -21
  29. mdb_engine/constants.py +65 -0
  30. mdb_engine/core/README.md +117 -6
  31. mdb_engine/core/__init__.py +39 -7
  32. mdb_engine/core/app_registration.py +22 -41
  33. mdb_engine/core/app_secrets.py +290 -0
  34. mdb_engine/core/connection.py +18 -9
  35. mdb_engine/core/encryption.py +223 -0
  36. mdb_engine/core/engine.py +758 -95
  37. mdb_engine/core/index_management.py +12 -16
  38. mdb_engine/core/manifest.py +424 -135
  39. mdb_engine/core/ray_integration.py +435 -0
  40. mdb_engine/core/seeding.py +10 -18
  41. mdb_engine/core/service_initialization.py +12 -23
  42. mdb_engine/core/types.py +2 -5
  43. mdb_engine/database/README.md +112 -16
  44. mdb_engine/database/__init__.py +17 -6
  45. mdb_engine/database/abstraction.py +25 -37
  46. mdb_engine/database/connection.py +11 -18
  47. mdb_engine/database/query_validator.py +367 -0
  48. mdb_engine/database/resource_limiter.py +204 -0
  49. mdb_engine/database/scoped_wrapper.py +713 -196
  50. mdb_engine/embeddings/__init__.py +17 -9
  51. mdb_engine/embeddings/dependencies.py +1 -3
  52. mdb_engine/embeddings/service.py +11 -25
  53. mdb_engine/exceptions.py +92 -0
  54. mdb_engine/indexes/README.md +30 -13
  55. mdb_engine/indexes/__init__.py +1 -0
  56. mdb_engine/indexes/helpers.py +1 -1
  57. mdb_engine/indexes/manager.py +50 -114
  58. mdb_engine/memory/README.md +2 -2
  59. mdb_engine/memory/__init__.py +1 -2
  60. mdb_engine/memory/service.py +30 -87
  61. mdb_engine/observability/README.md +4 -2
  62. mdb_engine/observability/__init__.py +26 -9
  63. mdb_engine/observability/health.py +8 -9
  64. mdb_engine/observability/metrics.py +32 -12
  65. mdb_engine/routing/README.md +1 -1
  66. mdb_engine/routing/__init__.py +1 -3
  67. mdb_engine/routing/websockets.py +25 -60
  68. mdb_engine-0.1.7.dist-info/METADATA +285 -0
  69. mdb_engine-0.1.7.dist-info/RECORD +85 -0
  70. mdb_engine-0.1.6.dist-info/METADATA +0 -213
  71. mdb_engine-0.1.6.dist-info/RECORD +0 -75
  72. {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/WHEEL +0 -0
  73. {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/entry_points.txt +0 -0
  74. {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/licenses/LICENSE +0 -0
  75. {mdb_engine-0.1.6.dist-info → mdb_engine-0.1.7.dist-info}/top_level.txt +0 -0
@@ -51,8 +51,7 @@ def get_secure_cookie_settings(
51
51
  # Auto-detect: secure if HTTPS or production environment
52
52
  is_https = request.url.scheme == "https"
53
53
  is_production = (
54
- os.getenv("G_NOME_ENV") == "production"
55
- or os.getenv("ENVIRONMENT") == "production"
54
+ os.getenv("G_NOME_ENV") == "production" or os.getenv("ENVIRONMENT") == "production"
56
55
  )
57
56
  secure = is_https or is_production
58
57
  elif cookie_secure == "true":
@@ -63,8 +62,7 @@ def get_secure_cookie_settings(
63
62
  # No config - use environment-based defaults
64
63
  is_https = request.url.scheme == "https"
65
64
  is_production = (
66
- os.getenv("G_NOME_ENV") == "production"
67
- or os.getenv("ENVIRONMENT") == "production"
65
+ os.getenv("G_NOME_ENV") == "production" or os.getenv("ENVIRONMENT") == "production"
68
66
  )
69
67
  secure = is_https or is_production
70
68
 
@@ -153,6 +151,4 @@ def clear_auth_cookies(response, request: Optional[Request] = None):
153
151
  response.delete_cookie(key="token", httponly=True, secure=secure, samesite=samesite)
154
152
 
155
153
  # Delete refresh token cookie
156
- response.delete_cookie(
157
- key="refresh_token", httponly=True, secure=secure, samesite=samesite
158
- )
154
+ response.delete_cookie(key="refresh_token", httponly=True, secure=secure, samesite=samesite)
@@ -0,0 +1,373 @@
1
+ """
2
+ CSRF Protection Middleware
3
+
4
+ Implements the Double-Submit Cookie pattern for Cross-Site Request Forgery protection.
5
+ Auto-enabled for shared auth mode, with manifest-configurable options.
6
+
7
+ This module is part of MDB_ENGINE - MongoDB Engine.
8
+
9
+ Security Features:
10
+ - Double-submit cookie pattern (industry standard)
11
+ - Cryptographically secure token generation
12
+ - Configurable exempt routes for APIs
13
+ - SameSite cookie attribute for additional protection
14
+ - Token rotation on each request (optional)
15
+
16
+ Usage:
17
+ # Auto-enabled for shared auth mode in engine.create_app()
18
+
19
+ # Or manual usage:
20
+ from mdb_engine.auth.csrf import CSRFMiddleware
21
+ app.add_middleware(CSRFMiddleware, exempt_routes=["/api/*"])
22
+
23
+ # In templates, include the token:
24
+ <input type="hidden" name="csrf_token" value="{{ csrf_token }}">
25
+
26
+ # Or in JavaScript:
27
+ fetch('/endpoint', {
28
+ headers: {'X-CSRF-Token': getCookie('csrf_token')}
29
+ })
30
+ """
31
+
32
+ import fnmatch
33
+ import hashlib
34
+ import hmac
35
+ import logging
36
+ import os
37
+ import secrets
38
+ import time
39
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, Set
40
+
41
+ from fastapi import Request, Response, status
42
+ from fastapi.responses import JSONResponse
43
+ from starlette.middleware.base import BaseHTTPMiddleware
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ # Token settings
48
+ CSRF_TOKEN_LENGTH = 32 # 256 bits
49
+ CSRF_COOKIE_NAME = "csrf_token"
50
+ CSRF_HEADER_NAME = "X-CSRF-Token"
51
+ CSRF_FORM_FIELD = "csrf_token"
52
+ DEFAULT_TOKEN_TTL = 3600 # 1 hour
53
+
54
+ # Methods that require CSRF validation
55
+ UNSAFE_METHODS = {"POST", "PUT", "DELETE", "PATCH"}
56
+
57
+
58
+ def generate_csrf_token(secret: Optional[str] = None) -> str:
59
+ """
60
+ Generate a cryptographically secure CSRF token.
61
+
62
+ Args:
63
+ secret: Optional secret for HMAC signing (adds tamper detection)
64
+
65
+ Returns:
66
+ URL-safe base64 encoded token
67
+ """
68
+ raw_token = secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
69
+
70
+ if secret:
71
+ # Add HMAC signature for tamper detection
72
+ timestamp = str(int(time.time()))
73
+ message = f"{raw_token}:{timestamp}"
74
+ signature = hmac.new(secret.encode(), message.encode(), hashlib.sha256).hexdigest()[:16]
75
+ return f"{raw_token}:{timestamp}:{signature}"
76
+
77
+ return raw_token
78
+
79
+
80
+ def validate_csrf_token(
81
+ token: str,
82
+ secret: Optional[str] = None,
83
+ max_age: int = DEFAULT_TOKEN_TTL,
84
+ ) -> bool:
85
+ """
86
+ Validate a CSRF token.
87
+
88
+ Args:
89
+ token: The token to validate
90
+ secret: Optional secret for HMAC verification
91
+ max_age: Maximum token age in seconds
92
+
93
+ Returns:
94
+ True if valid, False otherwise
95
+ """
96
+ if not token:
97
+ return False
98
+
99
+ if secret and ":" in token:
100
+ # Verify HMAC-signed token
101
+ try:
102
+ parts = token.split(":")
103
+ if len(parts) != 3:
104
+ return False
105
+
106
+ raw_token, timestamp_str, signature = parts
107
+ timestamp = int(timestamp_str)
108
+
109
+ # Check age
110
+ if time.time() - timestamp > max_age:
111
+ logger.debug("CSRF token expired")
112
+ return False
113
+
114
+ # Verify signature
115
+ message = f"{raw_token}:{timestamp_str}"
116
+ expected_sig = hmac.new(secret.encode(), message.encode(), hashlib.sha256).hexdigest()[
117
+ :16
118
+ ]
119
+
120
+ if not hmac.compare_digest(signature, expected_sig):
121
+ logger.warning("CSRF token signature mismatch")
122
+ return False
123
+
124
+ return True
125
+ except (ValueError, IndexError) as e:
126
+ logger.warning(f"CSRF token validation error: {e}")
127
+ return False
128
+
129
+ # Simple token validation (just check it exists and has reasonable length)
130
+ return len(token) >= CSRF_TOKEN_LENGTH
131
+
132
+
133
+ class CSRFMiddleware(BaseHTTPMiddleware):
134
+ """
135
+ CSRF Protection Middleware using Double-Submit Cookie pattern.
136
+
137
+ The double-submit cookie pattern works by:
138
+ 1. Setting a CSRF token in a cookie (with HttpOnly=False so JS can read it)
139
+ 2. Requiring the same token in a header or form field
140
+ 3. Since attackers can't read cookies from other domains, they can't forge requests
141
+
142
+ Additional protection from SameSite=Lax cookies prevents the browser from
143
+ sending cookies on cross-site requests.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ app,
149
+ secret: Optional[str] = None,
150
+ exempt_routes: Optional[List[str]] = None,
151
+ exempt_methods: Optional[Set[str]] = None,
152
+ cookie_name: str = CSRF_COOKIE_NAME,
153
+ header_name: str = CSRF_HEADER_NAME,
154
+ form_field: str = CSRF_FORM_FIELD,
155
+ token_ttl: int = DEFAULT_TOKEN_TTL,
156
+ rotate_tokens: bool = False,
157
+ secure_cookies: bool = True,
158
+ ):
159
+ """
160
+ Initialize CSRF middleware.
161
+
162
+ Args:
163
+ app: FastAPI application
164
+ secret: Secret for HMAC token signing (recommended for production)
165
+ exempt_routes: Routes exempt from CSRF (supports wildcards: /api/*)
166
+ exempt_methods: HTTP methods exempt from CSRF (default: safe methods)
167
+ cookie_name: Name of the CSRF cookie
168
+ header_name: Name of the CSRF header
169
+ form_field: Name of the CSRF form field
170
+ token_ttl: Token time-to-live in seconds
171
+ rotate_tokens: Rotate token on each request (more secure, less convenient)
172
+ secure_cookies: Use Secure cookie flag (auto-detect HTTPS)
173
+ """
174
+ super().__init__(app)
175
+ self.secret = secret or os.getenv("MDB_ENGINE_CSRF_SECRET")
176
+ self.exempt_routes = exempt_routes or []
177
+ self.exempt_methods = exempt_methods or {"GET", "HEAD", "OPTIONS", "TRACE"}
178
+ self.cookie_name = cookie_name
179
+ self.header_name = header_name
180
+ self.form_field = form_field
181
+ self.token_ttl = token_ttl
182
+ self.rotate_tokens = rotate_tokens
183
+ self.secure_cookies = secure_cookies
184
+
185
+ logger.info(
186
+ f"CSRFMiddleware initialized (exempt_routes={self.exempt_routes}, "
187
+ f"rotate_tokens={rotate_tokens})"
188
+ )
189
+
190
+ def _is_exempt(self, path: str) -> bool:
191
+ """Check if a path is exempt from CSRF validation."""
192
+ for pattern in self.exempt_routes:
193
+ if fnmatch.fnmatch(path, pattern):
194
+ return True
195
+ return False
196
+
197
+ async def dispatch(
198
+ self,
199
+ request: Request,
200
+ call_next: Callable[[Request], Awaitable[Response]],
201
+ ) -> Response:
202
+ """
203
+ Process request through CSRF middleware.
204
+ """
205
+ path = request.url.path
206
+ method = request.method
207
+
208
+ # Skip exempt routes
209
+ if self._is_exempt(path):
210
+ return await call_next(request)
211
+
212
+ # Skip safe methods
213
+ if method in self.exempt_methods:
214
+ # Generate and set token for GET requests (for forms)
215
+ response = await call_next(request)
216
+
217
+ # Set CSRF token cookie if not present
218
+ if not request.cookies.get(self.cookie_name):
219
+ token = generate_csrf_token(self.secret)
220
+ self._set_csrf_cookie(request, response, token)
221
+
222
+ # Make token available in request state for templates
223
+ request.state.csrf_token = request.cookies.get(self.cookie_name) or generate_csrf_token(
224
+ self.secret
225
+ )
226
+
227
+ return response
228
+
229
+ # Validate CSRF token for unsafe methods
230
+ cookie_token = request.cookies.get(self.cookie_name)
231
+ if not cookie_token:
232
+ logger.warning(f"CSRF cookie missing for {method} {path}")
233
+ return JSONResponse(
234
+ status_code=status.HTTP_403_FORBIDDEN,
235
+ content={"detail": "CSRF token missing"},
236
+ )
237
+
238
+ # Get token from header or form
239
+ header_token = request.headers.get(self.header_name)
240
+ form_token = None
241
+
242
+ # Note: Form-based CSRF token extraction not implemented.
243
+ # For now, we rely on header-based CSRF for all requests.
244
+ # TODO: Implement request.form() based extraction if needed.
245
+
246
+ submitted_token = header_token or form_token
247
+
248
+ if not submitted_token:
249
+ logger.warning(f"CSRF token not submitted for {method} {path}")
250
+ return JSONResponse(
251
+ status_code=status.HTTP_403_FORBIDDEN,
252
+ content={"detail": "CSRF token not provided in header or form"},
253
+ )
254
+
255
+ # Compare tokens (constant-time comparison)
256
+ if not hmac.compare_digest(cookie_token, submitted_token):
257
+ logger.warning(f"CSRF token mismatch for {method} {path}")
258
+ return JSONResponse(
259
+ status_code=status.HTTP_403_FORBIDDEN,
260
+ content={"detail": "CSRF token invalid"},
261
+ )
262
+
263
+ # Validate token (check signature if secret is used)
264
+ if self.secret and not validate_csrf_token(cookie_token, self.secret, self.token_ttl):
265
+ logger.warning(f"CSRF token validation failed for {method} {path}")
266
+ return JSONResponse(
267
+ status_code=status.HTTP_403_FORBIDDEN,
268
+ content={"detail": "CSRF token expired or invalid"},
269
+ )
270
+
271
+ # Process request
272
+ response = await call_next(request)
273
+
274
+ # Optionally rotate token
275
+ if self.rotate_tokens:
276
+ new_token = generate_csrf_token(self.secret)
277
+ self._set_csrf_cookie(request, response, new_token)
278
+
279
+ return response
280
+
281
+ def _set_csrf_cookie(
282
+ self,
283
+ request: Request,
284
+ response: Response,
285
+ token: str,
286
+ ) -> None:
287
+ """Set the CSRF token cookie."""
288
+ is_https = request.url.scheme == "https"
289
+ is_production = os.getenv("ENVIRONMENT", "").lower() == "production"
290
+
291
+ response.set_cookie(
292
+ key=self.cookie_name,
293
+ value=token,
294
+ httponly=False, # Must be readable by JavaScript
295
+ secure=self.secure_cookies and (is_https or is_production),
296
+ samesite="lax", # Provides CSRF protection + allows top-level navigation
297
+ max_age=self.token_ttl,
298
+ path="/",
299
+ )
300
+
301
+
302
+ def create_csrf_middleware(
303
+ manifest_auth: Dict[str, Any],
304
+ secret: Optional[str] = None,
305
+ ) -> type:
306
+ """
307
+ Create CSRF middleware from manifest configuration.
308
+
309
+ Args:
310
+ manifest_auth: Auth section from manifest
311
+ secret: Optional CSRF secret (defaults to env var)
312
+
313
+ Returns:
314
+ Configured CSRFMiddleware class
315
+ """
316
+ csrf_config = manifest_auth.get("csrf_protection", True)
317
+
318
+ # Handle boolean or object config
319
+ if isinstance(csrf_config, bool):
320
+ if not csrf_config:
321
+ # Return a no-op middleware
322
+ class NoOpMiddleware(BaseHTTPMiddleware):
323
+ async def dispatch(self, request, call_next):
324
+ return await call_next(request)
325
+
326
+ return NoOpMiddleware
327
+
328
+ # Use defaults
329
+ exempt_routes = manifest_auth.get("public_routes", [])
330
+ rotate_tokens = False
331
+ token_ttl = DEFAULT_TOKEN_TTL
332
+ else:
333
+ # Object configuration
334
+ exempt_routes = csrf_config.get("exempt_routes", manifest_auth.get("public_routes", []))
335
+ rotate_tokens = csrf_config.get("rotate_tokens", False)
336
+ token_ttl = csrf_config.get("token_ttl", DEFAULT_TOKEN_TTL)
337
+
338
+ # Create configured middleware class
339
+ class ConfiguredCSRFMiddleware(CSRFMiddleware):
340
+ def __init__(self, app):
341
+ super().__init__(
342
+ app,
343
+ secret=secret or os.getenv("MDB_ENGINE_CSRF_SECRET"),
344
+ exempt_routes=exempt_routes,
345
+ rotate_tokens=rotate_tokens,
346
+ token_ttl=token_ttl,
347
+ )
348
+
349
+ return ConfiguredCSRFMiddleware
350
+
351
+
352
+ # Dependency for getting CSRF token in routes
353
+ def get_csrf_token(request: Request) -> str:
354
+ """
355
+ Get or generate CSRF token for use in templates.
356
+
357
+ Usage in FastAPI route:
358
+ @app.get("/form")
359
+ def form_page(csrf_token: str = Depends(get_csrf_token)):
360
+ return templates.TemplateResponse("form.html", {"csrf_token": csrf_token})
361
+ """
362
+ # Try to get from request state (set by middleware)
363
+ if hasattr(request.state, "csrf_token"):
364
+ return request.state.csrf_token
365
+
366
+ # Try to get from cookie
367
+ token = request.cookies.get(CSRF_COOKIE_NAME)
368
+ if token:
369
+ return token
370
+
371
+ # Generate new token
372
+ secret = os.getenv("MDB_ENGINE_CSRF_SECRET")
373
+ return generate_csrf_token(secret)
@@ -70,10 +70,7 @@ def _is_production_environment() -> bool:
70
70
  """Check if running in production environment."""
71
71
  import os
72
72
 
73
- return (
74
- os.getenv("G_NOME_ENV") == "production"
75
- or os.getenv("ENVIRONMENT") == "production"
76
- )
73
+ return os.getenv("G_NOME_ENV") == "production" or os.getenv("ENVIRONMENT") == "production"
77
74
 
78
75
 
79
76
  def _validate_https(request: Request) -> None:
@@ -189,17 +186,13 @@ def rate_limit_auth(
189
186
 
190
187
  # Use provided values or config values or defaults
191
188
  if max_attempts is None:
192
- max_attempts_val = (
193
- rate_limit_config.get("max_attempts") if rate_limit_config else 5
194
- )
189
+ max_attempts_val = rate_limit_config.get("max_attempts") if rate_limit_config else 5
195
190
  else:
196
191
  max_attempts_val = max_attempts
197
192
 
198
193
  if window_seconds is None:
199
194
  window_seconds_val = (
200
- rate_limit_config.get("window_seconds")
201
- if rate_limit_config
202
- else 300
195
+ rate_limit_config.get("window_seconds") if rate_limit_config else 300
203
196
  )
204
197
  else:
205
198
  window_seconds_val = window_seconds
@@ -14,9 +14,11 @@ from typing import Any, Dict, Mapping, Optional, Tuple
14
14
 
15
15
  import jwt
16
16
  from fastapi import Cookie, Depends, HTTPException, Request, status
17
+ from pymongo.errors import PyMongoError
17
18
 
18
19
  from ..exceptions import ConfigurationError
19
20
  from .jwt import decode_jwt_token, extract_token_metadata
21
+
20
22
  # Import from local modules
21
23
  from .provider import AuthorizationProvider
22
24
  from .session_manager import SessionManager
@@ -164,9 +166,7 @@ async def get_current_user(
164
166
  if blacklist:
165
167
  is_revoked = await blacklist.is_revoked(jti)
166
168
  if is_revoked:
167
- logger.info(
168
- f"get_current_user: Token {jti} is blacklisted (revoked)"
169
- )
169
+ logger.info(f"get_current_user: Token {jti} is blacklisted (revoked)")
170
170
  return None
171
171
 
172
172
  # Also check user-level revocation
@@ -174,9 +174,7 @@ async def get_current_user(
174
174
  if user_id:
175
175
  user_revoked = await blacklist.is_user_revoked(user_id)
176
176
  if user_revoked:
177
- logger.info(
178
- f"get_current_user: All tokens for user {user_id} are revoked"
179
- )
177
+ logger.info(f"get_current_user: All tokens for user {user_id} are revoked")
180
178
  return None
181
179
 
182
180
  payload = decode_jwt_token(token, str(SECRET_KEY))
@@ -184,9 +182,7 @@ async def get_current_user(
184
182
  # Verify token type (should be access token for backward compatibility, or no type)
185
183
  token_type = payload.get("type")
186
184
  if token_type and token_type not in ("access", None):
187
- logger.warning(
188
- f"get_current_user: Invalid token type '{token_type}' for access token"
189
- )
185
+ logger.warning(f"get_current_user: Invalid token type '{token_type}' for access token")
190
186
  return None
191
187
 
192
188
  logger.debug(
@@ -203,10 +199,12 @@ async def get_current_user(
203
199
  except (ValueError, TypeError):
204
200
  logger.exception("Validation error decoding JWT token")
205
201
  return None
206
- except Exception:
207
- logger.exception("Unexpected error decoding JWT token")
208
- # Re-raise unexpected errors for debugging
209
- raise
202
+ except PyMongoError:
203
+ logger.exception("Database error checking token blacklist")
204
+ return None
205
+ except (AttributeError, KeyError):
206
+ logger.exception("State access error in get_current_user")
207
+ return None
210
208
 
211
209
 
212
210
  async def get_current_user_from_request(request: Request) -> Optional[Dict[str, Any]]:
@@ -276,10 +274,12 @@ async def get_current_user_from_request(request: Request) -> Optional[Dict[str,
276
274
  except (ValueError, TypeError):
277
275
  logger.exception("Validation error decoding JWT token from request")
278
276
  return None
279
- except Exception:
280
- logger.exception("Unexpected error decoding JWT token from request")
281
- # Re-raise unexpected errors for debugging
282
- raise
277
+ except PyMongoError:
278
+ logger.exception("Database error checking token blacklist from request")
279
+ return None
280
+ except (AttributeError, KeyError):
281
+ logger.exception("State access error in get_current_user_from_request")
282
+ return None
283
283
 
284
284
 
285
285
  async def get_refresh_token(
@@ -314,9 +314,7 @@ async def get_refresh_token(
314
314
  if blacklist:
315
315
  is_revoked = await blacklist.is_revoked(jti)
316
316
  if is_revoked:
317
- logger.info(
318
- f"get_refresh_token: Refresh token {jti} is blacklisted"
319
- )
317
+ logger.info(f"get_refresh_token: Refresh token {jti} is blacklisted")
320
318
  return None
321
319
 
322
320
  payload = decode_jwt_token(refresh_token, str(SECRET_KEY))
@@ -350,13 +348,9 @@ async def get_refresh_token(
350
348
  if stored_fingerprint:
351
349
  from .utils import generate_session_fingerprint
352
350
 
353
- device_id = request.cookies.get("device_id") or payload.get(
354
- "device_id"
355
- )
351
+ device_id = request.cookies.get("device_id") or payload.get("device_id")
356
352
  if device_id:
357
- current_fingerprint = generate_session_fingerprint(
358
- request, device_id
359
- )
353
+ current_fingerprint = generate_session_fingerprint(request, device_id)
360
354
  if current_fingerprint != stored_fingerprint:
361
355
  logger.warning(
362
356
  f"get_refresh_token: Session fingerprint mismatch "
@@ -377,10 +371,12 @@ async def get_refresh_token(
377
371
  except (ValueError, TypeError):
378
372
  logger.exception("Validation error decoding refresh token")
379
373
  return None
380
- except Exception:
381
- logger.exception("Unexpected error decoding refresh token")
382
- # Re-raise unexpected errors for debugging
383
- raise
374
+ except PyMongoError:
375
+ logger.exception("Database error checking refresh token")
376
+ return None
377
+ except (AttributeError, KeyError):
378
+ logger.exception("State access error in get_refresh_token")
379
+ return None
384
380
 
385
381
 
386
382
  async def require_admin(
@@ -504,14 +500,14 @@ async def get_current_user_or_redirect(
504
500
  headers={"Location": redirect_url},
505
501
  detail="Not authenticated. Redirecting to login.",
506
502
  )
507
- except (ValueError, KeyError, AttributeError):
503
+ except (ValueError, KeyError, AttributeError) as e:
508
504
  logger.exception(
509
505
  f"Failed to generate login redirect URL for route '{login_route_name}'"
510
506
  )
511
507
  raise HTTPException(
512
508
  status_code=status.HTTP_401_UNAUTHORIZED,
513
509
  detail="Authentication required, but redirect failed.",
514
- )
510
+ ) from e
515
511
  return dict(user)
516
512
 
517
513
 
@@ -619,9 +615,7 @@ async def refresh_access_token(
619
615
  from ..config import TOKEN_ROTATION_ENABLED
620
616
  from .jwt import generate_token_pair
621
617
 
622
- user_id = refresh_token_payload.get("user_id") or refresh_token_payload.get(
623
- "email"
624
- )
618
+ user_id = refresh_token_payload.get("user_id") or refresh_token_payload.get("email")
625
619
  old_refresh_jti = refresh_token_payload.get("jti")
626
620
  device_id = refresh_token_payload.get("device_id")
627
621
 
@@ -653,9 +647,7 @@ async def refresh_access_token(
653
647
 
654
648
  device_id = device_id or request.cookies.get("device_id")
655
649
  if device_id:
656
- current_fingerprint = generate_session_fingerprint(
657
- request, device_id
658
- )
650
+ current_fingerprint = generate_session_fingerprint(request, device_id)
659
651
  if current_fingerprint != stored_fingerprint:
660
652
  logger.warning(
661
653
  f"refresh_access_token: Session fingerprint mismatch "
@@ -671,9 +663,7 @@ async def refresh_access_token(
671
663
 
672
664
  # Use existing device_id or generate new one
673
665
  if not device_id:
674
- device_id = (
675
- str(uuid.uuid4()) if not device_info else device_info.get("device_id")
676
- )
666
+ device_id = str(uuid.uuid4()) if not device_info else device_info.get("device_id")
677
667
 
678
668
  if device_info:
679
669
  device_info["device_id"] = device_id
@@ -741,7 +731,9 @@ async def refresh_access_token(
741
731
  except (ValueError, TypeError, jwt.InvalidTokenError):
742
732
  logger.exception("Validation error refreshing token")
743
733
  return None
744
- except Exception:
745
- logger.exception("Unexpected error refreshing token")
746
- # Re-raise unexpected errors for debugging
747
- raise
734
+ except PyMongoError:
735
+ logger.exception("Database error refreshing token")
736
+ return None
737
+ except (AttributeError, KeyError):
738
+ logger.exception("State access error refreshing token")
739
+ return None
@@ -19,7 +19,7 @@ async def initialize_token_management(app, db):
19
19
 
20
20
  Args:
21
21
  app: FastAPI application instance
22
- db: MongoDB database instance (Motor AsyncIOMotorDatabase)
22
+ db: Scoped MongoDB database instance (ScopedMongoWrapper)
23
23
 
24
24
  Example:
25
25
  from mdb_engine.auth.helpers import initialize_token_management
@@ -27,8 +27,8 @@ async def initialize_token_management(app, db):
27
27
 
28
28
  @app.on_event("startup")
29
29
  async def startup():
30
- # Get database from engine
31
- db = engine.get_database()
30
+ # Get scoped database from engine
31
+ db = engine.get_scoped_db("my_app")
32
32
 
33
33
  # Initialize token management
34
34
  await initialize_token_management(app, db)