mdb-engine 0.1.6__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. mdb_engine/__init__.py +104 -11
  2. mdb_engine/auth/ARCHITECTURE.md +112 -0
  3. mdb_engine/auth/README.md +648 -11
  4. mdb_engine/auth/__init__.py +136 -29
  5. mdb_engine/auth/audit.py +592 -0
  6. mdb_engine/auth/base.py +252 -0
  7. mdb_engine/auth/casbin_factory.py +264 -69
  8. mdb_engine/auth/config_helpers.py +7 -6
  9. mdb_engine/auth/cookie_utils.py +3 -7
  10. mdb_engine/auth/csrf.py +373 -0
  11. mdb_engine/auth/decorators.py +3 -10
  12. mdb_engine/auth/dependencies.py +47 -50
  13. mdb_engine/auth/helpers.py +3 -3
  14. mdb_engine/auth/integration.py +53 -80
  15. mdb_engine/auth/jwt.py +2 -6
  16. mdb_engine/auth/middleware.py +77 -34
  17. mdb_engine/auth/oso_factory.py +18 -38
  18. mdb_engine/auth/provider.py +270 -171
  19. mdb_engine/auth/rate_limiter.py +504 -0
  20. mdb_engine/auth/restrictions.py +8 -24
  21. mdb_engine/auth/session_manager.py +14 -29
  22. mdb_engine/auth/shared_middleware.py +600 -0
  23. mdb_engine/auth/shared_users.py +759 -0
  24. mdb_engine/auth/token_store.py +14 -28
  25. mdb_engine/auth/users.py +54 -113
  26. mdb_engine/auth/utils.py +213 -15
  27. mdb_engine/cli/commands/generate.py +545 -9
  28. mdb_engine/cli/commands/validate.py +3 -7
  29. mdb_engine/cli/utils.py +3 -3
  30. mdb_engine/config.py +7 -21
  31. mdb_engine/constants.py +65 -0
  32. mdb_engine/core/README.md +117 -6
  33. mdb_engine/core/__init__.py +39 -7
  34. mdb_engine/core/app_registration.py +22 -41
  35. mdb_engine/core/app_secrets.py +290 -0
  36. mdb_engine/core/connection.py +18 -9
  37. mdb_engine/core/encryption.py +223 -0
  38. mdb_engine/core/engine.py +1057 -93
  39. mdb_engine/core/index_management.py +12 -16
  40. mdb_engine/core/manifest.py +459 -150
  41. mdb_engine/core/ray_integration.py +435 -0
  42. mdb_engine/core/seeding.py +10 -18
  43. mdb_engine/core/service_initialization.py +12 -23
  44. mdb_engine/core/types.py +2 -5
  45. mdb_engine/database/README.md +140 -17
  46. mdb_engine/database/__init__.py +17 -6
  47. mdb_engine/database/abstraction.py +25 -37
  48. mdb_engine/database/connection.py +11 -18
  49. mdb_engine/database/query_validator.py +367 -0
  50. mdb_engine/database/resource_limiter.py +204 -0
  51. mdb_engine/database/scoped_wrapper.py +713 -196
  52. mdb_engine/dependencies.py +426 -0
  53. mdb_engine/di/__init__.py +34 -0
  54. mdb_engine/di/container.py +248 -0
  55. mdb_engine/di/providers.py +205 -0
  56. mdb_engine/di/scopes.py +139 -0
  57. mdb_engine/embeddings/README.md +54 -24
  58. mdb_engine/embeddings/__init__.py +31 -24
  59. mdb_engine/embeddings/dependencies.py +37 -154
  60. mdb_engine/embeddings/service.py +11 -25
  61. mdb_engine/exceptions.py +92 -0
  62. mdb_engine/indexes/README.md +30 -13
  63. mdb_engine/indexes/__init__.py +1 -0
  64. mdb_engine/indexes/helpers.py +1 -1
  65. mdb_engine/indexes/manager.py +50 -114
  66. mdb_engine/memory/README.md +2 -2
  67. mdb_engine/memory/__init__.py +1 -2
  68. mdb_engine/memory/service.py +30 -87
  69. mdb_engine/observability/README.md +4 -2
  70. mdb_engine/observability/__init__.py +26 -9
  71. mdb_engine/observability/health.py +8 -9
  72. mdb_engine/observability/metrics.py +32 -12
  73. mdb_engine/repositories/__init__.py +34 -0
  74. mdb_engine/repositories/base.py +325 -0
  75. mdb_engine/repositories/mongo.py +233 -0
  76. mdb_engine/repositories/unit_of_work.py +166 -0
  77. mdb_engine/routing/README.md +1 -1
  78. mdb_engine/routing/__init__.py +1 -3
  79. mdb_engine/routing/websockets.py +25 -60
  80. mdb_engine-0.2.0.dist-info/METADATA +313 -0
  81. mdb_engine-0.2.0.dist-info/RECORD +96 -0
  82. mdb_engine-0.1.6.dist-info/METADATA +0 -213
  83. mdb_engine-0.1.6.dist-info/RECORD +0 -75
  84. {mdb_engine-0.1.6.dist-info → mdb_engine-0.2.0.dist-info}/WHEEL +0 -0
  85. {mdb_engine-0.1.6.dist-info → mdb_engine-0.2.0.dist-info}/entry_points.txt +0 -0
  86. {mdb_engine-0.1.6.dist-info → mdb_engine-0.2.0.dist-info}/licenses/LICENSE +0 -0
  87. {mdb_engine-0.1.6.dist-info → mdb_engine-0.2.0.dist-info}/top_level.txt +0 -0
@@ -12,9 +12,12 @@ from typing import Any, Dict
12
12
 
13
13
  from fastapi import Request
14
14
 
15
- from .config_defaults import (CORS_DEFAULTS, OBSERVABILITY_DEFAULTS,
16
- SECURITY_CONFIG_DEFAULTS,
17
- TOKEN_MANAGEMENT_DEFAULTS)
15
+ from .config_defaults import (
16
+ CORS_DEFAULTS,
17
+ OBSERVABILITY_DEFAULTS,
18
+ SECURITY_CONFIG_DEFAULTS,
19
+ TOKEN_MANAGEMENT_DEFAULTS,
20
+ )
18
21
 
19
22
  logger = logging.getLogger(__name__)
20
23
 
@@ -132,9 +135,7 @@ def get_ip_validation_config(request: Request) -> Dict[str, Any]:
132
135
  IP validation configuration dictionary
133
136
  """
134
137
  security_config = get_security_config(request)
135
- return security_config.get(
136
- "ip_validation", SECURITY_CONFIG_DEFAULTS["ip_validation"].copy()
137
- )
138
+ return security_config.get("ip_validation", SECURITY_CONFIG_DEFAULTS["ip_validation"].copy())
138
139
 
139
140
 
140
141
  def get_token_fingerprinting_config(request: Request) -> Dict[str, Any]:
@@ -51,8 +51,7 @@ def get_secure_cookie_settings(
51
51
  # Auto-detect: secure if HTTPS or production environment
52
52
  is_https = request.url.scheme == "https"
53
53
  is_production = (
54
- os.getenv("G_NOME_ENV") == "production"
55
- or os.getenv("ENVIRONMENT") == "production"
54
+ os.getenv("G_NOME_ENV") == "production" or os.getenv("ENVIRONMENT") == "production"
56
55
  )
57
56
  secure = is_https or is_production
58
57
  elif cookie_secure == "true":
@@ -63,8 +62,7 @@ def get_secure_cookie_settings(
63
62
  # No config - use environment-based defaults
64
63
  is_https = request.url.scheme == "https"
65
64
  is_production = (
66
- os.getenv("G_NOME_ENV") == "production"
67
- or os.getenv("ENVIRONMENT") == "production"
65
+ os.getenv("G_NOME_ENV") == "production" or os.getenv("ENVIRONMENT") == "production"
68
66
  )
69
67
  secure = is_https or is_production
70
68
 
@@ -153,6 +151,4 @@ def clear_auth_cookies(response, request: Optional[Request] = None):
153
151
  response.delete_cookie(key="token", httponly=True, secure=secure, samesite=samesite)
154
152
 
155
153
  # Delete refresh token cookie
156
- response.delete_cookie(
157
- key="refresh_token", httponly=True, secure=secure, samesite=samesite
158
- )
154
+ response.delete_cookie(key="refresh_token", httponly=True, secure=secure, samesite=samesite)
@@ -0,0 +1,373 @@
1
+ """
2
+ CSRF Protection Middleware
3
+
4
+ Implements the Double-Submit Cookie pattern for Cross-Site Request Forgery protection.
5
+ Auto-enabled for shared auth mode, with manifest-configurable options.
6
+
7
+ This module is part of MDB_ENGINE - MongoDB Engine.
8
+
9
+ Security Features:
10
+ - Double-submit cookie pattern (industry standard)
11
+ - Cryptographically secure token generation
12
+ - Configurable exempt routes for APIs
13
+ - SameSite cookie attribute for additional protection
14
+ - Token rotation on each request (optional)
15
+
16
+ Usage:
17
+ # Auto-enabled for shared auth mode in engine.create_app()
18
+
19
+ # Or manual usage:
20
+ from mdb_engine.auth.csrf import CSRFMiddleware
21
+ app.add_middleware(CSRFMiddleware, exempt_routes=["/api/*"])
22
+
23
+ # In templates, include the token:
24
+ <input type="hidden" name="csrf_token" value="{{ csrf_token }}">
25
+
26
+ # Or in JavaScript:
27
+ fetch('/endpoint', {
28
+ headers: {'X-CSRF-Token': getCookie('csrf_token')}
29
+ })
30
+ """
31
+
32
+ import fnmatch
33
+ import hashlib
34
+ import hmac
35
+ import logging
36
+ import os
37
+ import secrets
38
+ import time
39
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, Set
40
+
41
+ from fastapi import Request, Response, status
42
+ from fastapi.responses import JSONResponse
43
+ from starlette.middleware.base import BaseHTTPMiddleware
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ # Token settings
48
+ CSRF_TOKEN_LENGTH = 32 # 256 bits
49
+ CSRF_COOKIE_NAME = "csrf_token"
50
+ CSRF_HEADER_NAME = "X-CSRF-Token"
51
+ CSRF_FORM_FIELD = "csrf_token"
52
+ DEFAULT_TOKEN_TTL = 3600 # 1 hour
53
+
54
+ # Methods that require CSRF validation
55
+ UNSAFE_METHODS = {"POST", "PUT", "DELETE", "PATCH"}
56
+
57
+
58
+ def generate_csrf_token(secret: Optional[str] = None) -> str:
59
+ """
60
+ Generate a cryptographically secure CSRF token.
61
+
62
+ Args:
63
+ secret: Optional secret for HMAC signing (adds tamper detection)
64
+
65
+ Returns:
66
+ URL-safe base64 encoded token
67
+ """
68
+ raw_token = secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
69
+
70
+ if secret:
71
+ # Add HMAC signature for tamper detection
72
+ timestamp = str(int(time.time()))
73
+ message = f"{raw_token}:{timestamp}"
74
+ signature = hmac.new(secret.encode(), message.encode(), hashlib.sha256).hexdigest()[:16]
75
+ return f"{raw_token}:{timestamp}:{signature}"
76
+
77
+ return raw_token
78
+
79
+
80
+ def validate_csrf_token(
81
+ token: str,
82
+ secret: Optional[str] = None,
83
+ max_age: int = DEFAULT_TOKEN_TTL,
84
+ ) -> bool:
85
+ """
86
+ Validate a CSRF token.
87
+
88
+ Args:
89
+ token: The token to validate
90
+ secret: Optional secret for HMAC verification
91
+ max_age: Maximum token age in seconds
92
+
93
+ Returns:
94
+ True if valid, False otherwise
95
+ """
96
+ if not token:
97
+ return False
98
+
99
+ if secret and ":" in token:
100
+ # Verify HMAC-signed token
101
+ try:
102
+ parts = token.split(":")
103
+ if len(parts) != 3:
104
+ return False
105
+
106
+ raw_token, timestamp_str, signature = parts
107
+ timestamp = int(timestamp_str)
108
+
109
+ # Check age
110
+ if time.time() - timestamp > max_age:
111
+ logger.debug("CSRF token expired")
112
+ return False
113
+
114
+ # Verify signature
115
+ message = f"{raw_token}:{timestamp_str}"
116
+ expected_sig = hmac.new(secret.encode(), message.encode(), hashlib.sha256).hexdigest()[
117
+ :16
118
+ ]
119
+
120
+ if not hmac.compare_digest(signature, expected_sig):
121
+ logger.warning("CSRF token signature mismatch")
122
+ return False
123
+
124
+ return True
125
+ except (ValueError, IndexError) as e:
126
+ logger.warning(f"CSRF token validation error: {e}")
127
+ return False
128
+
129
+ # Simple token validation (just check it exists and has reasonable length)
130
+ return len(token) >= CSRF_TOKEN_LENGTH
131
+
132
+
133
+ class CSRFMiddleware(BaseHTTPMiddleware):
134
+ """
135
+ CSRF Protection Middleware using Double-Submit Cookie pattern.
136
+
137
+ The double-submit cookie pattern works by:
138
+ 1. Setting a CSRF token in a cookie (with HttpOnly=False so JS can read it)
139
+ 2. Requiring the same token in a header or form field
140
+ 3. Since attackers can't read cookies from other domains, they can't forge requests
141
+
142
+ Additional protection from SameSite=Lax cookies prevents the browser from
143
+ sending cookies on cross-site requests.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ app,
149
+ secret: Optional[str] = None,
150
+ exempt_routes: Optional[List[str]] = None,
151
+ exempt_methods: Optional[Set[str]] = None,
152
+ cookie_name: str = CSRF_COOKIE_NAME,
153
+ header_name: str = CSRF_HEADER_NAME,
154
+ form_field: str = CSRF_FORM_FIELD,
155
+ token_ttl: int = DEFAULT_TOKEN_TTL,
156
+ rotate_tokens: bool = False,
157
+ secure_cookies: bool = True,
158
+ ):
159
+ """
160
+ Initialize CSRF middleware.
161
+
162
+ Args:
163
+ app: FastAPI application
164
+ secret: Secret for HMAC token signing (recommended for production)
165
+ exempt_routes: Routes exempt from CSRF (supports wildcards: /api/*)
166
+ exempt_methods: HTTP methods exempt from CSRF (default: safe methods)
167
+ cookie_name: Name of the CSRF cookie
168
+ header_name: Name of the CSRF header
169
+ form_field: Name of the CSRF form field
170
+ token_ttl: Token time-to-live in seconds
171
+ rotate_tokens: Rotate token on each request (more secure, less convenient)
172
+ secure_cookies: Use Secure cookie flag (auto-detect HTTPS)
173
+ """
174
+ super().__init__(app)
175
+ self.secret = secret or os.getenv("MDB_ENGINE_CSRF_SECRET")
176
+ self.exempt_routes = exempt_routes or []
177
+ self.exempt_methods = exempt_methods or {"GET", "HEAD", "OPTIONS", "TRACE"}
178
+ self.cookie_name = cookie_name
179
+ self.header_name = header_name
180
+ self.form_field = form_field
181
+ self.token_ttl = token_ttl
182
+ self.rotate_tokens = rotate_tokens
183
+ self.secure_cookies = secure_cookies
184
+
185
+ logger.info(
186
+ f"CSRFMiddleware initialized (exempt_routes={self.exempt_routes}, "
187
+ f"rotate_tokens={rotate_tokens})"
188
+ )
189
+
190
+ def _is_exempt(self, path: str) -> bool:
191
+ """Check if a path is exempt from CSRF validation."""
192
+ for pattern in self.exempt_routes:
193
+ if fnmatch.fnmatch(path, pattern):
194
+ return True
195
+ return False
196
+
197
+ async def dispatch(
198
+ self,
199
+ request: Request,
200
+ call_next: Callable[[Request], Awaitable[Response]],
201
+ ) -> Response:
202
+ """
203
+ Process request through CSRF middleware.
204
+ """
205
+ path = request.url.path
206
+ method = request.method
207
+
208
+ # Skip exempt routes
209
+ if self._is_exempt(path):
210
+ return await call_next(request)
211
+
212
+ # Skip safe methods
213
+ if method in self.exempt_methods:
214
+ # Generate and set token for GET requests (for forms)
215
+ response = await call_next(request)
216
+
217
+ # Set CSRF token cookie if not present
218
+ if not request.cookies.get(self.cookie_name):
219
+ token = generate_csrf_token(self.secret)
220
+ self._set_csrf_cookie(request, response, token)
221
+
222
+ # Make token available in request state for templates
223
+ request.state.csrf_token = request.cookies.get(self.cookie_name) or generate_csrf_token(
224
+ self.secret
225
+ )
226
+
227
+ return response
228
+
229
+ # Validate CSRF token for unsafe methods
230
+ cookie_token = request.cookies.get(self.cookie_name)
231
+ if not cookie_token:
232
+ logger.warning(f"CSRF cookie missing for {method} {path}")
233
+ return JSONResponse(
234
+ status_code=status.HTTP_403_FORBIDDEN,
235
+ content={"detail": "CSRF token missing"},
236
+ )
237
+
238
+ # Get token from header or form
239
+ header_token = request.headers.get(self.header_name)
240
+ form_token = None
241
+
242
+ # Note: Form-based CSRF token extraction not implemented.
243
+ # For now, we rely on header-based CSRF for all requests.
244
+ # TODO: Implement request.form() based extraction if needed.
245
+
246
+ submitted_token = header_token or form_token
247
+
248
+ if not submitted_token:
249
+ logger.warning(f"CSRF token not submitted for {method} {path}")
250
+ return JSONResponse(
251
+ status_code=status.HTTP_403_FORBIDDEN,
252
+ content={"detail": "CSRF token not provided in header or form"},
253
+ )
254
+
255
+ # Compare tokens (constant-time comparison)
256
+ if not hmac.compare_digest(cookie_token, submitted_token):
257
+ logger.warning(f"CSRF token mismatch for {method} {path}")
258
+ return JSONResponse(
259
+ status_code=status.HTTP_403_FORBIDDEN,
260
+ content={"detail": "CSRF token invalid"},
261
+ )
262
+
263
+ # Validate token (check signature if secret is used)
264
+ if self.secret and not validate_csrf_token(cookie_token, self.secret, self.token_ttl):
265
+ logger.warning(f"CSRF token validation failed for {method} {path}")
266
+ return JSONResponse(
267
+ status_code=status.HTTP_403_FORBIDDEN,
268
+ content={"detail": "CSRF token expired or invalid"},
269
+ )
270
+
271
+ # Process request
272
+ response = await call_next(request)
273
+
274
+ # Optionally rotate token
275
+ if self.rotate_tokens:
276
+ new_token = generate_csrf_token(self.secret)
277
+ self._set_csrf_cookie(request, response, new_token)
278
+
279
+ return response
280
+
281
+ def _set_csrf_cookie(
282
+ self,
283
+ request: Request,
284
+ response: Response,
285
+ token: str,
286
+ ) -> None:
287
+ """Set the CSRF token cookie."""
288
+ is_https = request.url.scheme == "https"
289
+ is_production = os.getenv("ENVIRONMENT", "").lower() == "production"
290
+
291
+ response.set_cookie(
292
+ key=self.cookie_name,
293
+ value=token,
294
+ httponly=False, # Must be readable by JavaScript
295
+ secure=self.secure_cookies and (is_https or is_production),
296
+ samesite="lax", # Provides CSRF protection + allows top-level navigation
297
+ max_age=self.token_ttl,
298
+ path="/",
299
+ )
300
+
301
+
302
+ def create_csrf_middleware(
303
+ manifest_auth: Dict[str, Any],
304
+ secret: Optional[str] = None,
305
+ ) -> type:
306
+ """
307
+ Create CSRF middleware from manifest configuration.
308
+
309
+ Args:
310
+ manifest_auth: Auth section from manifest
311
+ secret: Optional CSRF secret (defaults to env var)
312
+
313
+ Returns:
314
+ Configured CSRFMiddleware class
315
+ """
316
+ csrf_config = manifest_auth.get("csrf_protection", True)
317
+
318
+ # Handle boolean or object config
319
+ if isinstance(csrf_config, bool):
320
+ if not csrf_config:
321
+ # Return a no-op middleware
322
+ class NoOpMiddleware(BaseHTTPMiddleware):
323
+ async def dispatch(self, request, call_next):
324
+ return await call_next(request)
325
+
326
+ return NoOpMiddleware
327
+
328
+ # Use defaults
329
+ exempt_routes = manifest_auth.get("public_routes", [])
330
+ rotate_tokens = False
331
+ token_ttl = DEFAULT_TOKEN_TTL
332
+ else:
333
+ # Object configuration
334
+ exempt_routes = csrf_config.get("exempt_routes", manifest_auth.get("public_routes", []))
335
+ rotate_tokens = csrf_config.get("rotate_tokens", False)
336
+ token_ttl = csrf_config.get("token_ttl", DEFAULT_TOKEN_TTL)
337
+
338
+ # Create configured middleware class
339
+ class ConfiguredCSRFMiddleware(CSRFMiddleware):
340
+ def __init__(self, app):
341
+ super().__init__(
342
+ app,
343
+ secret=secret or os.getenv("MDB_ENGINE_CSRF_SECRET"),
344
+ exempt_routes=exempt_routes,
345
+ rotate_tokens=rotate_tokens,
346
+ token_ttl=token_ttl,
347
+ )
348
+
349
+ return ConfiguredCSRFMiddleware
350
+
351
+
352
+ # Dependency for getting CSRF token in routes
353
+ def get_csrf_token(request: Request) -> str:
354
+ """
355
+ Get or generate CSRF token for use in templates.
356
+
357
+ Usage in FastAPI route:
358
+ @app.get("/form")
359
+ def form_page(csrf_token: str = Depends(get_csrf_token)):
360
+ return templates.TemplateResponse("form.html", {"csrf_token": csrf_token})
361
+ """
362
+ # Try to get from request state (set by middleware)
363
+ if hasattr(request.state, "csrf_token"):
364
+ return request.state.csrf_token
365
+
366
+ # Try to get from cookie
367
+ token = request.cookies.get(CSRF_COOKIE_NAME)
368
+ if token:
369
+ return token
370
+
371
+ # Generate new token
372
+ secret = os.getenv("MDB_ENGINE_CSRF_SECRET")
373
+ return generate_csrf_token(secret)
@@ -70,10 +70,7 @@ def _is_production_environment() -> bool:
70
70
  """Check if running in production environment."""
71
71
  import os
72
72
 
73
- return (
74
- os.getenv("G_NOME_ENV") == "production"
75
- or os.getenv("ENVIRONMENT") == "production"
76
- )
73
+ return os.getenv("G_NOME_ENV") == "production" or os.getenv("ENVIRONMENT") == "production"
77
74
 
78
75
 
79
76
  def _validate_https(request: Request) -> None:
@@ -189,17 +186,13 @@ def rate_limit_auth(
189
186
 
190
187
  # Use provided values or config values or defaults
191
188
  if max_attempts is None:
192
- max_attempts_val = (
193
- rate_limit_config.get("max_attempts") if rate_limit_config else 5
194
- )
189
+ max_attempts_val = rate_limit_config.get("max_attempts") if rate_limit_config else 5
195
190
  else:
196
191
  max_attempts_val = max_attempts
197
192
 
198
193
  if window_seconds is None:
199
194
  window_seconds_val = (
200
- rate_limit_config.get("window_seconds")
201
- if rate_limit_config
202
- else 300
195
+ rate_limit_config.get("window_seconds") if rate_limit_config else 300
203
196
  )
204
197
  else:
205
198
  window_seconds_val = window_seconds
@@ -14,9 +14,11 @@ from typing import Any, Dict, Mapping, Optional, Tuple
14
14
 
15
15
  import jwt
16
16
  from fastapi import Cookie, Depends, HTTPException, Request, status
17
+ from pymongo.errors import PyMongoError
17
18
 
18
19
  from ..exceptions import ConfigurationError
19
20
  from .jwt import decode_jwt_token, extract_token_metadata
21
+
20
22
  # Import from local modules
21
23
  from .provider import AuthorizationProvider
22
24
  from .session_manager import SessionManager
@@ -39,15 +41,20 @@ def _get_secret_key() -> str:
39
41
  if _SECRET_KEY_CACHE is not None:
40
42
  return _SECRET_KEY_CACHE
41
43
 
42
- secret_key = os.environ.get("FLASK_SECRET_KEY") or os.environ.get("SECRET_KEY")
44
+ secret_key = (
45
+ os.environ.get("FLASK_SECRET_KEY")
46
+ or os.environ.get("SECRET_KEY")
47
+ or os.environ.get("APP_SECRET_KEY")
48
+ )
43
49
 
44
50
  if not secret_key:
45
51
  raise ConfigurationError(
46
- "FLASK_SECRET_KEY environment variable is required for JWT token security. "
47
- "Set a strong secret key (minimum 32 characters, cryptographically random). "
48
- "Example: export FLASK_SECRET_KEY=$(python -c "
52
+ "SECRET_KEY environment variable is required for JWT token security. "
53
+ "Set FLASK_SECRET_KEY, SECRET_KEY, or APP_SECRET_KEY with a strong secret key "
54
+ "(minimum 32 characters, cryptographically random). "
55
+ "Example: export SECRET_KEY=$(python -c "
49
56
  "'import secrets; print(secrets.token_urlsafe(32))')",
50
- config_key="FLASK_SECRET_KEY",
57
+ config_key="SECRET_KEY",
51
58
  )
52
59
 
53
60
  if len(secret_key) < 32:
@@ -164,9 +171,7 @@ async def get_current_user(
164
171
  if blacklist:
165
172
  is_revoked = await blacklist.is_revoked(jti)
166
173
  if is_revoked:
167
- logger.info(
168
- f"get_current_user: Token {jti} is blacklisted (revoked)"
169
- )
174
+ logger.info(f"get_current_user: Token {jti} is blacklisted (revoked)")
170
175
  return None
171
176
 
172
177
  # Also check user-level revocation
@@ -174,9 +179,7 @@ async def get_current_user(
174
179
  if user_id:
175
180
  user_revoked = await blacklist.is_user_revoked(user_id)
176
181
  if user_revoked:
177
- logger.info(
178
- f"get_current_user: All tokens for user {user_id} are revoked"
179
- )
182
+ logger.info(f"get_current_user: All tokens for user {user_id} are revoked")
180
183
  return None
181
184
 
182
185
  payload = decode_jwt_token(token, str(SECRET_KEY))
@@ -184,9 +187,7 @@ async def get_current_user(
184
187
  # Verify token type (should be access token for backward compatibility, or no type)
185
188
  token_type = payload.get("type")
186
189
  if token_type and token_type not in ("access", None):
187
- logger.warning(
188
- f"get_current_user: Invalid token type '{token_type}' for access token"
189
- )
190
+ logger.warning(f"get_current_user: Invalid token type '{token_type}' for access token")
190
191
  return None
191
192
 
192
193
  logger.debug(
@@ -203,10 +204,12 @@ async def get_current_user(
203
204
  except (ValueError, TypeError):
204
205
  logger.exception("Validation error decoding JWT token")
205
206
  return None
206
- except Exception:
207
- logger.exception("Unexpected error decoding JWT token")
208
- # Re-raise unexpected errors for debugging
209
- raise
207
+ except PyMongoError:
208
+ logger.exception("Database error checking token blacklist")
209
+ return None
210
+ except (AttributeError, KeyError):
211
+ logger.exception("State access error in get_current_user")
212
+ return None
210
213
 
211
214
 
212
215
  async def get_current_user_from_request(request: Request) -> Optional[Dict[str, Any]]:
@@ -276,10 +279,12 @@ async def get_current_user_from_request(request: Request) -> Optional[Dict[str,
276
279
  except (ValueError, TypeError):
277
280
  logger.exception("Validation error decoding JWT token from request")
278
281
  return None
279
- except Exception:
280
- logger.exception("Unexpected error decoding JWT token from request")
281
- # Re-raise unexpected errors for debugging
282
- raise
282
+ except PyMongoError:
283
+ logger.exception("Database error checking token blacklist from request")
284
+ return None
285
+ except (AttributeError, KeyError):
286
+ logger.exception("State access error in get_current_user_from_request")
287
+ return None
283
288
 
284
289
 
285
290
  async def get_refresh_token(
@@ -314,9 +319,7 @@ async def get_refresh_token(
314
319
  if blacklist:
315
320
  is_revoked = await blacklist.is_revoked(jti)
316
321
  if is_revoked:
317
- logger.info(
318
- f"get_refresh_token: Refresh token {jti} is blacklisted"
319
- )
322
+ logger.info(f"get_refresh_token: Refresh token {jti} is blacklisted")
320
323
  return None
321
324
 
322
325
  payload = decode_jwt_token(refresh_token, str(SECRET_KEY))
@@ -350,13 +353,9 @@ async def get_refresh_token(
350
353
  if stored_fingerprint:
351
354
  from .utils import generate_session_fingerprint
352
355
 
353
- device_id = request.cookies.get("device_id") or payload.get(
354
- "device_id"
355
- )
356
+ device_id = request.cookies.get("device_id") or payload.get("device_id")
356
357
  if device_id:
357
- current_fingerprint = generate_session_fingerprint(
358
- request, device_id
359
- )
358
+ current_fingerprint = generate_session_fingerprint(request, device_id)
360
359
  if current_fingerprint != stored_fingerprint:
361
360
  logger.warning(
362
361
  f"get_refresh_token: Session fingerprint mismatch "
@@ -377,10 +376,12 @@ async def get_refresh_token(
377
376
  except (ValueError, TypeError):
378
377
  logger.exception("Validation error decoding refresh token")
379
378
  return None
380
- except Exception:
381
- logger.exception("Unexpected error decoding refresh token")
382
- # Re-raise unexpected errors for debugging
383
- raise
379
+ except PyMongoError:
380
+ logger.exception("Database error checking refresh token")
381
+ return None
382
+ except (AttributeError, KeyError):
383
+ logger.exception("State access error in get_refresh_token")
384
+ return None
384
385
 
385
386
 
386
387
  async def require_admin(
@@ -504,14 +505,14 @@ async def get_current_user_or_redirect(
504
505
  headers={"Location": redirect_url},
505
506
  detail="Not authenticated. Redirecting to login.",
506
507
  )
507
- except (ValueError, KeyError, AttributeError):
508
+ except (ValueError, KeyError, AttributeError) as e:
508
509
  logger.exception(
509
510
  f"Failed to generate login redirect URL for route '{login_route_name}'"
510
511
  )
511
512
  raise HTTPException(
512
513
  status_code=status.HTTP_401_UNAUTHORIZED,
513
514
  detail="Authentication required, but redirect failed.",
514
- )
515
+ ) from e
515
516
  return dict(user)
516
517
 
517
518
 
@@ -619,9 +620,7 @@ async def refresh_access_token(
619
620
  from ..config import TOKEN_ROTATION_ENABLED
620
621
  from .jwt import generate_token_pair
621
622
 
622
- user_id = refresh_token_payload.get("user_id") or refresh_token_payload.get(
623
- "email"
624
- )
623
+ user_id = refresh_token_payload.get("user_id") or refresh_token_payload.get("email")
625
624
  old_refresh_jti = refresh_token_payload.get("jti")
626
625
  device_id = refresh_token_payload.get("device_id")
627
626
 
@@ -653,9 +652,7 @@ async def refresh_access_token(
653
652
 
654
653
  device_id = device_id or request.cookies.get("device_id")
655
654
  if device_id:
656
- current_fingerprint = generate_session_fingerprint(
657
- request, device_id
658
- )
655
+ current_fingerprint = generate_session_fingerprint(request, device_id)
659
656
  if current_fingerprint != stored_fingerprint:
660
657
  logger.warning(
661
658
  f"refresh_access_token: Session fingerprint mismatch "
@@ -671,9 +668,7 @@ async def refresh_access_token(
671
668
 
672
669
  # Use existing device_id or generate new one
673
670
  if not device_id:
674
- device_id = (
675
- str(uuid.uuid4()) if not device_info else device_info.get("device_id")
676
- )
671
+ device_id = str(uuid.uuid4()) if not device_info else device_info.get("device_id")
677
672
 
678
673
  if device_info:
679
674
  device_info["device_id"] = device_id
@@ -741,7 +736,9 @@ async def refresh_access_token(
741
736
  except (ValueError, TypeError, jwt.InvalidTokenError):
742
737
  logger.exception("Validation error refreshing token")
743
738
  return None
744
- except Exception:
745
- logger.exception("Unexpected error refreshing token")
746
- # Re-raise unexpected errors for debugging
747
- raise
739
+ except PyMongoError:
740
+ logger.exception("Database error refreshing token")
741
+ return None
742
+ except (AttributeError, KeyError):
743
+ logger.exception("State access error refreshing token")
744
+ return None
@@ -19,7 +19,7 @@ async def initialize_token_management(app, db):
19
19
 
20
20
  Args:
21
21
  app: FastAPI application instance
22
- db: MongoDB database instance (Motor AsyncIOMotorDatabase)
22
+ db: Scoped MongoDB database instance (ScopedMongoWrapper)
23
23
 
24
24
  Example:
25
25
  from mdb_engine.auth.helpers import initialize_token_management
@@ -27,8 +27,8 @@ async def initialize_token_management(app, db):
27
27
 
28
28
  @app.on_event("startup")
29
29
  async def startup():
30
- # Get database from engine
31
- db = engine.get_database()
30
+ # Get scoped database from engine
31
+ db = engine.get_scoped_db("my_app")
32
32
 
33
33
  # Initialize token management
34
34
  await initialize_token_management(app, db)