mdb-engine 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -270,6 +270,31 @@ _manager_lock = asyncio.Lock()
270
270
  # Note: Registration happens synchronously during app startup, so no lock needed
271
271
  _message_handlers: dict[str, dict[str, Callable[[Any, dict[str, Any]], Awaitable[None]]]] = {}
272
272
 
273
+ # Global WebSocket ticket store (shared across all apps in multi-app setups)
274
+ # This is set when the engine initializes and is used for ticket authentication
275
+ _global_websocket_ticket_store: Any | None = None
276
+
277
+
278
+ def set_global_websocket_ticket_store(ticket_store: Any) -> None:
279
+ """
280
+ Set the global WebSocket ticket store.
281
+
282
+ This is called by the engine during initialization to make the ticket store
283
+ accessible to all WebSocket endpoints, even when routes are registered via routers.
284
+
285
+ Args:
286
+ ticket_store: WebSocketTicketStore instance
287
+ """
288
+ global _global_websocket_ticket_store
289
+ _global_websocket_ticket_store = ticket_store
290
+ if ticket_store:
291
+ logger.info(
292
+ f"✅ Global WebSocket ticket store set for multi-app WebSocket authentication "
293
+ f"(store type: {type(ticket_store).__name__})"
294
+ )
295
+ else:
296
+ logger.warning("⚠️ Global WebSocket ticket store set to None")
297
+
273
298
 
274
299
  async def get_websocket_manager(app_slug: str) -> WebSocketConnectionManager:
275
300
  """
@@ -359,18 +384,190 @@ def _get_cookies_from_websocket(websocket: Any) -> dict[str, str]:
359
384
  return cookies
360
385
 
361
386
 
387
+ async def _validate_websocket_origin_in_handler(websocket: Any, app_slug: str) -> bool:
388
+ """
389
+ Validate WebSocket Origin header in handler (since middleware may not intercept upgrades).
390
+
391
+ This provides CSWSH (Cross-Site WebSocket Hijacking) protection by validating
392
+ the Origin header before accepting the WebSocket connection.
393
+
394
+ Args:
395
+ websocket: FastAPI WebSocket instance
396
+ app_slug: App slug for context
397
+
398
+ Returns:
399
+ True if origin is valid, False otherwise
400
+ """
401
+ try:
402
+ # Get origin from headers
403
+ origin = None
404
+ if hasattr(websocket, "headers"):
405
+ origin = websocket.headers.get("origin")
406
+ elif hasattr(websocket, "scope") and "headers" in websocket.scope:
407
+ # Extract from ASGI scope headers
408
+ headers_dict = dict(websocket.scope["headers"])
409
+ origin_bytes = headers_dict.get(b"origin")
410
+ if not origin_bytes:
411
+ # Try case-insensitive lookup
412
+ for key, value in headers_dict.items():
413
+ if isinstance(key, bytes) and key.lower() == b"origin":
414
+ origin_bytes = value
415
+ break
416
+ if origin_bytes:
417
+ origin = origin_bytes.decode("utf-8")
418
+
419
+ logger.info(
420
+ f"🔍 WebSocket origin validation in handler for '{app_slug}': " f"origin={origin}"
421
+ )
422
+
423
+ # Get allowed origins from app state (CORS config) FIRST
424
+ # This allows us to check if wildcard is allowed before rejecting missing origin
425
+ allowed_origins = []
426
+ try:
427
+ app = getattr(websocket, "app", None)
428
+ if app:
429
+ # Traverse up to find parent app with CORS config
430
+ while app:
431
+ cors_config = getattr(app.state, "cors_config", None)
432
+ if cors_config and cors_config.get("allow_origins"):
433
+ origins = cors_config["allow_origins"]
434
+ allowed_origins = origins if isinstance(origins, list) else [origins]
435
+ logger.debug(
436
+ f"Found CORS config on app '{getattr(app, 'title', 'unknown')}' "
437
+ f"for '{app_slug}': {allowed_origins}"
438
+ )
439
+ break
440
+ parent_app = getattr(app, "app", None)
441
+ if parent_app is app:
442
+ break
443
+ app = parent_app
444
+ except (AttributeError, TypeError, KeyError) as e:
445
+ logger.debug(f"Could not read CORS config: {e}")
446
+
447
+ # Check if CORS allows all origins - if so, allow missing origin
448
+ if allowed_origins and "*" in allowed_origins:
449
+ if not origin:
450
+ logger.info(
451
+ f"WebSocket upgrade missing Origin header for '{app_slug}', "
452
+ "but CORS allows all origins (*) - allowing connection"
453
+ )
454
+ return True
455
+
456
+ if not origin:
457
+ # In development, allow connections without Origin
458
+ import os
459
+
460
+ env = os.getenv("ENVIRONMENT", "").lower()
461
+ g_nome_env = os.getenv("G_NOME_ENV", "").lower()
462
+ is_dev = env in ["development", "dev"] or g_nome_env in ["development", "dev"]
463
+ if is_dev:
464
+ logger.warning(
465
+ f"WebSocket upgrade missing Origin header in development mode: "
466
+ f"allowing connection for '{app_slug}'"
467
+ )
468
+ return True
469
+ else:
470
+ logger.warning(f"WebSocket upgrade missing Origin header for '{app_slug}'")
471
+ return False
472
+
473
+ # Normalize origin for comparison
474
+ def normalize_origin(orig: str) -> str:
475
+ """Normalize origin handling localhost variants."""
476
+ if not orig:
477
+ return orig
478
+ import re
479
+
480
+ normalized = re.sub(
481
+ r"://(0\.0\.0\.0|127\.0\.0\.1|localhost|::1)",
482
+ "://localhost",
483
+ orig.lower(),
484
+ flags=re.IGNORECASE,
485
+ )
486
+ normalized = re.sub(r":80$", "", normalized)
487
+ normalized = re.sub(r":443$", "", normalized)
488
+ return normalized.rstrip("/")
489
+
490
+ normalized_origin = normalize_origin(origin)
491
+ logger.debug(f"Normalized origin: {origin} -> {normalized_origin}")
492
+
493
+ # If no CORS config, generate fallback origins from request
494
+ if not allowed_origins:
495
+ # Try to get host from websocket scope
496
+ try:
497
+ if hasattr(websocket, "scope"):
498
+ host = websocket.scope.get("server")
499
+ if host:
500
+ hostname = host[0] if isinstance(host, tuple) else host
501
+ scheme = websocket.scope.get("scheme", "http")
502
+ port = host[1] if isinstance(host, tuple) and len(host) > 1 else None
503
+
504
+ # Generate localhost variants if server binds to 0.0.0.0
505
+ if hostname in ["localhost", "0.0.0.0", "127.0.0.1", "::1"]:
506
+ for variant in ["localhost", "127.0.0.1"]:
507
+ if port and port not in [80, 443]:
508
+ allowed_origins.append(f"{scheme}://{variant}:{port}")
509
+ else:
510
+ allowed_origins.append(f"{scheme}://{variant}")
511
+ else:
512
+ if port and port not in [80, 443]:
513
+ allowed_origins.append(f"{scheme}://{hostname}:{port}")
514
+ else:
515
+ allowed_origins.append(f"{scheme}://{hostname}")
516
+ except (AttributeError, TypeError, KeyError) as e:
517
+ logger.debug(f"Could not determine fallback origins: {e}")
518
+
519
+ logger.debug(f"Allowed origins for '{app_slug}': {allowed_origins}")
520
+
521
+ # Check if origin matches any allowed origin
522
+ for allowed in allowed_origins:
523
+ if allowed == "*":
524
+ logger.warning(
525
+ "WebSocket Origin validation using wildcard '*' - "
526
+ "not recommended for production"
527
+ )
528
+ return True
529
+
530
+ normalized_allowed = normalize_origin(allowed)
531
+ if normalized_origin == normalized_allowed:
532
+ logger.info(
533
+ f"✅ WebSocket origin validated in handler for '{app_slug}': "
534
+ f"{origin} matches {allowed}"
535
+ )
536
+ return True
537
+
538
+ logger.warning(
539
+ f"❌ WebSocket origin validation failed for '{app_slug}': "
540
+ f"origin={origin} (normalized: {normalized_origin}), "
541
+ f"allowed={allowed_origins} "
542
+ f"(normalized: {[normalize_origin(a) for a in allowed_origins]})"
543
+ )
544
+ return False
545
+
546
+ except (
547
+ AttributeError,
548
+ KeyError,
549
+ UnicodeDecodeError,
550
+ TypeError,
551
+ ValueError,
552
+ RuntimeError,
553
+ ) as e:
554
+ logger.error(f"❌ Error validating WebSocket origin for '{app_slug}': {e}", exc_info=True)
555
+ # Fail secure - reject if we can't validate
556
+ return False
557
+
558
+
362
559
  async def authenticate_websocket(
363
560
  websocket: Any,
364
561
  app_slug: str,
365
562
  require_auth: bool = True,
366
563
  ) -> tuple[str | None, str | None]:
367
564
  """
368
- Authenticate a WebSocket connection via httpOnly cookies.
565
+ Authenticate a WebSocket connection using multiple methods with fallback.
369
566
 
370
- Uses cookie-based authentication with CSRF protection:
371
- - Token stored in httpOnly cookie (not accessible to JavaScript)
372
- - CSRF token validated via double-submit cookie pattern
373
- - Origin validation provides additional protection
567
+ Authentication methods (in order of preference):
568
+ 1. Ticket (query param or header) - short-lived, single-use, secure for multi-app SSO
569
+ 2. Session key (query param or header) - encrypted session-based authentication
570
+ 3. Cookie (httpOnly JWT token) - backward compatibility fallback
374
571
 
375
572
  Args:
376
573
  websocket: FastAPI WebSocket instance (can access headers before accept)
@@ -394,50 +591,201 @@ async def authenticate_websocket(
394
591
  return None, None
395
592
 
396
593
  try:
397
- # Extract token from httpOnly cookie
398
- # Use same cookie name as SharedAuthMiddleware for consistency
399
- from ..auth.shared_middleware import AUTH_COOKIE_NAME
400
-
401
- cookies = _get_cookies_from_websocket(websocket)
402
- token = cookies.get(AUTH_COOKIE_NAME) # Use mdb_auth_token (same as shared middleware)
594
+ # Helper function to get app and traverse hierarchy
595
+ def get_app_and_state():
596
+ """Get app instance and traverse hierarchy to find state."""
597
+ app = getattr(websocket, "app", None)
598
+ apps_checked = []
599
+ max_iterations = 10 # Safety limit to prevent infinite loops
600
+ iteration = 0
601
+ while app and iteration < max_iterations:
602
+ iteration += 1
603
+ app_title = getattr(app, "title", "unknown")
604
+ apps_checked.append(app_title)
605
+ yield app, app_title
606
+ # Try to get parent app (FastAPI routers have .app attribute pointing to parent)
607
+ parent_app = getattr(app, "app", None)
608
+ if parent_app is app or parent_app is None: # Prevent infinite loop
609
+ break
610
+ app = parent_app
611
+
612
+ # Try to get services from app state
613
+ websocket_ticket_store = _global_websocket_ticket_store
614
+ websocket_session_manager = None
615
+ user_pool = None
616
+
617
+ apps_checked_list = []
618
+ for app, app_title in get_app_and_state():
619
+ apps_checked_list.append(app_title)
620
+ if not websocket_ticket_store:
621
+ websocket_ticket_store = getattr(app.state, "websocket_ticket_store", None)
622
+ if websocket_ticket_store:
623
+ logger.debug(
624
+ f"Found websocket_ticket_store on app '{app_title}' "
625
+ f"for app_slug '{app_slug}'"
626
+ )
627
+ if not websocket_session_manager:
628
+ websocket_session_manager = getattr(app.state, "websocket_session_manager", None)
629
+ if websocket_session_manager:
630
+ logger.info(
631
+ f"✅ Found websocket_session_manager on app '{app_title}' "
632
+ f"for app_slug '{app_slug}' (checked: {apps_checked_list})"
633
+ )
634
+ if not user_pool:
635
+ user_pool = getattr(app.state, "user_pool", None)
636
+ if user_pool:
637
+ logger.debug(f"Found user_pool on app '{app_title}' for app_slug '{app_slug}'")
403
638
 
404
- if not token:
405
- logger.error(
406
- f" No token cookie found for WebSocket connection to app '{app_slug}' "
407
- f"(require_auth={require_auth}). "
408
- f"Available cookies: {list(cookies.keys()) if cookies else 'none'}. "
409
- f"Ensure httpOnly cookie is set during authentication."
639
+ if not websocket_session_manager:
640
+ logger.warning(
641
+ f"⚠️ websocket_session_manager not found for '{app_slug}' "
642
+ f"(checked apps: {apps_checked_list})"
410
643
  )
411
- if require_auth:
412
- return None, None # Signal auth failure
413
- return None, None
414
644
 
415
- logger.info(
416
- f"WebSocket token found in cookie for app '{app_slug}' " "(cookie-based authentication)"
645
+ # Method 1: Ticket authentication (preferred)
646
+ ticket = None
647
+ try:
648
+ if hasattr(websocket, "query_params"):
649
+ ticket = websocket.query_params.get("ticket")
650
+ if not ticket and hasattr(websocket, "headers"):
651
+ ticket = websocket.headers.get("X-WebSocket-Ticket")
652
+ except (AttributeError, TypeError, KeyError):
653
+ pass
654
+
655
+ if ticket and websocket_ticket_store:
656
+ try:
657
+ ticket_data = await websocket_ticket_store.validate_and_consume_ticket(ticket)
658
+ if ticket_data:
659
+ user_id = ticket_data.get("user_id")
660
+ user_email = ticket_data.get("user_email")
661
+ logger.info(
662
+ f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
663
+ f"(method: ticket)"
664
+ )
665
+ return user_id, user_email
666
+ except (ValueError, TypeError, AttributeError, KeyError, RuntimeError) as e:
667
+ logger.debug(f"Ticket validation failed: {e}")
668
+
669
+ # Method 2: Session key authentication (fallback)
670
+ session_key = None
671
+ try:
672
+ if hasattr(websocket, "query_params"):
673
+ # Try dict-like access first
674
+ if hasattr(websocket.query_params, "get"):
675
+ session_key = websocket.query_params.get("session_key")
676
+ # Fallback: try dict access
677
+ elif isinstance(websocket.query_params, dict):
678
+ session_key = websocket.query_params.get("session_key")
679
+ # Fallback: try attribute access
680
+ elif hasattr(websocket.query_params, "session_key"):
681
+ session_key = websocket.query_params.session_key
682
+ if not session_key and hasattr(websocket, "headers"):
683
+ session_key = websocket.headers.get("X-WebSocket-Session-Key")
684
+ except (AttributeError, TypeError, KeyError) as e:
685
+ logger.debug(f"Error extracting session key from websocket: {e}")
686
+ pass
687
+
688
+ logger.debug(
689
+ f"Session key extraction for '{app_slug}': "
690
+ f"session_key={'present' if session_key else 'missing'}, "
691
+ f"websocket_session_manager={'present' if websocket_session_manager else 'missing'}"
417
692
  )
418
693
 
419
- # Decode and validate token
420
- import jwt
694
+ if session_key and websocket_session_manager:
695
+ try:
696
+ logger.debug(
697
+ f"Validating session key for '{app_slug}': "
698
+ f"session_key={session_key[:16]}... (truncated), "
699
+ f"manager_type={type(websocket_session_manager).__name__}"
700
+ )
701
+ session_data = await websocket_session_manager.validate_session(session_key)
702
+ if session_data:
703
+ user_id = session_data.get("user_id")
704
+ user_email = session_data.get("user_email")
705
+ logger.info(
706
+ f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
707
+ f"(method: session_key)"
708
+ )
709
+ return user_id, user_email
710
+ else:
711
+ logger.warning(
712
+ f"Session key validation returned None for '{app_slug}': "
713
+ f"session_key={session_key[:16]}... (truncated)"
714
+ )
715
+ except RuntimeError as e:
716
+ # Handle event loop conflicts (e.g., TestClient with anyio)
717
+ error_msg = str(e).lower()
718
+ if "attached to a different loop" in error_msg or "different loop" in error_msg:
719
+ logger.warning(
720
+ f"Event loop conflict during session key validation for '{app_slug}': {e}. "
721
+ "This may occur in test environments with TestClient. "
722
+ "Session key validation failed."
723
+ )
724
+ # Return None to indicate authentication failure
725
+ # Don't re-raise - let the normal auth failure path handle it
726
+ else:
727
+ # Re-raise other RuntimeErrors
728
+ logger.warning(
729
+ f"Session key validation failed for '{app_slug}': {e}",
730
+ exc_info=True,
731
+ )
732
+ raise
733
+ except (ValueError, TypeError, AttributeError, KeyError) as e:
734
+ logger.warning(
735
+ f"Session key validation failed for '{app_slug}': {e}",
736
+ exc_info=True,
737
+ )
738
+ elif session_key and not websocket_session_manager:
739
+ logger.error(
740
+ f"Session key provided for '{app_slug}' but websocket_session_manager not found. "
741
+ "Cannot validate session key."
742
+ )
421
743
 
422
- from ..auth.dependencies import SECRET_KEY
423
- from ..auth.jwt import decode_jwt_token
744
+ # Method 3: Cookie-based JWT authentication (backward compatibility)
745
+ from ..auth.shared_middleware import AUTH_COOKIE_NAME
424
746
 
425
- try:
426
- payload = decode_jwt_token(token, str(SECRET_KEY))
427
- user_id = payload.get("sub") or payload.get("user_id")
428
- user_email = payload.get("email")
747
+ cookies = _get_cookies_from_websocket(websocket)
748
+ auth_token = cookies.get(AUTH_COOKIE_NAME) if cookies else None
429
749
 
430
- logger.info(
431
- f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
432
- f"(method: cookie)"
433
- )
434
- return user_id, user_email
435
- except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
436
- logger.exception(
437
- f"❌ JWT decode error for app '{app_slug}'. "
438
- f"Token present: {bool(token)}, Token length: {len(token) if token else 0}"
750
+ if auth_token and user_pool:
751
+ try:
752
+ # For tests that expect JWT errors to be raised, we need to validate
753
+ # the token structure first. However, validate_token handles JWT errors
754
+ # internally, so we'll let it handle validation and only catch non-JWT errors.
755
+ user = await user_pool.validate_token(auth_token)
756
+ if user:
757
+ user_id = str(user.get("_id") or user.get("sub") or user.get("user_id"))
758
+ user_email = user.get("email")
759
+ logger.info(
760
+ f"WebSocket authenticated successfully for app '{app_slug}': {user_email} "
761
+ f"(method: cookie)"
762
+ )
763
+ return user_id, user_email
764
+ except (
765
+ ValueError,
766
+ TypeError,
767
+ AttributeError,
768
+ KeyError,
769
+ RuntimeError,
770
+ ) as e:
771
+ # Handle errors from token validation or user data access
772
+ logger.debug(f"Cookie token validation failed: {e}")
773
+ # Note: JWT errors (DecodeError, ExpiredSignatureError) are caught
774
+ # internally by validate_token and return None. For tests that need JWT
775
+ # errors raised, they should mock validate_token to raise them directly.
776
+
777
+ # No authentication method succeeded
778
+ if require_auth:
779
+ logger.error(
780
+ f"❌ WebSocket authentication failed for app '{app_slug}'. "
781
+ "No valid ticket, session key, or cookie found. "
782
+ "Generate ticket via /auth/ticket endpoint, "
783
+ "session key via /auth/websocket-session, "
784
+ "or ensure JWT cookie is present."
439
785
  )
440
- raise
786
+ return None, None
787
+
788
+ return None, None
441
789
 
442
790
  except WebSocketDisconnect:
443
791
  raise
@@ -662,13 +1010,46 @@ def create_websocket_endpoint(
662
1010
  # This print should appear in server logs when a WebSocket connection is attempted
663
1011
  import sys
664
1012
 
665
- print(
666
- f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}",
667
- file=sys.stderr,
668
- flush=True,
669
- )
670
- print(f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}", flush=True)
671
- logger.info(f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}")
1013
+ # Log BEFORE any operations to catch if handler is called at all
1014
+ try:
1015
+ print(
1016
+ f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}",
1017
+ file=sys.stderr,
1018
+ flush=True,
1019
+ )
1020
+ print(f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}", flush=True)
1021
+ logger.info(f"🔌 [WEBSOCKET HANDLER CALLED] App: '{app_slug}', Path: {path}")
1022
+
1023
+ # Try to access websocket properties to see if we can even access it
1024
+ try:
1025
+ ws_path = (
1026
+ getattr(websocket, "url", {}).path if hasattr(websocket, "url") else "unknown"
1027
+ )
1028
+ ws_headers = (
1029
+ dict(getattr(websocket, "headers", {})) if hasattr(websocket, "headers") else {}
1030
+ )
1031
+ origin = ws_headers.get("origin", "missing")
1032
+ print(
1033
+ f"🔌 [WEBSOCKET DETAILS] Path: {ws_path}, Origin: {origin}, "
1034
+ f"Headers: {list(ws_headers.keys())}",
1035
+ file=sys.stderr,
1036
+ flush=True,
1037
+ )
1038
+ except (AttributeError, RuntimeError, TypeError) as access_error:
1039
+ print(
1040
+ f"⚠️ [WEBSOCKET ACCESS ERROR] "
1041
+ f"Could not access websocket properties: {access_error}",
1042
+ file=sys.stderr,
1043
+ flush=True,
1044
+ )
1045
+ except (RuntimeError, AttributeError) as log_error:
1046
+ # Even logging failed - this is very bad
1047
+ print(
1048
+ f"❌ [CRITICAL] Failed to log WebSocket handler call: {log_error}",
1049
+ file=sys.stderr,
1050
+ flush=True,
1051
+ )
1052
+
672
1053
  connection = None
673
1054
  try:
674
1055
  # Log connection attempt with query params (can access before accept)
@@ -693,6 +1074,42 @@ def create_websocket_endpoint(
693
1074
  f"(require_auth={require_auth}, query_params={query_str})"
694
1075
  )
695
1076
 
1077
+ # CRITICAL: Accept FIRST, then validate (FastAPI requires immediate accept)
1078
+ # If we don't accept immediately, FastAPI closes with code 1000
1079
+ # We'll validate origin after accept, but before processing messages
1080
+ try:
1081
+ await websocket.accept()
1082
+ logger.info(
1083
+ f"✅ WebSocket accepted for app '{app_slug}' (before origin validation)"
1084
+ )
1085
+ print(f"✅ [WEBSOCKET ACCEPTED] App: '{app_slug}'", flush=True)
1086
+ except (RuntimeError, ConnectionError, OSError, ValueError) as accept_error:
1087
+ logger.error(
1088
+ f"❌ Failed to accept WebSocket for app '{app_slug}': {accept_error}",
1089
+ exc_info=True,
1090
+ )
1091
+ print(
1092
+ f"❌ [WEBSOCKET ACCEPT FAILED] App: '{app_slug}', Error: {accept_error}",
1093
+ flush=True,
1094
+ )
1095
+ return
1096
+
1097
+ # CRITICAL: Validate origin AFTER accepting (CSWSH protection)
1098
+ # We accept first to satisfy FastAPI's requirements, then validate
1099
+ origin_valid = await _validate_websocket_origin_in_handler(websocket, app_slug)
1100
+ if not origin_valid:
1101
+ logger.error(
1102
+ f"❌ WebSocket origin validation FAILED for app '{app_slug}' - "
1103
+ f"closing connection after accept"
1104
+ )
1105
+ try:
1106
+ await websocket.close(code=1008, reason="Invalid origin")
1107
+ except (RuntimeError, ConnectionError, OSError):
1108
+ # WebSocket may already be closed or in invalid state
1109
+ pass
1110
+ # Raise WebSocketDisconnect so TestClient can detect the rejection
1111
+ raise WebSocketDisconnect(code=1008, reason="Invalid origin")
1112
+
696
1113
  # CRITICAL: Authenticate BEFORE accepting connection
697
1114
  # This prevents CSRF middleware from rejecting established connections
698
1115
  # We can access headers/query_params before accept() is called
@@ -717,14 +1134,17 @@ def create_websocket_endpoint(
717
1134
  f"rejecting connection. require_auth={require_auth}, "
718
1135
  f"user_id={user_id}, user_email={user_email}"
719
1136
  )
720
- # Reject without accepting - FastAPI will send 403 if accept() not called
721
- # We can't call websocket.close() before accept(), so we just return
722
- # The connection will be rejected by the server
723
- return
724
-
725
- # Accept connection
726
- await _accept_websocket_connection(websocket, app_slug)
727
-
1137
+ # Close connection since we already accepted it
1138
+ try:
1139
+ await websocket.close(code=1008, reason="Authentication failed")
1140
+ except (RuntimeError, ConnectionError, OSError):
1141
+ # WebSocket may already be closed or in invalid state
1142
+ pass
1143
+ # Raise WebSocketDisconnect so TestClient can detect the close
1144
+ # WebSocketDisconnect is already imported at module level
1145
+ raise WebSocketDisconnect(code=1008, reason="Authentication failed")
1146
+
1147
+ # Connection already accepted at line 984 - no need to accept again
728
1148
  # Connect with metadata (websocket already accepted)
729
1149
  connection = await manager.connect(websocket, user_id=user_id, user_email=user_email)
730
1150
 
@@ -788,8 +1208,11 @@ def create_websocket_endpoint(
788
1208
  logger.warning(f"WebSocket receive error for app '{app_slug}': {e}")
789
1209
  await asyncio.sleep(0.1)
790
1210
 
1211
+ except WebSocketDisconnect as e:
1212
+ # Re-raise WebSocketDisconnect so TestClient can detect it
1213
+ logger.warning(f"WebSocket connection rejected for app '{app_slug}': {e}")
1214
+ raise
791
1215
  except (
792
- WebSocketDisconnect,
793
1216
  RuntimeError,
794
1217
  OSError,
795
1218
  ValueError,