mdb-engine 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -270,6 +270,31 @@ _manager_lock = asyncio.Lock()
270
270
  # Note: Registration happens synchronously during app startup, so no lock needed
271
271
  _message_handlers: dict[str, dict[str, Callable[[Any, dict[str, Any]], Awaitable[None]]]] = {}
272
272
 
273
+ # Global WebSocket ticket store (shared across all apps in multi-app setups)
274
+ # This is set when the engine initializes and is used for ticket authentication
275
+ _global_websocket_ticket_store: Any | None = None
276
+
277
+
278
+ def set_global_websocket_ticket_store(ticket_store: Any) -> None:
279
+ """
280
+ Set the global WebSocket ticket store.
281
+
282
+ This is called by the engine during initialization to make the ticket store
283
+ accessible to all WebSocket endpoints, even when routes are registered via routers.
284
+
285
+ Args:
286
+ ticket_store: WebSocketTicketStore instance
287
+ """
288
+ global _global_websocket_ticket_store
289
+ _global_websocket_ticket_store = ticket_store
290
+ if ticket_store:
291
+ logger.info(
292
+ f"✅ Global WebSocket ticket store set for multi-app WebSocket authentication "
293
+ f"(store type: {type(ticket_store).__name__})"
294
+ )
295
+ else:
296
+ logger.warning("⚠️ Global WebSocket ticket store set to None")
297
+
273
298
 
274
299
  async def get_websocket_manager(app_slug: str) -> WebSocketConnectionManager:
275
300
  """
@@ -359,17 +384,190 @@ def _get_cookies_from_websocket(websocket: Any) -> dict[str, str]:
359
384
  return cookies
360
385
 
361
386
 
387
+ async def _validate_websocket_origin_in_handler(websocket: Any, app_slug: str) -> bool:
388
+ """
389
+ Validate WebSocket Origin header in handler (since middleware may not intercept upgrades).
390
+
391
+ This provides CSWSH (Cross-Site WebSocket Hijacking) protection by validating
392
+ the Origin header before accepting the WebSocket connection.
393
+
394
+ Args:
395
+ websocket: FastAPI WebSocket instance
396
+ app_slug: App slug for context
397
+
398
+ Returns:
399
+ True if origin is valid, False otherwise
400
+ """
401
+ try:
402
+ # Get origin from headers
403
+ origin = None
404
+ if hasattr(websocket, "headers"):
405
+ origin = websocket.headers.get("origin")
406
+ elif hasattr(websocket, "scope") and "headers" in websocket.scope:
407
+ # Extract from ASGI scope headers
408
+ headers_dict = dict(websocket.scope["headers"])
409
+ origin_bytes = headers_dict.get(b"origin")
410
+ if not origin_bytes:
411
+ # Try case-insensitive lookup
412
+ for key, value in headers_dict.items():
413
+ if isinstance(key, bytes) and key.lower() == b"origin":
414
+ origin_bytes = value
415
+ break
416
+ if origin_bytes:
417
+ origin = origin_bytes.decode("utf-8")
418
+
419
+ logger.info(
420
+ f"🔍 WebSocket origin validation in handler for '{app_slug}': " f"origin={origin}"
421
+ )
422
+
423
+ # Get allowed origins from app state (CORS config) FIRST
424
+ # This allows us to check if wildcard is allowed before rejecting missing origin
425
+ allowed_origins = []
426
+ try:
427
+ app = getattr(websocket, "app", None)
428
+ if app:
429
+ # Traverse up to find parent app with CORS config
430
+ while app:
431
+ cors_config = getattr(app.state, "cors_config", None)
432
+ if cors_config and cors_config.get("allow_origins"):
433
+ origins = cors_config["allow_origins"]
434
+ allowed_origins = origins if isinstance(origins, list) else [origins]
435
+ logger.debug(
436
+ f"Found CORS config on app '{getattr(app, 'title', 'unknown')}' "
437
+ f"for '{app_slug}': {allowed_origins}"
438
+ )
439
+ break
440
+ parent_app = getattr(app, "app", None)
441
+ if parent_app is app:
442
+ break
443
+ app = parent_app
444
+ except (AttributeError, TypeError, KeyError) as e:
445
+ logger.debug(f"Could not read CORS config: {e}")
446
+
447
+ # Check if CORS allows all origins - if so, allow missing origin
448
+ if allowed_origins and "*" in allowed_origins:
449
+ if not origin:
450
+ logger.info(
451
+ f"WebSocket upgrade missing Origin header for '{app_slug}', "
452
+ "but CORS allows all origins (*) - allowing connection"
453
+ )
454
+ return True
455
+
456
+ if not origin:
457
+ # In development, allow connections without Origin
458
+ import os
459
+
460
+ env = os.getenv("ENVIRONMENT", "").lower()
461
+ g_nome_env = os.getenv("G_NOME_ENV", "").lower()
462
+ is_dev = env in ["development", "dev"] or g_nome_env in ["development", "dev"]
463
+ if is_dev:
464
+ logger.warning(
465
+ f"WebSocket upgrade missing Origin header in development mode: "
466
+ f"allowing connection for '{app_slug}'"
467
+ )
468
+ return True
469
+ else:
470
+ logger.warning(f"WebSocket upgrade missing Origin header for '{app_slug}'")
471
+ return False
472
+
473
+ # Normalize origin for comparison
474
+ def normalize_origin(orig: str) -> str:
475
+ """Normalize origin handling localhost variants."""
476
+ if not orig:
477
+ return orig
478
+ import re
479
+
480
+ normalized = re.sub(
481
+ r"://(0\.0\.0\.0|127\.0\.0\.1|localhost|::1)",
482
+ "://localhost",
483
+ orig.lower(),
484
+ flags=re.IGNORECASE,
485
+ )
486
+ normalized = re.sub(r":80$", "", normalized)
487
+ normalized = re.sub(r":443$", "", normalized)
488
+ return normalized.rstrip("/")
489
+
490
+ normalized_origin = normalize_origin(origin)
491
+ logger.debug(f"Normalized origin: {origin} -> {normalized_origin}")
492
+
493
+ # If no CORS config, generate fallback origins from request
494
+ if not allowed_origins:
495
+ # Try to get host from websocket scope
496
+ try:
497
+ if hasattr(websocket, "scope"):
498
+ host = websocket.scope.get("server")
499
+ if host:
500
+ hostname = host[0] if isinstance(host, tuple) else host
501
+ scheme = websocket.scope.get("scheme", "http")
502
+ port = host[1] if isinstance(host, tuple) and len(host) > 1 else None
503
+
504
+ # Generate localhost variants if server binds to 0.0.0.0
505
+ if hostname in ["localhost", "0.0.0.0", "127.0.0.1", "::1"]:
506
+ for variant in ["localhost", "127.0.0.1"]:
507
+ if port and port not in [80, 443]:
508
+ allowed_origins.append(f"{scheme}://{variant}:{port}")
509
+ else:
510
+ allowed_origins.append(f"{scheme}://{variant}")
511
+ else:
512
+ if port and port not in [80, 443]:
513
+ allowed_origins.append(f"{scheme}://{hostname}:{port}")
514
+ else:
515
+ allowed_origins.append(f"{scheme}://{hostname}")
516
+ except (AttributeError, TypeError, KeyError) as e:
517
+ logger.debug(f"Could not determine fallback origins: {e}")
518
+
519
+ logger.debug(f"Allowed origins for '{app_slug}': {allowed_origins}")
520
+
521
+ # Check if origin matches any allowed origin
522
+ for allowed in allowed_origins:
523
+ if allowed == "*":
524
+ logger.warning(
525
+ "WebSocket Origin validation using wildcard '*' - "
526
+ "not recommended for production"
527
+ )
528
+ return True
529
+
530
+ normalized_allowed = normalize_origin(allowed)
531
+ if normalized_origin == normalized_allowed:
532
+ logger.info(
533
+ f"✅ WebSocket origin validated in handler for '{app_slug}': "
534
+ f"{origin} matches {allowed}"
535
+ )
536
+ return True
537
+
538
+ logger.warning(
539
+ f"❌ WebSocket origin validation failed for '{app_slug}': "
540
+ f"origin={origin} (normalized: {normalized_origin}), "
541
+ f"allowed={allowed_origins} "
542
+ f"(normalized: {[normalize_origin(a) for a in allowed_origins]})"
543
+ )
544
+ return False
545
+
546
+ except (
547
+ AttributeError,
548
+ KeyError,
549
+ UnicodeDecodeError,
550
+ TypeError,
551
+ ValueError,
552
+ RuntimeError,
553
+ ) as e:
554
+ logger.error(f"❌ Error validating WebSocket origin for '{app_slug}': {e}", exc_info=True)
555
+ # Fail secure - reject if we can't validate
556
+ return False
557
+
558
+
362
559
  async def authenticate_websocket(
363
560
  websocket: Any,
364
561
  app_slug: str,
365
562
  require_auth: bool = True,
366
563
  ) -> tuple[str | None, str | None]:
367
564
  """
368
- Authenticate a WebSocket connection via session key or httpOnly cookies.
565
+ Authenticate a WebSocket connection using multiple methods with fallback.
369
566
 
370
567
  Authentication methods (in order of preference):
371
- 1. Session key (query param or header) - secure-by-default, uses envelope encryption
372
- 2. Cookie-based authentication - backward compatibility fallback
568
+ 1. Ticket (query param or header) - short-lived, single-use, secure for multi-app SSO
569
+ 2. Session key (query param or header) - encrypted session-based authentication
570
+ 3. Cookie (httpOnly JWT token) - backward compatibility fallback
373
571
 
374
572
  Args:
375
573
  websocket: FastAPI WebSocket instance (can access headers before accept)
@@ -393,36 +591,117 @@ async def authenticate_websocket(
393
591
  return None, None
394
592
 
395
593
  try:
396
- # Try to get WebSocket session manager from app
397
- websocket_session_manager = None
398
- try:
594
+ # Helper function to get app and traverse hierarchy
595
+ def get_app_and_state():
596
+ """Get app instance and traverse hierarchy to find state."""
399
597
  app = getattr(websocket, "app", None)
400
- if app:
598
+ apps_checked = []
599
+ max_iterations = 10 # Safety limit to prevent infinite loops
600
+ iteration = 0
601
+ while app and iteration < max_iterations:
602
+ iteration += 1
603
+ app_title = getattr(app, "title", "unknown")
604
+ apps_checked.append(app_title)
605
+ yield app, app_title
606
+ # Try to get parent app (FastAPI routers have .app attribute pointing to parent)
607
+ parent_app = getattr(app, "app", None)
608
+ if parent_app is app or parent_app is None: # Prevent infinite loop
609
+ break
610
+ app = parent_app
611
+
612
+ # Try to get services from app state
613
+ websocket_ticket_store = _global_websocket_ticket_store
614
+ websocket_session_manager = None
615
+ user_pool = None
616
+
617
+ apps_checked_list = []
618
+ for app, app_title in get_app_and_state():
619
+ apps_checked_list.append(app_title)
620
+ if not websocket_ticket_store:
621
+ websocket_ticket_store = getattr(app.state, "websocket_ticket_store", None)
622
+ if websocket_ticket_store:
623
+ logger.debug(
624
+ f"Found websocket_ticket_store on app '{app_title}' "
625
+ f"for app_slug '{app_slug}'"
626
+ )
627
+ if not websocket_session_manager:
401
628
  websocket_session_manager = getattr(app.state, "websocket_session_manager", None)
402
- except (AttributeError, TypeError):
629
+ if websocket_session_manager:
630
+ logger.info(
631
+ f"✅ Found websocket_session_manager on app '{app_title}' "
632
+ f"for app_slug '{app_slug}' (checked: {apps_checked_list})"
633
+ )
634
+ if not user_pool:
635
+ user_pool = getattr(app.state, "user_pool", None)
636
+ if user_pool:
637
+ logger.debug(f"Found user_pool on app '{app_title}' for app_slug '{app_slug}'")
638
+
639
+ if not websocket_session_manager:
640
+ logger.warning(
641
+ f"⚠️ websocket_session_manager not found for '{app_slug}' "
642
+ f"(checked apps: {apps_checked_list})"
643
+ )
644
+
645
+ # Method 1: Ticket authentication (preferred)
646
+ ticket = None
647
+ try:
648
+ if hasattr(websocket, "query_params"):
649
+ ticket = websocket.query_params.get("ticket")
650
+ if not ticket and hasattr(websocket, "headers"):
651
+ ticket = websocket.headers.get("X-WebSocket-Ticket")
652
+ except (AttributeError, TypeError, KeyError):
403
653
  pass
404
654
 
405
- # Method 1: Try session key authentication (secure-by-default)
655
+ if ticket and websocket_ticket_store:
656
+ try:
657
+ ticket_data = await websocket_ticket_store.validate_and_consume_ticket(ticket)
658
+ if ticket_data:
659
+ user_id = ticket_data.get("user_id")
660
+ user_email = ticket_data.get("user_email")
661
+ logger.info(
662
+ f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
663
+ f"(method: ticket)"
664
+ )
665
+ return user_id, user_email
666
+ except (ValueError, TypeError, AttributeError, KeyError, RuntimeError) as e:
667
+ logger.debug(f"Ticket validation failed: {e}")
668
+
669
+ # Method 2: Session key authentication (fallback)
406
670
  session_key = None
407
671
  try:
408
- # Check query params first
409
672
  if hasattr(websocket, "query_params"):
410
- session_key = websocket.query_params.get("session_key")
411
-
412
- # Check headers if not in query params
673
+ # Try dict-like access first
674
+ if hasattr(websocket.query_params, "get"):
675
+ session_key = websocket.query_params.get("session_key")
676
+ # Fallback: try dict access
677
+ elif isinstance(websocket.query_params, dict):
678
+ session_key = websocket.query_params.get("session_key")
679
+ # Fallback: try attribute access
680
+ elif hasattr(websocket.query_params, "session_key"):
681
+ session_key = websocket.query_params.session_key
413
682
  if not session_key and hasattr(websocket, "headers"):
414
683
  session_key = websocket.headers.get("X-WebSocket-Session-Key")
415
- except (AttributeError, TypeError, KeyError):
684
+ except (AttributeError, TypeError, KeyError) as e:
685
+ logger.debug(f"Error extracting session key from websocket: {e}")
416
686
  pass
417
687
 
688
+ logger.debug(
689
+ f"Session key extraction for '{app_slug}': "
690
+ f"session_key={'present' if session_key else 'missing'}, "
691
+ f"websocket_session_manager={'present' if websocket_session_manager else 'missing'}"
692
+ )
693
+
418
694
  if session_key and websocket_session_manager:
419
695
  try:
420
- # Validate session key
696
+ logger.debug(
697
+ f"Validating session key for '{app_slug}': "
698
+ f"session_key={session_key[:16]}... (truncated), "
699
+ f"manager_type={type(websocket_session_manager).__name__}"
700
+ )
421
701
  session_data = await websocket_session_manager.validate_session(session_key)
422
702
  if session_data:
423
703
  user_id = session_data.get("user_id")
424
704
  user_email = session_data.get("user_email")
425
-
426
705
  logger.info(
427
706
  f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
428
707
  f"(method: session_key)"
@@ -430,58 +709,83 @@ async def authenticate_websocket(
430
709
  return user_id, user_email
431
710
  else:
432
711
  logger.warning(
433
- f"WebSocket session key validation failed for app '{app_slug}'. "
434
- f"Session key: {session_key[:16]}..."
712
+ f"Session key validation returned None for '{app_slug}': "
713
+ f"session_key={session_key[:16]}... (truncated)"
435
714
  )
436
- except (ValueError, TypeError, AttributeError, KeyError, RuntimeError) as e:
437
- logger.warning(f"WebSocket session key validation error for app '{app_slug}': {e}")
438
- # Fall through to cookie-based auth
715
+ except RuntimeError as e:
716
+ # Handle event loop conflicts (e.g., TestClient with anyio)
717
+ error_msg = str(e).lower()
718
+ if "attached to a different loop" in error_msg or "different loop" in error_msg:
719
+ logger.warning(
720
+ f"Event loop conflict during session key validation for '{app_slug}': {e}. "
721
+ "This may occur in test environments with TestClient. "
722
+ "Session key validation failed."
723
+ )
724
+ # Return None to indicate authentication failure
725
+ # Don't re-raise - let the normal auth failure path handle it
726
+ else:
727
+ # Re-raise other RuntimeErrors
728
+ logger.warning(
729
+ f"Session key validation failed for '{app_slug}': {e}",
730
+ exc_info=True,
731
+ )
732
+ raise
733
+ except (ValueError, TypeError, AttributeError, KeyError) as e:
734
+ logger.warning(
735
+ f"Session key validation failed for '{app_slug}': {e}",
736
+ exc_info=True,
737
+ )
738
+ elif session_key and not websocket_session_manager:
739
+ logger.error(
740
+ f"Session key provided for '{app_slug}' but websocket_session_manager not found. "
741
+ "Cannot validate session key."
742
+ )
439
743
 
440
- # Method 2: Fall back to cookie-based authentication (backward compatibility)
744
+ # Method 3: Cookie-based JWT authentication (backward compatibility)
441
745
  from ..auth.shared_middleware import AUTH_COOKIE_NAME
442
746
 
443
747
  cookies = _get_cookies_from_websocket(websocket)
444
- token = cookies.get(AUTH_COOKIE_NAME) # Use mdb_auth_token (same as shared middleware)
748
+ auth_token = cookies.get(AUTH_COOKIE_NAME) if cookies else None
445
749
 
446
- if not token:
750
+ if auth_token and user_pool:
751
+ try:
752
+ # For tests that expect JWT errors to be raised, we need to validate
753
+ # the token structure first. However, validate_token handles JWT errors
754
+ # internally, so we'll let it handle validation and only catch non-JWT errors.
755
+ user = await user_pool.validate_token(auth_token)
756
+ if user:
757
+ user_id = str(user.get("_id") or user.get("sub") or user.get("user_id"))
758
+ user_email = user.get("email")
759
+ logger.info(
760
+ f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
761
+ f"(method: cookie)"
762
+ )
763
+ return user_id, user_email
764
+ except (
765
+ ValueError,
766
+ TypeError,
767
+ AttributeError,
768
+ KeyError,
769
+ RuntimeError,
770
+ ) as e:
771
+ # Handle errors from token validation or user data access
772
+ logger.debug(f"Cookie token validation failed: {e}")
773
+ # Note: JWT errors (DecodeError, ExpiredSignatureError) are caught
774
+ # internally by validate_token and return None. For tests that need JWT
775
+ # errors raised, they should mock validate_token to raise them directly.
776
+
777
+ # No authentication method succeeded
778
+ if require_auth:
447
779
  logger.error(
448
- f"❌ No authentication found for WebSocket connection to app '{app_slug}' "
449
- f"(require_auth={require_auth}). "
450
- f"Session key: {bool(session_key)}, Cookie: {bool(token)}, "
451
- f"Available cookies: {list(cookies.keys()) if cookies else 'none'}. "
452
- f"Ensure session key or httpOnly cookie is set during authentication."
780
+ f"❌ WebSocket authentication failed for app '{app_slug}'. "
781
+ "No valid ticket, session key, or cookie found. "
782
+ "Generate ticket via /auth/ticket endpoint, "
783
+ "session key via /auth/websocket-session, "
784
+ "or ensure JWT cookie is present."
453
785
  )
454
- if require_auth:
455
- return None, None # Signal auth failure
456
786
  return None, None
457
787
 
458
- logger.info(
459
- f"WebSocket token found in cookie for app '{app_slug}' "
460
- "(cookie-based authentication, fallback)"
461
- )
462
-
463
- # Decode and validate token
464
- import jwt
465
-
466
- from ..auth.dependencies import SECRET_KEY
467
- from ..auth.jwt import decode_jwt_token
468
-
469
- try:
470
- payload = decode_jwt_token(token, str(SECRET_KEY))
471
- user_id = payload.get("sub") or payload.get("user_id")
472
- user_email = payload.get("email")
473
-
474
- logger.info(
475
- f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
476
- f"(method: cookie)"
477
- )
478
- return user_id, user_email
479
- except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
480
- logger.exception(
481
- f"❌ JWT decode error for app '{app_slug}'. "
482
- f"Token present: {bool(token)}, Token length: {len(token) if token else 0}"
483
- )
484
- raise
788
+ return None, None
485
789
 
486
790
  except WebSocketDisconnect:
487
791
  raise
@@ -706,13 +1010,46 @@ def create_websocket_endpoint(
706
1010
  # This print should appear in server logs when a WebSocket connection is attempted
707
1011
  import sys
708
1012
 
709
- print(
710
- f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}",
711
- file=sys.stderr,
712
- flush=True,
713
- )
714
- print(f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}", flush=True)
715
- logger.info(f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}")
1013
+ # Log BEFORE any operations to catch if handler is called at all
1014
+ try:
1015
+ print(
1016
+ f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}",
1017
+ file=sys.stderr,
1018
+ flush=True,
1019
+ )
1020
+ print(f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}", flush=True)
1021
+ logger.info(f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}")
1022
+
1023
+ # Try to access websocket properties to see if we can even access it
1024
+ try:
1025
+ ws_path = (
1026
+ getattr(websocket, "url", {}).path if hasattr(websocket, "url") else "unknown"
1027
+ )
1028
+ ws_headers = (
1029
+ dict(getattr(websocket, "headers", {})) if hasattr(websocket, "headers") else {}
1030
+ )
1031
+ origin = ws_headers.get("origin", "missing")
1032
+ print(
1033
+ f"🔌 [WEBSOCKET DETAILS] Path: {ws_path}, Origin: {origin}, "
1034
+ f"Headers: {list(ws_headers.keys())}",
1035
+ file=sys.stderr,
1036
+ flush=True,
1037
+ )
1038
+ except (AttributeError, RuntimeError, TypeError) as access_error:
1039
+ print(
1040
+ f"⚠️ [WEBSOCKET ACCESS ERROR] "
1041
+ f"Could not access websocket properties: {access_error}",
1042
+ file=sys.stderr,
1043
+ flush=True,
1044
+ )
1045
+ except (RuntimeError, AttributeError) as log_error:
1046
+ # Even logging failed - this is very bad
1047
+ print(
1048
+ f"❌ [CRITICAL] Failed to log WebSocket handler call: {log_error}",
1049
+ file=sys.stderr,
1050
+ flush=True,
1051
+ )
1052
+
716
1053
  connection = None
717
1054
  try:
718
1055
  # Log connection attempt with query params (can access before accept)
@@ -737,6 +1074,42 @@ def create_websocket_endpoint(
737
1074
  f"(require_auth={require_auth}, query_params={query_str})"
738
1075
  )
739
1076
 
1077
+ # CRITICAL: Accept FIRST, then validate (FastAPI requires immediate accept)
1078
+ # If we don't accept immediately, FastAPI closes with code 1000
1079
+ # We'll validate origin after accept, but before processing messages
1080
+ try:
1081
+ await websocket.accept()
1082
+ logger.info(
1083
+ f"✅ WebSocket accepted for app '{app_slug}' (before origin validation)"
1084
+ )
1085
+ print(f"✅ [WEBSOCKET ACCEPTED] App: '{app_slug}'", flush=True)
1086
+ except (RuntimeError, ConnectionError, OSError, ValueError) as accept_error:
1087
+ logger.error(
1088
+ f"❌ Failed to accept WebSocket for app '{app_slug}': {accept_error}",
1089
+ exc_info=True,
1090
+ )
1091
+ print(
1092
+ f"❌ [WEBSOCKET ACCEPT FAILED] App: '{app_slug}', Error: {accept_error}",
1093
+ flush=True,
1094
+ )
1095
+ return
1096
+
1097
+ # CRITICAL: Validate origin AFTER accepting (CSWSH protection)
1098
+ # We accept first to satisfy FastAPI's requirements, then validate
1099
+ origin_valid = await _validate_websocket_origin_in_handler(websocket, app_slug)
1100
+ if not origin_valid:
1101
+ logger.error(
1102
+ f"❌ WebSocket origin validation FAILED for app '{app_slug}' - "
1103
+ f"closing connection after accept"
1104
+ )
1105
+ try:
1106
+ await websocket.close(code=1008, reason="Invalid origin")
1107
+ except (RuntimeError, ConnectionError, OSError):
1108
+ # WebSocket may already be closed or in invalid state
1109
+ pass
1110
+ # Raise WebSocketDisconnect so TestClient can detect the rejection
1111
+ raise WebSocketDisconnect(code=1008, reason="Invalid origin")
1112
+
740
1113
  # CRITICAL: Authenticate BEFORE accepting connection
741
1114
  # This prevents CSRF middleware from rejecting established connections
742
1115
  # We can access headers/query_params before accept() is called
@@ -761,14 +1134,17 @@ def create_websocket_endpoint(
761
1134
  f"rejecting connection. require_auth={require_auth}, "
762
1135
  f"user_id={user_id}, user_email={user_email}"
763
1136
  )
764
- # Reject without accepting - FastAPI will send 403 if accept() not called
765
- # We can't call websocket.close() before accept(), so we just return
766
- # The connection will be rejected by the server
767
- return
768
-
769
- # Accept connection
770
- await _accept_websocket_connection(websocket, app_slug)
771
-
1137
+ # Close connection since we already accepted it
1138
+ try:
1139
+ await websocket.close(code=1008, reason="Authentication failed")
1140
+ except (RuntimeError, ConnectionError, OSError):
1141
+ # WebSocket may already be closed or in invalid state
1142
+ pass
1143
+ # Raise WebSocketDisconnect so TestClient can detect the close
1144
+ # WebSocketDisconnect is already imported at module level
1145
+ raise WebSocketDisconnect(code=1008, reason="Authentication failed")
1146
+
1147
+ # Connection already accepted at line 984 - no need to accept again
772
1148
  # Connect with metadata (websocket already accepted)
773
1149
  connection = await manager.connect(websocket, user_id=user_id, user_email=user_email)
774
1150
 
@@ -832,8 +1208,11 @@ def create_websocket_endpoint(
832
1208
  logger.warning(f"WebSocket receive error for app '{app_slug}': {e}")
833
1209
  await asyncio.sleep(0.1)
834
1210
 
1211
+ except WebSocketDisconnect as e:
1212
+ # Re-raise WebSocketDisconnect so TestClient can detect it
1213
+ logger.warning(f"WebSocket connection rejected for app '{app_slug}': {e}")
1214
+ raise
835
1215
  except (
836
- WebSocketDisconnect,
837
1216
  RuntimeError,
838
1217
  OSError,
839
1218
  ValueError,