mdb-engine 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mdb_engine/core/engine.py CHANGED
@@ -148,6 +148,8 @@ class MongoDBEngine:
148
148
  self._service_initializer: ServiceInitializer | None = None
149
149
  self._encryption_service: EnvelopeEncryptionService | None = None
150
150
  self._app_secrets_manager: AppSecretsManager | None = None
151
+ self._websocket_session_manager: Any | None = None # WebSocketSessionManager
152
+ self._websocket_ticket_store: Any | None = None # WebSocketTicketStore
151
153
 
152
154
  # Store app read_scopes mapping for validation
153
155
  self._app_read_scopes: dict[str, list[str]] = {}
@@ -201,6 +203,20 @@ class MongoDBEngine:
201
203
  mongo_db=self._connection_manager.mongo_db,
202
204
  encryption_service=self._encryption_service,
203
205
  )
206
+ # Initialize WebSocket session manager for secure-by-default WebSocket auth
207
+ from ..auth.websocket_sessions import WebSocketSessionManager
208
+
209
+ self._websocket_session_manager = WebSocketSessionManager(
210
+ mongo_db=self._connection_manager.mongo_db,
211
+ encryption_service=self._encryption_service,
212
+ )
213
+
214
+ # Initialize WebSocket ticket store (in-memory, no dependencies needed)
215
+ # Tickets are preferred for multi-app SSO setups (short-lived, single-use)
216
+ from ..auth.websocket_tickets import WebSocketTicketStore
217
+
218
+ self._websocket_ticket_store = WebSocketTicketStore()
219
+ logger.info("WebSocket ticket store initialized")
204
220
 
205
221
  # Set up component managers
206
222
  self._app_registration_manager = AppRegistrationManager(
@@ -257,6 +273,16 @@ class MongoDBEngine:
257
273
  """Check if Ray is enabled and initialized."""
258
274
  return self.enable_ray and self.ray_actor is not None
259
275
 
276
+ @property
277
+ def connection_manager(self):
278
+ """
279
+ Get the connection manager.
280
+
281
+ Returns:
282
+ ConnectionManager instance
283
+ """
284
+ return self._connection_manager
285
+
260
286
  @property
261
287
  def mongo_client(self) -> AsyncIOMotorClient:
262
288
  """
@@ -701,6 +727,25 @@ class MongoDBEngine:
701
727
  app: FastAPI application instance
702
728
  slug: App slug
703
729
  """
730
+ # CRITICAL: Ensure websocket_ticket_store is available
731
+ # Ticket authentication is required for WebSocket connections
732
+ if not self._websocket_ticket_store:
733
+ error_msg = (
734
+ f"WebSocket routes cannot be registered for app '{slug}': "
735
+ "websocket_ticket_store is not available. "
736
+ "WebSocket authentication requires ticket store to be initialized."
737
+ )
738
+ contextual_logger.error(error_msg)
739
+ raise RuntimeError(error_msg)
740
+
741
+ # Ensure ticket store is in app state (may have been set in create_app)
742
+ if (
743
+ not hasattr(app.state, "websocket_ticket_store")
744
+ or app.state.websocket_ticket_store is None
745
+ ):
746
+ app.state.websocket_ticket_store = self._websocket_ticket_store
747
+ contextual_logger.debug(f"WebSocket ticket store stored in app state for '{slug}'")
748
+
704
749
  # Check if WebSockets are configured for this app
705
750
  websockets_config = self.get_websocket_config(slug)
706
751
  if not websockets_config:
@@ -907,9 +952,9 @@ class MongoDBEngine:
907
952
  return get_embedding_service_for_app(slug, self)
908
953
 
909
954
  @property
910
- def _apps(self) -> dict[str, Any]:
955
+ def apps(self) -> dict[str, Any]:
911
956
  """
912
- Get the apps dictionary (for backward compatibility with tests).
957
+ Get all registered apps.
913
958
 
914
959
  Returns:
915
960
  Dictionary of registered apps
@@ -919,7 +964,27 @@ class MongoDBEngine:
919
964
  """
920
965
  if not self._app_registration_manager:
921
966
  raise RuntimeError("MongoDBEngine not initialized. Call initialize() first.")
922
- return self._app_registration_manager._apps
967
+ return self._app_registration_manager.apps
968
+
969
+ @property
970
+ def websocket_ticket_store(self):
971
+ """
972
+ Get the WebSocket ticket store.
973
+
974
+ Returns:
975
+ WebSocketTicketStore instance or None if not initialized
976
+ """
977
+ return self._websocket_ticket_store
978
+
979
+ @property
980
+ def websocket_session_manager(self):
981
+ """
982
+ Get the WebSocket session manager.
983
+
984
+ Returns:
985
+ WebSocketSessionManager instance or None if not initialized
986
+ """
987
+ return self._websocket_session_manager
923
988
 
924
989
  def list_apps(self) -> list[str]:
925
990
  """
@@ -1092,6 +1157,73 @@ class MongoDBEngine:
1092
1157
  # FastAPI Integration Methods
1093
1158
  # =========================================================================
1094
1159
 
1160
+ async def _register_websocket_endpoints(self, app: "FastAPI", engine: "MongoDBEngine") -> None:
1161
+ """Register WebSocket ticket and session endpoints."""
1162
+ # Register WebSocket ticket endpoint AFTER initialization
1163
+ # (ticket store is now available)
1164
+ if engine.websocket_ticket_store:
1165
+ app.state.websocket_ticket_store = engine.websocket_ticket_store
1166
+ logger.info("WebSocket ticket store stored in app state")
1167
+
1168
+ # Set global ticket store for WebSocket authentication (works with routers)
1169
+ from ..routing.websockets import set_global_websocket_ticket_store
1170
+
1171
+ set_global_websocket_ticket_store(engine.websocket_ticket_store)
1172
+ logger.info("Global WebSocket ticket store set for multi-app authentication")
1173
+
1174
+ # Register WebSocket ticket endpoint
1175
+ from ..auth.websocket_tickets import create_websocket_ticket_endpoint
1176
+
1177
+ ticket_endpoint = create_websocket_ticket_endpoint(engine.websocket_ticket_store)
1178
+ app.post("/auth/ticket")(ticket_endpoint)
1179
+ logger.info("WebSocket ticket endpoint registered at /auth/ticket")
1180
+
1181
+ # Register WebSocket session endpoint AFTER initialization
1182
+ # (session manager is now available)
1183
+ if engine.websocket_session_manager:
1184
+ app.state.websocket_session_manager = engine.websocket_session_manager
1185
+ logger.info("WebSocket session manager stored in app state")
1186
+
1187
+ # Register WebSocket session endpoint
1188
+ from ..auth.websocket_sessions import create_websocket_session_endpoint
1189
+
1190
+ session_endpoint = create_websocket_session_endpoint(engine.websocket_session_manager)
1191
+ app.get("/auth/websocket-session")(session_endpoint)
1192
+ logger.info("WebSocket session endpoint registered at /auth/websocket-session")
1193
+
1194
+ async def _configure_websocket_ticket_ttl(
1195
+ self, app: "FastAPI", app_manifest: dict[str, Any], slug: str
1196
+ ) -> None:
1197
+ """Configure WebSocket ticket TTL from manifest."""
1198
+ websockets_config = app_manifest.get("websockets", {})
1199
+ if not websockets_config:
1200
+ return
1201
+
1202
+ from ..auth.websocket_tickets import WebSocketTicketStore
1203
+
1204
+ ticket_ttl_values: list[int] = []
1205
+ for endpoint_config in websockets_config.values():
1206
+ if isinstance(endpoint_config, dict):
1207
+ ticket_ttl = endpoint_config.get("ticket_ttl_seconds")
1208
+ if ticket_ttl is not None:
1209
+ ticket_ttl_values.append(ticket_ttl)
1210
+
1211
+ if ticket_ttl_values:
1212
+ configured_ticket_ttl = min(ticket_ttl_values) # Use minimum for maximum security
1213
+ # Reinitialize ticket store if needed
1214
+ ticket_store = self._websocket_ticket_store
1215
+ if ticket_store is None or ticket_store.ticket_ttl != configured_ticket_ttl:
1216
+ self._websocket_ticket_store = WebSocketTicketStore(
1217
+ ticket_ttl_seconds=configured_ticket_ttl
1218
+ )
1219
+ logger.info(
1220
+ f"WebSocket ticket store initialized with TTL: "
1221
+ f"{configured_ticket_ttl}s (from app '{slug}' manifest)"
1222
+ )
1223
+ # Update app state if ticket store was already set
1224
+ if hasattr(app.state, "websocket_ticket_store"):
1225
+ app.state.websocket_ticket_store = self._websocket_ticket_store
1226
+
1095
1227
  def create_app(
1096
1228
  self,
1097
1229
  slug: str,
@@ -1191,8 +1323,15 @@ class MongoDBEngine:
1191
1323
  if not is_sub_app:
1192
1324
  await engine.initialize()
1193
1325
 
1326
+ # Register WebSocket endpoints
1327
+ await self._register_websocket_endpoints(app, engine)
1328
+
1194
1329
  # Load and register manifest
1195
1330
  app_manifest = await engine.load_manifest(manifest_path)
1331
+
1332
+ # Configure WebSocket ticket TTL from manifest
1333
+ await self._configure_websocket_ticket_ttl(app, app_manifest, slug)
1334
+
1196
1335
  await engine.register_app(app_manifest)
1197
1336
 
1198
1337
  # Auto-detect multi-site mode from manifest
@@ -1227,11 +1366,11 @@ class MongoDBEngine:
1227
1366
  f"Sub-app '{slug}' uses shared auth but user_pool not found. "
1228
1367
  "Initializing now (parent should have initialized it)."
1229
1368
  )
1230
- await engine._initialize_shared_user_pool(app, app_manifest)
1369
+ await self._initialize_shared_user_pool(app, app_manifest)
1231
1370
  else:
1232
1371
  logger.debug(f"Sub-app '{slug}' using shared user_pool from parent app")
1233
1372
  else:
1234
- await engine._initialize_shared_user_pool(app, app_manifest)
1373
+ await self._initialize_shared_user_pool(app, app_manifest)
1235
1374
  else:
1236
1375
  logger.info(f"Per-app auth mode for '{slug}'")
1237
1376
  # Auto-retrieve app token for "app" mode
@@ -1497,6 +1636,10 @@ class MongoDBEngine:
1497
1636
  f"(require_role={auth_config.get('require_role')})"
1498
1637
  )
1499
1638
 
1639
+ # NOTE: WebSocket ticket endpoint registration is moved to lifespan context manager
1640
+ # (after engine.initialize()) because ticket store is only available after initialization.
1641
+ # This ensures consistency with create_multi_app() behavior.
1642
+
1500
1643
  # Add CSRF middleware (after auth - auto-enabled for shared mode)
1501
1644
  # CSRF protection is enabled by default for shared auth mode
1502
1645
  # SKIP for sub-apps in multi-app setups - parent app handles CSRF
@@ -1504,8 +1647,17 @@ class MongoDBEngine:
1504
1647
  if csrf_config and not is_sub_app: # Don't add CSRF to child apps
1505
1648
  from ..auth.csrf import create_csrf_middleware
1506
1649
 
1650
+ # Add ticket endpoint to public routes (it handles its own auth)
1651
+ public_routes = auth_config.get("public_routes", [])
1652
+ public_routes_with_ticket = list(public_routes) + ["/auth/ticket"]
1653
+
1654
+ csrf_config_with_routes = {
1655
+ **auth_config,
1656
+ "public_routes": public_routes_with_ticket,
1657
+ }
1658
+
1507
1659
  csrf_middleware = create_csrf_middleware(
1508
- manifest_auth=auth_config,
1660
+ manifest_auth=csrf_config_with_routes,
1509
1661
  )
1510
1662
  app.add_middleware(csrf_middleware)
1511
1663
  logger.info(f"CSRFMiddleware added for '{slug}'")
@@ -2134,6 +2286,7 @@ class MongoDBEngine:
2134
2286
  )
2135
2287
 
2136
2288
  # Check if any app uses shared auth and collect public routes for CSRF exemption
2289
+ # Also collect ticket TTL values from websocket configs
2137
2290
  has_shared_auth = False
2138
2291
  all_public_routes = [
2139
2292
  "/health",
@@ -2141,6 +2294,7 @@ class MongoDBEngine:
2141
2294
  "/openapi.json",
2142
2295
  "/_mdb/routes",
2143
2296
  ] # Base exempt routes
2297
+ ticket_ttl_values: list[int] = [] # Collect ticket TTLs from all apps
2144
2298
  for app_config in apps:
2145
2299
  try:
2146
2300
  manifest_path = app_config["manifest"]
@@ -2160,9 +2314,45 @@ class MongoDBEngine:
2160
2314
  prefixed_route = f"{path_prefix.rstrip('/')}/{route}"
2161
2315
  if prefixed_route not in all_public_routes:
2162
2316
  all_public_routes.append(prefixed_route)
2317
+
2318
+ # Collect ticket TTL from websocket configs
2319
+ websockets_config = app_manifest_pre.get("websockets", {})
2320
+ if websockets_config:
2321
+ for endpoint_config in websockets_config.values():
2322
+ if isinstance(endpoint_config, dict):
2323
+ ticket_ttl = endpoint_config.get("ticket_ttl_seconds")
2324
+ if ticket_ttl is not None:
2325
+ ticket_ttl_values.append(ticket_ttl)
2163
2326
  except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
2164
2327
  logger.warning(f"Could not check auth mode for app '{app_config.get('slug')}': {e}")
2165
2328
 
2329
+ # Determine ticket TTL: use minimum from app configs (most secure), or default
2330
+ from ..auth.websocket_tickets import DEFAULT_TICKET_TTL_SECONDS
2331
+
2332
+ if ticket_ttl_values:
2333
+ configured_ticket_ttl = min(ticket_ttl_values) # Use minimum for maximum security
2334
+ logger.info(
2335
+ f"Ticket TTL configured from app manifests: {configured_ticket_ttl}s "
2336
+ f"(found values: {ticket_ttl_values}, using minimum)"
2337
+ )
2338
+ else:
2339
+ configured_ticket_ttl = DEFAULT_TICKET_TTL_SECONDS
2340
+ logger.debug(
2341
+ f"No ticket TTL specified in app manifests, using default: {configured_ticket_ttl}s"
2342
+ )
2343
+
2344
+ # Reinitialize ticket store with configured TTL if different from current
2345
+ if (
2346
+ self._websocket_ticket_store is None
2347
+ or self._websocket_ticket_store.ticket_ttl != configured_ticket_ttl
2348
+ ):
2349
+ from ..auth.websocket_tickets import WebSocketTicketStore
2350
+
2351
+ self._websocket_ticket_store = WebSocketTicketStore(
2352
+ ticket_ttl_seconds=configured_ticket_ttl
2353
+ )
2354
+ logger.info(f"WebSocket ticket store initialized with TTL: {configured_ticket_ttl}s")
2355
+
2166
2356
  # Validate hooks before creating lifespan (fail fast)
2167
2357
  for app_config in apps:
2168
2358
  slug = app_config.get("slug", "unknown")
@@ -2283,6 +2473,50 @@ class MongoDBEngine:
2283
2473
  logger.debug(f"No WebSocket configuration found for app '{slug}'")
2284
2474
  return
2285
2475
 
2476
+ # CRITICAL: Check if session manager is required and available
2477
+ # Some endpoints require session keys (csrf_required=True), which need
2478
+ # session manager
2479
+ requires_session_manager = False
2480
+ for _endpoint_name, endpoint_config in websockets_config.items():
2481
+ auth_config = endpoint_config.get("auth", {})
2482
+ if isinstance(auth_config, dict):
2483
+ csrf_required = auth_config.get("csrf_required", True) # Default to True
2484
+ if csrf_required:
2485
+ requires_session_manager = True
2486
+ break
2487
+
2488
+ if requires_session_manager and not engine.websocket_session_manager:
2489
+ error_msg = (
2490
+ f"WebSocket routes cannot be registered for app '{slug}': "
2491
+ "websocket_session_manager is not available. "
2492
+ "WebSocket endpoints with csrf_required=True require "
2493
+ "session manager to be initialized. "
2494
+ "Set MDB_ENGINE_MASTER_KEY environment variable to enable "
2495
+ "session manager."
2496
+ )
2497
+ logger.error(error_msg)
2498
+ raise RuntimeError(error_msg)
2499
+
2500
+ # CRITICAL: Ensure websocket_ticket_store is available
2501
+ # Ticket authentication is required for WebSocket connections
2502
+ if not engine.websocket_ticket_store:
2503
+ error_msg = (
2504
+ f"WebSocket routes cannot be registered for app '{slug}': "
2505
+ "websocket_ticket_store is not available. "
2506
+ "WebSocket authentication requires ticket store to be initialized."
2507
+ )
2508
+ logger.error(error_msg)
2509
+ raise RuntimeError(error_msg)
2510
+
2511
+ # Store WebSocket config in parent app state for CSRF middleware to access
2512
+ if not hasattr(parent_app.state, "websocket_configs"):
2513
+ parent_app.state.websocket_configs = {}
2514
+ parent_app.state.websocket_configs[slug] = websockets_config
2515
+ logger.info(
2516
+ f"✅ Stored WebSocket config for '{slug}' in parent app state "
2517
+ f"({len(websockets_config)} endpoint(s))"
2518
+ )
2519
+
2286
2520
  try:
2287
2521
  from fastapi import APIRouter
2288
2522
 
@@ -2322,16 +2556,51 @@ class MongoDBEngine:
2322
2556
  ping_interval=ping_interval,
2323
2557
  )
2324
2558
 
2325
- # Register on parent app with full path
2326
- ws_router = APIRouter()
2327
- ws_router.websocket(full_ws_path)(handler)
2328
- parent_app.include_router(ws_router)
2559
+ # Register on parent app with full path using FastAPI's
2560
+ # proper WebSocket registration
2561
+ # We register BEFORE mounting apps to ensure WebSocket
2562
+ # routes are checked first
2563
+ try:
2564
+ # Use FastAPI's APIRouter approach (same as single-app mode)
2565
+ # This maintains FastAPI features
2566
+ # (dependency injection, OpenAPI docs, etc.)
2567
+ ws_router = APIRouter()
2568
+ ws_router.websocket(full_ws_path)(handler)
2569
+
2570
+ # Include router BEFORE mounting child app to ensure route priority
2571
+ parent_app.include_router(ws_router)
2572
+
2573
+ logger.info(
2574
+ f"✅ Registered WebSocket route '{full_ws_path}' "
2575
+ f"using FastAPI APIRouter "
2576
+ f"(registered before app mount to ensure priority)"
2577
+ )
2578
+ except (
2579
+ ValueError,
2580
+ RuntimeError,
2581
+ AttributeError,
2582
+ TypeError,
2583
+ ) as fastapi_error:
2584
+ logger.error(
2585
+ f"❌ Failed to register WebSocket route "
2586
+ f"'{full_ws_path}' with FastAPI: {fastapi_error}",
2587
+ exc_info=True,
2588
+ )
2589
+ raise
2329
2590
 
2330
2591
  logger.info(
2331
2592
  f"✅ Registered WebSocket route '{full_ws_path}' "
2332
2593
  f"for mounted app '{slug}' (mounted at '{path_prefix}', "
2333
2594
  f"auth: {require_auth}, ping: {ping_interval}s)"
2334
2595
  )
2596
+ import sys
2597
+
2598
+ print(
2599
+ f"✅ [ROUTE REGISTRATION] WebSocket route '{full_ws_path}' "
2600
+ f"registered for '{slug}' using FastAPI APIRouter",
2601
+ file=sys.stderr,
2602
+ flush=True,
2603
+ )
2335
2604
 
2336
2605
  # Verify route was actually registered
2337
2606
  registered_routes = [
@@ -2339,6 +2608,26 @@ class MongoDBEngine:
2339
2608
  for r in parent_app.routes
2340
2609
  if hasattr(r, "path") and full_ws_path in str(getattr(r, "path", ""))
2341
2610
  ]
2611
+
2612
+ # CRITICAL: Log all WebSocket routes to verify registration
2613
+ # FastAPI APIRouter creates routes of type 'APIWebSocketRoute'
2614
+ all_ws_routes = [
2615
+ (r.path, type(r).__name__)
2616
+ for r in parent_app.routes
2617
+ if hasattr(r, "path")
2618
+ and ("ws" in str(r.path).lower() or hasattr(r, "endpoint"))
2619
+ ]
2620
+ import sys
2621
+
2622
+ print(
2623
+ f"📋 [ROUTE VERIFICATION] All WebSocket-like routes: {all_ws_routes}",
2624
+ file=sys.stderr,
2625
+ flush=True,
2626
+ )
2627
+ logger.info(
2628
+ f"📋 [ROUTE VERIFICATION] All WebSocket-like routes: {all_ws_routes}"
2629
+ )
2630
+
2342
2631
  if registered_routes:
2343
2632
  registered_count += 1
2344
2633
  logger.debug(
@@ -2390,6 +2679,40 @@ class MongoDBEngine:
2390
2679
  # Initialize engine
2391
2680
  await engine.initialize()
2392
2681
 
2682
+ # Register WebSocket ticket endpoint AFTER initialization
2683
+ # (ticket store is now available)
2684
+ if engine.websocket_ticket_store:
2685
+ app.state.websocket_ticket_store = engine.websocket_ticket_store
2686
+ logger.info("WebSocket ticket store stored in parent app state")
2687
+
2688
+ # Set global ticket store for WebSocket authentication (works with routers)
2689
+ from ..routing.websockets import set_global_websocket_ticket_store
2690
+
2691
+ set_global_websocket_ticket_store(engine.websocket_ticket_store)
2692
+ logger.info("Global WebSocket ticket store set for multi-app authentication")
2693
+
2694
+ # Register WebSocket ticket endpoint on parent app
2695
+ from ..auth.websocket_tickets import create_websocket_ticket_endpoint
2696
+
2697
+ ticket_endpoint = create_websocket_ticket_endpoint(engine.websocket_ticket_store)
2698
+ app.post("/auth/ticket")(ticket_endpoint)
2699
+ logger.info("WebSocket ticket endpoint registered at /auth/ticket")
2700
+
2701
+ # Register WebSocket session endpoint AFTER initialization
2702
+ # (session manager is now available)
2703
+ if engine.websocket_session_manager:
2704
+ app.state.websocket_session_manager = engine.websocket_session_manager
2705
+ logger.info("WebSocket session manager stored in parent app state")
2706
+
2707
+ # Register WebSocket session endpoint on parent app
2708
+ from ..auth.websocket_sessions import create_websocket_session_endpoint
2709
+
2710
+ session_endpoint = create_websocket_session_endpoint(
2711
+ engine.websocket_session_manager
2712
+ )
2713
+ app.get("/auth/websocket-session")(session_endpoint)
2714
+ logger.info("WebSocket session endpoint registered at /auth/websocket-session")
2715
+
2393
2716
  # Initialize shared user pool once if any app uses shared auth
2394
2717
  if has_shared_auth:
2395
2718
  logger.info("Initializing shared user pool for multi-app deployment")
@@ -2401,7 +2724,7 @@ class MongoDBEngine:
2401
2724
  app_manifest_pre = json.load(f)
2402
2725
  auth_config = app_manifest_pre.get("auth", {})
2403
2726
  if auth_config.get("mode") == "shared":
2404
- await engine._initialize_shared_user_pool(app, app_manifest_pre)
2727
+ await self._initialize_shared_user_pool(app, app_manifest_pre)
2405
2728
  shared_user_pool_initialized = True
2406
2729
  logger.info("Shared user pool initialized for multi-app deployment")
2407
2730
  break
@@ -2489,6 +2812,18 @@ class MongoDBEngine:
2489
2812
  child_app.state.audit_log = app.state.audit_log
2490
2813
  logger.debug(f"Shared user_pool with child app '{slug}'")
2491
2814
 
2815
+ # Share WebSocket session manager with child app
2816
+ if hasattr(app.state, "websocket_session_manager"):
2817
+ child_app.state.websocket_session_manager = (
2818
+ app.state.websocket_session_manager
2819
+ )
2820
+ logger.debug(f"Shared WebSocket session manager with child app '{slug}'")
2821
+
2822
+ # Share WebSocket ticket store with child app
2823
+ if hasattr(app.state, "websocket_ticket_store"):
2824
+ child_app.state.websocket_ticket_store = app.state.websocket_ticket_store
2825
+ logger.debug(f"Shared WebSocket ticket store with child app '{slug}'")
2826
+
2492
2827
  # Add middleware for app context helpers
2493
2828
  from starlette.middleware.base import BaseHTTPMiddleware
2494
2829
  from starlette.requests import Request
@@ -2576,15 +2911,17 @@ class MongoDBEngine:
2576
2911
  child_app.add_middleware(middleware_class)
2577
2912
  logger.debug(f"Added AppContextMiddleware to child app '{slug}'")
2578
2913
 
2579
- # Mount child app at path prefix
2914
+ # CRITICAL FIX: Register WebSocket routes on parent app BEFORE mounting
2915
+ # This ensures WebSocket routes are checked before mounted app routes
2916
+ # Mounted apps create catch-all routes that intercept /app-slug/* paths
2917
+ await _register_websocket_routes(app, app_manifest_data, slug, path_prefix)
2918
+
2919
+ # Mount child app at path prefix (AFTER WebSocket routes are registered)
2580
2920
  app.mount(path_prefix, child_app)
2581
2921
 
2582
2922
  # CRITICAL FIX: Merge CORS config from child app to parent app
2583
2923
  await _merge_cors_config_to_parent(app, child_app, app_manifest_data, slug)
2584
2924
 
2585
- # CRITICAL FIX: Register WebSocket routes on parent app with full path
2586
- await _register_websocket_routes(app, app_manifest_data, slug, path_prefix)
2587
-
2588
2925
  # Update existing entry instead of appending
2589
2926
  entry = _find_mounted_app_entry(slug)
2590
2927
  if entry:
@@ -2713,6 +3050,11 @@ class MongoDBEngine:
2713
3050
  "manifest_path": str(manifest_path),
2714
3051
  }
2715
3052
  )
3053
+ # Always re-raise RuntimeError for critical failures
3054
+ # (like missing session manager)
3055
+ # These are configuration errors that should fail fast
3056
+ if isinstance(e, RuntimeError) and "websocket_session_manager" in str(e):
3057
+ raise RuntimeError(error_msg) from e
2716
3058
  if strict:
2717
3059
  raise RuntimeError(error_msg) from e
2718
3060
  continue
@@ -2806,11 +3148,78 @@ class MongoDBEngine:
2806
3148
  logger.debug("Set default CORS config on parent app for WebSocket origin validation")
2807
3149
 
2808
3150
  # Store app reference in engine for get_mounted_apps()
2809
- engine._multi_app_instance = parent_app
3151
+ self._multi_app_instance = parent_app
2810
3152
 
2811
- # Add request scope middleware
3153
+ # Add diagnostic ASGI middleware FIRST (outermost - runs before everything)
3154
+ # This will catch WebSocket upgrades before any other middleware
2812
3155
  from starlette.middleware.base import BaseHTTPMiddleware
2813
3156
 
3157
+ class DiagnosticMiddleware(BaseHTTPMiddleware):
3158
+ """Diagnostic middleware to log ALL requests, especially WebSocket upgrades."""
3159
+
3160
+ async def dispatch(self, request, call_next):
3161
+ path = request.url.path
3162
+ method = request.method
3163
+ upgrade_header = request.headers.get("upgrade", "").lower()
3164
+ connection_header = request.headers.get("connection", "").lower()
3165
+ origin_header = request.headers.get("origin")
3166
+
3167
+ # Log WebSocket upgrade attempts IMMEDIATELY
3168
+ if upgrade_header == "websocket" or "websocket" in path.lower():
3169
+ import sys
3170
+
3171
+ print(
3172
+ f"🔬 [DIAGNOSTIC MIDDLEWARE] WebSocket upgrade detected: "
3173
+ f"{method} {path}, upgrade={upgrade_header}, "
3174
+ f"connection={connection_header}, origin={origin_header}",
3175
+ file=sys.stderr,
3176
+ flush=True,
3177
+ )
3178
+ logger.info(
3179
+ f"🔬 [DIAGNOSTIC MIDDLEWARE] WebSocket upgrade: {method} {path}, "
3180
+ f"origin={origin_header}"
3181
+ )
3182
+
3183
+ try:
3184
+ response = await call_next(request)
3185
+
3186
+ # Log response for WebSocket upgrades
3187
+ if upgrade_header == "websocket" or "websocket" in path.lower():
3188
+ import sys
3189
+
3190
+ print(
3191
+ f"🔬 [DIAGNOSTIC MIDDLEWARE] WebSocket response: "
3192
+ f"{method} {path} -> {response.status_code}",
3193
+ file=sys.stderr,
3194
+ flush=True,
3195
+ )
3196
+ logger.info(
3197
+ f"🔬 [DIAGNOSTIC MIDDLEWARE] WebSocket response: "
3198
+ f"{method} {path} -> {response.status_code}"
3199
+ )
3200
+
3201
+ return response
3202
+ except (RuntimeError, ConnectionError, ValueError, AttributeError) as e:
3203
+ if upgrade_header == "websocket" or "websocket" in path.lower():
3204
+ import sys
3205
+
3206
+ print(
3207
+ f"🔬 [DIAGNOSTIC MIDDLEWARE] WebSocket exception: "
3208
+ f"{method} {path} -> {type(e).__name__}: {e}",
3209
+ file=sys.stderr,
3210
+ flush=True,
3211
+ )
3212
+ logger.error(
3213
+ f"🔬 [DIAGNOSTIC MIDDLEWARE] WebSocket exception: "
3214
+ f"{method} {path} -> {type(e).__name__}: {e}",
3215
+ exc_info=True,
3216
+ )
3217
+ raise
3218
+
3219
+ parent_app.add_middleware(DiagnosticMiddleware)
3220
+ logger.debug("DiagnosticMiddleware added for parent app (outermost layer)")
3221
+
3222
+ # Add request scope middleware
2814
3223
  from ..di import ScopeManager
2815
3224
 
2816
3225
  class RequestScopeMiddleware(BaseHTTPMiddleware):
@@ -2834,14 +3243,28 @@ class MongoDBEngine:
2834
3243
  from ..auth.csrf import create_csrf_middleware
2835
3244
 
2836
3245
  # Create CSRF middleware with default config (will use parent app's CORS config)
2837
- # Exempt routes that don't need CSRF (health checks, public routes from child apps)
2838
- # all_public_routes includes base routes + child app public routes with path prefixes
3246
+ # Exempt routes that don't need CSRF (health checks, public routes
3247
+ # from child apps)
3248
+ # all_public_routes includes base routes + child app public routes
3249
+ # with path prefixes
3250
+ # Add WebSocket session and ticket endpoints to public routes
3251
+ # (they handle their own auth)
3252
+ public_routes_with_websocket_endpoints = list(all_public_routes) + [
3253
+ "/auth/websocket-session",
3254
+ "/auth/ticket",
3255
+ ]
2839
3256
  parent_csrf_config = {
2840
3257
  "csrf_protection": True,
2841
- "public_routes": all_public_routes,
3258
+ "public_routes": public_routes_with_websocket_endpoints,
2842
3259
  }
2843
3260
  csrf_middleware = create_csrf_middleware(parent_csrf_config)
2844
3261
  parent_app.add_middleware(csrf_middleware)
3262
+
3263
+ # NOTE: WebSocket ticket and session endpoint registrations are moved to lifespan
3264
+ # context manager (after engine.initialize()) because they're only available after
3265
+ # initialization. The CSRF middleware still needs to know about these routes, so
3266
+ # they're added to public_routes_with_websocket_endpoints above.
3267
+
2845
3268
  logger.info("CSRFMiddleware added to parent app for WebSocket origin validation")
2846
3269
 
2847
3270
  # Add shared CORS middleware if configured
@@ -2865,6 +3288,21 @@ class MongoDBEngine:
2865
3288
  # Read CORS config from app.state (may have been merged from child apps)
2866
3289
  cors_config = getattr(request.app.state, "cors_config", {})
2867
3290
 
3291
+ # CRITICAL: Log WebSocket upgrade requests to see if CORS
3292
+ # is intercepting them
3293
+ upgrade_header = request.headers.get("upgrade", "").lower()
3294
+ if upgrade_header == "websocket" or "websocket" in request.url.path.lower():
3295
+ import sys
3296
+
3297
+ print(
3298
+ f"🌐 [CORS MIDDLEWARE] WebSocket upgrade: "
3299
+ f"{request.method} {request.url.path}, "
3300
+ f"origin={request.headers.get('origin')}, "
3301
+ f"cors_enabled={cors_config.get('enabled', False)}",
3302
+ file=sys.stderr,
3303
+ flush=True,
3304
+ )
3305
+
2868
3306
  if not cors_config.get("enabled", False):
2869
3307
  # CORS not enabled, pass through
2870
3308
  return await call_next(request)
@@ -2941,6 +3379,46 @@ class MongoDBEngine:
2941
3379
  except ImportError:
2942
3380
  logger.warning("CORS middleware not available")
2943
3381
 
3382
+ # Wrap parent app in ASGI wrapper to intercept WebSocket connections at ASGI level
3383
+ # This must be done AFTER all middleware and routes are registered
3384
+ from starlette.types import ASGIApp, Receive, Scope, Send
3385
+
3386
+ class WebSocketASGIWrapper:
3387
+ """ASGI wrapper to intercept WebSocket connections before FastAPI routing."""
3388
+
3389
+ def __init__(self, app: ASGIApp):
3390
+ self.app = app
3391
+ # Delegate attribute access to underlying app
3392
+ self.__dict__.update(app.__dict__)
3393
+
3394
+ def __getattr__(self, name):
3395
+ # Delegate any missing attributes to underlying app
3396
+ return getattr(self.app, name)
3397
+
3398
+ async def __call__(self, scope: Scope, receive: Receive, send: Send):
3399
+ # Intercept WebSocket connections at ASGI level
3400
+ if scope["type"] == "websocket":
3401
+ import sys
3402
+
3403
+ path = scope.get("path", "unknown")
3404
+ headers_dict = {k.decode(): v.decode() for k, v in scope.get("headers", [])}
3405
+ origin = headers_dict.get("origin", "missing")
3406
+ upgrade = headers_dict.get("upgrade", "missing")
3407
+ connection = headers_dict.get("connection", "missing")
3408
+ print(
3409
+ f"🌐 [ASGI WEBSOCKET] Intercepted at ASGI level: path={path}, "
3410
+ f"origin={origin}, upgrade={upgrade}, connection={connection}",
3411
+ file=sys.stderr,
3412
+ flush=True,
3413
+ )
3414
+ logger.info(f"🌐 [ASGI WEBSOCKET] Intercepted: path={path}, origin={origin}")
3415
+
3416
+ # Call the actual app
3417
+ await self.app(scope, receive, send)
3418
+
3419
+ # Wrap the app (but keep reference to original for internal use)
3420
+ WebSocketASGIWrapper(parent_app) # Wrapped for WebSocket support
3421
+
2944
3422
  # Add unified health check endpoint
2945
3423
  @parent_app.get("/health")
2946
3424
  async def health_check():
@@ -3227,7 +3705,163 @@ class MongoDBEngine:
3227
3705
 
3228
3706
  logger.info(f"Multi-app parent created with {len(apps)} app(s) configured")
3229
3707
 
3230
- return parent_app
3708
+ # CRITICAL: Wrap the FastAPI app in an ASGI wrapper to intercept WebSocket connections
3709
+ # BEFORE FastAPI's routing handles them. This will catch rejections at the framework level.
3710
+ from starlette.types import ASGIApp, Receive, Scope, Send
3711
+
3712
+ class WebSocketASGIInterceptor:
3713
+ """ASGI wrapper to intercept WebSocket connections at the ASGI level."""
3714
+
3715
+ def __init__(self, app: ASGIApp):
3716
+ self.app = app
3717
+ # Delegate attribute access to underlying FastAPI app
3718
+ # This allows app.routes, etc. to work
3719
+ # Note: We don't copy state here - we delegate it via property
3720
+ # to ensure changes to app.state are always visible
3721
+ for key, value in app.__dict__.items():
3722
+ if key != "state": # Don't copy state - delegate it
3723
+ setattr(self, key, value)
3724
+
3725
+ @property
3726
+ def state(self):
3727
+ # Always delegate state access to underlying app
3728
+ # This ensures changes to app.state are immediately visible
3729
+ return self.app.state
3730
+
3731
+ def __getattr__(self, name):
3732
+ # Delegate any missing attributes to underlying app
3733
+ return getattr(self.app, name)
3734
+
3735
+ @property
3736
+ def __class__(self):
3737
+ # Make isinstance() checks work by returning the underlying app's class
3738
+ # This allows isinstance(wrapper, FastAPI) to return True
3739
+ return type(self.app)
3740
+
3741
+ async def __call__(self, scope: Scope, receive: Receive, send: Send):
3742
+ # Intercept WebSocket connections at ASGI level (before FastAPI routing)
3743
+ if scope["type"] == "websocket":
3744
+ import sys
3745
+
3746
+ path = scope.get("path", "unknown")
3747
+ headers = {k.decode(): v.decode() for k, v in scope.get("headers", [])}
3748
+ origin = headers.get("origin", "missing")
3749
+ upgrade = headers.get("upgrade", "missing")
3750
+ connection = headers.get("connection", "missing")
3751
+
3752
+ print(
3753
+ f"🌐 [ASGI INTERCEPTOR] WebSocket connection at ASGI level: "
3754
+ f"path={path}, origin={origin}, upgrade={upgrade}, connection={connection}",
3755
+ file=sys.stderr,
3756
+ flush=True,
3757
+ )
3758
+ logger.info(f"🌐 [ASGI INTERCEPTOR] WebSocket: path={path}, origin={origin}")
3759
+
3760
+ # Wrap send to catch ALL messages to see what FastAPI is doing
3761
+ async def intercepted_send(message):
3762
+ import sys
3763
+
3764
+ msg_type = message.get("type", "unknown")
3765
+ print(
3766
+ f"🌐 [ASGI INTERCEPTOR] Message type: {msg_type}, "
3767
+ f"keys: {list(message.keys())}",
3768
+ file=sys.stderr,
3769
+ flush=True,
3770
+ )
3771
+
3772
+ if msg_type == "websocket.close":
3773
+ print(
3774
+ f"🌐 [ASGI INTERCEPTOR] WebSocket closed: "
3775
+ f"code={message.get('code', 'unknown')}, "
3776
+ f"reason={message.get('reason', 'unknown')}",
3777
+ file=sys.stderr,
3778
+ flush=True,
3779
+ )
3780
+ logger.warning(
3781
+ f"🌐 [ASGI INTERCEPTOR] WebSocket closed: "
3782
+ f"code={message.get('code')}, reason={message.get('reason')}"
3783
+ )
3784
+ elif msg_type == "websocket.accept":
3785
+ print(
3786
+ "🌐 [ASGI INTERCEPTOR] WebSocket ACCEPTED!",
3787
+ file=sys.stderr,
3788
+ flush=True,
3789
+ )
3790
+ logger.info("🌐 [ASGI INTERCEPTOR] WebSocket ACCEPTED!")
3791
+ elif msg_type == "websocket.http.response.start":
3792
+ status = message.get("status", "unknown")
3793
+ print(
3794
+ f"🌐 [ASGI INTERCEPTOR] HTTP response: status={status}",
3795
+ file=sys.stderr,
3796
+ flush=True,
3797
+ )
3798
+ logger.warning(f"🌐 [ASGI INTERCEPTOR] HTTP response: status={status}")
3799
+
3800
+ await send(message)
3801
+
3802
+ # Wrap receive to see what FastAPI is receiving
3803
+ async def intercepted_receive():
3804
+ msg = await receive()
3805
+ import sys
3806
+
3807
+ msg_type = msg.get("type", "unknown")
3808
+ print(
3809
+ f"🌐 [ASGI INTERCEPTOR] Received message: type={msg_type}, "
3810
+ f"keys: {list(msg.keys())}",
3811
+ file=sys.stderr,
3812
+ flush=True,
3813
+ )
3814
+ if msg_type == "websocket.connect":
3815
+ print(
3816
+ "🌐 [ASGI INTERCEPTOR] WebSocket CONNECT received!",
3817
+ file=sys.stderr,
3818
+ flush=True,
3819
+ )
3820
+ return msg
3821
+
3822
+ # Check if route exists before calling app
3823
+ import sys
3824
+
3825
+ if hasattr(self.app, "routes"):
3826
+ ws_routes = [
3827
+ r for r in self.app.routes if hasattr(r, "path") and "ws" in str(r.path)
3828
+ ]
3829
+ print(
3830
+ f"🌐 [ASGI INTERCEPTOR] Found {len(ws_routes)} WebSocket route(s): "
3831
+ f"{[r.path for r in ws_routes]}",
3832
+ file=sys.stderr,
3833
+ flush=True,
3834
+ )
3835
+
3836
+ try:
3837
+ await self.app(scope, intercepted_receive, intercepted_send)
3838
+ except (
3839
+ RuntimeError,
3840
+ ConnectionError,
3841
+ OSError,
3842
+ ValueError,
3843
+ AttributeError,
3844
+ ) as e:
3845
+ import sys
3846
+
3847
+ print(
3848
+ f"🌐 [ASGI INTERCEPTOR] Exception during WebSocket handling: "
3849
+ f"{type(e).__name__}: {e}",
3850
+ file=sys.stderr,
3851
+ flush=True,
3852
+ )
3853
+ logger.error(
3854
+ f"🌐 [ASGI INTERCEPTOR] Exception: {type(e).__name__}: {e}",
3855
+ exc_info=True,
3856
+ )
3857
+ raise
3858
+ else:
3859
+ # Non-WebSocket requests pass through normally
3860
+ await self.app(scope, receive, send)
3861
+
3862
+ # Re-enable ASGI interceptor to debug WebSocket connection issues
3863
+ # This will show us exactly what's happening at the ASGI level
3864
+ return WebSocketASGIInterceptor(parent_app)
3231
3865
 
3232
3866
  def get_mounted_apps(self, app: Optional["FastAPI"] = None) -> list[dict[str, Any]]:
3233
3867
  """
@@ -3320,6 +3954,7 @@ class MongoDBEngine:
3320
3954
  self._shared_user_pool = SharedUserPool(
3321
3955
  self._connection_manager.mongo_db,
3322
3956
  allow_insecure_dev=is_dev,
3957
+ websocket_session_manager=self._websocket_session_manager,
3323
3958
  )
3324
3959
  await self._shared_user_pool.ensure_indexes()
3325
3960
  logger.info("SharedUserPool initialized")