mdb-engine 0.1.6__py3-none-any.whl → 0.4.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. mdb_engine/__init__.py +116 -11
  2. mdb_engine/auth/ARCHITECTURE.md +112 -0
  3. mdb_engine/auth/README.md +654 -11
  4. mdb_engine/auth/__init__.py +136 -29
  5. mdb_engine/auth/audit.py +592 -0
  6. mdb_engine/auth/base.py +252 -0
  7. mdb_engine/auth/casbin_factory.py +265 -70
  8. mdb_engine/auth/config_defaults.py +5 -5
  9. mdb_engine/auth/config_helpers.py +19 -18
  10. mdb_engine/auth/cookie_utils.py +12 -16
  11. mdb_engine/auth/csrf.py +483 -0
  12. mdb_engine/auth/decorators.py +10 -16
  13. mdb_engine/auth/dependencies.py +69 -71
  14. mdb_engine/auth/helpers.py +3 -3
  15. mdb_engine/auth/integration.py +61 -88
  16. mdb_engine/auth/jwt.py +11 -15
  17. mdb_engine/auth/middleware.py +79 -35
  18. mdb_engine/auth/oso_factory.py +21 -41
  19. mdb_engine/auth/provider.py +270 -171
  20. mdb_engine/auth/rate_limiter.py +505 -0
  21. mdb_engine/auth/restrictions.py +21 -36
  22. mdb_engine/auth/session_manager.py +24 -41
  23. mdb_engine/auth/shared_middleware.py +977 -0
  24. mdb_engine/auth/shared_users.py +775 -0
  25. mdb_engine/auth/token_lifecycle.py +10 -12
  26. mdb_engine/auth/token_store.py +17 -32
  27. mdb_engine/auth/users.py +99 -159
  28. mdb_engine/auth/utils.py +236 -42
  29. mdb_engine/cli/commands/generate.py +546 -10
  30. mdb_engine/cli/commands/validate.py +3 -7
  31. mdb_engine/cli/utils.py +7 -7
  32. mdb_engine/config.py +13 -28
  33. mdb_engine/constants.py +65 -0
  34. mdb_engine/core/README.md +117 -6
  35. mdb_engine/core/__init__.py +39 -7
  36. mdb_engine/core/app_registration.py +31 -50
  37. mdb_engine/core/app_secrets.py +289 -0
  38. mdb_engine/core/connection.py +20 -12
  39. mdb_engine/core/encryption.py +222 -0
  40. mdb_engine/core/engine.py +2862 -115
  41. mdb_engine/core/index_management.py +12 -16
  42. mdb_engine/core/manifest.py +628 -204
  43. mdb_engine/core/ray_integration.py +436 -0
  44. mdb_engine/core/seeding.py +13 -21
  45. mdb_engine/core/service_initialization.py +20 -30
  46. mdb_engine/core/types.py +40 -43
  47. mdb_engine/database/README.md +140 -17
  48. mdb_engine/database/__init__.py +17 -6
  49. mdb_engine/database/abstraction.py +37 -50
  50. mdb_engine/database/connection.py +51 -30
  51. mdb_engine/database/query_validator.py +367 -0
  52. mdb_engine/database/resource_limiter.py +204 -0
  53. mdb_engine/database/scoped_wrapper.py +747 -237
  54. mdb_engine/dependencies.py +427 -0
  55. mdb_engine/di/__init__.py +34 -0
  56. mdb_engine/di/container.py +247 -0
  57. mdb_engine/di/providers.py +206 -0
  58. mdb_engine/di/scopes.py +139 -0
  59. mdb_engine/embeddings/README.md +54 -24
  60. mdb_engine/embeddings/__init__.py +31 -24
  61. mdb_engine/embeddings/dependencies.py +38 -155
  62. mdb_engine/embeddings/service.py +78 -75
  63. mdb_engine/exceptions.py +104 -12
  64. mdb_engine/indexes/README.md +30 -13
  65. mdb_engine/indexes/__init__.py +1 -0
  66. mdb_engine/indexes/helpers.py +11 -11
  67. mdb_engine/indexes/manager.py +59 -123
  68. mdb_engine/memory/README.md +95 -4
  69. mdb_engine/memory/__init__.py +1 -2
  70. mdb_engine/memory/service.py +363 -1168
  71. mdb_engine/observability/README.md +4 -2
  72. mdb_engine/observability/__init__.py +26 -9
  73. mdb_engine/observability/health.py +17 -17
  74. mdb_engine/observability/logging.py +10 -10
  75. mdb_engine/observability/metrics.py +40 -19
  76. mdb_engine/repositories/__init__.py +34 -0
  77. mdb_engine/repositories/base.py +325 -0
  78. mdb_engine/repositories/mongo.py +233 -0
  79. mdb_engine/repositories/unit_of_work.py +166 -0
  80. mdb_engine/routing/README.md +1 -1
  81. mdb_engine/routing/__init__.py +1 -3
  82. mdb_engine/routing/websockets.py +41 -75
  83. mdb_engine/utils/__init__.py +3 -1
  84. mdb_engine/utils/mongo.py +117 -0
  85. mdb_engine-0.4.12.dist-info/METADATA +492 -0
  86. mdb_engine-0.4.12.dist-info/RECORD +97 -0
  87. {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/WHEEL +1 -1
  88. mdb_engine-0.1.6.dist-info/METADATA +0 -213
  89. mdb_engine-0.1.6.dist-info/RECORD +0 -75
  90. {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/entry_points.txt +0 -0
  91. {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/licenses/LICENSE +0 -0
  92. {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,483 @@
1
+ """
2
+ CSRF Protection Middleware
3
+
4
+ Implements the Double-Submit Cookie pattern for Cross-Site Request Forgery protection.
5
+ Auto-enabled for shared auth mode, with manifest-configurable options.
6
+
7
+ This module is part of MDB_ENGINE - MongoDB Engine.
8
+
9
+ Security Features:
10
+ - Double-submit cookie pattern (industry standard)
11
+ - Cryptographically secure token generation
12
+ - Configurable exempt routes for APIs
13
+ - SameSite cookie attribute for additional protection
14
+ - Token rotation on each request (optional)
15
+
16
+ Usage:
17
+ # Auto-enabled for shared auth mode in engine.create_app()
18
+
19
+ # Or manual usage:
20
+ from mdb_engine.auth.csrf import CSRFMiddleware
21
+ app.add_middleware(CSRFMiddleware, exempt_routes=["/api/*"])
22
+
23
+ # In templates, include the token:
24
+ <input type="hidden" name="csrf_token" value="{{ csrf_token }}">
25
+
26
+ # Or in JavaScript:
27
+ fetch('/endpoint', {
28
+ headers: {'X-CSRF-Token': getCookie('csrf_token')}
29
+ })
30
+ """
31
+
32
+ import fnmatch
33
+ import hashlib
34
+ import hmac
35
+ import logging
36
+ import os
37
+ import secrets
38
+ import time
39
+ from collections.abc import Awaitable, Callable
40
+ from typing import Any
41
+
42
+ from fastapi import Request, Response, status
43
+ from fastapi.responses import JSONResponse
44
+ from starlette.middleware.base import BaseHTTPMiddleware
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+ # Token settings
49
+ CSRF_TOKEN_LENGTH = 32 # 256 bits
50
+ CSRF_COOKIE_NAME = "csrf_token"
51
+ CSRF_HEADER_NAME = "X-CSRF-Token"
52
+ CSRF_FORM_FIELD = "csrf_token"
53
+ DEFAULT_TOKEN_TTL = 3600 # 1 hour
54
+
55
+ # Methods that require CSRF validation
56
+ UNSAFE_METHODS = {"POST", "PUT", "DELETE", "PATCH"}
57
+
58
+
59
+ def generate_csrf_token(secret: str | None = None) -> str:
60
+ """
61
+ Generate a cryptographically secure CSRF token.
62
+
63
+ Args:
64
+ secret: Optional secret for HMAC signing (adds tamper detection)
65
+
66
+ Returns:
67
+ URL-safe base64 encoded token
68
+ """
69
+ raw_token = secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
70
+
71
+ if secret:
72
+ # Add HMAC signature for tamper detection
73
+ timestamp = str(int(time.time()))
74
+ message = f"{raw_token}:{timestamp}"
75
+ signature = hmac.new(secret.encode(), message.encode(), hashlib.sha256).hexdigest()[:16]
76
+ return f"{raw_token}:{timestamp}:{signature}"
77
+
78
+ return raw_token
79
+
80
+
81
+ def validate_csrf_token(
82
+ token: str,
83
+ secret: str | None = None,
84
+ max_age: int = DEFAULT_TOKEN_TTL,
85
+ ) -> bool:
86
+ """
87
+ Validate a CSRF token.
88
+
89
+ Args:
90
+ token: The token to validate
91
+ secret: Optional secret for HMAC verification
92
+ max_age: Maximum token age in seconds
93
+
94
+ Returns:
95
+ True if valid, False otherwise
96
+ """
97
+ if not token:
98
+ return False
99
+
100
+ if secret and ":" in token:
101
+ # Verify HMAC-signed token
102
+ try:
103
+ parts = token.split(":")
104
+ if len(parts) != 3:
105
+ return False
106
+
107
+ raw_token, timestamp_str, signature = parts
108
+ timestamp = int(timestamp_str)
109
+
110
+ # Check age
111
+ if time.time() - timestamp > max_age:
112
+ logger.debug("CSRF token expired")
113
+ return False
114
+
115
+ # Verify signature
116
+ message = f"{raw_token}:{timestamp_str}"
117
+ expected_sig = hmac.new(secret.encode(), message.encode(), hashlib.sha256).hexdigest()[
118
+ :16
119
+ ]
120
+
121
+ if not hmac.compare_digest(signature, expected_sig):
122
+ logger.warning("CSRF token signature mismatch")
123
+ return False
124
+
125
+ return True
126
+ except (ValueError, IndexError) as e:
127
+ logger.warning(f"CSRF token validation error: {e}")
128
+ return False
129
+
130
+ # Simple token validation (just check it exists and has reasonable length)
131
+ return len(token) >= CSRF_TOKEN_LENGTH
132
+
133
+
134
+ class CSRFMiddleware(BaseHTTPMiddleware):
135
+ """
136
+ CSRF Protection Middleware using Double-Submit Cookie pattern.
137
+
138
+ The double-submit cookie pattern works by:
139
+ 1. Setting a CSRF token in a cookie (with HttpOnly=False so JS can read it)
140
+ 2. Requiring the same token in a header or form field
141
+ 3. Since attackers can't read cookies from other domains, they can't forge requests
142
+
143
+ Additional protection from SameSite=Lax cookies prevents the browser from
144
+ sending cookies on cross-site requests.
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ app,
150
+ secret: str | None = None,
151
+ exempt_routes: list[str] | None = None,
152
+ exempt_methods: set[str] | None = None,
153
+ cookie_name: str = CSRF_COOKIE_NAME,
154
+ header_name: str = CSRF_HEADER_NAME,
155
+ form_field: str = CSRF_FORM_FIELD,
156
+ token_ttl: int = DEFAULT_TOKEN_TTL,
157
+ rotate_tokens: bool = False,
158
+ secure_cookies: bool = True,
159
+ ):
160
+ """
161
+ Initialize CSRF middleware.
162
+
163
+ Args:
164
+ app: FastAPI application
165
+ secret: Secret for HMAC token signing (recommended for production)
166
+ exempt_routes: Routes exempt from CSRF (supports wildcards: /api/*)
167
+ exempt_methods: HTTP methods exempt from CSRF (default: safe methods)
168
+ cookie_name: Name of the CSRF cookie
169
+ header_name: Name of the CSRF header
170
+ form_field: Name of the CSRF form field
171
+ token_ttl: Token time-to-live in seconds
172
+ rotate_tokens: Rotate token on each request (more secure, less convenient)
173
+ secure_cookies: Use Secure cookie flag (auto-detect HTTPS)
174
+ """
175
+ super().__init__(app)
176
+ self.secret = secret or os.getenv("MDB_ENGINE_CSRF_SECRET")
177
+ self.exempt_routes = exempt_routes or []
178
+ self.exempt_methods = exempt_methods or {"GET", "HEAD", "OPTIONS", "TRACE"}
179
+ self.cookie_name = cookie_name
180
+ self.header_name = header_name
181
+ self.form_field = form_field
182
+ self.token_ttl = token_ttl
183
+ self.rotate_tokens = rotate_tokens
184
+ self.secure_cookies = secure_cookies
185
+
186
+ logger.info(
187
+ f"CSRFMiddleware initialized (exempt_routes={self.exempt_routes}, "
188
+ f"rotate_tokens={rotate_tokens})"
189
+ )
190
+
191
+ def _is_exempt(self, path: str) -> bool:
192
+ """Check if a path is exempt from CSRF validation."""
193
+ # WebSocket upgrade requests are handled separately in dispatch()
194
+ # Don't exempt them here - they need origin validation
195
+ for pattern in self.exempt_routes:
196
+ if fnmatch.fnmatch(path, pattern):
197
+ return True
198
+ return False
199
+
200
+ def _is_websocket_upgrade(self, request: Request) -> bool:
201
+ """Check if request is a WebSocket upgrade request."""
202
+ upgrade_header = request.headers.get("upgrade", "").lower()
203
+ return upgrade_header == "websocket"
204
+
205
+ def _get_allowed_origins(self, request: Request) -> list[str]:
206
+ """
207
+ Get allowed origins from app state (CORS config) or use request host as fallback.
208
+
209
+ For multi-app setups, checks parent app's CORS config first (since WebSocket routes
210
+ are registered on parent app), then falls back to request host.
211
+ """
212
+ try:
213
+ # For WebSocket routes on parent app, request.app is parent app
214
+ # Parent app has merged CORS config from all child apps
215
+ cors_config = getattr(request.app.state, "cors_config", None)
216
+ if cors_config and cors_config.get("allow_origins"):
217
+ origins = cors_config["allow_origins"]
218
+ if origins:
219
+ return origins if isinstance(origins, list) else [origins]
220
+ except (AttributeError, TypeError, KeyError) as e:
221
+ logger.debug(f"Could not read CORS config from app.state: {e}")
222
+
223
+ # Fallback: Check if this is a multi-app setup and try to find mounted app's CORS config
224
+ try:
225
+ if hasattr(request.app.state, "mounted_apps"):
226
+ # This is a parent app in multi-app setup
227
+ # Try to find which mounted app this request is for
228
+ path = request.url.path
229
+ mounted_apps = request.app.state.mounted_apps
230
+
231
+ # Find matching mounted app by path prefix
232
+ for app_info in mounted_apps:
233
+ path_prefix = app_info.get("path_prefix", "")
234
+ if path_prefix and path.startswith(path_prefix):
235
+ # Try to get child app's CORS config if available
236
+ # Note: Child app might not be directly accessible, so we rely on
237
+ # parent app's merged CORS config (set during mounting)
238
+ break
239
+ except (AttributeError, TypeError, KeyError):
240
+ pass
241
+
242
+ # Final fallback: Use request host
243
+ try:
244
+ host = request.url.hostname
245
+ scheme = request.url.scheme
246
+ port = request.url.port
247
+ if port and port not in [80, 443]:
248
+ origin = f"{scheme}://{host}:{port}"
249
+ else:
250
+ origin = f"{scheme}://{host}"
251
+ return [origin]
252
+ except (AttributeError, TypeError):
253
+ # Return empty list if we can't determine origin (will reject)
254
+ return []
255
+
256
+ def _validate_websocket_origin(self, request: Request) -> bool:
257
+ """
258
+ Validate Origin header for WebSocket upgrade requests.
259
+
260
+ Primary defense against Cross-Site WebSocket Hijacking (CSWSH).
261
+ Returns True if Origin is valid, False otherwise.
262
+ """
263
+ origin = request.headers.get("origin")
264
+ if not origin:
265
+ logger.warning(f"WebSocket upgrade missing Origin header: {request.url.path}")
266
+ return False
267
+
268
+ allowed_origins = self._get_allowed_origins(request)
269
+
270
+ for allowed in allowed_origins:
271
+ if allowed == "*":
272
+ logger.warning(
273
+ "WebSocket Origin validation using wildcard '*' - "
274
+ "not recommended for production"
275
+ )
276
+ return True
277
+ if origin == allowed or origin.rstrip("/") == allowed.rstrip("/"):
278
+ return True
279
+
280
+ cors_config = getattr(request.app.state, "cors_config", None)
281
+ cors_enabled = cors_config.get("enabled", False) if cors_config else False
282
+ logger.warning(
283
+ f"WebSocket upgrade rejected - invalid Origin: {origin} "
284
+ f"(allowed: {allowed_origins}, app: {getattr(request.app, 'title', 'unknown')}, "
285
+ f"path: {request.url.path}, CORS enabled: {cors_enabled}, "
286
+ f"has_cors_config: {hasattr(request.app.state, 'cors_config')})"
287
+ )
288
+ return False
289
+
290
+ async def dispatch(
291
+ self,
292
+ request: Request,
293
+ call_next: Callable[[Request], Awaitable[Response]],
294
+ ) -> Response:
295
+ """
296
+ Process request through CSRF middleware.
297
+ """
298
+ path = request.url.path
299
+ method = request.method
300
+
301
+ # CRITICAL: Handle WebSocket upgrade requests BEFORE other CSRF checks
302
+ # WebSocket upgrades don't use CSRF tokens, but need origin validation
303
+ if self._is_websocket_upgrade(request):
304
+ # Validate origin for WebSocket connections (CSWSH protection)
305
+ if not self._validate_websocket_origin(request):
306
+ logger.warning(
307
+ f"WebSocket origin validation failed for {path}: "
308
+ f"origin={request.headers.get('origin')}, "
309
+ f"allowed={self._get_allowed_origins(request)}"
310
+ )
311
+ return JSONResponse(
312
+ status_code=status.HTTP_403_FORBIDDEN,
313
+ content={"detail": "Invalid origin for WebSocket connection"},
314
+ )
315
+ # Origin validated - allow WebSocket upgrade to proceed
316
+ # No CSRF token check needed for WebSocket upgrades
317
+ return await call_next(request)
318
+
319
+ if self._is_exempt(path):
320
+ return await call_next(request)
321
+
322
+ # Skip safe methods
323
+ if method in self.exempt_methods:
324
+ # Generate and set token for GET requests (for forms)
325
+ response = await call_next(request)
326
+
327
+ # Set CSRF token cookie if not present
328
+ if not request.cookies.get(self.cookie_name):
329
+ token = generate_csrf_token(self.secret)
330
+ self._set_csrf_cookie(request, response, token)
331
+
332
+ # Make token available in request state for templates
333
+ request.state.csrf_token = request.cookies.get(self.cookie_name) or generate_csrf_token(
334
+ self.secret
335
+ )
336
+
337
+ return response
338
+
339
+ # Validate CSRF token for unsafe methods
340
+ cookie_token = request.cookies.get(self.cookie_name)
341
+ if not cookie_token:
342
+ logger.warning(f"CSRF cookie missing for {method} {path}")
343
+ return JSONResponse(
344
+ status_code=status.HTTP_403_FORBIDDEN,
345
+ content={"detail": "CSRF token missing"},
346
+ )
347
+
348
+ # Get token from header or form
349
+ header_token = request.headers.get(self.header_name)
350
+ form_token = None
351
+
352
+ # Note: Form-based CSRF token extraction not implemented.
353
+ # For now, we rely on header-based CSRF for all requests.
354
+ # TODO: Implement request.form() based extraction if needed.
355
+
356
+ submitted_token = header_token or form_token
357
+
358
+ if not submitted_token:
359
+ logger.warning(f"CSRF token not submitted for {method} {path}")
360
+ return JSONResponse(
361
+ status_code=status.HTTP_403_FORBIDDEN,
362
+ content={"detail": "CSRF token not provided in header or form"},
363
+ )
364
+
365
+ # Compare tokens (constant-time comparison)
366
+ if not hmac.compare_digest(cookie_token, submitted_token):
367
+ logger.warning(f"CSRF token mismatch for {method} {path}")
368
+ return JSONResponse(
369
+ status_code=status.HTTP_403_FORBIDDEN,
370
+ content={"detail": "CSRF token invalid"},
371
+ )
372
+
373
+ # Validate token (check signature if secret is used)
374
+ if self.secret and not validate_csrf_token(cookie_token, self.secret, self.token_ttl):
375
+ logger.warning(f"CSRF token validation failed for {method} {path}")
376
+ return JSONResponse(
377
+ status_code=status.HTTP_403_FORBIDDEN,
378
+ content={"detail": "CSRF token expired or invalid"},
379
+ )
380
+
381
+ # Process request
382
+ response = await call_next(request)
383
+
384
+ # Optionally rotate token
385
+ if self.rotate_tokens:
386
+ new_token = generate_csrf_token(self.secret)
387
+ self._set_csrf_cookie(request, response, new_token)
388
+
389
+ return response
390
+
391
+ def _set_csrf_cookie(
392
+ self,
393
+ request: Request,
394
+ response: Response,
395
+ token: str,
396
+ ) -> None:
397
+ """Set the CSRF token cookie."""
398
+ is_https = request.url.scheme == "https"
399
+ is_production = os.getenv("ENVIRONMENT", "").lower() == "production"
400
+
401
+ response.set_cookie(
402
+ key=self.cookie_name,
403
+ value=token,
404
+ httponly=False, # Must be readable by JavaScript
405
+ secure=self.secure_cookies and (is_https or is_production),
406
+ samesite="lax", # Provides CSRF protection + allows top-level navigation
407
+ max_age=self.token_ttl,
408
+ path="/",
409
+ )
410
+
411
+
412
+ def create_csrf_middleware(
413
+ manifest_auth: dict[str, Any],
414
+ secret: str | None = None,
415
+ ) -> type:
416
+ """
417
+ Create CSRF middleware from manifest configuration.
418
+
419
+ Args:
420
+ manifest_auth: Auth section from manifest
421
+ secret: Optional CSRF secret (defaults to env var)
422
+
423
+ Returns:
424
+ Configured CSRFMiddleware class
425
+ """
426
+ csrf_config = manifest_auth.get("csrf_protection", True)
427
+
428
+ # Handle boolean or object config
429
+ if isinstance(csrf_config, bool):
430
+ if not csrf_config:
431
+ # Return a no-op middleware
432
+ class NoOpMiddleware(BaseHTTPMiddleware):
433
+ async def dispatch(self, request, call_next):
434
+ return await call_next(request)
435
+
436
+ return NoOpMiddleware
437
+
438
+ # Use defaults
439
+ exempt_routes = manifest_auth.get("public_routes", [])
440
+ rotate_tokens = False
441
+ token_ttl = DEFAULT_TOKEN_TTL
442
+ else:
443
+ # Object configuration
444
+ exempt_routes = csrf_config.get("exempt_routes", manifest_auth.get("public_routes", []))
445
+ rotate_tokens = csrf_config.get("rotate_tokens", False)
446
+ token_ttl = csrf_config.get("token_ttl", DEFAULT_TOKEN_TTL)
447
+
448
+ # Create configured middleware class
449
+ class ConfiguredCSRFMiddleware(CSRFMiddleware):
450
+ def __init__(self, app):
451
+ super().__init__(
452
+ app,
453
+ secret=secret or os.getenv("MDB_ENGINE_CSRF_SECRET"),
454
+ exempt_routes=exempt_routes,
455
+ rotate_tokens=rotate_tokens,
456
+ token_ttl=token_ttl,
457
+ )
458
+
459
+ return ConfiguredCSRFMiddleware
460
+
461
+
462
+ # Dependency for getting CSRF token in routes
463
+ def get_csrf_token(request: Request) -> str:
464
+ """
465
+ Get or generate CSRF token for use in templates.
466
+
467
+ Usage in FastAPI route:
468
+ @app.get("/form")
469
+ def form_page(csrf_token: str = Depends(get_csrf_token)):
470
+ return templates.TemplateResponse("form.html", {"csrf_token": csrf_token})
471
+ """
472
+ # Try to get from request state (set by middleware)
473
+ if hasattr(request.state, "csrf_token"):
474
+ return request.state.csrf_token
475
+
476
+ # Try to get from cookie
477
+ token = request.cookies.get(CSRF_COOKIE_NAME)
478
+ if token:
479
+ return token
480
+
481
+ # Generate new token
482
+ secret = os.getenv("MDB_ENGINE_CSRF_SECRET")
483
+ return generate_csrf_token(secret)
@@ -9,8 +9,9 @@ This module is part of MDB_ENGINE - MongoDB Engine.
9
9
  import logging
10
10
  import time
11
11
  from collections import defaultdict
12
+ from collections.abc import Awaitable, Callable
12
13
  from functools import wraps
13
- from typing import Any, Awaitable, Callable, Dict, Optional
14
+ from typing import Any
14
15
 
15
16
  from fastapi import HTTPException, Request, status
16
17
  from fastapi.responses import RedirectResponse
@@ -20,7 +21,7 @@ from .dependencies import get_current_user_from_request
20
21
  logger = logging.getLogger(__name__)
21
22
 
22
23
  # Rate limiting storage (in-memory, can be replaced with Redis for distributed systems)
23
- _rate_limit_storage: Dict[str, Dict[str, Any]] = defaultdict(dict)
24
+ _rate_limit_storage: dict[str, dict[str, Any]] = defaultdict(dict)
24
25
 
25
26
 
26
27
  def require_auth(redirect_to: str = "/login"):
@@ -70,10 +71,7 @@ def _is_production_environment() -> bool:
70
71
  """Check if running in production environment."""
71
72
  import os
72
73
 
73
- return (
74
- os.getenv("G_NOME_ENV") == "production"
75
- or os.getenv("ENVIRONMENT") == "production"
76
- )
74
+ return os.getenv("G_NOME_ENV") == "production" or os.getenv("ENVIRONMENT") == "production"
77
75
 
78
76
 
79
77
  def _validate_https(request: Request) -> None:
@@ -85,7 +83,7 @@ def _validate_https(request: Request) -> None:
85
83
  )
86
84
 
87
85
 
88
- async def _get_csrf_token(request: Request) -> Optional[str]:
86
+ async def _get_csrf_token(request: Request) -> str | None:
89
87
  """Extract CSRF token from request headers or form data."""
90
88
  csrf_token = request.headers.get("X-CSRF-Token")
91
89
  if csrf_token:
@@ -154,8 +152,8 @@ def token_security(enforce_https: bool = True, check_csrf: bool = True):
154
152
 
155
153
  def rate_limit_auth(
156
154
  endpoint: str = "login",
157
- max_attempts: Optional[int] = None,
158
- window_seconds: Optional[int] = None,
155
+ max_attempts: int | None = None,
156
+ window_seconds: int | None = None,
159
157
  ):
160
158
  """
161
159
  Rate limiting decorator for auth endpoints.
@@ -189,17 +187,13 @@ def rate_limit_auth(
189
187
 
190
188
  # Use provided values or config values or defaults
191
189
  if max_attempts is None:
192
- max_attempts_val = (
193
- rate_limit_config.get("max_attempts") if rate_limit_config else 5
194
- )
190
+ max_attempts_val = rate_limit_config.get("max_attempts") if rate_limit_config else 5
195
191
  else:
196
192
  max_attempts_val = max_attempts
197
193
 
198
194
  if window_seconds is None:
199
195
  window_seconds_val = (
200
- rate_limit_config.get("window_seconds")
201
- if rate_limit_config
202
- else 300
196
+ rate_limit_config.get("window_seconds") if rate_limit_config else 300
203
197
  )
204
198
  else:
205
199
  window_seconds_val = window_seconds
@@ -249,7 +243,7 @@ def rate_limit_auth(
249
243
  return decorator
250
244
 
251
245
 
252
- def auto_token_setup(func: Optional[Callable[..., Awaitable[Any]]] = None):
246
+ def auto_token_setup(func: Callable[..., Awaitable[Any]] | None = None):
253
247
  """
254
248
  Decorator to automatically set up tokens on successful login/register.
255
249