mdb-engine 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/__init__.py +13 -9
- mdb_engine/auth/__init__.py +18 -0
- mdb_engine/auth/csrf.py +651 -69
- mdb_engine/auth/provider.py +10 -0
- mdb_engine/auth/shared_users.py +73 -2
- mdb_engine/auth/users.py +2 -1
- mdb_engine/auth/utils.py +31 -6
- mdb_engine/auth/websocket_sessions.py +433 -0
- mdb_engine/auth/websocket_tickets.py +307 -0
- mdb_engine/core/app_registration.py +10 -0
- mdb_engine/core/engine.py +656 -21
- mdb_engine/core/manifest.py +26 -0
- mdb_engine/core/ray_integration.py +4 -4
- mdb_engine/core/types.py +2 -0
- mdb_engine/database/connection.py +6 -3
- mdb_engine/database/scoped_wrapper.py +3 -3
- mdb_engine/indexes/manager.py +3 -3
- mdb_engine/observability/health.py +7 -7
- mdb_engine/routing/README.md +9 -2
- mdb_engine/routing/websockets.py +479 -56
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/METADATA +128 -4
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/RECORD +26 -24
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/WHEEL +0 -0
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/entry_points.txt +0 -0
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/licenses/LICENSE +0 -0
- {mdb_engine-0.5.1.dist-info → mdb_engine-0.7.0.dist-info}/top_level.txt +0 -0
mdb_engine/routing/websockets.py
CHANGED
|
@@ -270,6 +270,31 @@ _manager_lock = asyncio.Lock()
|
|
|
270
270
|
# Note: Registration happens synchronously during app startup, so no lock needed
|
|
271
271
|
_message_handlers: dict[str, dict[str, Callable[[Any, dict[str, Any]], Awaitable[None]]]] = {}
|
|
272
272
|
|
|
273
|
+
# Global WebSocket ticket store (shared across all apps in multi-app setups)
|
|
274
|
+
# This is set when the engine initializes and is used for ticket authentication
|
|
275
|
+
_global_websocket_ticket_store: Any | None = None
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def set_global_websocket_ticket_store(ticket_store: Any) -> None:
|
|
279
|
+
"""
|
|
280
|
+
Set the global WebSocket ticket store.
|
|
281
|
+
|
|
282
|
+
This is called by the engine during initialization to make the ticket store
|
|
283
|
+
accessible to all WebSocket endpoints, even when routes are registered via routers.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
ticket_store: WebSocketTicketStore instance
|
|
287
|
+
"""
|
|
288
|
+
global _global_websocket_ticket_store
|
|
289
|
+
_global_websocket_ticket_store = ticket_store
|
|
290
|
+
if ticket_store:
|
|
291
|
+
logger.info(
|
|
292
|
+
f"✅ Global WebSocket ticket store set for multi-app WebSocket authentication "
|
|
293
|
+
f"(store type: {type(ticket_store).__name__})"
|
|
294
|
+
)
|
|
295
|
+
else:
|
|
296
|
+
logger.warning("⚠️ Global WebSocket ticket store set to None")
|
|
297
|
+
|
|
273
298
|
|
|
274
299
|
async def get_websocket_manager(app_slug: str) -> WebSocketConnectionManager:
|
|
275
300
|
"""
|
|
@@ -359,18 +384,190 @@ def _get_cookies_from_websocket(websocket: Any) -> dict[str, str]:
|
|
|
359
384
|
return cookies
|
|
360
385
|
|
|
361
386
|
|
|
387
|
+
async def _validate_websocket_origin_in_handler(websocket: Any, app_slug: str) -> bool:
|
|
388
|
+
"""
|
|
389
|
+
Validate WebSocket Origin header in handler (since middleware may not intercept upgrades).
|
|
390
|
+
|
|
391
|
+
This provides CSWSH (Cross-Site WebSocket Hijacking) protection by validating
|
|
392
|
+
the Origin header before accepting the WebSocket connection.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
websocket: FastAPI WebSocket instance
|
|
396
|
+
app_slug: App slug for context
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
True if origin is valid, False otherwise
|
|
400
|
+
"""
|
|
401
|
+
try:
|
|
402
|
+
# Get origin from headers
|
|
403
|
+
origin = None
|
|
404
|
+
if hasattr(websocket, "headers"):
|
|
405
|
+
origin = websocket.headers.get("origin")
|
|
406
|
+
elif hasattr(websocket, "scope") and "headers" in websocket.scope:
|
|
407
|
+
# Extract from ASGI scope headers
|
|
408
|
+
headers_dict = dict(websocket.scope["headers"])
|
|
409
|
+
origin_bytes = headers_dict.get(b"origin")
|
|
410
|
+
if not origin_bytes:
|
|
411
|
+
# Try case-insensitive lookup
|
|
412
|
+
for key, value in headers_dict.items():
|
|
413
|
+
if isinstance(key, bytes) and key.lower() == b"origin":
|
|
414
|
+
origin_bytes = value
|
|
415
|
+
break
|
|
416
|
+
if origin_bytes:
|
|
417
|
+
origin = origin_bytes.decode("utf-8")
|
|
418
|
+
|
|
419
|
+
logger.info(
|
|
420
|
+
f"🔍 WebSocket origin validation in handler for '{app_slug}': " f"origin={origin}"
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
# Get allowed origins from app state (CORS config) FIRST
|
|
424
|
+
# This allows us to check if wildcard is allowed before rejecting missing origin
|
|
425
|
+
allowed_origins = []
|
|
426
|
+
try:
|
|
427
|
+
app = getattr(websocket, "app", None)
|
|
428
|
+
if app:
|
|
429
|
+
# Traverse up to find parent app with CORS config
|
|
430
|
+
while app:
|
|
431
|
+
cors_config = getattr(app.state, "cors_config", None)
|
|
432
|
+
if cors_config and cors_config.get("allow_origins"):
|
|
433
|
+
origins = cors_config["allow_origins"]
|
|
434
|
+
allowed_origins = origins if isinstance(origins, list) else [origins]
|
|
435
|
+
logger.debug(
|
|
436
|
+
f"Found CORS config on app '{getattr(app, 'title', 'unknown')}' "
|
|
437
|
+
f"for '{app_slug}': {allowed_origins}"
|
|
438
|
+
)
|
|
439
|
+
break
|
|
440
|
+
parent_app = getattr(app, "app", None)
|
|
441
|
+
if parent_app is app:
|
|
442
|
+
break
|
|
443
|
+
app = parent_app
|
|
444
|
+
except (AttributeError, TypeError, KeyError) as e:
|
|
445
|
+
logger.debug(f"Could not read CORS config: {e}")
|
|
446
|
+
|
|
447
|
+
# Check if CORS allows all origins - if so, allow missing origin
|
|
448
|
+
if allowed_origins and "*" in allowed_origins:
|
|
449
|
+
if not origin:
|
|
450
|
+
logger.info(
|
|
451
|
+
f"WebSocket upgrade missing Origin header for '{app_slug}', "
|
|
452
|
+
"but CORS allows all origins (*) - allowing connection"
|
|
453
|
+
)
|
|
454
|
+
return True
|
|
455
|
+
|
|
456
|
+
if not origin:
|
|
457
|
+
# In development, allow connections without Origin
|
|
458
|
+
import os
|
|
459
|
+
|
|
460
|
+
env = os.getenv("ENVIRONMENT", "").lower()
|
|
461
|
+
g_nome_env = os.getenv("G_NOME_ENV", "").lower()
|
|
462
|
+
is_dev = env in ["development", "dev"] or g_nome_env in ["development", "dev"]
|
|
463
|
+
if is_dev:
|
|
464
|
+
logger.warning(
|
|
465
|
+
f"WebSocket upgrade missing Origin header in development mode: "
|
|
466
|
+
f"allowing connection for '{app_slug}'"
|
|
467
|
+
)
|
|
468
|
+
return True
|
|
469
|
+
else:
|
|
470
|
+
logger.warning(f"WebSocket upgrade missing Origin header for '{app_slug}'")
|
|
471
|
+
return False
|
|
472
|
+
|
|
473
|
+
# Normalize origin for comparison
|
|
474
|
+
def normalize_origin(orig: str) -> str:
|
|
475
|
+
"""Normalize origin handling localhost variants."""
|
|
476
|
+
if not orig:
|
|
477
|
+
return orig
|
|
478
|
+
import re
|
|
479
|
+
|
|
480
|
+
normalized = re.sub(
|
|
481
|
+
r"://(0\.0\.0\.0|127\.0\.0\.1|localhost|::1)",
|
|
482
|
+
"://localhost",
|
|
483
|
+
orig.lower(),
|
|
484
|
+
flags=re.IGNORECASE,
|
|
485
|
+
)
|
|
486
|
+
normalized = re.sub(r":80$", "", normalized)
|
|
487
|
+
normalized = re.sub(r":443$", "", normalized)
|
|
488
|
+
return normalized.rstrip("/")
|
|
489
|
+
|
|
490
|
+
normalized_origin = normalize_origin(origin)
|
|
491
|
+
logger.debug(f"Normalized origin: {origin} -> {normalized_origin}")
|
|
492
|
+
|
|
493
|
+
# If no CORS config, generate fallback origins from request
|
|
494
|
+
if not allowed_origins:
|
|
495
|
+
# Try to get host from websocket scope
|
|
496
|
+
try:
|
|
497
|
+
if hasattr(websocket, "scope"):
|
|
498
|
+
host = websocket.scope.get("server")
|
|
499
|
+
if host:
|
|
500
|
+
hostname = host[0] if isinstance(host, tuple) else host
|
|
501
|
+
scheme = websocket.scope.get("scheme", "http")
|
|
502
|
+
port = host[1] if isinstance(host, tuple) and len(host) > 1 else None
|
|
503
|
+
|
|
504
|
+
# Generate localhost variants if server binds to 0.0.0.0
|
|
505
|
+
if hostname in ["localhost", "0.0.0.0", "127.0.0.1", "::1"]:
|
|
506
|
+
for variant in ["localhost", "127.0.0.1"]:
|
|
507
|
+
if port and port not in [80, 443]:
|
|
508
|
+
allowed_origins.append(f"{scheme}://{variant}:{port}")
|
|
509
|
+
else:
|
|
510
|
+
allowed_origins.append(f"{scheme}://{variant}")
|
|
511
|
+
else:
|
|
512
|
+
if port and port not in [80, 443]:
|
|
513
|
+
allowed_origins.append(f"{scheme}://{hostname}:{port}")
|
|
514
|
+
else:
|
|
515
|
+
allowed_origins.append(f"{scheme}://{hostname}")
|
|
516
|
+
except (AttributeError, TypeError, KeyError) as e:
|
|
517
|
+
logger.debug(f"Could not determine fallback origins: {e}")
|
|
518
|
+
|
|
519
|
+
logger.debug(f"Allowed origins for '{app_slug}': {allowed_origins}")
|
|
520
|
+
|
|
521
|
+
# Check if origin matches any allowed origin
|
|
522
|
+
for allowed in allowed_origins:
|
|
523
|
+
if allowed == "*":
|
|
524
|
+
logger.warning(
|
|
525
|
+
"WebSocket Origin validation using wildcard '*' - "
|
|
526
|
+
"not recommended for production"
|
|
527
|
+
)
|
|
528
|
+
return True
|
|
529
|
+
|
|
530
|
+
normalized_allowed = normalize_origin(allowed)
|
|
531
|
+
if normalized_origin == normalized_allowed:
|
|
532
|
+
logger.info(
|
|
533
|
+
f"✅ WebSocket origin validated in handler for '{app_slug}': "
|
|
534
|
+
f"{origin} matches {allowed}"
|
|
535
|
+
)
|
|
536
|
+
return True
|
|
537
|
+
|
|
538
|
+
logger.warning(
|
|
539
|
+
f"❌ WebSocket origin validation failed for '{app_slug}': "
|
|
540
|
+
f"origin={origin} (normalized: {normalized_origin}), "
|
|
541
|
+
f"allowed={allowed_origins} "
|
|
542
|
+
f"(normalized: {[normalize_origin(a) for a in allowed_origins]})"
|
|
543
|
+
)
|
|
544
|
+
return False
|
|
545
|
+
|
|
546
|
+
except (
|
|
547
|
+
AttributeError,
|
|
548
|
+
KeyError,
|
|
549
|
+
UnicodeDecodeError,
|
|
550
|
+
TypeError,
|
|
551
|
+
ValueError,
|
|
552
|
+
RuntimeError,
|
|
553
|
+
) as e:
|
|
554
|
+
logger.error(f"❌ Error validating WebSocket origin for '{app_slug}': {e}", exc_info=True)
|
|
555
|
+
# Fail secure - reject if we can't validate
|
|
556
|
+
return False
|
|
557
|
+
|
|
558
|
+
|
|
362
559
|
async def authenticate_websocket(
|
|
363
560
|
websocket: Any,
|
|
364
561
|
app_slug: str,
|
|
365
562
|
require_auth: bool = True,
|
|
366
563
|
) -> tuple[str | None, str | None]:
|
|
367
564
|
"""
|
|
368
|
-
Authenticate a WebSocket connection
|
|
565
|
+
Authenticate a WebSocket connection using multiple methods with fallback.
|
|
369
566
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
567
|
+
Authentication methods (in order of preference):
|
|
568
|
+
1. Ticket (query param or header) - short-lived, single-use, secure for multi-app SSO
|
|
569
|
+
2. Session key (query param or header) - encrypted session-based authentication
|
|
570
|
+
3. Cookie (httpOnly JWT token) - backward compatibility fallback
|
|
374
571
|
|
|
375
572
|
Args:
|
|
376
573
|
websocket: FastAPI WebSocket instance (can access headers before accept)
|
|
@@ -394,50 +591,201 @@ async def authenticate_websocket(
|
|
|
394
591
|
return None, None
|
|
395
592
|
|
|
396
593
|
try:
|
|
397
|
-
#
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
594
|
+
# Helper function to get app and traverse hierarchy
|
|
595
|
+
def get_app_and_state():
|
|
596
|
+
"""Get app instance and traverse hierarchy to find state."""
|
|
597
|
+
app = getattr(websocket, "app", None)
|
|
598
|
+
apps_checked = []
|
|
599
|
+
max_iterations = 10 # Safety limit to prevent infinite loops
|
|
600
|
+
iteration = 0
|
|
601
|
+
while app and iteration < max_iterations:
|
|
602
|
+
iteration += 1
|
|
603
|
+
app_title = getattr(app, "title", "unknown")
|
|
604
|
+
apps_checked.append(app_title)
|
|
605
|
+
yield app, app_title
|
|
606
|
+
# Try to get parent app (FastAPI routers have .app attribute pointing to parent)
|
|
607
|
+
parent_app = getattr(app, "app", None)
|
|
608
|
+
if parent_app is app or parent_app is None: # Prevent infinite loop
|
|
609
|
+
break
|
|
610
|
+
app = parent_app
|
|
611
|
+
|
|
612
|
+
# Try to get services from app state
|
|
613
|
+
websocket_ticket_store = _global_websocket_ticket_store
|
|
614
|
+
websocket_session_manager = None
|
|
615
|
+
user_pool = None
|
|
616
|
+
|
|
617
|
+
apps_checked_list = []
|
|
618
|
+
for app, app_title in get_app_and_state():
|
|
619
|
+
apps_checked_list.append(app_title)
|
|
620
|
+
if not websocket_ticket_store:
|
|
621
|
+
websocket_ticket_store = getattr(app.state, "websocket_ticket_store", None)
|
|
622
|
+
if websocket_ticket_store:
|
|
623
|
+
logger.debug(
|
|
624
|
+
f"Found websocket_ticket_store on app '{app_title}' "
|
|
625
|
+
f"for app_slug '{app_slug}'"
|
|
626
|
+
)
|
|
627
|
+
if not websocket_session_manager:
|
|
628
|
+
websocket_session_manager = getattr(app.state, "websocket_session_manager", None)
|
|
629
|
+
if websocket_session_manager:
|
|
630
|
+
logger.info(
|
|
631
|
+
f"✅ Found websocket_session_manager on app '{app_title}' "
|
|
632
|
+
f"for app_slug '{app_slug}' (checked: {apps_checked_list})"
|
|
633
|
+
)
|
|
634
|
+
if not user_pool:
|
|
635
|
+
user_pool = getattr(app.state, "user_pool", None)
|
|
636
|
+
if user_pool:
|
|
637
|
+
logger.debug(f"Found user_pool on app '{app_title}' for app_slug '{app_slug}'")
|
|
403
638
|
|
|
404
|
-
if not
|
|
405
|
-
logger.
|
|
406
|
-
f"
|
|
407
|
-
f"(
|
|
408
|
-
f"Available cookies: {list(cookies.keys()) if cookies else 'none'}. "
|
|
409
|
-
f"Ensure httpOnly cookie is set during authentication."
|
|
639
|
+
if not websocket_session_manager:
|
|
640
|
+
logger.warning(
|
|
641
|
+
f"⚠️ websocket_session_manager not found for '{app_slug}' "
|
|
642
|
+
f"(checked apps: {apps_checked_list})"
|
|
410
643
|
)
|
|
411
|
-
if require_auth:
|
|
412
|
-
return None, None # Signal auth failure
|
|
413
|
-
return None, None
|
|
414
644
|
|
|
415
|
-
|
|
416
|
-
|
|
645
|
+
# Method 1: Ticket authentication (preferred)
|
|
646
|
+
ticket = None
|
|
647
|
+
try:
|
|
648
|
+
if hasattr(websocket, "query_params"):
|
|
649
|
+
ticket = websocket.query_params.get("ticket")
|
|
650
|
+
if not ticket and hasattr(websocket, "headers"):
|
|
651
|
+
ticket = websocket.headers.get("X-WebSocket-Ticket")
|
|
652
|
+
except (AttributeError, TypeError, KeyError):
|
|
653
|
+
pass
|
|
654
|
+
|
|
655
|
+
if ticket and websocket_ticket_store:
|
|
656
|
+
try:
|
|
657
|
+
ticket_data = await websocket_ticket_store.validate_and_consume_ticket(ticket)
|
|
658
|
+
if ticket_data:
|
|
659
|
+
user_id = ticket_data.get("user_id")
|
|
660
|
+
user_email = ticket_data.get("user_email")
|
|
661
|
+
logger.info(
|
|
662
|
+
f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
|
|
663
|
+
f"(method: ticket)"
|
|
664
|
+
)
|
|
665
|
+
return user_id, user_email
|
|
666
|
+
except (ValueError, TypeError, AttributeError, KeyError, RuntimeError) as e:
|
|
667
|
+
logger.debug(f"Ticket validation failed: {e}")
|
|
668
|
+
|
|
669
|
+
# Method 2: Session key authentication (fallback)
|
|
670
|
+
session_key = None
|
|
671
|
+
try:
|
|
672
|
+
if hasattr(websocket, "query_params"):
|
|
673
|
+
# Try dict-like access first
|
|
674
|
+
if hasattr(websocket.query_params, "get"):
|
|
675
|
+
session_key = websocket.query_params.get("session_key")
|
|
676
|
+
# Fallback: try dict access
|
|
677
|
+
elif isinstance(websocket.query_params, dict):
|
|
678
|
+
session_key = websocket.query_params.get("session_key")
|
|
679
|
+
# Fallback: try attribute access
|
|
680
|
+
elif hasattr(websocket.query_params, "session_key"):
|
|
681
|
+
session_key = websocket.query_params.session_key
|
|
682
|
+
if not session_key and hasattr(websocket, "headers"):
|
|
683
|
+
session_key = websocket.headers.get("X-WebSocket-Session-Key")
|
|
684
|
+
except (AttributeError, TypeError, KeyError) as e:
|
|
685
|
+
logger.debug(f"Error extracting session key from websocket: {e}")
|
|
686
|
+
pass
|
|
687
|
+
|
|
688
|
+
logger.debug(
|
|
689
|
+
f"Session key extraction for '{app_slug}': "
|
|
690
|
+
f"session_key={'present' if session_key else 'missing'}, "
|
|
691
|
+
f"websocket_session_manager={'present' if websocket_session_manager else 'missing'}"
|
|
417
692
|
)
|
|
418
693
|
|
|
419
|
-
|
|
420
|
-
|
|
694
|
+
if session_key and websocket_session_manager:
|
|
695
|
+
try:
|
|
696
|
+
logger.debug(
|
|
697
|
+
f"Validating session key for '{app_slug}': "
|
|
698
|
+
f"session_key={session_key[:16]}... (truncated), "
|
|
699
|
+
f"manager_type={type(websocket_session_manager).__name__}"
|
|
700
|
+
)
|
|
701
|
+
session_data = await websocket_session_manager.validate_session(session_key)
|
|
702
|
+
if session_data:
|
|
703
|
+
user_id = session_data.get("user_id")
|
|
704
|
+
user_email = session_data.get("user_email")
|
|
705
|
+
logger.info(
|
|
706
|
+
f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
|
|
707
|
+
f"(method: session_key)"
|
|
708
|
+
)
|
|
709
|
+
return user_id, user_email
|
|
710
|
+
else:
|
|
711
|
+
logger.warning(
|
|
712
|
+
f"Session key validation returned None for '{app_slug}': "
|
|
713
|
+
f"session_key={session_key[:16]}... (truncated)"
|
|
714
|
+
)
|
|
715
|
+
except RuntimeError as e:
|
|
716
|
+
# Handle event loop conflicts (e.g., TestClient with anyio)
|
|
717
|
+
error_msg = str(e).lower()
|
|
718
|
+
if "attached to a different loop" in error_msg or "different loop" in error_msg:
|
|
719
|
+
logger.warning(
|
|
720
|
+
f"Event loop conflict during session key validation for '{app_slug}': {e}. "
|
|
721
|
+
"This may occur in test environments with TestClient. "
|
|
722
|
+
"Session key validation failed."
|
|
723
|
+
)
|
|
724
|
+
# Return None to indicate authentication failure
|
|
725
|
+
# Don't re-raise - let the normal auth failure path handle it
|
|
726
|
+
else:
|
|
727
|
+
# Re-raise other RuntimeErrors
|
|
728
|
+
logger.warning(
|
|
729
|
+
f"Session key validation failed for '{app_slug}': {e}",
|
|
730
|
+
exc_info=True,
|
|
731
|
+
)
|
|
732
|
+
raise
|
|
733
|
+
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
|
734
|
+
logger.warning(
|
|
735
|
+
f"Session key validation failed for '{app_slug}': {e}",
|
|
736
|
+
exc_info=True,
|
|
737
|
+
)
|
|
738
|
+
elif session_key and not websocket_session_manager:
|
|
739
|
+
logger.error(
|
|
740
|
+
f"Session key provided for '{app_slug}' but websocket_session_manager not found. "
|
|
741
|
+
"Cannot validate session key."
|
|
742
|
+
)
|
|
421
743
|
|
|
422
|
-
|
|
423
|
-
from ..auth.
|
|
744
|
+
# Method 3: Cookie-based JWT authentication (backward compatibility)
|
|
745
|
+
from ..auth.shared_middleware import AUTH_COOKIE_NAME
|
|
424
746
|
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
user_id = payload.get("sub") or payload.get("user_id")
|
|
428
|
-
user_email = payload.get("email")
|
|
747
|
+
cookies = _get_cookies_from_websocket(websocket)
|
|
748
|
+
auth_token = cookies.get(AUTH_COOKIE_NAME) if cookies else None
|
|
429
749
|
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
750
|
+
if auth_token and user_pool:
|
|
751
|
+
try:
|
|
752
|
+
# For tests that expect JWT errors to be raised, we need to validate
|
|
753
|
+
# the token structure first. However, validate_token handles JWT errors
|
|
754
|
+
# internally, so we'll let it handle validation and only catch non-JWT errors.
|
|
755
|
+
user = await user_pool.validate_token(auth_token)
|
|
756
|
+
if user:
|
|
757
|
+
user_id = str(user.get("_id") or user.get("sub") or user.get("user_id"))
|
|
758
|
+
user_email = user.get("email")
|
|
759
|
+
logger.info(
|
|
760
|
+
f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
|
|
761
|
+
f"(method: cookie)"
|
|
762
|
+
)
|
|
763
|
+
return user_id, user_email
|
|
764
|
+
except (
|
|
765
|
+
ValueError,
|
|
766
|
+
TypeError,
|
|
767
|
+
AttributeError,
|
|
768
|
+
KeyError,
|
|
769
|
+
RuntimeError,
|
|
770
|
+
) as e:
|
|
771
|
+
# Handle errors from token validation or user data access
|
|
772
|
+
logger.debug(f"Cookie token validation failed: {e}")
|
|
773
|
+
# Note: JWT errors (DecodeError, ExpiredSignatureError) are caught
|
|
774
|
+
# internally by validate_token and return None. For tests that need JWT
|
|
775
|
+
# errors raised, they should mock validate_token to raise them directly.
|
|
776
|
+
|
|
777
|
+
# No authentication method succeeded
|
|
778
|
+
if require_auth:
|
|
779
|
+
logger.error(
|
|
780
|
+
f"❌ WebSocket authentication failed for app '{app_slug}'. "
|
|
781
|
+
"No valid ticket, session key, or cookie found. "
|
|
782
|
+
"Generate ticket via /auth/ticket endpoint, "
|
|
783
|
+
"session key via /auth/websocket-session, "
|
|
784
|
+
"or ensure JWT cookie is present."
|
|
439
785
|
)
|
|
440
|
-
|
|
786
|
+
return None, None
|
|
787
|
+
|
|
788
|
+
return None, None
|
|
441
789
|
|
|
442
790
|
except WebSocketDisconnect:
|
|
443
791
|
raise
|
|
@@ -662,13 +1010,46 @@ def create_websocket_endpoint(
|
|
|
662
1010
|
# This print should appear in server logs when a WebSocket connection is attempted
|
|
663
1011
|
import sys
|
|
664
1012
|
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
1013
|
+
# Log BEFORE any operations to catch if handler is called at all
|
|
1014
|
+
try:
|
|
1015
|
+
print(
|
|
1016
|
+
f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}",
|
|
1017
|
+
file=sys.stderr,
|
|
1018
|
+
flush=True,
|
|
1019
|
+
)
|
|
1020
|
+
print(f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}", flush=True)
|
|
1021
|
+
logger.info(f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}")
|
|
1022
|
+
|
|
1023
|
+
# Try to access websocket properties to see if we can even access it
|
|
1024
|
+
try:
|
|
1025
|
+
ws_path = (
|
|
1026
|
+
getattr(websocket, "url", {}).path if hasattr(websocket, "url") else "unknown"
|
|
1027
|
+
)
|
|
1028
|
+
ws_headers = (
|
|
1029
|
+
dict(getattr(websocket, "headers", {})) if hasattr(websocket, "headers") else {}
|
|
1030
|
+
)
|
|
1031
|
+
origin = ws_headers.get("origin", "missing")
|
|
1032
|
+
print(
|
|
1033
|
+
f"🔌 [WEBSOCKET DETAILS] Path: {ws_path}, Origin: {origin}, "
|
|
1034
|
+
f"Headers: {list(ws_headers.keys())}",
|
|
1035
|
+
file=sys.stderr,
|
|
1036
|
+
flush=True,
|
|
1037
|
+
)
|
|
1038
|
+
except (AttributeError, RuntimeError, TypeError) as access_error:
|
|
1039
|
+
print(
|
|
1040
|
+
f"⚠️ [WEBSOCKET ACCESS ERROR] "
|
|
1041
|
+
f"Could not access websocket properties: {access_error}",
|
|
1042
|
+
file=sys.stderr,
|
|
1043
|
+
flush=True,
|
|
1044
|
+
)
|
|
1045
|
+
except (RuntimeError, AttributeError) as log_error:
|
|
1046
|
+
# Even logging failed - this is very bad
|
|
1047
|
+
print(
|
|
1048
|
+
f"❌ [CRITICAL] Failed to log WebSocket handler call: {log_error}",
|
|
1049
|
+
file=sys.stderr,
|
|
1050
|
+
flush=True,
|
|
1051
|
+
)
|
|
1052
|
+
|
|
672
1053
|
connection = None
|
|
673
1054
|
try:
|
|
674
1055
|
# Log connection attempt with query params (can access before accept)
|
|
@@ -693,6 +1074,42 @@ def create_websocket_endpoint(
|
|
|
693
1074
|
f"(require_auth={require_auth}, query_params={query_str})"
|
|
694
1075
|
)
|
|
695
1076
|
|
|
1077
|
+
# CRITICAL: Accept FIRST, then validate (FastAPI requires immediate accept)
|
|
1078
|
+
# If we don't accept immediately, FastAPI closes with code 1000
|
|
1079
|
+
# We'll validate origin after accept, but before processing messages
|
|
1080
|
+
try:
|
|
1081
|
+
await websocket.accept()
|
|
1082
|
+
logger.info(
|
|
1083
|
+
f"✅ WebSocket accepted for app '{app_slug}' (before origin validation)"
|
|
1084
|
+
)
|
|
1085
|
+
print(f"✅ [WEBSOCKET ACCEPTED] App: '{app_slug}'", flush=True)
|
|
1086
|
+
except (RuntimeError, ConnectionError, OSError, ValueError) as accept_error:
|
|
1087
|
+
logger.error(
|
|
1088
|
+
f"❌ Failed to accept WebSocket for app '{app_slug}': {accept_error}",
|
|
1089
|
+
exc_info=True,
|
|
1090
|
+
)
|
|
1091
|
+
print(
|
|
1092
|
+
f"❌ [WEBSOCKET ACCEPT FAILED] App: '{app_slug}', Error: {accept_error}",
|
|
1093
|
+
flush=True,
|
|
1094
|
+
)
|
|
1095
|
+
return
|
|
1096
|
+
|
|
1097
|
+
# CRITICAL: Validate origin AFTER accepting (CSWSH protection)
|
|
1098
|
+
# We accept first to satisfy FastAPI's requirements, then validate
|
|
1099
|
+
origin_valid = await _validate_websocket_origin_in_handler(websocket, app_slug)
|
|
1100
|
+
if not origin_valid:
|
|
1101
|
+
logger.error(
|
|
1102
|
+
f"❌ WebSocket origin validation FAILED for app '{app_slug}' - "
|
|
1103
|
+
f"closing connection after accept"
|
|
1104
|
+
)
|
|
1105
|
+
try:
|
|
1106
|
+
await websocket.close(code=1008, reason="Invalid origin")
|
|
1107
|
+
except (RuntimeError, ConnectionError, OSError):
|
|
1108
|
+
# WebSocket may already be closed or in invalid state
|
|
1109
|
+
pass
|
|
1110
|
+
# Raise WebSocketDisconnect so TestClient can detect the rejection
|
|
1111
|
+
raise WebSocketDisconnect(code=1008, reason="Invalid origin")
|
|
1112
|
+
|
|
696
1113
|
# CRITICAL: Authenticate BEFORE accepting connection
|
|
697
1114
|
# This prevents CSRF middleware from rejecting established connections
|
|
698
1115
|
# We can access headers/query_params before accept() is called
|
|
@@ -717,14 +1134,17 @@ def create_websocket_endpoint(
|
|
|
717
1134
|
f"rejecting connection. require_auth={require_auth}, "
|
|
718
1135
|
f"user_id={user_id}, user_email={user_email}"
|
|
719
1136
|
)
|
|
720
|
-
#
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
1137
|
+
# Close connection since we already accepted it
|
|
1138
|
+
try:
|
|
1139
|
+
await websocket.close(code=1008, reason="Authentication failed")
|
|
1140
|
+
except (RuntimeError, ConnectionError, OSError):
|
|
1141
|
+
# WebSocket may already be closed or in invalid state
|
|
1142
|
+
pass
|
|
1143
|
+
# Raise WebSocketDisconnect so TestClient can detect the close
|
|
1144
|
+
# WebSocketDisconnect is already imported at module level
|
|
1145
|
+
raise WebSocketDisconnect(code=1008, reason="Authentication failed")
|
|
1146
|
+
|
|
1147
|
+
# Connection already accepted at line 984 - no need to accept again
|
|
728
1148
|
# Connect with metadata (websocket already accepted)
|
|
729
1149
|
connection = await manager.connect(websocket, user_id=user_id, user_email=user_email)
|
|
730
1150
|
|
|
@@ -788,8 +1208,11 @@ def create_websocket_endpoint(
|
|
|
788
1208
|
logger.warning(f"WebSocket receive error for app '{app_slug}': {e}")
|
|
789
1209
|
await asyncio.sleep(0.1)
|
|
790
1210
|
|
|
1211
|
+
except WebSocketDisconnect as e:
|
|
1212
|
+
# Re-raise WebSocketDisconnect so TestClient can detect it
|
|
1213
|
+
logger.warning(f"WebSocket connection rejected for app '{app_slug}': {e}")
|
|
1214
|
+
raise
|
|
791
1215
|
except (
|
|
792
|
-
WebSocketDisconnect,
|
|
793
1216
|
RuntimeError,
|
|
794
1217
|
OSError,
|
|
795
1218
|
ValueError,
|