mdb-engine 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/__init__.py +13 -9
- mdb_engine/auth/__init__.py +18 -0
- mdb_engine/auth/csrf.py +651 -69
- mdb_engine/auth/provider.py +10 -0
- mdb_engine/auth/shared_users.py +73 -2
- mdb_engine/auth/users.py +2 -1
- mdb_engine/auth/utils.py +31 -6
- mdb_engine/auth/websocket_sessions.py +433 -0
- mdb_engine/auth/websocket_tickets.py +307 -0
- mdb_engine/core/app_registration.py +10 -0
- mdb_engine/core/engine.py +656 -21
- mdb_engine/core/manifest.py +26 -0
- mdb_engine/core/ray_integration.py +4 -4
- mdb_engine/core/types.py +2 -0
- mdb_engine/database/connection.py +6 -3
- mdb_engine/database/scoped_wrapper.py +3 -3
- mdb_engine/indexes/manager.py +3 -3
- mdb_engine/observability/health.py +7 -7
- mdb_engine/routing/README.md +9 -2
- mdb_engine/routing/websockets.py +479 -56
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/METADATA +128 -4
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/RECORD +26 -24
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/WHEEL +0 -0
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/top_level.txt +0 -0
mdb_engine/auth/csrf.py
CHANGED
|
@@ -102,14 +102,16 @@ def validate_csrf_token(
|
|
|
102
102
|
try:
|
|
103
103
|
parts = token.split(":")
|
|
104
104
|
if len(parts) != 3:
|
|
105
|
+
logger.debug("CSRF token has wrong format (expected 3 parts)")
|
|
105
106
|
return False
|
|
106
107
|
|
|
107
108
|
raw_token, timestamp_str, signature = parts
|
|
108
109
|
timestamp = int(timestamp_str)
|
|
109
110
|
|
|
110
111
|
# Check age
|
|
111
|
-
|
|
112
|
-
|
|
112
|
+
age = time.time() - timestamp
|
|
113
|
+
if age > max_age:
|
|
114
|
+
logger.debug(f"CSRF token expired (age: {age:.0f}s, max: {max_age}s)")
|
|
113
115
|
return False
|
|
114
116
|
|
|
115
117
|
# Verify signature
|
|
@@ -119,7 +121,10 @@ def validate_csrf_token(
|
|
|
119
121
|
]
|
|
120
122
|
|
|
121
123
|
if not hmac.compare_digest(signature, expected_sig):
|
|
122
|
-
logger.warning(
|
|
124
|
+
logger.warning(
|
|
125
|
+
f"CSRF token signature mismatch. "
|
|
126
|
+
f"Token format: signed, Has secret: {bool(secret)}"
|
|
127
|
+
)
|
|
123
128
|
return False
|
|
124
129
|
|
|
125
130
|
return True
|
|
@@ -128,7 +133,10 @@ def validate_csrf_token(
|
|
|
128
133
|
return False
|
|
129
134
|
|
|
130
135
|
# Simple token validation (just check it exists and has reasonable length)
|
|
131
|
-
|
|
136
|
+
is_valid = len(token) >= CSRF_TOKEN_LENGTH
|
|
137
|
+
if not is_valid:
|
|
138
|
+
logger.debug(f"CSRF token too short (length: {len(token)}, required: {CSRF_TOKEN_LENGTH})")
|
|
139
|
+
return is_valid
|
|
132
140
|
|
|
133
141
|
|
|
134
142
|
class CSRFMiddleware(BaseHTTPMiddleware):
|
|
@@ -197,10 +205,218 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|
|
197
205
|
return True
|
|
198
206
|
return False
|
|
199
207
|
|
|
208
|
+
def _validate_csrf_token(self, token: str, request: Request) -> bool:
|
|
209
|
+
"""
|
|
210
|
+
Validate a CSRF token using the middleware's secret and TTL.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
token: CSRF token to validate
|
|
214
|
+
request: FastAPI request (unused, kept for API consistency)
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
True if token is valid, False otherwise
|
|
218
|
+
"""
|
|
219
|
+
return validate_csrf_token(token, self.secret, self.token_ttl)
|
|
220
|
+
|
|
221
|
+
def _websocket_requires_csrf(self, request: Request, path: str) -> bool:
|
|
222
|
+
"""
|
|
223
|
+
Check if WebSocket endpoint requires CSRF validation.
|
|
224
|
+
|
|
225
|
+
Defaults to True (security by default). Can be disabled per-endpoint via manifest.json:
|
|
226
|
+
websockets.{endpoint}.auth.csrf_required = false
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
request: FastAPI request
|
|
230
|
+
path: WebSocket path (e.g., "/chat-app/ws")
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
True if CSRF validation is required, False otherwise
|
|
234
|
+
"""
|
|
235
|
+
# Try parent app first (where websocket_configs should be stored)
|
|
236
|
+
websocket_configs = getattr(request.app.state, "websocket_configs", None)
|
|
237
|
+
|
|
238
|
+
# If not found, try to traverse up to find parent app
|
|
239
|
+
if not websocket_configs:
|
|
240
|
+
logger.debug(f"No websocket_configs found on request.app.state for path '{path}'")
|
|
241
|
+
app = request.app
|
|
242
|
+
apps_checked = []
|
|
243
|
+
while app:
|
|
244
|
+
app_title = getattr(app, "title", "unknown")
|
|
245
|
+
apps_checked.append(app_title)
|
|
246
|
+
websocket_configs = getattr(app.state, "websocket_configs", None)
|
|
247
|
+
if websocket_configs:
|
|
248
|
+
logger.debug(
|
|
249
|
+
f"Found websocket_configs on app '{app_title}' "
|
|
250
|
+
f"(checked: {apps_checked})"
|
|
251
|
+
)
|
|
252
|
+
break
|
|
253
|
+
parent_app = getattr(app, "app", None)
|
|
254
|
+
if parent_app is app: # Prevent infinite loop
|
|
255
|
+
break
|
|
256
|
+
app = parent_app
|
|
257
|
+
|
|
258
|
+
if not websocket_configs:
|
|
259
|
+
# No WebSocket configs found - use default (CSRF required for security by default)
|
|
260
|
+
logger.debug(
|
|
261
|
+
f"No websocket_configs found anywhere for path '{path}' - "
|
|
262
|
+
f"using default csrf_required=true"
|
|
263
|
+
)
|
|
264
|
+
return True
|
|
265
|
+
|
|
266
|
+
# Normalize path for matching (handle trailing slashes)
|
|
267
|
+
normalized_path = path.rstrip("/")
|
|
268
|
+
logger.debug(
|
|
269
|
+
f"Checking CSRF requirement for path '{normalized_path}' "
|
|
270
|
+
f"against {len(websocket_configs)} app config(s)"
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Try to find matching app config
|
|
274
|
+
# WebSocket paths are registered as /app-slug/endpoint-path
|
|
275
|
+
# e.g., /chat-app/ws where app_slug="chat-app" and endpoint_path="/ws"
|
|
276
|
+
for app_slug, config in websocket_configs.items():
|
|
277
|
+
logger.debug(f"Checking app '{app_slug}' config with {len(config)} endpoint(s)")
|
|
278
|
+
# Check each endpoint in this app's config
|
|
279
|
+
for endpoint_name, endpoint_config in config.items():
|
|
280
|
+
endpoint_path = endpoint_config.get("path", "")
|
|
281
|
+
# Normalize endpoint path
|
|
282
|
+
normalized_endpoint = endpoint_path.rstrip("/")
|
|
283
|
+
|
|
284
|
+
# Build expected full path: /app-slug/endpoint-path
|
|
285
|
+
if normalized_endpoint.startswith("/"):
|
|
286
|
+
expected_full_path = f"/{app_slug}{normalized_endpoint}"
|
|
287
|
+
else:
|
|
288
|
+
expected_full_path = f"/{app_slug}/{normalized_endpoint}"
|
|
289
|
+
|
|
290
|
+
# Match patterns:
|
|
291
|
+
# 1. Full path match: /chat-app/ws == /chat-app/ws
|
|
292
|
+
# 2. Endpoint-only match: /ws (if path ends with endpoint)
|
|
293
|
+
# 3. Path contains endpoint: /chat-app/ws contains /ws
|
|
294
|
+
matches = (
|
|
295
|
+
normalized_path == expected_full_path
|
|
296
|
+
or normalized_path == normalized_endpoint
|
|
297
|
+
or normalized_path.endswith(normalized_endpoint)
|
|
298
|
+
or normalized_path.endswith(f"/{app_slug}{normalized_endpoint}")
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if matches:
|
|
302
|
+
auth_config = endpoint_config.get("auth", {})
|
|
303
|
+
if isinstance(auth_config, dict):
|
|
304
|
+
# Return csrf_required setting (defaults to True - security by default)
|
|
305
|
+
csrf_required = auth_config.get("csrf_required", True)
|
|
306
|
+
logger.info(
|
|
307
|
+
f"✅ WebSocket '{normalized_path}' csrf_required={csrf_required} "
|
|
308
|
+
f"(from app='{app_slug}', endpoint='{endpoint_name}', "
|
|
309
|
+
f"endpoint_path='{normalized_endpoint}')"
|
|
310
|
+
)
|
|
311
|
+
return csrf_required
|
|
312
|
+
else:
|
|
313
|
+
logger.debug(
|
|
314
|
+
f"WebSocket '{normalized_path}' auth_config is not a dict: "
|
|
315
|
+
f"{type(auth_config)}"
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# No matching config found - use default (CSRF required for security by default)
|
|
319
|
+
logger.debug(
|
|
320
|
+
f"❌ No WebSocket config match for '{normalized_path}' "
|
|
321
|
+
f"(checked {len(websocket_configs)} app(s)) - using default csrf_required=true"
|
|
322
|
+
)
|
|
323
|
+
return True
|
|
324
|
+
|
|
200
325
|
def _is_websocket_upgrade(self, request: Request) -> bool:
|
|
201
326
|
"""Check if request is a WebSocket upgrade request."""
|
|
202
327
|
upgrade_header = request.headers.get("upgrade", "").lower()
|
|
203
|
-
|
|
328
|
+
connection_header = request.headers.get("connection", "").lower()
|
|
329
|
+
path = request.url.path
|
|
330
|
+
|
|
331
|
+
# CRITICAL: Enhanced logging for WebSocket detection
|
|
332
|
+
import sys
|
|
333
|
+
|
|
334
|
+
print(
|
|
335
|
+
f"🔍 [_is_websocket_upgrade] Path: {path}, "
|
|
336
|
+
f"upgrade='{upgrade_header}', connection='{connection_header}'",
|
|
337
|
+
file=sys.stderr,
|
|
338
|
+
flush=True,
|
|
339
|
+
)
|
|
340
|
+
logger.info(
|
|
341
|
+
f"_is_websocket_upgrade check: upgrade='{upgrade_header}', "
|
|
342
|
+
f"connection='{connection_header}', path='{path}'"
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Primary check: WebSocket upgrade requires both Upgrade: websocket
|
|
346
|
+
# and Connection: Upgrade headers
|
|
347
|
+
has_upgrade_header = upgrade_header == "websocket"
|
|
348
|
+
has_connection_upgrade = "upgrade" in connection_header or "websocket" in connection_header
|
|
349
|
+
|
|
350
|
+
print(
|
|
351
|
+
f"🔍 [_is_websocket_upgrade] has_upgrade={has_upgrade_header}, "
|
|
352
|
+
f"has_connection={has_connection_upgrade}",
|
|
353
|
+
file=sys.stderr,
|
|
354
|
+
flush=True,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Secondary check: If upgrade header is present but connection is
|
|
358
|
+
# overridden (e.g., by TestClient), check if path matches a known
|
|
359
|
+
# WebSocket route pattern
|
|
360
|
+
path_matches_websocket_route = False
|
|
361
|
+
if has_upgrade_header and not has_connection_upgrade:
|
|
362
|
+
# Check if path matches any configured WebSocket route
|
|
363
|
+
websocket_configs = getattr(request.app.state, "websocket_configs", None)
|
|
364
|
+
if websocket_configs:
|
|
365
|
+
path = request.url.path.rstrip("/") or "/"
|
|
366
|
+
for app_slug, config in websocket_configs.items():
|
|
367
|
+
for _endpoint_name, endpoint_config in config.items():
|
|
368
|
+
endpoint_path = endpoint_config.get("path", "").rstrip("/") or "/"
|
|
369
|
+
# Try various path matching patterns
|
|
370
|
+
expected_full_path = (
|
|
371
|
+
f"/{app_slug}{endpoint_path}"
|
|
372
|
+
if endpoint_path != "/"
|
|
373
|
+
else f"/{app_slug}"
|
|
374
|
+
)
|
|
375
|
+
# Match patterns:
|
|
376
|
+
# 1. Exact match with app prefix: /app-slug/endpoint-path
|
|
377
|
+
# 2. Endpoint-only match: /endpoint-path (if path ends with endpoint)
|
|
378
|
+
# 3. Root match: / matches / or /app-slug
|
|
379
|
+
if (
|
|
380
|
+
path == expected_full_path
|
|
381
|
+
or path.endswith(endpoint_path)
|
|
382
|
+
or path == endpoint_path
|
|
383
|
+
or (path == "/" and endpoint_path == "/")
|
|
384
|
+
or (path == f"/{app_slug}" and endpoint_path == "/")
|
|
385
|
+
):
|
|
386
|
+
path_matches_websocket_route = True
|
|
387
|
+
break
|
|
388
|
+
if path_matches_websocket_route:
|
|
389
|
+
break
|
|
390
|
+
|
|
391
|
+
is_websocket = has_upgrade_header and (
|
|
392
|
+
has_connection_upgrade or path_matches_websocket_route
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
# CRITICAL: Enhanced logging
|
|
396
|
+
import sys
|
|
397
|
+
|
|
398
|
+
if is_websocket:
|
|
399
|
+
print(
|
|
400
|
+
f"✅ [_is_websocket_upgrade] WebSocket detected: path={path}, "
|
|
401
|
+
f"upgrade={upgrade_header}, connection={connection_header}, "
|
|
402
|
+
f"path_match={path_matches_websocket_route}, result={is_websocket}",
|
|
403
|
+
file=sys.stderr,
|
|
404
|
+
flush=True,
|
|
405
|
+
)
|
|
406
|
+
logger.info(
|
|
407
|
+
f"WebSocket upgrade detected: path={path}, "
|
|
408
|
+
f"upgrade={upgrade_header}, connection={connection_header}, "
|
|
409
|
+
f"path_match={path_matches_websocket_route}"
|
|
410
|
+
)
|
|
411
|
+
else:
|
|
412
|
+
print(
|
|
413
|
+
f"❌ [_is_websocket_upgrade] NOT a WebSocket: path={path}, "
|
|
414
|
+
f"upgrade={upgrade_header}, connection={connection_header}, "
|
|
415
|
+
f"has_upgrade={has_upgrade_header}, has_connection={has_connection_upgrade}",
|
|
416
|
+
file=sys.stderr,
|
|
417
|
+
flush=True,
|
|
418
|
+
)
|
|
419
|
+
return is_websocket
|
|
204
420
|
|
|
205
421
|
def _get_allowed_origins(self, request: Request) -> list[str]:
|
|
206
422
|
"""
|
|
@@ -239,33 +455,100 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|
|
239
455
|
except (AttributeError, TypeError, KeyError):
|
|
240
456
|
pass
|
|
241
457
|
|
|
242
|
-
# Final fallback: Use request host
|
|
458
|
+
# Final fallback: Use request host (normalize localhost variants)
|
|
243
459
|
try:
|
|
244
460
|
host = request.url.hostname
|
|
245
461
|
scheme = request.url.scheme
|
|
246
462
|
port = request.url.port
|
|
463
|
+
|
|
464
|
+
# Normalize localhost variants - return all common variants for development
|
|
465
|
+
# This handles cases where server binds to 0.0.0.0 but browser sends localhost
|
|
466
|
+
if host in ["localhost", "0.0.0.0", "127.0.0.1", "::1"]:
|
|
467
|
+
origins = []
|
|
468
|
+
for localhost_variant in ["localhost", "127.0.0.1"]:
|
|
469
|
+
if port and port not in [80, 443]:
|
|
470
|
+
origins.append(f"{scheme}://{localhost_variant}:{port}")
|
|
471
|
+
else:
|
|
472
|
+
origins.append(f"{scheme}://{localhost_variant}")
|
|
473
|
+
logger.debug(f"Generated localhost variant origins for host '{host}': {origins}")
|
|
474
|
+
return origins
|
|
475
|
+
|
|
476
|
+
# For other hosts, use the actual hostname
|
|
247
477
|
if port and port not in [80, 443]:
|
|
248
478
|
origin = f"{scheme}://{host}:{port}"
|
|
249
479
|
else:
|
|
250
480
|
origin = f"{scheme}://{host}"
|
|
251
481
|
return [origin]
|
|
252
|
-
except (AttributeError, TypeError):
|
|
482
|
+
except (AttributeError, TypeError) as e:
|
|
483
|
+
logger.debug(f"Could not determine origin from request: {e}")
|
|
253
484
|
# Return empty list if we can't determine origin (will reject)
|
|
254
485
|
return []
|
|
255
486
|
|
|
487
|
+
def _normalize_origin(self, origin: str) -> str:
|
|
488
|
+
"""
|
|
489
|
+
Normalize origin for comparison (handles localhost/0.0.0.0/127.0.0.1/::1 equivalency).
|
|
490
|
+
|
|
491
|
+
In development, localhost, 0.0.0.0, 127.0.0.1, and ::1 should be treated as equivalent.
|
|
492
|
+
Also normalizes ports (80/443 vs explicit ports).
|
|
493
|
+
"""
|
|
494
|
+
if not origin:
|
|
495
|
+
return origin
|
|
496
|
+
|
|
497
|
+
import re
|
|
498
|
+
|
|
499
|
+
# Normalize localhost variants - replace all variants with localhost
|
|
500
|
+
# Handle IPv4: 0.0.0.0, 127.0.0.1
|
|
501
|
+
# Handle IPv6: ::1
|
|
502
|
+
# Handle hostname: localhost
|
|
503
|
+
normalized = re.sub(
|
|
504
|
+
r"://(0\.0\.0\.0|127\.0\.0\.1|localhost|::1)",
|
|
505
|
+
"://localhost",
|
|
506
|
+
origin.lower(),
|
|
507
|
+
flags=re.IGNORECASE,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# Normalize ports: remove :80 for http and :443 for https
|
|
511
|
+
normalized = re.sub(r":80$", "", normalized)
|
|
512
|
+
normalized = re.sub(r":443$", "", normalized)
|
|
513
|
+
|
|
514
|
+
return normalized.rstrip("/")
|
|
515
|
+
|
|
516
|
+
def _is_development_mode(self) -> bool:
|
|
517
|
+
"""Check if running in development mode."""
|
|
518
|
+
import os
|
|
519
|
+
|
|
520
|
+
env = os.getenv("ENVIRONMENT", "").lower()
|
|
521
|
+
g_nome_env = os.getenv("G_NOME_ENV", "").lower()
|
|
522
|
+
return env in ["development", "dev"] or g_nome_env in ["development", "dev"]
|
|
523
|
+
|
|
256
524
|
def _validate_websocket_origin(self, request: Request) -> bool:
|
|
257
525
|
"""
|
|
258
526
|
Validate Origin header for WebSocket upgrade requests.
|
|
259
527
|
|
|
260
528
|
Primary defense against Cross-Site WebSocket Hijacking (CSWSH).
|
|
261
529
|
Returns True if Origin is valid, False otherwise.
|
|
530
|
+
|
|
531
|
+
In development mode, allows connections without Origin header (with warning).
|
|
262
532
|
"""
|
|
263
533
|
origin = request.headers.get("origin")
|
|
264
534
|
if not origin:
|
|
265
|
-
|
|
266
|
-
|
|
535
|
+
if self._is_development_mode():
|
|
536
|
+
logger.warning(
|
|
537
|
+
f"WebSocket upgrade missing Origin header in development mode: "
|
|
538
|
+
f"{request.url.path} - allowing connection"
|
|
539
|
+
)
|
|
540
|
+
return True
|
|
541
|
+
else:
|
|
542
|
+
logger.warning(f"WebSocket upgrade missing Origin header: {request.url.path}")
|
|
543
|
+
return False
|
|
267
544
|
|
|
268
545
|
allowed_origins = self._get_allowed_origins(request)
|
|
546
|
+
normalized_origin = self._normalize_origin(origin)
|
|
547
|
+
|
|
548
|
+
logger.debug(
|
|
549
|
+
f"Validating WebSocket origin: {origin} (normalized: {normalized_origin}) "
|
|
550
|
+
f"against allowed: {allowed_origins}"
|
|
551
|
+
)
|
|
269
552
|
|
|
270
553
|
for allowed in allowed_origins:
|
|
271
554
|
if allowed == "*":
|
|
@@ -274,14 +557,23 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|
|
274
557
|
"not recommended for production"
|
|
275
558
|
)
|
|
276
559
|
return True
|
|
277
|
-
|
|
560
|
+
|
|
561
|
+
normalized_allowed = self._normalize_origin(allowed)
|
|
562
|
+
if normalized_origin == normalized_allowed:
|
|
563
|
+
logger.debug(
|
|
564
|
+
f"✅ WebSocket origin validated: {origin} matches {allowed} "
|
|
565
|
+
f"(normalized: {normalized_origin} == {normalized_allowed})"
|
|
566
|
+
)
|
|
278
567
|
return True
|
|
279
568
|
|
|
280
569
|
cors_config = getattr(request.app.state, "cors_config", None)
|
|
281
570
|
cors_enabled = cors_config.get("enabled", False) if cors_config else False
|
|
571
|
+
normalized_allowed_list = [self._normalize_origin(a) for a in allowed_origins]
|
|
282
572
|
logger.warning(
|
|
283
|
-
f"WebSocket upgrade rejected - invalid Origin: {origin} "
|
|
284
|
-
f"(
|
|
573
|
+
f"❌ WebSocket upgrade rejected - invalid Origin: {origin} "
|
|
574
|
+
f"(normalized: {normalized_origin}, allowed: {allowed_origins}, "
|
|
575
|
+
f"normalized_allowed: {normalized_allowed_list}, "
|
|
576
|
+
f"app: {getattr(request.app, 'title', 'unknown')}, "
|
|
285
577
|
f"path: {request.url.path}, CORS enabled: {cors_enabled}, "
|
|
286
578
|
f"has_cors_config: {hasattr(request.app.state, 'cors_config')})"
|
|
287
579
|
)
|
|
@@ -295,79 +587,369 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|
|
295
587
|
"""
|
|
296
588
|
Process request through CSRF middleware.
|
|
297
589
|
"""
|
|
590
|
+
# CRITICAL: Log EVERY request immediately to catch WebSocket upgrades
|
|
298
591
|
path = request.url.path
|
|
299
592
|
method = request.method
|
|
593
|
+
upgrade_header = request.headers.get("upgrade", "").lower()
|
|
594
|
+
connection_header = request.headers.get("connection", "").lower()
|
|
595
|
+
origin_header = request.headers.get("origin")
|
|
596
|
+
|
|
597
|
+
# Log ALL WebSocket-related requests IMMEDIATELY (before any processing)
|
|
598
|
+
if upgrade_header or "websocket" in path.lower() or connection_header == "upgrade":
|
|
599
|
+
import sys
|
|
600
|
+
|
|
601
|
+
print(
|
|
602
|
+
f"🚨 [CSRF MIDDLEWARE ENTRY] {method} {path}, "
|
|
603
|
+
f"upgrade={upgrade_header}, connection={connection_header}, "
|
|
604
|
+
f"origin={origin_header}",
|
|
605
|
+
file=sys.stderr,
|
|
606
|
+
flush=True,
|
|
607
|
+
)
|
|
608
|
+
logger.info(
|
|
609
|
+
f"🚨 [CSRF MIDDLEWARE ENTRY] {method} {path}, "
|
|
610
|
+
f"upgrade={upgrade_header}, connection={connection_header}, "
|
|
611
|
+
f"origin={origin_header}"
|
|
612
|
+
)
|
|
300
613
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
f"
|
|
308
|
-
f"origin={
|
|
309
|
-
f"allowed={self._get_allowed_origins(request)}"
|
|
614
|
+
try:
|
|
615
|
+
# CRITICAL: Log ALL requests to verify middleware is running
|
|
616
|
+
# Always log WebSocket-related requests
|
|
617
|
+
if upgrade_header or "websocket" in path.lower() or connection_header == "upgrade":
|
|
618
|
+
logger.info(
|
|
619
|
+
f"🔍 CSRF middleware INTERCEPTED: {method} {path}, "
|
|
620
|
+
f"upgrade={upgrade_header}, connection={connection_header}, "
|
|
621
|
+
f"origin={origin_header}"
|
|
310
622
|
)
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
623
|
+
import sys
|
|
624
|
+
|
|
625
|
+
print(
|
|
626
|
+
f"🔍 [CSRF MIDDLEWARE] {method} {path}, "
|
|
627
|
+
f"upgrade={upgrade_header}, origin={origin_header}",
|
|
628
|
+
file=sys.stderr,
|
|
629
|
+
flush=True,
|
|
314
630
|
)
|
|
315
631
|
|
|
316
|
-
#
|
|
317
|
-
#
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
#
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
632
|
+
# CRITICAL: Handle WebSocket upgrade requests BEFORE other CSRF checks
|
|
633
|
+
# WebSocket upgrades use cookie-based authentication and require CSRF validation
|
|
634
|
+
is_ws_upgrade = self._is_websocket_upgrade(request)
|
|
635
|
+
logger.info(f"🔍 WebSocket upgrade detection for {path}: is_websocket={is_ws_upgrade}")
|
|
636
|
+
|
|
637
|
+
if is_ws_upgrade:
|
|
638
|
+
logger.info(
|
|
639
|
+
f"🔌 CSRF middleware processing WebSocket upgrade: {path}, "
|
|
640
|
+
f"origin: {request.headers.get('origin')}"
|
|
641
|
+
)
|
|
642
|
+
# Always validate origin for WebSocket connections (CSWSH protection)
|
|
643
|
+
origin_valid = self._validate_websocket_origin(request)
|
|
644
|
+
logger.info(
|
|
645
|
+
f"🔍 WebSocket origin validation for {path}: "
|
|
646
|
+
f"origin={request.headers.get('origin')}, "
|
|
647
|
+
f"allowed={self._get_allowed_origins(request)}, "
|
|
648
|
+
f"valid={origin_valid}"
|
|
649
|
+
)
|
|
650
|
+
if not origin_valid:
|
|
651
|
+
logger.warning(
|
|
652
|
+
f"❌ WebSocket origin validation failed for {path}: "
|
|
653
|
+
f"origin={request.headers.get('origin')}, "
|
|
654
|
+
f"allowed={self._get_allowed_origins(request)}"
|
|
655
|
+
)
|
|
334
656
|
return JSONResponse(
|
|
335
657
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
336
|
-
content={"detail": "
|
|
658
|
+
content={"detail": "Invalid origin for WebSocket connection"},
|
|
337
659
|
)
|
|
338
660
|
|
|
339
|
-
#
|
|
340
|
-
if
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
661
|
+
# Cookie-based authentication requires CSRF protection
|
|
662
|
+
# Check if authentication token cookie is present
|
|
663
|
+
# Use same cookie name as SharedAuthMiddleware for consistency
|
|
664
|
+
from .shared_middleware import AUTH_COOKIE_NAME
|
|
665
|
+
|
|
666
|
+
auth_token_cookie = request.cookies.get(AUTH_COOKIE_NAME)
|
|
667
|
+
logger.info(
|
|
668
|
+
f"🔍 WebSocket auth check for {path}: "
|
|
669
|
+
f"auth_cookie={'present' if auth_token_cookie else 'missing'}"
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
# Check if ticket/session key authentication is required
|
|
673
|
+
# (csrf_required flag controls whether WebSocket needs ticket/session key)
|
|
674
|
+
# If csrf_required=false, we skip ticket validation entirely
|
|
675
|
+
csrf_required = self._websocket_requires_csrf(request, path)
|
|
676
|
+
logger.info(
|
|
677
|
+
f"🔍 WebSocket auth check for {path}: "
|
|
678
|
+
f"ticket/session_key_required={csrf_required}"
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
# Only validate ticket/session key if:
|
|
682
|
+
# 1. Auth cookie is present (user is authenticated)
|
|
683
|
+
# 2. Ticket/session key is required for this endpoint
|
|
684
|
+
if auth_token_cookie and csrf_required:
|
|
685
|
+
# WebSocket Authentication (NOT CSRF):
|
|
686
|
+
# WebSockets use JWT → Ticket → WebSocket flow for authentication.
|
|
687
|
+
# CSRF protection comes from Origin validation + SameSite cookies.
|
|
688
|
+
#
|
|
689
|
+
# Authentication Methods (in order of preference):
|
|
690
|
+
# 1. Ticket (JWT → Ticket exchange) - preferred for single-app
|
|
691
|
+
# - Client: POST /auth/ticket (sends JWT cookie)
|
|
692
|
+
# - Server: Validates JWT, generates ticket (UUID)
|
|
693
|
+
# - Client: ws://host/app/ws?ticket=<uuid>
|
|
694
|
+
# - Server: Validates & consumes ticket (single-use)
|
|
695
|
+
# 2. Session key - preferred for multi-app SSO
|
|
696
|
+
# - Generated via /auth/websocket-session endpoint
|
|
697
|
+
# - Encrypted, database-backed, long TTL (24h)
|
|
698
|
+
# 3. CSRF cookie - backward compatibility only
|
|
699
|
+
#
|
|
700
|
+
# CSRF Protection (separate from authentication):
|
|
701
|
+
# - Origin validation (already done above) - primary CSRF defense
|
|
702
|
+
# - SameSite cookies - prevents cross-site cookie sending
|
|
703
|
+
#
|
|
704
|
+
# Ticket flow: JWT (httpOnly cookie) → POST /auth/ticket → Ticket (UUID)
|
|
705
|
+
# → WebSocket connection with ticket → Validated & consumed
|
|
706
|
+
|
|
707
|
+
# Check for session key first (preferred for multi-app setups)
|
|
708
|
+
session_key = request.query_params.get("session_key") or request.headers.get(
|
|
709
|
+
"X-WebSocket-Session-Key"
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
# Check for ticket (preferred for single-app setups)
|
|
713
|
+
ticket = request.query_params.get("ticket") or request.headers.get(
|
|
714
|
+
"X-WebSocket-Ticket"
|
|
349
715
|
)
|
|
350
716
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
717
|
+
if session_key:
|
|
718
|
+
# Session key authentication (bypasses CSRF - encrypted and secure)
|
|
719
|
+
# For WebSocket upgrades, let the handler validate session keys
|
|
720
|
+
# This allows TestClient to catch WebSocketDisconnect exceptions properly
|
|
721
|
+
# The handler will validate and raise WebSocketDisconnect if invalid
|
|
722
|
+
websocket_session_manager = None
|
|
723
|
+
app = request.app
|
|
724
|
+
apps_checked = []
|
|
725
|
+
while app:
|
|
726
|
+
app_title = getattr(app, "title", "unknown")
|
|
727
|
+
apps_checked.append(app_title)
|
|
728
|
+
websocket_session_manager = getattr(
|
|
729
|
+
app.state, "websocket_session_manager", None
|
|
730
|
+
)
|
|
731
|
+
if websocket_session_manager:
|
|
732
|
+
logger.debug(
|
|
733
|
+
f"Found websocket_session_manager on app '{app_title}' "
|
|
734
|
+
f"for WebSocket path '{path}' (checked: {apps_checked})"
|
|
735
|
+
)
|
|
736
|
+
break
|
|
737
|
+
parent_app = getattr(app, "app", None)
|
|
738
|
+
if parent_app is app: # Prevent infinite loop
|
|
739
|
+
break
|
|
740
|
+
app = parent_app
|
|
741
|
+
|
|
742
|
+
if not websocket_session_manager:
|
|
743
|
+
logger.error(
|
|
744
|
+
f"❌ WebSocket session key provided for {path} but "
|
|
745
|
+
"websocket_session_manager not found"
|
|
746
|
+
)
|
|
747
|
+
return JSONResponse(
|
|
748
|
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
749
|
+
content={
|
|
750
|
+
"detail": (
|
|
751
|
+
"WebSocket session manager not available. "
|
|
752
|
+
"Server configuration error."
|
|
753
|
+
)
|
|
754
|
+
},
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
# For WebSocket upgrades, let the handler validate the session key
|
|
758
|
+
# This ensures TestClient can catch WebSocketDisconnect exceptions
|
|
759
|
+
# The handler will validate and raise WebSocketDisconnect if invalid
|
|
760
|
+
logger.info(
|
|
761
|
+
f"✅ WebSocket session key provided for {path} - "
|
|
762
|
+
"CSRF validation bypassed (session key will be validated in handler)"
|
|
763
|
+
)
|
|
764
|
+
elif ticket:
|
|
765
|
+
# Ticket-based authentication (preferred)
|
|
766
|
+
# Get WebSocket ticket store
|
|
767
|
+
from ..routing.websockets import _global_websocket_ticket_store
|
|
768
|
+
|
|
769
|
+
websocket_ticket_store = _global_websocket_ticket_store
|
|
770
|
+
|
|
771
|
+
# Fallback: Try to get from app state (for backward compatibility)
|
|
772
|
+
if not websocket_ticket_store:
|
|
773
|
+
app = request.app
|
|
774
|
+
apps_checked = []
|
|
775
|
+
while app:
|
|
776
|
+
app_title = getattr(app, "title", "unknown")
|
|
777
|
+
apps_checked.append(app_title)
|
|
778
|
+
|
|
779
|
+
# Get ticket store
|
|
780
|
+
websocket_ticket_store = getattr(
|
|
781
|
+
app.state, "websocket_ticket_store", None
|
|
782
|
+
)
|
|
783
|
+
if websocket_ticket_store:
|
|
784
|
+
logger.debug(
|
|
785
|
+
f"Found websocket_ticket_store on app '{app_title}' "
|
|
786
|
+
f"for WebSocket path '{path}' (checked: {apps_checked})"
|
|
787
|
+
)
|
|
788
|
+
break
|
|
789
|
+
|
|
790
|
+
# Try to get parent app
|
|
791
|
+
parent_app = getattr(app, "app", None)
|
|
792
|
+
if parent_app is app: # Prevent infinite loop
|
|
793
|
+
break
|
|
794
|
+
app = parent_app
|
|
795
|
+
|
|
796
|
+
if not websocket_ticket_store:
|
|
797
|
+
logger.error(
|
|
798
|
+
f"❌ WebSocket ticket store not available for {path}. "
|
|
799
|
+
"Ticket authentication requires websocket_ticket_store "
|
|
800
|
+
"to be initialized."
|
|
801
|
+
)
|
|
802
|
+
return JSONResponse(
|
|
803
|
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
804
|
+
content={
|
|
805
|
+
"detail": (
|
|
806
|
+
"WebSocket ticket store not available. "
|
|
807
|
+
"Server configuration error."
|
|
808
|
+
)
|
|
809
|
+
},
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
# Validate and consume ticket (atomic operation - single-use)
|
|
813
|
+
try:
|
|
814
|
+
logger.info(
|
|
815
|
+
f"🔍 Validating WebSocket ticket for {path}: "
|
|
816
|
+
f"ticket={ticket[:16]}... (truncated)"
|
|
817
|
+
)
|
|
818
|
+
ticket_data = await websocket_ticket_store.validate_and_consume_ticket(
|
|
819
|
+
ticket
|
|
820
|
+
)
|
|
821
|
+
if not ticket_data:
|
|
822
|
+
logger.error(
|
|
823
|
+
f"❌ WebSocket ticket validation failed for {path}. "
|
|
824
|
+
f"Ticket: {ticket[:16]}... "
|
|
825
|
+
"Ticket may be expired, invalid, or already used."
|
|
826
|
+
)
|
|
827
|
+
return JSONResponse(
|
|
828
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
829
|
+
content={
|
|
830
|
+
"detail": (
|
|
831
|
+
"WebSocket ticket expired or invalid. "
|
|
832
|
+
"Generate a new ticket via /auth/ticket endpoint."
|
|
833
|
+
)
|
|
834
|
+
},
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
# Store ticket data in request state for WebSocket handler
|
|
838
|
+
request.state.websocket_session = ticket_data
|
|
839
|
+
logger.info(
|
|
840
|
+
f"✅ WebSocket ticket validated for {path} "
|
|
841
|
+
f"(user_id: {ticket_data.get('user_id')}, "
|
|
842
|
+
f"user_email: {ticket_data.get('user_email')})"
|
|
843
|
+
)
|
|
844
|
+
except (
|
|
845
|
+
ValueError,
|
|
846
|
+
TypeError,
|
|
847
|
+
AttributeError,
|
|
848
|
+
RuntimeError,
|
|
849
|
+
) as e:
|
|
850
|
+
logger.error(
|
|
851
|
+
f"❌ Error validating WebSocket ticket for {path}: {e}",
|
|
852
|
+
exc_info=True,
|
|
853
|
+
)
|
|
854
|
+
return JSONResponse(
|
|
855
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
856
|
+
content={
|
|
857
|
+
"detail": "WebSocket ticket validation error. "
|
|
858
|
+
"Generate a new ticket."
|
|
859
|
+
},
|
|
860
|
+
)
|
|
861
|
+
else:
|
|
862
|
+
# Fallback to CSRF cookie validation (backward compatibility)
|
|
863
|
+
# For WebSocket, CSRF header is optional (JS can't set headers on upgrade)
|
|
864
|
+
# but if provided, it must match the cookie
|
|
865
|
+
cookie_token = request.cookies.get(self.cookie_name)
|
|
866
|
+
header_token = request.headers.get(self.header_name)
|
|
867
|
+
|
|
868
|
+
if not cookie_token:
|
|
869
|
+
logger.error(
|
|
870
|
+
f"❌ WebSocket upgrade missing CSRF cookie for {path}. "
|
|
871
|
+
"CSRF protection is required. "
|
|
872
|
+
"Generate ticket via /auth/ticket endpoint or include CSRF cookie."
|
|
873
|
+
)
|
|
874
|
+
return JSONResponse(
|
|
875
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
876
|
+
content={
|
|
877
|
+
"detail": (
|
|
878
|
+
"CSRF token missing. "
|
|
879
|
+
"Generate ticket via /auth/ticket endpoint "
|
|
880
|
+
"or include CSRF cookie/token."
|
|
881
|
+
)
|
|
882
|
+
},
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
# If header is provided, validate it matches the cookie
|
|
886
|
+
if header_token:
|
|
887
|
+
if not hmac.compare_digest(cookie_token, header_token):
|
|
888
|
+
logger.error(
|
|
889
|
+
f"❌ WebSocket CSRF token mismatch for {path}. "
|
|
890
|
+
"Header token does not match cookie token."
|
|
891
|
+
)
|
|
892
|
+
return JSONResponse(
|
|
893
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
894
|
+
content={"detail": "CSRF token invalid."},
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
# Validate CSRF token (check signature if secret is used)
|
|
898
|
+
if not self._validate_csrf_token(cookie_token, request):
|
|
899
|
+
logger.error(f"❌ WebSocket CSRF token validation failed for {path}")
|
|
900
|
+
return JSONResponse(
|
|
901
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
902
|
+
content={"detail": "CSRF token validation failed."},
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
logger.info(
|
|
906
|
+
f"✅ WebSocket CSRF cookie validated for {path} "
|
|
907
|
+
"(backward compatibility mode)"
|
|
361
908
|
)
|
|
909
|
+
elif auth_token_cookie and not csrf_required:
|
|
910
|
+
logger.info(
|
|
911
|
+
f"✅ WebSocket CSRF validation skipped for {path} "
|
|
912
|
+
f"(csrf_required=false) - only origin validation performed"
|
|
913
|
+
)
|
|
914
|
+
elif not auth_token_cookie:
|
|
915
|
+
logger.info(
|
|
916
|
+
f"✅ WebSocket connection allowed for {path} "
|
|
917
|
+
f"(no auth cookie - WebSocket handler will authenticate)"
|
|
918
|
+
)
|
|
362
919
|
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
920
|
+
validation_status = (
|
|
921
|
+
"CSRF/ticket validated"
|
|
922
|
+
if auth_token_cookie and csrf_required
|
|
923
|
+
else "CSRF skipped"
|
|
924
|
+
)
|
|
925
|
+
logger.info(
|
|
926
|
+
f"✅ WebSocket upgrade CSRF validation passed for {path} "
|
|
927
|
+
f"(Origin validated, {validation_status})"
|
|
366
928
|
)
|
|
367
929
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
930
|
+
# Origin validated (and CSRF/ticket validated if authenticated
|
|
931
|
+
# and csrf_required=true)
|
|
932
|
+
# Allow WebSocket upgrade to proceed to WebSocket handler
|
|
933
|
+
logger.debug(f"✅ WebSocket upgrade request allowed to proceed: {path}")
|
|
934
|
+
return await call_next(request)
|
|
935
|
+
|
|
936
|
+
except (
|
|
937
|
+
AttributeError,
|
|
938
|
+
KeyError,
|
|
939
|
+
RuntimeError,
|
|
940
|
+
ValueError,
|
|
941
|
+
TypeError,
|
|
942
|
+
ConnectionError,
|
|
943
|
+
) as e:
|
|
944
|
+
# Catch exceptions in WebSocket handling to see what's failing
|
|
945
|
+
logger.error(
|
|
946
|
+
f"❌ CRITICAL: Exception in CSRF middleware WebSocket handling: {e}", exc_info=True
|
|
947
|
+
)
|
|
948
|
+
import sys
|
|
949
|
+
|
|
950
|
+
print(f"❌ [CSRF MIDDLEWARE EXCEPTION] {e}", file=sys.stderr, flush=True)
|
|
951
|
+
# Re-raise to see the full error
|
|
952
|
+
raise
|
|
371
953
|
|
|
372
954
|
if self._is_exempt(path):
|
|
373
955
|
return await call_next(request)
|