mdb-engine 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/__init__.py +7 -13
- mdb_engine/auth/__init__.py +9 -0
- mdb_engine/auth/csrf.py +493 -144
- mdb_engine/auth/provider.py +10 -0
- mdb_engine/auth/shared_users.py +41 -0
- mdb_engine/auth/users.py +2 -1
- mdb_engine/auth/websocket_tickets.py +307 -0
- mdb_engine/cli/main.py +1 -1
- mdb_engine/core/app_registration.py +10 -0
- mdb_engine/core/engine.py +687 -38
- mdb_engine/core/manifest.py +14 -0
- mdb_engine/core/ray_integration.py +4 -4
- mdb_engine/core/service_initialization.py +63 -7
- mdb_engine/core/types.py +1 -0
- mdb_engine/database/connection.py +6 -3
- mdb_engine/database/scoped_wrapper.py +3 -3
- mdb_engine/indexes/manager.py +3 -3
- mdb_engine/observability/health.py +7 -7
- mdb_engine/routing/README.md +9 -2
- mdb_engine/routing/websockets.py +453 -74
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.1.dist-info}/METADATA +128 -4
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.1.dist-info}/RECORD +26 -25
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.1.dist-info}/WHEEL +0 -0
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.1.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.1.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.6.0.dist-info → mdb_engine-0.7.1.dist-info}/top_level.txt +0 -0
mdb_engine/routing/websockets.py
CHANGED
|
@@ -270,6 +270,31 @@ _manager_lock = asyncio.Lock()
|
|
|
270
270
|
# Note: Registration happens synchronously during app startup, so no lock needed
|
|
271
271
|
_message_handlers: dict[str, dict[str, Callable[[Any, dict[str, Any]], Awaitable[None]]]] = {}
|
|
272
272
|
|
|
273
|
+
# Global WebSocket ticket store (shared across all apps in multi-app setups)
|
|
274
|
+
# This is set when the engine initializes and is used for ticket authentication
|
|
275
|
+
_global_websocket_ticket_store: Any | None = None
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def set_global_websocket_ticket_store(ticket_store: Any) -> None:
|
|
279
|
+
"""
|
|
280
|
+
Set the global WebSocket ticket store.
|
|
281
|
+
|
|
282
|
+
This is called by the engine during initialization to make the ticket store
|
|
283
|
+
accessible to all WebSocket endpoints, even when routes are registered via routers.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
ticket_store: WebSocketTicketStore instance
|
|
287
|
+
"""
|
|
288
|
+
global _global_websocket_ticket_store
|
|
289
|
+
_global_websocket_ticket_store = ticket_store
|
|
290
|
+
if ticket_store:
|
|
291
|
+
logger.info(
|
|
292
|
+
f"✅ Global WebSocket ticket store set for multi-app WebSocket authentication "
|
|
293
|
+
f"(store type: {type(ticket_store).__name__})"
|
|
294
|
+
)
|
|
295
|
+
else:
|
|
296
|
+
logger.warning("⚠️ Global WebSocket ticket store set to None")
|
|
297
|
+
|
|
273
298
|
|
|
274
299
|
async def get_websocket_manager(app_slug: str) -> WebSocketConnectionManager:
|
|
275
300
|
"""
|
|
@@ -359,17 +384,190 @@ def _get_cookies_from_websocket(websocket: Any) -> dict[str, str]:
|
|
|
359
384
|
return cookies
|
|
360
385
|
|
|
361
386
|
|
|
387
|
+
async def _validate_websocket_origin_in_handler(websocket: Any, app_slug: str) -> bool:
|
|
388
|
+
"""
|
|
389
|
+
Validate WebSocket Origin header in handler (since middleware may not intercept upgrades).
|
|
390
|
+
|
|
391
|
+
This provides CSWSH (Cross-Site WebSocket Hijacking) protection by validating
|
|
392
|
+
the Origin header before accepting the WebSocket connection.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
websocket: FastAPI WebSocket instance
|
|
396
|
+
app_slug: App slug for context
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
True if origin is valid, False otherwise
|
|
400
|
+
"""
|
|
401
|
+
try:
|
|
402
|
+
# Get origin from headers
|
|
403
|
+
origin = None
|
|
404
|
+
if hasattr(websocket, "headers"):
|
|
405
|
+
origin = websocket.headers.get("origin")
|
|
406
|
+
elif hasattr(websocket, "scope") and "headers" in websocket.scope:
|
|
407
|
+
# Extract from ASGI scope headers
|
|
408
|
+
headers_dict = dict(websocket.scope["headers"])
|
|
409
|
+
origin_bytes = headers_dict.get(b"origin")
|
|
410
|
+
if not origin_bytes:
|
|
411
|
+
# Try case-insensitive lookup
|
|
412
|
+
for key, value in headers_dict.items():
|
|
413
|
+
if isinstance(key, bytes) and key.lower() == b"origin":
|
|
414
|
+
origin_bytes = value
|
|
415
|
+
break
|
|
416
|
+
if origin_bytes:
|
|
417
|
+
origin = origin_bytes.decode("utf-8")
|
|
418
|
+
|
|
419
|
+
logger.info(
|
|
420
|
+
f"🔍 WebSocket origin validation in handler for '{app_slug}': " f"origin={origin}"
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# Get allowed origins from app state (CORS config) FIRST
|
|
424
|
+
# This allows us to check if wildcard is allowed before rejecting missing origin
|
|
425
|
+
allowed_origins = []
|
|
426
|
+
try:
|
|
427
|
+
app = getattr(websocket, "app", None)
|
|
428
|
+
if app:
|
|
429
|
+
# Traverse up to find parent app with CORS config
|
|
430
|
+
while app:
|
|
431
|
+
cors_config = getattr(app.state, "cors_config", None)
|
|
432
|
+
if cors_config and cors_config.get("allow_origins"):
|
|
433
|
+
origins = cors_config["allow_origins"]
|
|
434
|
+
allowed_origins = origins if isinstance(origins, list) else [origins]
|
|
435
|
+
logger.debug(
|
|
436
|
+
f"Found CORS config on app '{getattr(app, 'title', 'unknown')}' "
|
|
437
|
+
f"for '{app_slug}': {allowed_origins}"
|
|
438
|
+
)
|
|
439
|
+
break
|
|
440
|
+
parent_app = getattr(app, "app", None)
|
|
441
|
+
if parent_app is app:
|
|
442
|
+
break
|
|
443
|
+
app = parent_app
|
|
444
|
+
except (AttributeError, TypeError, KeyError) as e:
|
|
445
|
+
logger.debug(f"Could not read CORS config: {e}")
|
|
446
|
+
|
|
447
|
+
# Check if CORS allows all origins - if so, allow missing origin
|
|
448
|
+
if allowed_origins and "*" in allowed_origins:
|
|
449
|
+
if not origin:
|
|
450
|
+
logger.info(
|
|
451
|
+
f"WebSocket upgrade missing Origin header for '{app_slug}', "
|
|
452
|
+
"but CORS allows all origins (*) - allowing connection"
|
|
453
|
+
)
|
|
454
|
+
return True
|
|
455
|
+
|
|
456
|
+
if not origin:
|
|
457
|
+
# In development, allow connections without Origin
|
|
458
|
+
import os
|
|
459
|
+
|
|
460
|
+
env = os.getenv("ENVIRONMENT", "").lower()
|
|
461
|
+
g_nome_env = os.getenv("G_NOME_ENV", "").lower()
|
|
462
|
+
is_dev = env in ["development", "dev"] or g_nome_env in ["development", "dev"]
|
|
463
|
+
if is_dev:
|
|
464
|
+
logger.warning(
|
|
465
|
+
f"WebSocket upgrade missing Origin header in development mode: "
|
|
466
|
+
f"allowing connection for '{app_slug}'"
|
|
467
|
+
)
|
|
468
|
+
return True
|
|
469
|
+
else:
|
|
470
|
+
logger.warning(f"WebSocket upgrade missing Origin header for '{app_slug}'")
|
|
471
|
+
return False
|
|
472
|
+
|
|
473
|
+
# Normalize origin for comparison
|
|
474
|
+
def normalize_origin(orig: str) -> str:
|
|
475
|
+
"""Normalize origin handling localhost variants."""
|
|
476
|
+
if not orig:
|
|
477
|
+
return orig
|
|
478
|
+
import re
|
|
479
|
+
|
|
480
|
+
normalized = re.sub(
|
|
481
|
+
r"://(0\.0\.0\.0|127\.0\.0\.1|localhost|::1)",
|
|
482
|
+
"://localhost",
|
|
483
|
+
orig.lower(),
|
|
484
|
+
flags=re.IGNORECASE,
|
|
485
|
+
)
|
|
486
|
+
normalized = re.sub(r":80$", "", normalized)
|
|
487
|
+
normalized = re.sub(r":443$", "", normalized)
|
|
488
|
+
return normalized.rstrip("/")
|
|
489
|
+
|
|
490
|
+
normalized_origin = normalize_origin(origin)
|
|
491
|
+
logger.debug(f"Normalized origin: {origin} -> {normalized_origin}")
|
|
492
|
+
|
|
493
|
+
# If no CORS config, generate fallback origins from request
|
|
494
|
+
if not allowed_origins:
|
|
495
|
+
# Try to get host from websocket scope
|
|
496
|
+
try:
|
|
497
|
+
if hasattr(websocket, "scope"):
|
|
498
|
+
host = websocket.scope.get("server")
|
|
499
|
+
if host:
|
|
500
|
+
hostname = host[0] if isinstance(host, tuple) else host
|
|
501
|
+
scheme = websocket.scope.get("scheme", "http")
|
|
502
|
+
port = host[1] if isinstance(host, tuple) and len(host) > 1 else None
|
|
503
|
+
|
|
504
|
+
# Generate localhost variants if server binds to 0.0.0.0
|
|
505
|
+
if hostname in ["localhost", "0.0.0.0", "127.0.0.1", "::1"]:
|
|
506
|
+
for variant in ["localhost", "127.0.0.1"]:
|
|
507
|
+
if port and port not in [80, 443]:
|
|
508
|
+
allowed_origins.append(f"{scheme}://{variant}:{port}")
|
|
509
|
+
else:
|
|
510
|
+
allowed_origins.append(f"{scheme}://{variant}")
|
|
511
|
+
else:
|
|
512
|
+
if port and port not in [80, 443]:
|
|
513
|
+
allowed_origins.append(f"{scheme}://{hostname}:{port}")
|
|
514
|
+
else:
|
|
515
|
+
allowed_origins.append(f"{scheme}://{hostname}")
|
|
516
|
+
except (AttributeError, TypeError, KeyError) as e:
|
|
517
|
+
logger.debug(f"Could not determine fallback origins: {e}")
|
|
518
|
+
|
|
519
|
+
logger.debug(f"Allowed origins for '{app_slug}': {allowed_origins}")
|
|
520
|
+
|
|
521
|
+
# Check if origin matches any allowed origin
|
|
522
|
+
for allowed in allowed_origins:
|
|
523
|
+
if allowed == "*":
|
|
524
|
+
logger.warning(
|
|
525
|
+
"WebSocket Origin validation using wildcard '*' - "
|
|
526
|
+
"not recommended for production"
|
|
527
|
+
)
|
|
528
|
+
return True
|
|
529
|
+
|
|
530
|
+
normalized_allowed = normalize_origin(allowed)
|
|
531
|
+
if normalized_origin == normalized_allowed:
|
|
532
|
+
logger.info(
|
|
533
|
+
f"✅ WebSocket origin validated in handler for '{app_slug}': "
|
|
534
|
+
f"{origin} matches {allowed}"
|
|
535
|
+
)
|
|
536
|
+
return True
|
|
537
|
+
|
|
538
|
+
logger.warning(
|
|
539
|
+
f"❌ WebSocket origin validation failed for '{app_slug}': "
|
|
540
|
+
f"origin={origin} (normalized: {normalized_origin}), "
|
|
541
|
+
f"allowed={allowed_origins} "
|
|
542
|
+
f"(normalized: {[normalize_origin(a) for a in allowed_origins]})"
|
|
543
|
+
)
|
|
544
|
+
return False
|
|
545
|
+
|
|
546
|
+
except (
|
|
547
|
+
AttributeError,
|
|
548
|
+
KeyError,
|
|
549
|
+
UnicodeDecodeError,
|
|
550
|
+
TypeError,
|
|
551
|
+
ValueError,
|
|
552
|
+
RuntimeError,
|
|
553
|
+
) as e:
|
|
554
|
+
logger.error(f"❌ Error validating WebSocket origin for '{app_slug}': {e}", exc_info=True)
|
|
555
|
+
# Fail secure - reject if we can't validate
|
|
556
|
+
return False
|
|
557
|
+
|
|
558
|
+
|
|
362
559
|
async def authenticate_websocket(
|
|
363
560
|
websocket: Any,
|
|
364
561
|
app_slug: str,
|
|
365
562
|
require_auth: bool = True,
|
|
366
563
|
) -> tuple[str | None, str | None]:
|
|
367
564
|
"""
|
|
368
|
-
Authenticate a WebSocket connection
|
|
565
|
+
Authenticate a WebSocket connection using multiple methods with fallback.
|
|
369
566
|
|
|
370
567
|
Authentication methods (in order of preference):
|
|
371
|
-
1.
|
|
372
|
-
2.
|
|
568
|
+
1. Ticket (query param or header) - short-lived, single-use, secure for multi-app SSO
|
|
569
|
+
2. Session key (query param or header) - encrypted session-based authentication
|
|
570
|
+
3. Cookie (httpOnly JWT token) - backward compatibility fallback
|
|
373
571
|
|
|
374
572
|
Args:
|
|
375
573
|
websocket: FastAPI WebSocket instance (can access headers before accept)
|
|
@@ -393,36 +591,117 @@ async def authenticate_websocket(
|
|
|
393
591
|
return None, None
|
|
394
592
|
|
|
395
593
|
try:
|
|
396
|
-
#
|
|
397
|
-
|
|
398
|
-
|
|
594
|
+
# Helper function to get app and traverse hierarchy
|
|
595
|
+
def get_app_and_state():
|
|
596
|
+
"""Get app instance and traverse hierarchy to find state."""
|
|
399
597
|
app = getattr(websocket, "app", None)
|
|
400
|
-
|
|
598
|
+
apps_checked = []
|
|
599
|
+
max_iterations = 10 # Safety limit to prevent infinite loops
|
|
600
|
+
iteration = 0
|
|
601
|
+
while app and iteration < max_iterations:
|
|
602
|
+
iteration += 1
|
|
603
|
+
app_title = getattr(app, "title", "unknown")
|
|
604
|
+
apps_checked.append(app_title)
|
|
605
|
+
yield app, app_title
|
|
606
|
+
# Try to get parent app (FastAPI routers have .app attribute pointing to parent)
|
|
607
|
+
parent_app = getattr(app, "app", None)
|
|
608
|
+
if parent_app is app or parent_app is None: # Prevent infinite loop
|
|
609
|
+
break
|
|
610
|
+
app = parent_app
|
|
611
|
+
|
|
612
|
+
# Try to get services from app state
|
|
613
|
+
websocket_ticket_store = _global_websocket_ticket_store
|
|
614
|
+
websocket_session_manager = None
|
|
615
|
+
user_pool = None
|
|
616
|
+
|
|
617
|
+
apps_checked_list = []
|
|
618
|
+
for app, app_title in get_app_and_state():
|
|
619
|
+
apps_checked_list.append(app_title)
|
|
620
|
+
if not websocket_ticket_store:
|
|
621
|
+
websocket_ticket_store = getattr(app.state, "websocket_ticket_store", None)
|
|
622
|
+
if websocket_ticket_store:
|
|
623
|
+
logger.debug(
|
|
624
|
+
f"Found websocket_ticket_store on app '{app_title}' "
|
|
625
|
+
f"for app_slug '{app_slug}'"
|
|
626
|
+
)
|
|
627
|
+
if not websocket_session_manager:
|
|
401
628
|
websocket_session_manager = getattr(app.state, "websocket_session_manager", None)
|
|
402
|
-
|
|
629
|
+
if websocket_session_manager:
|
|
630
|
+
logger.info(
|
|
631
|
+
f"✅ Found websocket_session_manager on app '{app_title}' "
|
|
632
|
+
f"for app_slug '{app_slug}' (checked: {apps_checked_list})"
|
|
633
|
+
)
|
|
634
|
+
if not user_pool:
|
|
635
|
+
user_pool = getattr(app.state, "user_pool", None)
|
|
636
|
+
if user_pool:
|
|
637
|
+
logger.debug(f"Found user_pool on app '{app_title}' for app_slug '{app_slug}'")
|
|
638
|
+
|
|
639
|
+
if not websocket_session_manager:
|
|
640
|
+
logger.warning(
|
|
641
|
+
f"⚠️ websocket_session_manager not found for '{app_slug}' "
|
|
642
|
+
f"(checked apps: {apps_checked_list})"
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
# Method 1: Ticket authentication (preferred)
|
|
646
|
+
ticket = None
|
|
647
|
+
try:
|
|
648
|
+
if hasattr(websocket, "query_params"):
|
|
649
|
+
ticket = websocket.query_params.get("ticket")
|
|
650
|
+
if not ticket and hasattr(websocket, "headers"):
|
|
651
|
+
ticket = websocket.headers.get("X-WebSocket-Ticket")
|
|
652
|
+
except (AttributeError, TypeError, KeyError):
|
|
403
653
|
pass
|
|
404
654
|
|
|
405
|
-
|
|
655
|
+
if ticket and websocket_ticket_store:
|
|
656
|
+
try:
|
|
657
|
+
ticket_data = await websocket_ticket_store.validate_and_consume_ticket(ticket)
|
|
658
|
+
if ticket_data:
|
|
659
|
+
user_id = ticket_data.get("user_id")
|
|
660
|
+
user_email = ticket_data.get("user_email")
|
|
661
|
+
logger.info(
|
|
662
|
+
f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
|
|
663
|
+
f"(method: ticket)"
|
|
664
|
+
)
|
|
665
|
+
return user_id, user_email
|
|
666
|
+
except (ValueError, TypeError, AttributeError, KeyError, RuntimeError) as e:
|
|
667
|
+
logger.debug(f"Ticket validation failed: {e}")
|
|
668
|
+
|
|
669
|
+
# Method 2: Session key authentication (fallback)
|
|
406
670
|
session_key = None
|
|
407
671
|
try:
|
|
408
|
-
# Check query params first
|
|
409
672
|
if hasattr(websocket, "query_params"):
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
673
|
+
# Try dict-like access first
|
|
674
|
+
if hasattr(websocket.query_params, "get"):
|
|
675
|
+
session_key = websocket.query_params.get("session_key")
|
|
676
|
+
# Fallback: try dict access
|
|
677
|
+
elif isinstance(websocket.query_params, dict):
|
|
678
|
+
session_key = websocket.query_params.get("session_key")
|
|
679
|
+
# Fallback: try attribute access
|
|
680
|
+
elif hasattr(websocket.query_params, "session_key"):
|
|
681
|
+
session_key = websocket.query_params.session_key
|
|
413
682
|
if not session_key and hasattr(websocket, "headers"):
|
|
414
683
|
session_key = websocket.headers.get("X-WebSocket-Session-Key")
|
|
415
|
-
except (AttributeError, TypeError, KeyError):
|
|
684
|
+
except (AttributeError, TypeError, KeyError) as e:
|
|
685
|
+
logger.debug(f"Error extracting session key from websocket: {e}")
|
|
416
686
|
pass
|
|
417
687
|
|
|
688
|
+
logger.debug(
|
|
689
|
+
f"Session key extraction for '{app_slug}': "
|
|
690
|
+
f"session_key={'present' if session_key else 'missing'}, "
|
|
691
|
+
f"websocket_session_manager={'present' if websocket_session_manager else 'missing'}"
|
|
692
|
+
)
|
|
693
|
+
|
|
418
694
|
if session_key and websocket_session_manager:
|
|
419
695
|
try:
|
|
420
|
-
|
|
696
|
+
logger.debug(
|
|
697
|
+
f"Validating session key for '{app_slug}': "
|
|
698
|
+
f"session_key={session_key[:16]}... (truncated), "
|
|
699
|
+
f"manager_type={type(websocket_session_manager).__name__}"
|
|
700
|
+
)
|
|
421
701
|
session_data = await websocket_session_manager.validate_session(session_key)
|
|
422
702
|
if session_data:
|
|
423
703
|
user_id = session_data.get("user_id")
|
|
424
704
|
user_email = session_data.get("user_email")
|
|
425
|
-
|
|
426
705
|
logger.info(
|
|
427
706
|
f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
|
|
428
707
|
f"(method: session_key)"
|
|
@@ -430,58 +709,83 @@ async def authenticate_websocket(
|
|
|
430
709
|
return user_id, user_email
|
|
431
710
|
else:
|
|
432
711
|
logger.warning(
|
|
433
|
-
f"
|
|
434
|
-
f"
|
|
712
|
+
f"Session key validation returned None for '{app_slug}': "
|
|
713
|
+
f"session_key={session_key[:16]}... (truncated)"
|
|
435
714
|
)
|
|
436
|
-
except
|
|
437
|
-
|
|
438
|
-
|
|
715
|
+
except RuntimeError as e:
|
|
716
|
+
# Handle event loop conflicts (e.g., TestClient with anyio)
|
|
717
|
+
error_msg = str(e).lower()
|
|
718
|
+
if "attached to a different loop" in error_msg or "different loop" in error_msg:
|
|
719
|
+
logger.warning(
|
|
720
|
+
f"Event loop conflict during session key validation for '{app_slug}': {e}. "
|
|
721
|
+
"This may occur in test environments with TestClient. "
|
|
722
|
+
"Session key validation failed."
|
|
723
|
+
)
|
|
724
|
+
# Return None to indicate authentication failure
|
|
725
|
+
# Don't re-raise - let the normal auth failure path handle it
|
|
726
|
+
else:
|
|
727
|
+
# Re-raise other RuntimeErrors
|
|
728
|
+
logger.warning(
|
|
729
|
+
f"Session key validation failed for '{app_slug}': {e}",
|
|
730
|
+
exc_info=True,
|
|
731
|
+
)
|
|
732
|
+
raise
|
|
733
|
+
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
|
734
|
+
logger.warning(
|
|
735
|
+
f"Session key validation failed for '{app_slug}': {e}",
|
|
736
|
+
exc_info=True,
|
|
737
|
+
)
|
|
738
|
+
elif session_key and not websocket_session_manager:
|
|
739
|
+
logger.error(
|
|
740
|
+
f"Session key provided for '{app_slug}' but websocket_session_manager not found. "
|
|
741
|
+
"Cannot validate session key."
|
|
742
|
+
)
|
|
439
743
|
|
|
440
|
-
# Method
|
|
744
|
+
# Method 3: Cookie-based JWT authentication (backward compatibility)
|
|
441
745
|
from ..auth.shared_middleware import AUTH_COOKIE_NAME
|
|
442
746
|
|
|
443
747
|
cookies = _get_cookies_from_websocket(websocket)
|
|
444
|
-
|
|
748
|
+
auth_token = cookies.get(AUTH_COOKIE_NAME) if cookies else None
|
|
445
749
|
|
|
446
|
-
if
|
|
750
|
+
if auth_token and user_pool:
|
|
751
|
+
try:
|
|
752
|
+
# For tests that expect JWT errors to be raised, we need to validate
|
|
753
|
+
# the token structure first. However, validate_token handles JWT errors
|
|
754
|
+
# internally, so we'll let it handle validation and only catch non-JWT errors.
|
|
755
|
+
user = await user_pool.validate_token(auth_token)
|
|
756
|
+
if user:
|
|
757
|
+
user_id = str(user.get("_id") or user.get("sub") or user.get("user_id"))
|
|
758
|
+
user_email = user.get("email")
|
|
759
|
+
logger.info(
|
|
760
|
+
f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
|
|
761
|
+
f"(method: cookie)"
|
|
762
|
+
)
|
|
763
|
+
return user_id, user_email
|
|
764
|
+
except (
|
|
765
|
+
ValueError,
|
|
766
|
+
TypeError,
|
|
767
|
+
AttributeError,
|
|
768
|
+
KeyError,
|
|
769
|
+
RuntimeError,
|
|
770
|
+
) as e:
|
|
771
|
+
# Handle errors from token validation or user data access
|
|
772
|
+
logger.debug(f"Cookie token validation failed: {e}")
|
|
773
|
+
# Note: JWT errors (DecodeError, ExpiredSignatureError) are caught
|
|
774
|
+
# internally by validate_token and return None. For tests that need JWT
|
|
775
|
+
# errors raised, they should mock validate_token to raise them directly.
|
|
776
|
+
|
|
777
|
+
# No authentication method succeeded
|
|
778
|
+
if require_auth:
|
|
447
779
|
logger.error(
|
|
448
|
-
f"❌
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
780
|
+
f"❌ WebSocket authentication failed for app '{app_slug}'. "
|
|
781
|
+
"No valid ticket, session key, or cookie found. "
|
|
782
|
+
"Generate ticket via /auth/ticket endpoint, "
|
|
783
|
+
"session key via /auth/websocket-session, "
|
|
784
|
+
"or ensure JWT cookie is present."
|
|
453
785
|
)
|
|
454
|
-
if require_auth:
|
|
455
|
-
return None, None # Signal auth failure
|
|
456
786
|
return None, None
|
|
457
787
|
|
|
458
|
-
|
|
459
|
-
f"WebSocket token found in cookie for app '{app_slug}' "
|
|
460
|
-
"(cookie-based authentication, fallback)"
|
|
461
|
-
)
|
|
462
|
-
|
|
463
|
-
# Decode and validate token
|
|
464
|
-
import jwt
|
|
465
|
-
|
|
466
|
-
from ..auth.dependencies import SECRET_KEY
|
|
467
|
-
from ..auth.jwt import decode_jwt_token
|
|
468
|
-
|
|
469
|
-
try:
|
|
470
|
-
payload = decode_jwt_token(token, str(SECRET_KEY))
|
|
471
|
-
user_id = payload.get("sub") or payload.get("user_id")
|
|
472
|
-
user_email = payload.get("email")
|
|
473
|
-
|
|
474
|
-
logger.info(
|
|
475
|
-
f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
|
|
476
|
-
f"(method: cookie)"
|
|
477
|
-
)
|
|
478
|
-
return user_id, user_email
|
|
479
|
-
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
|
|
480
|
-
logger.exception(
|
|
481
|
-
f"❌ JWT decode error for app '{app_slug}'. "
|
|
482
|
-
f"Token present: {bool(token)}, Token length: {len(token) if token else 0}"
|
|
483
|
-
)
|
|
484
|
-
raise
|
|
788
|
+
return None, None
|
|
485
789
|
|
|
486
790
|
except WebSocketDisconnect:
|
|
487
791
|
raise
|
|
@@ -706,13 +1010,46 @@ def create_websocket_endpoint(
|
|
|
706
1010
|
# This print should appear in server logs when a WebSocket connection is attempted
|
|
707
1011
|
import sys
|
|
708
1012
|
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
1013
|
+
# Log BEFORE any operations to catch if handler is called at all
|
|
1014
|
+
try:
|
|
1015
|
+
print(
|
|
1016
|
+
f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}",
|
|
1017
|
+
file=sys.stderr,
|
|
1018
|
+
flush=True,
|
|
1019
|
+
)
|
|
1020
|
+
print(f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}", flush=True)
|
|
1021
|
+
logger.info(f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}")
|
|
1022
|
+
|
|
1023
|
+
# Try to access websocket properties to see if we can even access it
|
|
1024
|
+
try:
|
|
1025
|
+
ws_path = (
|
|
1026
|
+
getattr(websocket, "url", {}).path if hasattr(websocket, "url") else "unknown"
|
|
1027
|
+
)
|
|
1028
|
+
ws_headers = (
|
|
1029
|
+
dict(getattr(websocket, "headers", {})) if hasattr(websocket, "headers") else {}
|
|
1030
|
+
)
|
|
1031
|
+
origin = ws_headers.get("origin", "missing")
|
|
1032
|
+
print(
|
|
1033
|
+
f"🔌 [WEBSOCKET DETAILS] Path: {ws_path}, Origin: {origin}, "
|
|
1034
|
+
f"Headers: {list(ws_headers.keys())}",
|
|
1035
|
+
file=sys.stderr,
|
|
1036
|
+
flush=True,
|
|
1037
|
+
)
|
|
1038
|
+
except (AttributeError, RuntimeError, TypeError) as access_error:
|
|
1039
|
+
print(
|
|
1040
|
+
f"⚠️ [WEBSOCKET ACCESS ERROR] "
|
|
1041
|
+
f"Could not access websocket properties: {access_error}",
|
|
1042
|
+
file=sys.stderr,
|
|
1043
|
+
flush=True,
|
|
1044
|
+
)
|
|
1045
|
+
except (RuntimeError, AttributeError) as log_error:
|
|
1046
|
+
# Even logging failed - this is very bad
|
|
1047
|
+
print(
|
|
1048
|
+
f"❌ [CRITICAL] Failed to log WebSocket handler call: {log_error}",
|
|
1049
|
+
file=sys.stderr,
|
|
1050
|
+
flush=True,
|
|
1051
|
+
)
|
|
1052
|
+
|
|
716
1053
|
connection = None
|
|
717
1054
|
try:
|
|
718
1055
|
# Log connection attempt with query params (can access before accept)
|
|
@@ -737,6 +1074,42 @@ def create_websocket_endpoint(
|
|
|
737
1074
|
f"(require_auth={require_auth}, query_params={query_str})"
|
|
738
1075
|
)
|
|
739
1076
|
|
|
1077
|
+
# CRITICAL: Accept FIRST, then validate (FastAPI requires immediate accept)
|
|
1078
|
+
# If we don't accept immediately, FastAPI closes with code 1000
|
|
1079
|
+
# We'll validate origin after accept, but before processing messages
|
|
1080
|
+
try:
|
|
1081
|
+
await websocket.accept()
|
|
1082
|
+
logger.info(
|
|
1083
|
+
f"✅ WebSocket accepted for app '{app_slug}' (before origin validation)"
|
|
1084
|
+
)
|
|
1085
|
+
print(f"✅ [WEBSOCKET ACCEPTED] App: '{app_slug}'", flush=True)
|
|
1086
|
+
except (RuntimeError, ConnectionError, OSError, ValueError) as accept_error:
|
|
1087
|
+
logger.error(
|
|
1088
|
+
f"❌ Failed to accept WebSocket for app '{app_slug}': {accept_error}",
|
|
1089
|
+
exc_info=True,
|
|
1090
|
+
)
|
|
1091
|
+
print(
|
|
1092
|
+
f"❌ [WEBSOCKET ACCEPT FAILED] App: '{app_slug}', Error: {accept_error}",
|
|
1093
|
+
flush=True,
|
|
1094
|
+
)
|
|
1095
|
+
return
|
|
1096
|
+
|
|
1097
|
+
# CRITICAL: Validate origin AFTER accepting (CSWSH protection)
|
|
1098
|
+
# We accept first to satisfy FastAPI's requirements, then validate
|
|
1099
|
+
origin_valid = await _validate_websocket_origin_in_handler(websocket, app_slug)
|
|
1100
|
+
if not origin_valid:
|
|
1101
|
+
logger.error(
|
|
1102
|
+
f"❌ WebSocket origin validation FAILED for app '{app_slug}' - "
|
|
1103
|
+
f"closing connection after accept"
|
|
1104
|
+
)
|
|
1105
|
+
try:
|
|
1106
|
+
await websocket.close(code=1008, reason="Invalid origin")
|
|
1107
|
+
except (RuntimeError, ConnectionError, OSError):
|
|
1108
|
+
# WebSocket may already be closed or in invalid state
|
|
1109
|
+
pass
|
|
1110
|
+
# Raise WebSocketDisconnect so TestClient can detect the rejection
|
|
1111
|
+
raise WebSocketDisconnect(code=1008, reason="Invalid origin")
|
|
1112
|
+
|
|
740
1113
|
# CRITICAL: Authenticate BEFORE accepting connection
|
|
741
1114
|
# This prevents CSRF middleware from rejecting established connections
|
|
742
1115
|
# We can access headers/query_params before accept() is called
|
|
@@ -761,14 +1134,17 @@ def create_websocket_endpoint(
|
|
|
761
1134
|
f"rejecting connection. require_auth={require_auth}, "
|
|
762
1135
|
f"user_id={user_id}, user_email={user_email}"
|
|
763
1136
|
)
|
|
764
|
-
#
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
1137
|
+
# Close connection since we already accepted it
|
|
1138
|
+
try:
|
|
1139
|
+
await websocket.close(code=1008, reason="Authentication failed")
|
|
1140
|
+
except (RuntimeError, ConnectionError, OSError):
|
|
1141
|
+
# WebSocket may already be closed or in invalid state
|
|
1142
|
+
pass
|
|
1143
|
+
# Raise WebSocketDisconnect so TestClient can detect the close
|
|
1144
|
+
# WebSocketDisconnect is already imported at module level
|
|
1145
|
+
raise WebSocketDisconnect(code=1008, reason="Authentication failed")
|
|
1146
|
+
|
|
1147
|
+
# Connection already accepted at line 984 - no need to accept again
|
|
772
1148
|
# Connect with metadata (websocket already accepted)
|
|
773
1149
|
connection = await manager.connect(websocket, user_id=user_id, user_email=user_email)
|
|
774
1150
|
|
|
@@ -832,8 +1208,11 @@ def create_websocket_endpoint(
|
|
|
832
1208
|
logger.warning(f"WebSocket receive error for app '{app_slug}': {e}")
|
|
833
1209
|
await asyncio.sleep(0.1)
|
|
834
1210
|
|
|
1211
|
+
except WebSocketDisconnect as e:
|
|
1212
|
+
# Re-raise WebSocketDisconnect so TestClient can detect it
|
|
1213
|
+
logger.warning(f"WebSocket connection rejected for app '{app_slug}': {e}")
|
|
1214
|
+
raise
|
|
835
1215
|
except (
|
|
836
|
-
WebSocketDisconnect,
|
|
837
1216
|
RuntimeError,
|
|
838
1217
|
OSError,
|
|
839
1218
|
ValueError,
|