mdb-engine 0.1.6__py3-none-any.whl → 0.4.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/__init__.py +116 -11
- mdb_engine/auth/ARCHITECTURE.md +112 -0
- mdb_engine/auth/README.md +654 -11
- mdb_engine/auth/__init__.py +136 -29
- mdb_engine/auth/audit.py +592 -0
- mdb_engine/auth/base.py +252 -0
- mdb_engine/auth/casbin_factory.py +265 -70
- mdb_engine/auth/config_defaults.py +5 -5
- mdb_engine/auth/config_helpers.py +19 -18
- mdb_engine/auth/cookie_utils.py +12 -16
- mdb_engine/auth/csrf.py +483 -0
- mdb_engine/auth/decorators.py +10 -16
- mdb_engine/auth/dependencies.py +69 -71
- mdb_engine/auth/helpers.py +3 -3
- mdb_engine/auth/integration.py +61 -88
- mdb_engine/auth/jwt.py +11 -15
- mdb_engine/auth/middleware.py +79 -35
- mdb_engine/auth/oso_factory.py +21 -41
- mdb_engine/auth/provider.py +270 -171
- mdb_engine/auth/rate_limiter.py +505 -0
- mdb_engine/auth/restrictions.py +21 -36
- mdb_engine/auth/session_manager.py +24 -41
- mdb_engine/auth/shared_middleware.py +977 -0
- mdb_engine/auth/shared_users.py +775 -0
- mdb_engine/auth/token_lifecycle.py +10 -12
- mdb_engine/auth/token_store.py +17 -32
- mdb_engine/auth/users.py +99 -159
- mdb_engine/auth/utils.py +236 -42
- mdb_engine/cli/commands/generate.py +546 -10
- mdb_engine/cli/commands/validate.py +3 -7
- mdb_engine/cli/utils.py +7 -7
- mdb_engine/config.py +13 -28
- mdb_engine/constants.py +65 -0
- mdb_engine/core/README.md +117 -6
- mdb_engine/core/__init__.py +39 -7
- mdb_engine/core/app_registration.py +31 -50
- mdb_engine/core/app_secrets.py +289 -0
- mdb_engine/core/connection.py +20 -12
- mdb_engine/core/encryption.py +222 -0
- mdb_engine/core/engine.py +2862 -115
- mdb_engine/core/index_management.py +12 -16
- mdb_engine/core/manifest.py +628 -204
- mdb_engine/core/ray_integration.py +436 -0
- mdb_engine/core/seeding.py +13 -21
- mdb_engine/core/service_initialization.py +20 -30
- mdb_engine/core/types.py +40 -43
- mdb_engine/database/README.md +140 -17
- mdb_engine/database/__init__.py +17 -6
- mdb_engine/database/abstraction.py +37 -50
- mdb_engine/database/connection.py +51 -30
- mdb_engine/database/query_validator.py +367 -0
- mdb_engine/database/resource_limiter.py +204 -0
- mdb_engine/database/scoped_wrapper.py +747 -237
- mdb_engine/dependencies.py +427 -0
- mdb_engine/di/__init__.py +34 -0
- mdb_engine/di/container.py +247 -0
- mdb_engine/di/providers.py +206 -0
- mdb_engine/di/scopes.py +139 -0
- mdb_engine/embeddings/README.md +54 -24
- mdb_engine/embeddings/__init__.py +31 -24
- mdb_engine/embeddings/dependencies.py +38 -155
- mdb_engine/embeddings/service.py +78 -75
- mdb_engine/exceptions.py +104 -12
- mdb_engine/indexes/README.md +30 -13
- mdb_engine/indexes/__init__.py +1 -0
- mdb_engine/indexes/helpers.py +11 -11
- mdb_engine/indexes/manager.py +59 -123
- mdb_engine/memory/README.md +95 -4
- mdb_engine/memory/__init__.py +1 -2
- mdb_engine/memory/service.py +363 -1168
- mdb_engine/observability/README.md +4 -2
- mdb_engine/observability/__init__.py +26 -9
- mdb_engine/observability/health.py +17 -17
- mdb_engine/observability/logging.py +10 -10
- mdb_engine/observability/metrics.py +40 -19
- mdb_engine/repositories/__init__.py +34 -0
- mdb_engine/repositories/base.py +325 -0
- mdb_engine/repositories/mongo.py +233 -0
- mdb_engine/repositories/unit_of_work.py +166 -0
- mdb_engine/routing/README.md +1 -1
- mdb_engine/routing/__init__.py +1 -3
- mdb_engine/routing/websockets.py +41 -75
- mdb_engine/utils/__init__.py +3 -1
- mdb_engine/utils/mongo.py +117 -0
- mdb_engine-0.4.12.dist-info/METADATA +492 -0
- mdb_engine-0.4.12.dist-info/RECORD +97 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/WHEEL +1 -1
- mdb_engine-0.1.6.dist-info/METADATA +0 -213
- mdb_engine-0.1.6.dist-info/RECORD +0 -75
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.1.6.dist-info → mdb_engine-0.4.12.dist-info}/top_level.txt +0 -0
mdb_engine/auth/csrf.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CSRF Protection Middleware
|
|
3
|
+
|
|
4
|
+
Implements the Double-Submit Cookie pattern for Cross-Site Request Forgery protection.
|
|
5
|
+
Auto-enabled for shared auth mode, with manifest-configurable options.
|
|
6
|
+
|
|
7
|
+
This module is part of MDB_ENGINE - MongoDB Engine.
|
|
8
|
+
|
|
9
|
+
Security Features:
|
|
10
|
+
- Double-submit cookie pattern (industry standard)
|
|
11
|
+
- Cryptographically secure token generation
|
|
12
|
+
- Configurable exempt routes for APIs
|
|
13
|
+
- SameSite cookie attribute for additional protection
|
|
14
|
+
- Token rotation on each request (optional)
|
|
15
|
+
|
|
16
|
+
Usage:
|
|
17
|
+
# Auto-enabled for shared auth mode in engine.create_app()
|
|
18
|
+
|
|
19
|
+
# Or manual usage:
|
|
20
|
+
from mdb_engine.auth.csrf import CSRFMiddleware
|
|
21
|
+
app.add_middleware(CSRFMiddleware, exempt_routes=["/api/*"])
|
|
22
|
+
|
|
23
|
+
# In templates, include the token:
|
|
24
|
+
<input type="hidden" name="csrf_token" value="{{ csrf_token }}">
|
|
25
|
+
|
|
26
|
+
# Or in JavaScript:
|
|
27
|
+
fetch('/endpoint', {
|
|
28
|
+
headers: {'X-CSRF-Token': getCookie('csrf_token')}
|
|
29
|
+
})
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
import fnmatch
|
|
33
|
+
import hashlib
|
|
34
|
+
import hmac
|
|
35
|
+
import logging
|
|
36
|
+
import os
|
|
37
|
+
import secrets
|
|
38
|
+
import time
|
|
39
|
+
from collections.abc import Awaitable, Callable
|
|
40
|
+
from typing import Any
|
|
41
|
+
|
|
42
|
+
from fastapi import Request, Response, status
|
|
43
|
+
from fastapi.responses import JSONResponse
|
|
44
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
45
|
+
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
# Token settings
|
|
49
|
+
CSRF_TOKEN_LENGTH = 32 # 256 bits
|
|
50
|
+
CSRF_COOKIE_NAME = "csrf_token"
|
|
51
|
+
CSRF_HEADER_NAME = "X-CSRF-Token"
|
|
52
|
+
CSRF_FORM_FIELD = "csrf_token"
|
|
53
|
+
DEFAULT_TOKEN_TTL = 3600 # 1 hour
|
|
54
|
+
|
|
55
|
+
# Methods that require CSRF validation
|
|
56
|
+
UNSAFE_METHODS = {"POST", "PUT", "DELETE", "PATCH"}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def generate_csrf_token(secret: str | None = None) -> str:
|
|
60
|
+
"""
|
|
61
|
+
Generate a cryptographically secure CSRF token.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
secret: Optional secret for HMAC signing (adds tamper detection)
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
URL-safe base64 encoded token
|
|
68
|
+
"""
|
|
69
|
+
raw_token = secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
|
70
|
+
|
|
71
|
+
if secret:
|
|
72
|
+
# Add HMAC signature for tamper detection
|
|
73
|
+
timestamp = str(int(time.time()))
|
|
74
|
+
message = f"{raw_token}:{timestamp}"
|
|
75
|
+
signature = hmac.new(secret.encode(), message.encode(), hashlib.sha256).hexdigest()[:16]
|
|
76
|
+
return f"{raw_token}:{timestamp}:{signature}"
|
|
77
|
+
|
|
78
|
+
return raw_token
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def validate_csrf_token(
|
|
82
|
+
token: str,
|
|
83
|
+
secret: str | None = None,
|
|
84
|
+
max_age: int = DEFAULT_TOKEN_TTL,
|
|
85
|
+
) -> bool:
|
|
86
|
+
"""
|
|
87
|
+
Validate a CSRF token.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
token: The token to validate
|
|
91
|
+
secret: Optional secret for HMAC verification
|
|
92
|
+
max_age: Maximum token age in seconds
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
True if valid, False otherwise
|
|
96
|
+
"""
|
|
97
|
+
if not token:
|
|
98
|
+
return False
|
|
99
|
+
|
|
100
|
+
if secret and ":" in token:
|
|
101
|
+
# Verify HMAC-signed token
|
|
102
|
+
try:
|
|
103
|
+
parts = token.split(":")
|
|
104
|
+
if len(parts) != 3:
|
|
105
|
+
return False
|
|
106
|
+
|
|
107
|
+
raw_token, timestamp_str, signature = parts
|
|
108
|
+
timestamp = int(timestamp_str)
|
|
109
|
+
|
|
110
|
+
# Check age
|
|
111
|
+
if time.time() - timestamp > max_age:
|
|
112
|
+
logger.debug("CSRF token expired")
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
# Verify signature
|
|
116
|
+
message = f"{raw_token}:{timestamp_str}"
|
|
117
|
+
expected_sig = hmac.new(secret.encode(), message.encode(), hashlib.sha256).hexdigest()[
|
|
118
|
+
:16
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
if not hmac.compare_digest(signature, expected_sig):
|
|
122
|
+
logger.warning("CSRF token signature mismatch")
|
|
123
|
+
return False
|
|
124
|
+
|
|
125
|
+
return True
|
|
126
|
+
except (ValueError, IndexError) as e:
|
|
127
|
+
logger.warning(f"CSRF token validation error: {e}")
|
|
128
|
+
return False
|
|
129
|
+
|
|
130
|
+
# Simple token validation (just check it exists and has reasonable length)
|
|
131
|
+
return len(token) >= CSRF_TOKEN_LENGTH
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class CSRFMiddleware(BaseHTTPMiddleware):
|
|
135
|
+
"""
|
|
136
|
+
CSRF Protection Middleware using Double-Submit Cookie pattern.
|
|
137
|
+
|
|
138
|
+
The double-submit cookie pattern works by:
|
|
139
|
+
1. Setting a CSRF token in a cookie (with HttpOnly=False so JS can read it)
|
|
140
|
+
2. Requiring the same token in a header or form field
|
|
141
|
+
3. Since attackers can't read cookies from other domains, they can't forge requests
|
|
142
|
+
|
|
143
|
+
Additional protection from SameSite=Lax cookies prevents the browser from
|
|
144
|
+
sending cookies on cross-site requests.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
app,
|
|
150
|
+
secret: str | None = None,
|
|
151
|
+
exempt_routes: list[str] | None = None,
|
|
152
|
+
exempt_methods: set[str] | None = None,
|
|
153
|
+
cookie_name: str = CSRF_COOKIE_NAME,
|
|
154
|
+
header_name: str = CSRF_HEADER_NAME,
|
|
155
|
+
form_field: str = CSRF_FORM_FIELD,
|
|
156
|
+
token_ttl: int = DEFAULT_TOKEN_TTL,
|
|
157
|
+
rotate_tokens: bool = False,
|
|
158
|
+
secure_cookies: bool = True,
|
|
159
|
+
):
|
|
160
|
+
"""
|
|
161
|
+
Initialize CSRF middleware.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
app: FastAPI application
|
|
165
|
+
secret: Secret for HMAC token signing (recommended for production)
|
|
166
|
+
exempt_routes: Routes exempt from CSRF (supports wildcards: /api/*)
|
|
167
|
+
exempt_methods: HTTP methods exempt from CSRF (default: safe methods)
|
|
168
|
+
cookie_name: Name of the CSRF cookie
|
|
169
|
+
header_name: Name of the CSRF header
|
|
170
|
+
form_field: Name of the CSRF form field
|
|
171
|
+
token_ttl: Token time-to-live in seconds
|
|
172
|
+
rotate_tokens: Rotate token on each request (more secure, less convenient)
|
|
173
|
+
secure_cookies: Use Secure cookie flag (auto-detect HTTPS)
|
|
174
|
+
"""
|
|
175
|
+
super().__init__(app)
|
|
176
|
+
self.secret = secret or os.getenv("MDB_ENGINE_CSRF_SECRET")
|
|
177
|
+
self.exempt_routes = exempt_routes or []
|
|
178
|
+
self.exempt_methods = exempt_methods or {"GET", "HEAD", "OPTIONS", "TRACE"}
|
|
179
|
+
self.cookie_name = cookie_name
|
|
180
|
+
self.header_name = header_name
|
|
181
|
+
self.form_field = form_field
|
|
182
|
+
self.token_ttl = token_ttl
|
|
183
|
+
self.rotate_tokens = rotate_tokens
|
|
184
|
+
self.secure_cookies = secure_cookies
|
|
185
|
+
|
|
186
|
+
logger.info(
|
|
187
|
+
f"CSRFMiddleware initialized (exempt_routes={self.exempt_routes}, "
|
|
188
|
+
f"rotate_tokens={rotate_tokens})"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def _is_exempt(self, path: str) -> bool:
|
|
192
|
+
"""Check if a path is exempt from CSRF validation."""
|
|
193
|
+
# WebSocket upgrade requests are handled separately in dispatch()
|
|
194
|
+
# Don't exempt them here - they need origin validation
|
|
195
|
+
for pattern in self.exempt_routes:
|
|
196
|
+
if fnmatch.fnmatch(path, pattern):
|
|
197
|
+
return True
|
|
198
|
+
return False
|
|
199
|
+
|
|
200
|
+
def _is_websocket_upgrade(self, request: Request) -> bool:
|
|
201
|
+
"""Check if request is a WebSocket upgrade request."""
|
|
202
|
+
upgrade_header = request.headers.get("upgrade", "").lower()
|
|
203
|
+
return upgrade_header == "websocket"
|
|
204
|
+
|
|
205
|
+
def _get_allowed_origins(self, request: Request) -> list[str]:
|
|
206
|
+
"""
|
|
207
|
+
Get allowed origins from app state (CORS config) or use request host as fallback.
|
|
208
|
+
|
|
209
|
+
For multi-app setups, checks parent app's CORS config first (since WebSocket routes
|
|
210
|
+
are registered on parent app), then falls back to request host.
|
|
211
|
+
"""
|
|
212
|
+
try:
|
|
213
|
+
# For WebSocket routes on parent app, request.app is parent app
|
|
214
|
+
# Parent app has merged CORS config from all child apps
|
|
215
|
+
cors_config = getattr(request.app.state, "cors_config", None)
|
|
216
|
+
if cors_config and cors_config.get("allow_origins"):
|
|
217
|
+
origins = cors_config["allow_origins"]
|
|
218
|
+
if origins:
|
|
219
|
+
return origins if isinstance(origins, list) else [origins]
|
|
220
|
+
except (AttributeError, TypeError, KeyError) as e:
|
|
221
|
+
logger.debug(f"Could not read CORS config from app.state: {e}")
|
|
222
|
+
|
|
223
|
+
# Fallback: Check if this is a multi-app setup and try to find mounted app's CORS config
|
|
224
|
+
try:
|
|
225
|
+
if hasattr(request.app.state, "mounted_apps"):
|
|
226
|
+
# This is a parent app in multi-app setup
|
|
227
|
+
# Try to find which mounted app this request is for
|
|
228
|
+
path = request.url.path
|
|
229
|
+
mounted_apps = request.app.state.mounted_apps
|
|
230
|
+
|
|
231
|
+
# Find matching mounted app by path prefix
|
|
232
|
+
for app_info in mounted_apps:
|
|
233
|
+
path_prefix = app_info.get("path_prefix", "")
|
|
234
|
+
if path_prefix and path.startswith(path_prefix):
|
|
235
|
+
# Try to get child app's CORS config if available
|
|
236
|
+
# Note: Child app might not be directly accessible, so we rely on
|
|
237
|
+
# parent app's merged CORS config (set during mounting)
|
|
238
|
+
break
|
|
239
|
+
except (AttributeError, TypeError, KeyError):
|
|
240
|
+
pass
|
|
241
|
+
|
|
242
|
+
# Final fallback: Use request host
|
|
243
|
+
try:
|
|
244
|
+
host = request.url.hostname
|
|
245
|
+
scheme = request.url.scheme
|
|
246
|
+
port = request.url.port
|
|
247
|
+
if port and port not in [80, 443]:
|
|
248
|
+
origin = f"{scheme}://{host}:{port}"
|
|
249
|
+
else:
|
|
250
|
+
origin = f"{scheme}://{host}"
|
|
251
|
+
return [origin]
|
|
252
|
+
except (AttributeError, TypeError):
|
|
253
|
+
# Return empty list if we can't determine origin (will reject)
|
|
254
|
+
return []
|
|
255
|
+
|
|
256
|
+
def _validate_websocket_origin(self, request: Request) -> bool:
|
|
257
|
+
"""
|
|
258
|
+
Validate Origin header for WebSocket upgrade requests.
|
|
259
|
+
|
|
260
|
+
Primary defense against Cross-Site WebSocket Hijacking (CSWSH).
|
|
261
|
+
Returns True if Origin is valid, False otherwise.
|
|
262
|
+
"""
|
|
263
|
+
origin = request.headers.get("origin")
|
|
264
|
+
if not origin:
|
|
265
|
+
logger.warning(f"WebSocket upgrade missing Origin header: {request.url.path}")
|
|
266
|
+
return False
|
|
267
|
+
|
|
268
|
+
allowed_origins = self._get_allowed_origins(request)
|
|
269
|
+
|
|
270
|
+
for allowed in allowed_origins:
|
|
271
|
+
if allowed == "*":
|
|
272
|
+
logger.warning(
|
|
273
|
+
"WebSocket Origin validation using wildcard '*' - "
|
|
274
|
+
"not recommended for production"
|
|
275
|
+
)
|
|
276
|
+
return True
|
|
277
|
+
if origin == allowed or origin.rstrip("/") == allowed.rstrip("/"):
|
|
278
|
+
return True
|
|
279
|
+
|
|
280
|
+
cors_config = getattr(request.app.state, "cors_config", None)
|
|
281
|
+
cors_enabled = cors_config.get("enabled", False) if cors_config else False
|
|
282
|
+
logger.warning(
|
|
283
|
+
f"WebSocket upgrade rejected - invalid Origin: {origin} "
|
|
284
|
+
f"(allowed: {allowed_origins}, app: {getattr(request.app, 'title', 'unknown')}, "
|
|
285
|
+
f"path: {request.url.path}, CORS enabled: {cors_enabled}, "
|
|
286
|
+
f"has_cors_config: {hasattr(request.app.state, 'cors_config')})"
|
|
287
|
+
)
|
|
288
|
+
return False
|
|
289
|
+
|
|
290
|
+
async def dispatch(
|
|
291
|
+
self,
|
|
292
|
+
request: Request,
|
|
293
|
+
call_next: Callable[[Request], Awaitable[Response]],
|
|
294
|
+
) -> Response:
|
|
295
|
+
"""
|
|
296
|
+
Process request through CSRF middleware.
|
|
297
|
+
"""
|
|
298
|
+
path = request.url.path
|
|
299
|
+
method = request.method
|
|
300
|
+
|
|
301
|
+
# CRITICAL: Handle WebSocket upgrade requests BEFORE other CSRF checks
|
|
302
|
+
# WebSocket upgrades don't use CSRF tokens, but need origin validation
|
|
303
|
+
if self._is_websocket_upgrade(request):
|
|
304
|
+
# Validate origin for WebSocket connections (CSWSH protection)
|
|
305
|
+
if not self._validate_websocket_origin(request):
|
|
306
|
+
logger.warning(
|
|
307
|
+
f"WebSocket origin validation failed for {path}: "
|
|
308
|
+
f"origin={request.headers.get('origin')}, "
|
|
309
|
+
f"allowed={self._get_allowed_origins(request)}"
|
|
310
|
+
)
|
|
311
|
+
return JSONResponse(
|
|
312
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
313
|
+
content={"detail": "Invalid origin for WebSocket connection"},
|
|
314
|
+
)
|
|
315
|
+
# Origin validated - allow WebSocket upgrade to proceed
|
|
316
|
+
# No CSRF token check needed for WebSocket upgrades
|
|
317
|
+
return await call_next(request)
|
|
318
|
+
|
|
319
|
+
if self._is_exempt(path):
|
|
320
|
+
return await call_next(request)
|
|
321
|
+
|
|
322
|
+
# Skip safe methods
|
|
323
|
+
if method in self.exempt_methods:
|
|
324
|
+
# Generate and set token for GET requests (for forms)
|
|
325
|
+
response = await call_next(request)
|
|
326
|
+
|
|
327
|
+
# Set CSRF token cookie if not present
|
|
328
|
+
if not request.cookies.get(self.cookie_name):
|
|
329
|
+
token = generate_csrf_token(self.secret)
|
|
330
|
+
self._set_csrf_cookie(request, response, token)
|
|
331
|
+
|
|
332
|
+
# Make token available in request state for templates
|
|
333
|
+
request.state.csrf_token = request.cookies.get(self.cookie_name) or generate_csrf_token(
|
|
334
|
+
self.secret
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
return response
|
|
338
|
+
|
|
339
|
+
# Validate CSRF token for unsafe methods
|
|
340
|
+
cookie_token = request.cookies.get(self.cookie_name)
|
|
341
|
+
if not cookie_token:
|
|
342
|
+
logger.warning(f"CSRF cookie missing for {method} {path}")
|
|
343
|
+
return JSONResponse(
|
|
344
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
345
|
+
content={"detail": "CSRF token missing"},
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# Get token from header or form
|
|
349
|
+
header_token = request.headers.get(self.header_name)
|
|
350
|
+
form_token = None
|
|
351
|
+
|
|
352
|
+
# Note: Form-based CSRF token extraction not implemented.
|
|
353
|
+
# For now, we rely on header-based CSRF for all requests.
|
|
354
|
+
# TODO: Implement request.form() based extraction if needed.
|
|
355
|
+
|
|
356
|
+
submitted_token = header_token or form_token
|
|
357
|
+
|
|
358
|
+
if not submitted_token:
|
|
359
|
+
logger.warning(f"CSRF token not submitted for {method} {path}")
|
|
360
|
+
return JSONResponse(
|
|
361
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
362
|
+
content={"detail": "CSRF token not provided in header or form"},
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Compare tokens (constant-time comparison)
|
|
366
|
+
if not hmac.compare_digest(cookie_token, submitted_token):
|
|
367
|
+
logger.warning(f"CSRF token mismatch for {method} {path}")
|
|
368
|
+
return JSONResponse(
|
|
369
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
370
|
+
content={"detail": "CSRF token invalid"},
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# Validate token (check signature if secret is used)
|
|
374
|
+
if self.secret and not validate_csrf_token(cookie_token, self.secret, self.token_ttl):
|
|
375
|
+
logger.warning(f"CSRF token validation failed for {method} {path}")
|
|
376
|
+
return JSONResponse(
|
|
377
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
378
|
+
content={"detail": "CSRF token expired or invalid"},
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# Process request
|
|
382
|
+
response = await call_next(request)
|
|
383
|
+
|
|
384
|
+
# Optionally rotate token
|
|
385
|
+
if self.rotate_tokens:
|
|
386
|
+
new_token = generate_csrf_token(self.secret)
|
|
387
|
+
self._set_csrf_cookie(request, response, new_token)
|
|
388
|
+
|
|
389
|
+
return response
|
|
390
|
+
|
|
391
|
+
def _set_csrf_cookie(
|
|
392
|
+
self,
|
|
393
|
+
request: Request,
|
|
394
|
+
response: Response,
|
|
395
|
+
token: str,
|
|
396
|
+
) -> None:
|
|
397
|
+
"""Set the CSRF token cookie."""
|
|
398
|
+
is_https = request.url.scheme == "https"
|
|
399
|
+
is_production = os.getenv("ENVIRONMENT", "").lower() == "production"
|
|
400
|
+
|
|
401
|
+
response.set_cookie(
|
|
402
|
+
key=self.cookie_name,
|
|
403
|
+
value=token,
|
|
404
|
+
httponly=False, # Must be readable by JavaScript
|
|
405
|
+
secure=self.secure_cookies and (is_https or is_production),
|
|
406
|
+
samesite="lax", # Provides CSRF protection + allows top-level navigation
|
|
407
|
+
max_age=self.token_ttl,
|
|
408
|
+
path="/",
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def create_csrf_middleware(
|
|
413
|
+
manifest_auth: dict[str, Any],
|
|
414
|
+
secret: str | None = None,
|
|
415
|
+
) -> type:
|
|
416
|
+
"""
|
|
417
|
+
Create CSRF middleware from manifest configuration.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
manifest_auth: Auth section from manifest
|
|
421
|
+
secret: Optional CSRF secret (defaults to env var)
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
Configured CSRFMiddleware class
|
|
425
|
+
"""
|
|
426
|
+
csrf_config = manifest_auth.get("csrf_protection", True)
|
|
427
|
+
|
|
428
|
+
# Handle boolean or object config
|
|
429
|
+
if isinstance(csrf_config, bool):
|
|
430
|
+
if not csrf_config:
|
|
431
|
+
# Return a no-op middleware
|
|
432
|
+
class NoOpMiddleware(BaseHTTPMiddleware):
|
|
433
|
+
async def dispatch(self, request, call_next):
|
|
434
|
+
return await call_next(request)
|
|
435
|
+
|
|
436
|
+
return NoOpMiddleware
|
|
437
|
+
|
|
438
|
+
# Use defaults
|
|
439
|
+
exempt_routes = manifest_auth.get("public_routes", [])
|
|
440
|
+
rotate_tokens = False
|
|
441
|
+
token_ttl = DEFAULT_TOKEN_TTL
|
|
442
|
+
else:
|
|
443
|
+
# Object configuration
|
|
444
|
+
exempt_routes = csrf_config.get("exempt_routes", manifest_auth.get("public_routes", []))
|
|
445
|
+
rotate_tokens = csrf_config.get("rotate_tokens", False)
|
|
446
|
+
token_ttl = csrf_config.get("token_ttl", DEFAULT_TOKEN_TTL)
|
|
447
|
+
|
|
448
|
+
# Create configured middleware class
|
|
449
|
+
class ConfiguredCSRFMiddleware(CSRFMiddleware):
|
|
450
|
+
def __init__(self, app):
|
|
451
|
+
super().__init__(
|
|
452
|
+
app,
|
|
453
|
+
secret=secret or os.getenv("MDB_ENGINE_CSRF_SECRET"),
|
|
454
|
+
exempt_routes=exempt_routes,
|
|
455
|
+
rotate_tokens=rotate_tokens,
|
|
456
|
+
token_ttl=token_ttl,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
return ConfiguredCSRFMiddleware
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
# Dependency for getting CSRF token in routes
|
|
463
|
+
def get_csrf_token(request: Request) -> str:
|
|
464
|
+
"""
|
|
465
|
+
Get or generate CSRF token for use in templates.
|
|
466
|
+
|
|
467
|
+
Usage in FastAPI route:
|
|
468
|
+
@app.get("/form")
|
|
469
|
+
def form_page(csrf_token: str = Depends(get_csrf_token)):
|
|
470
|
+
return templates.TemplateResponse("form.html", {"csrf_token": csrf_token})
|
|
471
|
+
"""
|
|
472
|
+
# Try to get from request state (set by middleware)
|
|
473
|
+
if hasattr(request.state, "csrf_token"):
|
|
474
|
+
return request.state.csrf_token
|
|
475
|
+
|
|
476
|
+
# Try to get from cookie
|
|
477
|
+
token = request.cookies.get(CSRF_COOKIE_NAME)
|
|
478
|
+
if token:
|
|
479
|
+
return token
|
|
480
|
+
|
|
481
|
+
# Generate new token
|
|
482
|
+
secret = os.getenv("MDB_ENGINE_CSRF_SECRET")
|
|
483
|
+
return generate_csrf_token(secret)
|
mdb_engine/auth/decorators.py
CHANGED
|
@@ -9,8 +9,9 @@ This module is part of MDB_ENGINE - MongoDB Engine.
|
|
|
9
9
|
import logging
|
|
10
10
|
import time
|
|
11
11
|
from collections import defaultdict
|
|
12
|
+
from collections.abc import Awaitable, Callable
|
|
12
13
|
from functools import wraps
|
|
13
|
-
from typing import Any
|
|
14
|
+
from typing import Any
|
|
14
15
|
|
|
15
16
|
from fastapi import HTTPException, Request, status
|
|
16
17
|
from fastapi.responses import RedirectResponse
|
|
@@ -20,7 +21,7 @@ from .dependencies import get_current_user_from_request
|
|
|
20
21
|
logger = logging.getLogger(__name__)
|
|
21
22
|
|
|
22
23
|
# Rate limiting storage (in-memory, can be replaced with Redis for distributed systems)
|
|
23
|
-
_rate_limit_storage:
|
|
24
|
+
_rate_limit_storage: dict[str, dict[str, Any]] = defaultdict(dict)
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
def require_auth(redirect_to: str = "/login"):
|
|
@@ -70,10 +71,7 @@ def _is_production_environment() -> bool:
|
|
|
70
71
|
"""Check if running in production environment."""
|
|
71
72
|
import os
|
|
72
73
|
|
|
73
|
-
return (
|
|
74
|
-
os.getenv("G_NOME_ENV") == "production"
|
|
75
|
-
or os.getenv("ENVIRONMENT") == "production"
|
|
76
|
-
)
|
|
74
|
+
return os.getenv("G_NOME_ENV") == "production" or os.getenv("ENVIRONMENT") == "production"
|
|
77
75
|
|
|
78
76
|
|
|
79
77
|
def _validate_https(request: Request) -> None:
|
|
@@ -85,7 +83,7 @@ def _validate_https(request: Request) -> None:
|
|
|
85
83
|
)
|
|
86
84
|
|
|
87
85
|
|
|
88
|
-
async def _get_csrf_token(request: Request) ->
|
|
86
|
+
async def _get_csrf_token(request: Request) -> str | None:
|
|
89
87
|
"""Extract CSRF token from request headers or form data."""
|
|
90
88
|
csrf_token = request.headers.get("X-CSRF-Token")
|
|
91
89
|
if csrf_token:
|
|
@@ -154,8 +152,8 @@ def token_security(enforce_https: bool = True, check_csrf: bool = True):
|
|
|
154
152
|
|
|
155
153
|
def rate_limit_auth(
|
|
156
154
|
endpoint: str = "login",
|
|
157
|
-
max_attempts:
|
|
158
|
-
window_seconds:
|
|
155
|
+
max_attempts: int | None = None,
|
|
156
|
+
window_seconds: int | None = None,
|
|
159
157
|
):
|
|
160
158
|
"""
|
|
161
159
|
Rate limiting decorator for auth endpoints.
|
|
@@ -189,17 +187,13 @@ def rate_limit_auth(
|
|
|
189
187
|
|
|
190
188
|
# Use provided values or config values or defaults
|
|
191
189
|
if max_attempts is None:
|
|
192
|
-
max_attempts_val = (
|
|
193
|
-
rate_limit_config.get("max_attempts") if rate_limit_config else 5
|
|
194
|
-
)
|
|
190
|
+
max_attempts_val = rate_limit_config.get("max_attempts") if rate_limit_config else 5
|
|
195
191
|
else:
|
|
196
192
|
max_attempts_val = max_attempts
|
|
197
193
|
|
|
198
194
|
if window_seconds is None:
|
|
199
195
|
window_seconds_val = (
|
|
200
|
-
rate_limit_config.get("window_seconds")
|
|
201
|
-
if rate_limit_config
|
|
202
|
-
else 300
|
|
196
|
+
rate_limit_config.get("window_seconds") if rate_limit_config else 300
|
|
203
197
|
)
|
|
204
198
|
else:
|
|
205
199
|
window_seconds_val = window_seconds
|
|
@@ -249,7 +243,7 @@ def rate_limit_auth(
|
|
|
249
243
|
return decorator
|
|
250
244
|
|
|
251
245
|
|
|
252
|
-
def auto_token_setup(func:
|
|
246
|
+
def auto_token_setup(func: Callable[..., Awaitable[Any]] | None = None):
|
|
253
247
|
"""
|
|
254
248
|
Decorator to automatically set up tokens on successful login/register.
|
|
255
249
|
|