mdb-engine 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/auth/__init__.py +9 -0
- mdb_engine/auth/csrf.py +493 -144
- mdb_engine/auth/provider.py +10 -0
- mdb_engine/auth/shared_users.py +41 -0
- mdb_engine/auth/users.py +2 -1
- mdb_engine/auth/websocket_tickets.py +307 -0
- mdb_engine/core/app_registration.py +10 -0
- mdb_engine/core/engine.py +632 -37
- mdb_engine/core/manifest.py +14 -0
- mdb_engine/core/ray_integration.py +4 -4
- mdb_engine/core/types.py +1 -0
- mdb_engine/database/connection.py +6 -3
- mdb_engine/database/scoped_wrapper.py +3 -3
- mdb_engine/indexes/manager.py +3 -3
- mdb_engine/observability/health.py +7 -7
- mdb_engine/routing/README.md +9 -2
- mdb_engine/routing/websockets.py +453 -74
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.0.dist-info}/METADATA +128 -4
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.0.dist-info}/RECORD +23 -22
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.0.dist-info}/WHEEL +0 -0
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.0.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.0.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.0.dist-info}/top_level.txt +0 -0
mdb_engine/auth/csrf.py
CHANGED
|
@@ -205,6 +205,19 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|
|
205
205
|
return True
|
|
206
206
|
return False
|
|
207
207
|
|
|
208
|
+
def _validate_csrf_token(self, token: str, request: Request) -> bool:
|
|
209
|
+
"""
|
|
210
|
+
Validate a CSRF token using the middleware's secret and TTL.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
token: CSRF token to validate
|
|
214
|
+
request: FastAPI request (unused, kept for API consistency)
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
True if token is valid, False otherwise
|
|
218
|
+
"""
|
|
219
|
+
return validate_csrf_token(token, self.secret, self.token_ttl)
|
|
220
|
+
|
|
208
221
|
def _websocket_requires_csrf(self, request: Request, path: str) -> bool:
|
|
209
222
|
"""
|
|
210
223
|
Check if WebSocket endpoint requires CSRF validation.
|
|
@@ -214,63 +227,133 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|
|
214
227
|
|
|
215
228
|
Args:
|
|
216
229
|
request: FastAPI request
|
|
217
|
-
path: WebSocket path (e.g., "/app
|
|
230
|
+
path: WebSocket path (e.g., "/chat-app/ws")
|
|
218
231
|
|
|
219
232
|
Returns:
|
|
220
233
|
True if CSRF validation is required, False otherwise
|
|
221
234
|
"""
|
|
222
|
-
#
|
|
235
|
+
# Try parent app first (where websocket_configs should be stored)
|
|
223
236
|
websocket_configs = getattr(request.app.state, "websocket_configs", None)
|
|
237
|
+
|
|
238
|
+
# If not found, try to traverse up to find parent app
|
|
239
|
+
if not websocket_configs:
|
|
240
|
+
logger.debug(f"No websocket_configs found on request.app.state for path '{path}'")
|
|
241
|
+
app = request.app
|
|
242
|
+
apps_checked = []
|
|
243
|
+
while app:
|
|
244
|
+
app_title = getattr(app, "title", "unknown")
|
|
245
|
+
apps_checked.append(app_title)
|
|
246
|
+
websocket_configs = getattr(app.state, "websocket_configs", None)
|
|
247
|
+
if websocket_configs:
|
|
248
|
+
logger.debug(
|
|
249
|
+
f"Found websocket_configs on app '{app_title}' "
|
|
250
|
+
f"(checked: {apps_checked})"
|
|
251
|
+
)
|
|
252
|
+
break
|
|
253
|
+
parent_app = getattr(app, "app", None)
|
|
254
|
+
if parent_app is app: # Prevent infinite loop
|
|
255
|
+
break
|
|
256
|
+
app = parent_app
|
|
257
|
+
|
|
224
258
|
if not websocket_configs:
|
|
225
259
|
# No WebSocket configs found - use default (CSRF required for security by default)
|
|
260
|
+
logger.debug(
|
|
261
|
+
f"No websocket_configs found anywhere for path '{path}' - "
|
|
262
|
+
f"using default csrf_required=true"
|
|
263
|
+
)
|
|
226
264
|
return True
|
|
227
265
|
|
|
228
|
-
# Normalize path for matching
|
|
266
|
+
# Normalize path for matching (handle trailing slashes)
|
|
229
267
|
normalized_path = path.rstrip("/")
|
|
268
|
+
logger.debug(
|
|
269
|
+
f"Checking CSRF requirement for path '{normalized_path}' "
|
|
270
|
+
f"against {len(websocket_configs)} app config(s)"
|
|
271
|
+
)
|
|
230
272
|
|
|
231
273
|
# Try to find matching app config
|
|
232
274
|
# WebSocket paths are registered as /app-slug/endpoint-path
|
|
233
|
-
# e.g., /app
|
|
275
|
+
# e.g., /chat-app/ws where app_slug="chat-app" and endpoint_path="/ws"
|
|
234
276
|
for app_slug, config in websocket_configs.items():
|
|
277
|
+
logger.debug(f"Checking app '{app_slug}' config with {len(config)} endpoint(s)")
|
|
235
278
|
# Check each endpoint in this app's config
|
|
236
279
|
for endpoint_name, endpoint_config in config.items():
|
|
237
280
|
endpoint_path = endpoint_config.get("path", "")
|
|
238
281
|
# Normalize endpoint path
|
|
239
282
|
normalized_endpoint = endpoint_path.rstrip("/")
|
|
240
283
|
|
|
284
|
+
# Build expected full path: /app-slug/endpoint-path
|
|
285
|
+
if normalized_endpoint.startswith("/"):
|
|
286
|
+
expected_full_path = f"/{app_slug}{normalized_endpoint}"
|
|
287
|
+
else:
|
|
288
|
+
expected_full_path = f"/{app_slug}/{normalized_endpoint}"
|
|
289
|
+
|
|
241
290
|
# Match patterns:
|
|
242
|
-
# 1. Full path match: /app-
|
|
243
|
-
# 2. Endpoint-only match: /
|
|
244
|
-
|
|
245
|
-
|
|
291
|
+
# 1. Full path match: /chat-app/ws == /chat-app/ws
|
|
292
|
+
# 2. Endpoint-only match: /ws (if path ends with endpoint)
|
|
293
|
+
# 3. Path contains endpoint: /chat-app/ws contains /ws
|
|
294
|
+
matches = (
|
|
246
295
|
normalized_path == expected_full_path
|
|
247
|
-
or normalized_path.endswith(normalized_endpoint)
|
|
248
296
|
or normalized_path == normalized_endpoint
|
|
249
|
-
|
|
297
|
+
or normalized_path.endswith(normalized_endpoint)
|
|
298
|
+
or normalized_path.endswith(f"/{app_slug}{normalized_endpoint}")
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if matches:
|
|
250
302
|
auth_config = endpoint_config.get("auth", {})
|
|
251
303
|
if isinstance(auth_config, dict):
|
|
252
304
|
# Return csrf_required setting (defaults to True - security by default)
|
|
253
305
|
csrf_required = auth_config.get("csrf_required", True)
|
|
254
|
-
logger.
|
|
255
|
-
f"WebSocket {
|
|
256
|
-
f"(from app={app_slug}, endpoint={endpoint_name}
|
|
306
|
+
logger.info(
|
|
307
|
+
f"✅ WebSocket '{normalized_path}' csrf_required={csrf_required} "
|
|
308
|
+
f"(from app='{app_slug}', endpoint='{endpoint_name}', "
|
|
309
|
+
f"endpoint_path='{normalized_endpoint}')"
|
|
257
310
|
)
|
|
258
311
|
return csrf_required
|
|
312
|
+
else:
|
|
313
|
+
logger.debug(
|
|
314
|
+
f"WebSocket '{normalized_path}' auth_config is not a dict: "
|
|
315
|
+
f"{type(auth_config)}"
|
|
316
|
+
)
|
|
259
317
|
|
|
260
318
|
# No matching config found - use default (CSRF required for security by default)
|
|
261
|
-
logger.debug(
|
|
319
|
+
logger.debug(
|
|
320
|
+
f"❌ No WebSocket config match for '{normalized_path}' "
|
|
321
|
+
f"(checked {len(websocket_configs)} app(s)) - using default csrf_required=true"
|
|
322
|
+
)
|
|
262
323
|
return True
|
|
263
324
|
|
|
264
325
|
def _is_websocket_upgrade(self, request: Request) -> bool:
|
|
265
326
|
"""Check if request is a WebSocket upgrade request."""
|
|
266
327
|
upgrade_header = request.headers.get("upgrade", "").lower()
|
|
267
328
|
connection_header = request.headers.get("connection", "").lower()
|
|
329
|
+
path = request.url.path
|
|
330
|
+
|
|
331
|
+
# CRITICAL: Enhanced logging for WebSocket detection
|
|
332
|
+
import sys
|
|
333
|
+
|
|
334
|
+
print(
|
|
335
|
+
f"🔍 [_is_websocket_upgrade] Path: {path}, "
|
|
336
|
+
f"upgrade='{upgrade_header}', connection='{connection_header}'",
|
|
337
|
+
file=sys.stderr,
|
|
338
|
+
flush=True,
|
|
339
|
+
)
|
|
340
|
+
logger.info(
|
|
341
|
+
f"_is_websocket_upgrade check: upgrade='{upgrade_header}', "
|
|
342
|
+
f"connection='{connection_header}', path='{path}'"
|
|
343
|
+
)
|
|
268
344
|
|
|
269
345
|
# Primary check: WebSocket upgrade requires both Upgrade: websocket
|
|
270
346
|
# and Connection: Upgrade headers
|
|
271
347
|
has_upgrade_header = upgrade_header == "websocket"
|
|
272
348
|
has_connection_upgrade = "upgrade" in connection_header or "websocket" in connection_header
|
|
273
349
|
|
|
350
|
+
print(
|
|
351
|
+
f"🔍 [_is_websocket_upgrade] has_upgrade={has_upgrade_header}, "
|
|
352
|
+
f"has_connection={has_connection_upgrade}",
|
|
353
|
+
file=sys.stderr,
|
|
354
|
+
flush=True,
|
|
355
|
+
)
|
|
356
|
+
|
|
274
357
|
# Secondary check: If upgrade header is present but connection is
|
|
275
358
|
# overridden (e.g., by TestClient), check if path matches a known
|
|
276
359
|
# WebSocket route pattern
|
|
@@ -308,12 +391,31 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|
|
308
391
|
is_websocket = has_upgrade_header and (
|
|
309
392
|
has_connection_upgrade or path_matches_websocket_route
|
|
310
393
|
)
|
|
394
|
+
|
|
395
|
+
# CRITICAL: Enhanced logging
|
|
396
|
+
import sys
|
|
397
|
+
|
|
311
398
|
if is_websocket:
|
|
312
|
-
|
|
313
|
-
f"WebSocket
|
|
399
|
+
print(
|
|
400
|
+
f"✅ [_is_websocket_upgrade] WebSocket detected: path={path}, "
|
|
401
|
+
f"upgrade={upgrade_header}, connection={connection_header}, "
|
|
402
|
+
f"path_match={path_matches_websocket_route}, result={is_websocket}",
|
|
403
|
+
file=sys.stderr,
|
|
404
|
+
flush=True,
|
|
405
|
+
)
|
|
406
|
+
logger.info(
|
|
407
|
+
f"WebSocket upgrade detected: path={path}, "
|
|
314
408
|
f"upgrade={upgrade_header}, connection={connection_header}, "
|
|
315
409
|
f"path_match={path_matches_websocket_route}"
|
|
316
410
|
)
|
|
411
|
+
else:
|
|
412
|
+
print(
|
|
413
|
+
f"❌ [_is_websocket_upgrade] NOT a WebSocket: path={path}, "
|
|
414
|
+
f"upgrade={upgrade_header}, connection={connection_header}, "
|
|
415
|
+
f"has_upgrade={has_upgrade_header}, has_connection={has_connection_upgrade}",
|
|
416
|
+
file=sys.stderr,
|
|
417
|
+
flush=True,
|
|
418
|
+
)
|
|
317
419
|
return is_websocket
|
|
318
420
|
|
|
319
421
|
def _get_allowed_origins(self, request: Request) -> list[str]:
|
|
@@ -353,33 +455,100 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|
|
353
455
|
except (AttributeError, TypeError, KeyError):
|
|
354
456
|
pass
|
|
355
457
|
|
|
356
|
-
# Final fallback: Use request host
|
|
458
|
+
# Final fallback: Use request host (normalize localhost variants)
|
|
357
459
|
try:
|
|
358
460
|
host = request.url.hostname
|
|
359
461
|
scheme = request.url.scheme
|
|
360
462
|
port = request.url.port
|
|
463
|
+
|
|
464
|
+
# Normalize localhost variants - return all common variants for development
|
|
465
|
+
# This handles cases where server binds to 0.0.0.0 but browser sends localhost
|
|
466
|
+
if host in ["localhost", "0.0.0.0", "127.0.0.1", "::1"]:
|
|
467
|
+
origins = []
|
|
468
|
+
for localhost_variant in ["localhost", "127.0.0.1"]:
|
|
469
|
+
if port and port not in [80, 443]:
|
|
470
|
+
origins.append(f"{scheme}://{localhost_variant}:{port}")
|
|
471
|
+
else:
|
|
472
|
+
origins.append(f"{scheme}://{localhost_variant}")
|
|
473
|
+
logger.debug(f"Generated localhost variant origins for host '{host}': {origins}")
|
|
474
|
+
return origins
|
|
475
|
+
|
|
476
|
+
# For other hosts, use the actual hostname
|
|
361
477
|
if port and port not in [80, 443]:
|
|
362
478
|
origin = f"{scheme}://{host}:{port}"
|
|
363
479
|
else:
|
|
364
480
|
origin = f"{scheme}://{host}"
|
|
365
481
|
return [origin]
|
|
366
|
-
except (AttributeError, TypeError):
|
|
482
|
+
except (AttributeError, TypeError) as e:
|
|
483
|
+
logger.debug(f"Could not determine origin from request: {e}")
|
|
367
484
|
# Return empty list if we can't determine origin (will reject)
|
|
368
485
|
return []
|
|
369
486
|
|
|
487
|
+
def _normalize_origin(self, origin: str) -> str:
|
|
488
|
+
"""
|
|
489
|
+
Normalize origin for comparison (handles localhost/0.0.0.0/127.0.0.1/::1 equivalency).
|
|
490
|
+
|
|
491
|
+
In development, localhost, 0.0.0.0, 127.0.0.1, and ::1 should be treated as equivalent.
|
|
492
|
+
Also normalizes ports (80/443 vs explicit ports).
|
|
493
|
+
"""
|
|
494
|
+
if not origin:
|
|
495
|
+
return origin
|
|
496
|
+
|
|
497
|
+
import re
|
|
498
|
+
|
|
499
|
+
# Normalize localhost variants - replace all variants with localhost
|
|
500
|
+
# Handle IPv4: 0.0.0.0, 127.0.0.1
|
|
501
|
+
# Handle IPv6: ::1
|
|
502
|
+
# Handle hostname: localhost
|
|
503
|
+
normalized = re.sub(
|
|
504
|
+
r"://(0\.0\.0\.0|127\.0\.0\.1|localhost|::1)",
|
|
505
|
+
"://localhost",
|
|
506
|
+
origin.lower(),
|
|
507
|
+
flags=re.IGNORECASE,
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# Normalize ports: remove :80 for http and :443 for https
|
|
511
|
+
normalized = re.sub(r":80$", "", normalized)
|
|
512
|
+
normalized = re.sub(r":443$", "", normalized)
|
|
513
|
+
|
|
514
|
+
return normalized.rstrip("/")
|
|
515
|
+
|
|
516
|
+
def _is_development_mode(self) -> bool:
|
|
517
|
+
"""Check if running in development mode."""
|
|
518
|
+
import os
|
|
519
|
+
|
|
520
|
+
env = os.getenv("ENVIRONMENT", "").lower()
|
|
521
|
+
g_nome_env = os.getenv("G_NOME_ENV", "").lower()
|
|
522
|
+
return env in ["development", "dev"] or g_nome_env in ["development", "dev"]
|
|
523
|
+
|
|
370
524
|
def _validate_websocket_origin(self, request: Request) -> bool:
|
|
371
525
|
"""
|
|
372
526
|
Validate Origin header for WebSocket upgrade requests.
|
|
373
527
|
|
|
374
528
|
Primary defense against Cross-Site WebSocket Hijacking (CSWSH).
|
|
375
529
|
Returns True if Origin is valid, False otherwise.
|
|
530
|
+
|
|
531
|
+
In development mode, allows connections without Origin header (with warning).
|
|
376
532
|
"""
|
|
377
533
|
origin = request.headers.get("origin")
|
|
378
534
|
if not origin:
|
|
379
|
-
|
|
380
|
-
|
|
535
|
+
if self._is_development_mode():
|
|
536
|
+
logger.warning(
|
|
537
|
+
f"WebSocket upgrade missing Origin header in development mode: "
|
|
538
|
+
f"{request.url.path} - allowing connection"
|
|
539
|
+
)
|
|
540
|
+
return True
|
|
541
|
+
else:
|
|
542
|
+
logger.warning(f"WebSocket upgrade missing Origin header: {request.url.path}")
|
|
543
|
+
return False
|
|
381
544
|
|
|
382
545
|
allowed_origins = self._get_allowed_origins(request)
|
|
546
|
+
normalized_origin = self._normalize_origin(origin)
|
|
547
|
+
|
|
548
|
+
logger.debug(
|
|
549
|
+
f"Validating WebSocket origin: {origin} (normalized: {normalized_origin}) "
|
|
550
|
+
f"against allowed: {allowed_origins}"
|
|
551
|
+
)
|
|
383
552
|
|
|
384
553
|
for allowed in allowed_origins:
|
|
385
554
|
if allowed == "*":
|
|
@@ -388,14 +557,23 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|
|
388
557
|
"not recommended for production"
|
|
389
558
|
)
|
|
390
559
|
return True
|
|
391
|
-
|
|
560
|
+
|
|
561
|
+
normalized_allowed = self._normalize_origin(allowed)
|
|
562
|
+
if normalized_origin == normalized_allowed:
|
|
563
|
+
logger.debug(
|
|
564
|
+
f"✅ WebSocket origin validated: {origin} matches {allowed} "
|
|
565
|
+
f"(normalized: {normalized_origin} == {normalized_allowed})"
|
|
566
|
+
)
|
|
392
567
|
return True
|
|
393
568
|
|
|
394
569
|
cors_config = getattr(request.app.state, "cors_config", None)
|
|
395
570
|
cors_enabled = cors_config.get("enabled", False) if cors_config else False
|
|
571
|
+
normalized_allowed_list = [self._normalize_origin(a) for a in allowed_origins]
|
|
396
572
|
logger.warning(
|
|
397
|
-
f"WebSocket upgrade rejected - invalid Origin: {origin} "
|
|
398
|
-
f"(
|
|
573
|
+
f"❌ WebSocket upgrade rejected - invalid Origin: {origin} "
|
|
574
|
+
f"(normalized: {normalized_origin}, allowed: {allowed_origins}, "
|
|
575
|
+
f"normalized_allowed: {normalized_allowed_list}, "
|
|
576
|
+
f"app: {getattr(request.app, 'title', 'unknown')}, "
|
|
399
577
|
f"path: {request.url.path}, CORS enabled: {cors_enabled}, "
|
|
400
578
|
f"has_cors_config: {hasattr(request.app.state, 'cors_config')})"
|
|
401
579
|
)
|
|
@@ -409,198 +587,369 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|
|
409
587
|
"""
|
|
410
588
|
Process request through CSRF middleware.
|
|
411
589
|
"""
|
|
590
|
+
# CRITICAL: Log EVERY request immediately to catch WebSocket upgrades
|
|
412
591
|
path = request.url.path
|
|
413
592
|
method = request.method
|
|
414
|
-
|
|
415
|
-
# Debug: Log all requests to see what's happening
|
|
416
593
|
upgrade_header = request.headers.get("upgrade", "").lower()
|
|
417
594
|
connection_header = request.headers.get("connection", "").lower()
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
595
|
+
origin_header = request.headers.get("origin")
|
|
596
|
+
|
|
597
|
+
# Log ALL WebSocket-related requests IMMEDIATELY (before any processing)
|
|
598
|
+
if upgrade_header or "websocket" in path.lower() or connection_header == "upgrade":
|
|
599
|
+
import sys
|
|
423
600
|
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
601
|
+
print(
|
|
602
|
+
f"🚨 [CSRF MIDDLEWARE ENTRY] {method} {path}, "
|
|
603
|
+
f"upgrade={upgrade_header}, connection={connection_header}, "
|
|
604
|
+
f"origin={origin_header}",
|
|
605
|
+
file=sys.stderr,
|
|
606
|
+
flush=True,
|
|
607
|
+
)
|
|
427
608
|
logger.info(
|
|
428
|
-
f"
|
|
429
|
-
f"
|
|
609
|
+
f"🚨 [CSRF MIDDLEWARE ENTRY] {method} {path}, "
|
|
610
|
+
f"upgrade={upgrade_header}, connection={connection_header}, "
|
|
611
|
+
f"origin={origin_header}"
|
|
430
612
|
)
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
613
|
+
|
|
614
|
+
try:
|
|
615
|
+
# CRITICAL: Log ALL requests to verify middleware is running
|
|
616
|
+
# Always log WebSocket-related requests
|
|
617
|
+
if upgrade_header or "websocket" in path.lower() or connection_header == "upgrade":
|
|
618
|
+
logger.info(
|
|
619
|
+
f"🔍 CSRF middleware INTERCEPTED: {method} {path}, "
|
|
620
|
+
f"upgrade={upgrade_header}, connection={connection_header}, "
|
|
621
|
+
f"origin={origin_header}"
|
|
622
|
+
)
|
|
623
|
+
import sys
|
|
624
|
+
|
|
625
|
+
print(
|
|
626
|
+
f"🔍 [CSRF MIDDLEWARE] {method} {path}, "
|
|
627
|
+
f"upgrade={upgrade_header}, origin={origin_header}",
|
|
628
|
+
file=sys.stderr,
|
|
629
|
+
flush=True,
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
# CRITICAL: Handle WebSocket upgrade requests BEFORE other CSRF checks
|
|
633
|
+
# WebSocket upgrades use cookie-based authentication and require CSRF validation
|
|
634
|
+
is_ws_upgrade = self._is_websocket_upgrade(request)
|
|
635
|
+
logger.info(f"🔍 WebSocket upgrade detection for {path}: is_websocket={is_ws_upgrade}")
|
|
636
|
+
|
|
637
|
+
if is_ws_upgrade:
|
|
638
|
+
logger.info(
|
|
639
|
+
f"🔌 CSRF middleware processing WebSocket upgrade: {path}, "
|
|
640
|
+
f"origin: {request.headers.get('origin')}"
|
|
641
|
+
)
|
|
642
|
+
# Always validate origin for WebSocket connections (CSWSH protection)
|
|
643
|
+
origin_valid = self._validate_websocket_origin(request)
|
|
644
|
+
logger.info(
|
|
645
|
+
f"🔍 WebSocket origin validation for {path}: "
|
|
435
646
|
f"origin={request.headers.get('origin')}, "
|
|
436
|
-
f"allowed={self._get_allowed_origins(request)}"
|
|
647
|
+
f"allowed={self._get_allowed_origins(request)}, "
|
|
648
|
+
f"valid={origin_valid}"
|
|
437
649
|
)
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
650
|
+
if not origin_valid:
|
|
651
|
+
logger.warning(
|
|
652
|
+
f"❌ WebSocket origin validation failed for {path}: "
|
|
653
|
+
f"origin={request.headers.get('origin')}, "
|
|
654
|
+
f"allowed={self._get_allowed_origins(request)}"
|
|
655
|
+
)
|
|
656
|
+
return JSONResponse(
|
|
657
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
658
|
+
content={"detail": "Invalid origin for WebSocket connection"},
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
# Cookie-based authentication requires CSRF protection
|
|
662
|
+
# Check if authentication token cookie is present
|
|
663
|
+
# Use same cookie name as SharedAuthMiddleware for consistency
|
|
664
|
+
from .shared_middleware import AUTH_COOKIE_NAME
|
|
665
|
+
|
|
666
|
+
auth_token_cookie = request.cookies.get(AUTH_COOKIE_NAME)
|
|
667
|
+
logger.info(
|
|
668
|
+
f"🔍 WebSocket auth check for {path}: "
|
|
669
|
+
f"auth_cookie={'present' if auth_token_cookie else 'missing'}"
|
|
441
670
|
)
|
|
442
671
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
from .shared_middleware import AUTH_COOKIE_NAME
|
|
447
|
-
|
|
448
|
-
auth_token_cookie = request.cookies.get(AUTH_COOKIE_NAME)
|
|
449
|
-
if auth_token_cookie:
|
|
450
|
-
# SECURITY BY DEFAULT: WebSocket CSRF protection uses encrypted session keys
|
|
451
|
-
# stored in private collection via envelope encryption.
|
|
452
|
-
#
|
|
453
|
-
# Security Model:
|
|
454
|
-
# 1. Origin validation (already done above) - primary defense
|
|
455
|
-
# 2. Encrypted session key validation - CSRF protection via database
|
|
456
|
-
# 3. SameSite cookies - prevents cross-site cookie sending
|
|
457
|
-
#
|
|
458
|
-
# Session keys are:
|
|
459
|
-
# - Generated on authentication
|
|
460
|
-
# - Encrypted using envelope encryption (same as app secrets)
|
|
461
|
-
# - Stored in _mdb_engine_websocket_sessions private collection
|
|
462
|
-
# - Validated during WebSocket upgrade
|
|
463
|
-
|
|
464
|
-
# Check if this WebSocket endpoint requires CSRF validation
|
|
672
|
+
# Check if ticket/session key authentication is required
|
|
673
|
+
# (csrf_required flag controls whether WebSocket needs ticket/session key)
|
|
674
|
+
# If csrf_required=false, we skip ticket validation entirely
|
|
465
675
|
csrf_required = self._websocket_requires_csrf(request, path)
|
|
676
|
+
logger.info(
|
|
677
|
+
f"🔍 WebSocket auth check for {path}: "
|
|
678
|
+
f"ticket/session_key_required={csrf_required}"
|
|
679
|
+
)
|
|
466
680
|
|
|
467
|
-
if
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
681
|
+
# Only validate ticket/session key if:
|
|
682
|
+
# 1. Auth cookie is present (user is authenticated)
|
|
683
|
+
# 2. Ticket/session key is required for this endpoint
|
|
684
|
+
if auth_token_cookie and csrf_required:
|
|
685
|
+
# WebSocket Authentication (NOT CSRF):
|
|
686
|
+
# WebSockets use JWT → Ticket → WebSocket flow for authentication.
|
|
687
|
+
# CSRF protection comes from Origin validation + SameSite cookies.
|
|
688
|
+
#
|
|
689
|
+
# Authentication Methods (in order of preference):
|
|
690
|
+
# 1. Ticket (JWT → Ticket exchange) - preferred for single-app
|
|
691
|
+
# - Client: POST /auth/ticket (sends JWT cookie)
|
|
692
|
+
# - Server: Validates JWT, generates ticket (UUID)
|
|
693
|
+
# - Client: ws://host/app/ws?ticket=<uuid>
|
|
694
|
+
# - Server: Validates & consumes ticket (single-use)
|
|
695
|
+
# 2. Session key - preferred for multi-app SSO
|
|
696
|
+
# - Generated via /auth/websocket-session endpoint
|
|
697
|
+
# - Encrypted, database-backed, long TTL (24h)
|
|
698
|
+
# 3. CSRF cookie - backward compatibility only
|
|
699
|
+
#
|
|
700
|
+
# CSRF Protection (separate from authentication):
|
|
701
|
+
# - Origin validation (already done above) - primary CSRF defense
|
|
702
|
+
# - SameSite cookies - prevents cross-site cookie sending
|
|
703
|
+
#
|
|
704
|
+
# Ticket flow: JWT (httpOnly cookie) → POST /auth/ticket → Ticket (UUID)
|
|
705
|
+
# → WebSocket connection with ticket → Validated & consumed
|
|
706
|
+
|
|
707
|
+
# Check for session key first (preferred for multi-app setups)
|
|
708
|
+
session_key = request.query_params.get("session_key") or request.headers.get(
|
|
709
|
+
"X-WebSocket-Session-Key"
|
|
471
710
|
)
|
|
472
711
|
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
) or request.headers.get("X-WebSocket-Session-Key")
|
|
712
|
+
# Check for ticket (preferred for single-app setups)
|
|
713
|
+
ticket = request.query_params.get("ticket") or request.headers.get(
|
|
714
|
+
"X-WebSocket-Ticket"
|
|
715
|
+
)
|
|
478
716
|
|
|
479
|
-
|
|
717
|
+
if session_key:
|
|
718
|
+
# Session key authentication (bypasses CSRF - encrypted and secure)
|
|
719
|
+
# For WebSocket upgrades, let the handler validate session keys
|
|
720
|
+
# This allows TestClient to catch WebSocketDisconnect exceptions properly
|
|
721
|
+
# The handler will validate and raise WebSocketDisconnect if invalid
|
|
722
|
+
websocket_session_manager = None
|
|
723
|
+
app = request.app
|
|
724
|
+
apps_checked = []
|
|
725
|
+
while app:
|
|
726
|
+
app_title = getattr(app, "title", "unknown")
|
|
727
|
+
apps_checked.append(app_title)
|
|
728
|
+
websocket_session_manager = getattr(
|
|
729
|
+
app.state, "websocket_session_manager", None
|
|
730
|
+
)
|
|
731
|
+
if websocket_session_manager:
|
|
732
|
+
logger.debug(
|
|
733
|
+
f"Found websocket_session_manager on app '{app_title}' "
|
|
734
|
+
f"for WebSocket path '{path}' (checked: {apps_checked})"
|
|
735
|
+
)
|
|
736
|
+
break
|
|
737
|
+
parent_app = getattr(app, "app", None)
|
|
738
|
+
if parent_app is app: # Prevent infinite loop
|
|
739
|
+
break
|
|
740
|
+
app = parent_app
|
|
741
|
+
|
|
742
|
+
if not websocket_session_manager:
|
|
480
743
|
logger.error(
|
|
481
|
-
f"❌ WebSocket
|
|
482
|
-
|
|
483
|
-
f"Tip: Generate session key via /auth/websocket-session endpoint."
|
|
744
|
+
f"❌ WebSocket session key provided for {path} but "
|
|
745
|
+
"websocket_session_manager not found"
|
|
484
746
|
)
|
|
485
747
|
return JSONResponse(
|
|
486
|
-
status_code=status.
|
|
748
|
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
749
|
+
content={
|
|
750
|
+
"detail": (
|
|
751
|
+
"WebSocket session manager not available. "
|
|
752
|
+
"Server configuration error."
|
|
753
|
+
)
|
|
754
|
+
},
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
# For WebSocket upgrades, let the handler validate the session key
|
|
758
|
+
# This ensures TestClient can catch WebSocketDisconnect exceptions
|
|
759
|
+
# The handler will validate and raise WebSocketDisconnect if invalid
|
|
760
|
+
logger.info(
|
|
761
|
+
f"✅ WebSocket session key provided for {path} - "
|
|
762
|
+
"CSRF validation bypassed (session key will be validated in handler)"
|
|
763
|
+
)
|
|
764
|
+
elif ticket:
|
|
765
|
+
# Ticket-based authentication (preferred)
|
|
766
|
+
# Get WebSocket ticket store
|
|
767
|
+
from ..routing.websockets import _global_websocket_ticket_store
|
|
768
|
+
|
|
769
|
+
websocket_ticket_store = _global_websocket_ticket_store
|
|
770
|
+
|
|
771
|
+
# Fallback: Try to get from app state (for backward compatibility)
|
|
772
|
+
if not websocket_ticket_store:
|
|
773
|
+
app = request.app
|
|
774
|
+
apps_checked = []
|
|
775
|
+
while app:
|
|
776
|
+
app_title = getattr(app, "title", "unknown")
|
|
777
|
+
apps_checked.append(app_title)
|
|
778
|
+
|
|
779
|
+
# Get ticket store
|
|
780
|
+
websocket_ticket_store = getattr(
|
|
781
|
+
app.state, "websocket_ticket_store", None
|
|
782
|
+
)
|
|
783
|
+
if websocket_ticket_store:
|
|
784
|
+
logger.debug(
|
|
785
|
+
f"Found websocket_ticket_store on app '{app_title}' "
|
|
786
|
+
f"for WebSocket path '{path}' (checked: {apps_checked})"
|
|
787
|
+
)
|
|
788
|
+
break
|
|
789
|
+
|
|
790
|
+
# Try to get parent app
|
|
791
|
+
parent_app = getattr(app, "app", None)
|
|
792
|
+
if parent_app is app: # Prevent infinite loop
|
|
793
|
+
break
|
|
794
|
+
app = parent_app
|
|
795
|
+
|
|
796
|
+
if not websocket_ticket_store:
|
|
797
|
+
logger.error(
|
|
798
|
+
f"❌ WebSocket ticket store not available for {path}. "
|
|
799
|
+
"Ticket authentication requires websocket_ticket_store "
|
|
800
|
+
"to be initialized."
|
|
801
|
+
)
|
|
802
|
+
return JSONResponse(
|
|
803
|
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
487
804
|
content={
|
|
488
805
|
"detail": (
|
|
489
|
-
"WebSocket
|
|
490
|
-
"
|
|
806
|
+
"WebSocket ticket store not available. "
|
|
807
|
+
"Server configuration error."
|
|
491
808
|
)
|
|
492
809
|
},
|
|
493
810
|
)
|
|
494
811
|
|
|
495
|
-
# Validate
|
|
812
|
+
# Validate and consume ticket (atomic operation - single-use)
|
|
496
813
|
try:
|
|
497
|
-
|
|
498
|
-
|
|
814
|
+
logger.info(
|
|
815
|
+
f"🔍 Validating WebSocket ticket for {path}: "
|
|
816
|
+
f"ticket={ticket[:16]}... (truncated)"
|
|
817
|
+
)
|
|
818
|
+
ticket_data = await websocket_ticket_store.validate_and_consume_ticket(
|
|
819
|
+
ticket
|
|
499
820
|
)
|
|
500
|
-
if not
|
|
821
|
+
if not ticket_data:
|
|
501
822
|
logger.error(
|
|
502
|
-
f"❌ WebSocket
|
|
503
|
-
f"
|
|
823
|
+
f"❌ WebSocket ticket validation failed for {path}. "
|
|
824
|
+
f"Ticket: {ticket[:16]}... "
|
|
825
|
+
"Ticket may be expired, invalid, or already used."
|
|
504
826
|
)
|
|
505
827
|
return JSONResponse(
|
|
506
828
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
507
829
|
content={
|
|
508
830
|
"detail": (
|
|
509
|
-
"WebSocket
|
|
510
|
-
"Generate a new
|
|
831
|
+
"WebSocket ticket expired or invalid. "
|
|
832
|
+
"Generate a new ticket via /auth/ticket endpoint."
|
|
511
833
|
)
|
|
512
834
|
},
|
|
513
835
|
)
|
|
514
836
|
|
|
515
|
-
# Store
|
|
516
|
-
request.state.websocket_session =
|
|
517
|
-
logger.
|
|
518
|
-
f"✅ WebSocket
|
|
519
|
-
f"(
|
|
837
|
+
# Store ticket data in request state for WebSocket handler
|
|
838
|
+
request.state.websocket_session = ticket_data
|
|
839
|
+
logger.info(
|
|
840
|
+
f"✅ WebSocket ticket validated for {path} "
|
|
841
|
+
f"(user_id: {ticket_data.get('user_id')}, "
|
|
842
|
+
f"user_email: {ticket_data.get('user_email')})"
|
|
520
843
|
)
|
|
521
844
|
except (
|
|
522
845
|
ValueError,
|
|
523
846
|
TypeError,
|
|
524
847
|
AttributeError,
|
|
525
848
|
RuntimeError,
|
|
526
|
-
):
|
|
527
|
-
logger.exception("Error validating WebSocket session key")
|
|
528
|
-
return JSONResponse(
|
|
529
|
-
status_code=status.HTTP_403_FORBIDDEN,
|
|
530
|
-
content={"detail": "WebSocket session validation error"},
|
|
531
|
-
)
|
|
532
|
-
else:
|
|
533
|
-
# Fallback to cookie-based CSRF (backward compatibility)
|
|
534
|
-
csrf_cookie_token = request.cookies.get(self.cookie_name)
|
|
535
|
-
if not csrf_cookie_token:
|
|
849
|
+
) as e:
|
|
536
850
|
logger.error(
|
|
537
|
-
f"❌
|
|
538
|
-
|
|
539
|
-
f"CSRF cookie name: {self.cookie_name}, "
|
|
540
|
-
f"Available cookies: {list(request.cookies.keys())}. "
|
|
541
|
-
f"Tip: Make a GET request first to receive CSRF cookie."
|
|
851
|
+
f"❌ Error validating WebSocket ticket for {path}: {e}",
|
|
852
|
+
exc_info=True,
|
|
542
853
|
)
|
|
543
854
|
return JSONResponse(
|
|
544
855
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
545
856
|
content={
|
|
546
|
-
"detail":
|
|
547
|
-
|
|
548
|
-
"Make a GET request first to receive the CSRF cookie."
|
|
549
|
-
)
|
|
857
|
+
"detail": "WebSocket ticket validation error. "
|
|
858
|
+
"Generate a new ticket."
|
|
550
859
|
},
|
|
551
860
|
)
|
|
861
|
+
else:
|
|
862
|
+
# Fallback to CSRF cookie validation (backward compatibility)
|
|
863
|
+
# For WebSocket, CSRF header is optional (JS can't set headers on upgrade)
|
|
864
|
+
# but if provided, it must match the cookie
|
|
865
|
+
cookie_token = request.cookies.get(self.cookie_name)
|
|
866
|
+
header_token = request.headers.get(self.header_name)
|
|
552
867
|
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
868
|
+
if not cookie_token:
|
|
869
|
+
logger.error(
|
|
870
|
+
f"❌ WebSocket upgrade missing CSRF cookie for {path}. "
|
|
871
|
+
"CSRF protection is required. "
|
|
872
|
+
"Generate ticket via /auth/ticket endpoint or include CSRF cookie."
|
|
873
|
+
)
|
|
558
874
|
return JSONResponse(
|
|
559
875
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
560
876
|
content={
|
|
561
877
|
"detail": (
|
|
562
|
-
"CSRF token
|
|
878
|
+
"CSRF token missing. "
|
|
879
|
+
"Generate ticket via /auth/ticket endpoint "
|
|
880
|
+
"or include CSRF cookie/token."
|
|
563
881
|
)
|
|
564
882
|
},
|
|
565
883
|
)
|
|
566
884
|
|
|
567
|
-
# If
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
if csrf_header_token:
|
|
571
|
-
if not hmac.compare_digest(csrf_cookie_token, csrf_header_token):
|
|
885
|
+
# If header is provided, validate it matches the cookie
|
|
886
|
+
if header_token:
|
|
887
|
+
if not hmac.compare_digest(cookie_token, header_token):
|
|
572
888
|
logger.error(
|
|
573
|
-
f"❌ WebSocket CSRF
|
|
574
|
-
|
|
889
|
+
f"❌ WebSocket CSRF token mismatch for {path}. "
|
|
890
|
+
"Header token does not match cookie token."
|
|
575
891
|
)
|
|
576
892
|
return JSONResponse(
|
|
577
893
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
578
|
-
content={
|
|
579
|
-
"detail": (
|
|
580
|
-
"CSRF token mismatch: header token does not "
|
|
581
|
-
"match cookie token"
|
|
582
|
-
)
|
|
583
|
-
},
|
|
894
|
+
content={"detail": "CSRF token invalid."},
|
|
584
895
|
)
|
|
585
|
-
|
|
586
|
-
|
|
896
|
+
|
|
897
|
+
# Validate CSRF token (check signature if secret is used)
|
|
898
|
+
if not self._validate_csrf_token(cookie_token, request):
|
|
899
|
+
logger.error(f"❌ WebSocket CSRF token validation failed for {path}")
|
|
900
|
+
return JSONResponse(
|
|
901
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
902
|
+
content={"detail": "CSRF token validation failed."},
|
|
587
903
|
)
|
|
588
904
|
|
|
589
|
-
logger.
|
|
590
|
-
|
|
591
|
-
|
|
905
|
+
logger.info(
|
|
906
|
+
f"✅ WebSocket CSRF cookie validated for {path} "
|
|
907
|
+
"(backward compatibility mode)"
|
|
908
|
+
)
|
|
909
|
+
elif auth_token_cookie and not csrf_required:
|
|
910
|
+
logger.info(
|
|
592
911
|
f"✅ WebSocket CSRF validation skipped for {path} "
|
|
593
|
-
f"(csrf_required=false
|
|
912
|
+
f"(csrf_required=false) - only origin validation performed"
|
|
913
|
+
)
|
|
914
|
+
elif not auth_token_cookie:
|
|
915
|
+
logger.info(
|
|
916
|
+
f"✅ WebSocket connection allowed for {path} "
|
|
917
|
+
f"(no auth cookie - WebSocket handler will authenticate)"
|
|
594
918
|
)
|
|
595
919
|
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
920
|
+
validation_status = (
|
|
921
|
+
"CSRF/ticket validated"
|
|
922
|
+
if auth_token_cookie and csrf_required
|
|
923
|
+
else "CSRF skipped"
|
|
924
|
+
)
|
|
925
|
+
logger.info(
|
|
926
|
+
f"✅ WebSocket upgrade CSRF validation passed for {path} "
|
|
927
|
+
f"(Origin validated, {validation_status})"
|
|
599
928
|
)
|
|
600
929
|
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
930
|
+
# Origin validated (and CSRF/ticket validated if authenticated
|
|
931
|
+
# and csrf_required=true)
|
|
932
|
+
# Allow WebSocket upgrade to proceed to WebSocket handler
|
|
933
|
+
logger.debug(f"✅ WebSocket upgrade request allowed to proceed: {path}")
|
|
934
|
+
return await call_next(request)
|
|
935
|
+
|
|
936
|
+
except (
|
|
937
|
+
AttributeError,
|
|
938
|
+
KeyError,
|
|
939
|
+
RuntimeError,
|
|
940
|
+
ValueError,
|
|
941
|
+
TypeError,
|
|
942
|
+
ConnectionError,
|
|
943
|
+
) as e:
|
|
944
|
+
# Catch exceptions in WebSocket handling to see what's failing
|
|
945
|
+
logger.error(
|
|
946
|
+
f"❌ CRITICAL: Exception in CSRF middleware WebSocket handling: {e}", exc_info=True
|
|
947
|
+
)
|
|
948
|
+
import sys
|
|
949
|
+
|
|
950
|
+
print(f"❌ [CSRF MIDDLEWARE EXCEPTION] {e}", file=sys.stderr, flush=True)
|
|
951
|
+
# Re-raise to see the full error
|
|
952
|
+
raise
|
|
604
953
|
|
|
605
954
|
if self._is_exempt(path):
|
|
606
955
|
return await call_next(request)
|