mdb-engine 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mdb_engine/auth/csrf.py CHANGED
@@ -205,6 +205,19 @@ class CSRFMiddleware(BaseHTTPMiddleware):
205
205
  return True
206
206
  return False
207
207
 
208
+ def _validate_csrf_token(self, token: str, request: Request) -> bool:
209
+ """
210
+ Validate a CSRF token using the middleware's secret and TTL.
211
+
212
+ Args:
213
+ token: CSRF token to validate
214
+ request: FastAPI request (unused, kept for API consistency)
215
+
216
+ Returns:
217
+ True if token is valid, False otherwise
218
+ """
219
+ return validate_csrf_token(token, self.secret, self.token_ttl)
220
+
208
221
  def _websocket_requires_csrf(self, request: Request, path: str) -> bool:
209
222
  """
210
223
  Check if WebSocket endpoint requires CSRF validation.
@@ -214,63 +227,133 @@ class CSRFMiddleware(BaseHTTPMiddleware):
214
227
 
215
228
  Args:
216
229
  request: FastAPI request
217
- path: WebSocket path (e.g., "/app-3/ws")
230
+ path: WebSocket path (e.g., "/chat-app/ws")
218
231
 
219
232
  Returns:
220
233
  True if CSRF validation is required, False otherwise
221
234
  """
222
- # Check parent app state for WebSocket configs
235
+ # Try parent app first (where websocket_configs should be stored)
223
236
  websocket_configs = getattr(request.app.state, "websocket_configs", None)
237
+
238
+ # If not found, try to traverse up to find parent app
239
+ if not websocket_configs:
240
+ logger.debug(f"No websocket_configs found on request.app.state for path '{path}'")
241
+ app = request.app
242
+ apps_checked = []
243
+ while app:
244
+ app_title = getattr(app, "title", "unknown")
245
+ apps_checked.append(app_title)
246
+ websocket_configs = getattr(app.state, "websocket_configs", None)
247
+ if websocket_configs:
248
+ logger.debug(
249
+ f"Found websocket_configs on app '{app_title}' "
250
+ f"(checked: {apps_checked})"
251
+ )
252
+ break
253
+ parent_app = getattr(app, "app", None)
254
+ if parent_app is app: # Prevent infinite loop
255
+ break
256
+ app = parent_app
257
+
224
258
  if not websocket_configs:
225
259
  # No WebSocket configs found - use default (CSRF required for security by default)
260
+ logger.debug(
261
+ f"No websocket_configs found anywhere for path '{path}' - "
262
+ f"using default csrf_required=true"
263
+ )
226
264
  return True
227
265
 
228
- # Normalize path for matching
266
+ # Normalize path for matching (handle trailing slashes)
229
267
  normalized_path = path.rstrip("/")
268
+ logger.debug(
269
+ f"Checking CSRF requirement for path '{normalized_path}' "
270
+ f"against {len(websocket_configs)} app config(s)"
271
+ )
230
272
 
231
273
  # Try to find matching app config
232
274
  # WebSocket paths are registered as /app-slug/endpoint-path
233
- # e.g., /app-3/ws where app_slug="app-3" and endpoint_path="/ws"
275
+ # e.g., /chat-app/ws where app_slug="chat-app" and endpoint_path="/ws"
234
276
  for app_slug, config in websocket_configs.items():
277
+ logger.debug(f"Checking app '{app_slug}' config with {len(config)} endpoint(s)")
235
278
  # Check each endpoint in this app's config
236
279
  for endpoint_name, endpoint_config in config.items():
237
280
  endpoint_path = endpoint_config.get("path", "")
238
281
  # Normalize endpoint path
239
282
  normalized_endpoint = endpoint_path.rstrip("/")
240
283
 
284
+ # Build expected full path: /app-slug/endpoint-path
285
+ if normalized_endpoint.startswith("/"):
286
+ expected_full_path = f"/{app_slug}{normalized_endpoint}"
287
+ else:
288
+ expected_full_path = f"/{app_slug}/{normalized_endpoint}"
289
+
241
290
  # Match patterns:
242
- # 1. Full path match: /app-slug/endpoint-path
243
- # 2. Endpoint-only match: /endpoint-path (if path starts with endpoint)
244
- expected_full_path = f"/{app_slug}{normalized_endpoint}"
245
- if (
291
+ # 1. Full path match: /chat-app/ws == /chat-app/ws
292
+ # 2. Endpoint-only match: /ws (if path ends with endpoint)
293
+ # 3. Path contains endpoint: /chat-app/ws contains /ws
294
+ matches = (
246
295
  normalized_path == expected_full_path
247
- or normalized_path.endswith(normalized_endpoint)
248
296
  or normalized_path == normalized_endpoint
249
- ):
297
+ or normalized_path.endswith(normalized_endpoint)
298
+ or normalized_path.endswith(f"/{app_slug}{normalized_endpoint}")
299
+ )
300
+
301
+ if matches:
250
302
  auth_config = endpoint_config.get("auth", {})
251
303
  if isinstance(auth_config, dict):
252
304
  # Return csrf_required setting (defaults to True - security by default)
253
305
  csrf_required = auth_config.get("csrf_required", True)
254
- logger.debug(
255
- f"WebSocket {path} csrf_required={csrf_required} "
256
- f"(from app={app_slug}, endpoint={endpoint_name})"
306
+ logger.info(
307
+ f"WebSocket '{normalized_path}' csrf_required={csrf_required} "
308
+ f"(from app='{app_slug}', endpoint='{endpoint_name}', "
309
+ f"endpoint_path='{normalized_endpoint}')"
257
310
  )
258
311
  return csrf_required
312
+ else:
313
+ logger.debug(
314
+ f"WebSocket '{normalized_path}' auth_config is not a dict: "
315
+ f"{type(auth_config)}"
316
+ )
259
317
 
260
318
  # No matching config found - use default (CSRF required for security by default)
261
- logger.debug(f"No WebSocket config match for {path}, using default csrf_required=true")
319
+ logger.debug(
320
+ f"❌ No WebSocket config match for '{normalized_path}' "
321
+ f"(checked {len(websocket_configs)} app(s)) - using default csrf_required=true"
322
+ )
262
323
  return True
263
324
 
264
325
  def _is_websocket_upgrade(self, request: Request) -> bool:
265
326
  """Check if request is a WebSocket upgrade request."""
266
327
  upgrade_header = request.headers.get("upgrade", "").lower()
267
328
  connection_header = request.headers.get("connection", "").lower()
329
+ path = request.url.path
330
+
331
+ # CRITICAL: Enhanced logging for WebSocket detection
332
+ import sys
333
+
334
+ print(
335
+ f"🔍 [_is_websocket_upgrade] Path: {path}, "
336
+ f"upgrade='{upgrade_header}', connection='{connection_header}'",
337
+ file=sys.stderr,
338
+ flush=True,
339
+ )
340
+ logger.info(
341
+ f"_is_websocket_upgrade check: upgrade='{upgrade_header}', "
342
+ f"connection='{connection_header}', path='{path}'"
343
+ )
268
344
 
269
345
  # Primary check: WebSocket upgrade requires both Upgrade: websocket
270
346
  # and Connection: Upgrade headers
271
347
  has_upgrade_header = upgrade_header == "websocket"
272
348
  has_connection_upgrade = "upgrade" in connection_header or "websocket" in connection_header
273
349
 
350
+ print(
351
+ f"🔍 [_is_websocket_upgrade] has_upgrade={has_upgrade_header}, "
352
+ f"has_connection={has_connection_upgrade}",
353
+ file=sys.stderr,
354
+ flush=True,
355
+ )
356
+
274
357
  # Secondary check: If upgrade header is present but connection is
275
358
  # overridden (e.g., by TestClient), check if path matches a known
276
359
  # WebSocket route pattern
@@ -308,12 +391,31 @@ class CSRFMiddleware(BaseHTTPMiddleware):
308
391
  is_websocket = has_upgrade_header and (
309
392
  has_connection_upgrade or path_matches_websocket_route
310
393
  )
394
+
395
+ # CRITICAL: Enhanced logging
396
+ import sys
397
+
311
398
  if is_websocket:
312
- logger.debug(
313
- f"WebSocket upgrade detected: path={request.url.path}, "
399
+ print(
400
+ f"✅ [_is_websocket_upgrade] WebSocket detected: path={path}, "
401
+ f"upgrade={upgrade_header}, connection={connection_header}, "
402
+ f"path_match={path_matches_websocket_route}, result={is_websocket}",
403
+ file=sys.stderr,
404
+ flush=True,
405
+ )
406
+ logger.info(
407
+ f"WebSocket upgrade detected: path={path}, "
314
408
  f"upgrade={upgrade_header}, connection={connection_header}, "
315
409
  f"path_match={path_matches_websocket_route}"
316
410
  )
411
+ else:
412
+ print(
413
+ f"❌ [_is_websocket_upgrade] NOT a WebSocket: path={path}, "
414
+ f"upgrade={upgrade_header}, connection={connection_header}, "
415
+ f"has_upgrade={has_upgrade_header}, has_connection={has_connection_upgrade}",
416
+ file=sys.stderr,
417
+ flush=True,
418
+ )
317
419
  return is_websocket
318
420
 
319
421
  def _get_allowed_origins(self, request: Request) -> list[str]:
@@ -353,33 +455,100 @@ class CSRFMiddleware(BaseHTTPMiddleware):
353
455
  except (AttributeError, TypeError, KeyError):
354
456
  pass
355
457
 
356
- # Final fallback: Use request host
458
+ # Final fallback: Use request host (normalize localhost variants)
357
459
  try:
358
460
  host = request.url.hostname
359
461
  scheme = request.url.scheme
360
462
  port = request.url.port
463
+
464
+ # Normalize localhost variants - return all common variants for development
465
+ # This handles cases where server binds to 0.0.0.0 but browser sends localhost
466
+ if host in ["localhost", "0.0.0.0", "127.0.0.1", "::1"]:
467
+ origins = []
468
+ for localhost_variant in ["localhost", "127.0.0.1"]:
469
+ if port and port not in [80, 443]:
470
+ origins.append(f"{scheme}://{localhost_variant}:{port}")
471
+ else:
472
+ origins.append(f"{scheme}://{localhost_variant}")
473
+ logger.debug(f"Generated localhost variant origins for host '{host}': {origins}")
474
+ return origins
475
+
476
+ # For other hosts, use the actual hostname
361
477
  if port and port not in [80, 443]:
362
478
  origin = f"{scheme}://{host}:{port}"
363
479
  else:
364
480
  origin = f"{scheme}://{host}"
365
481
  return [origin]
366
- except (AttributeError, TypeError):
482
+ except (AttributeError, TypeError) as e:
483
+ logger.debug(f"Could not determine origin from request: {e}")
367
484
  # Return empty list if we can't determine origin (will reject)
368
485
  return []
369
486
 
487
+ def _normalize_origin(self, origin: str) -> str:
488
+ """
489
+ Normalize origin for comparison (handles localhost/0.0.0.0/127.0.0.1/::1 equivalency).
490
+
491
+ In development, localhost, 0.0.0.0, 127.0.0.1, and ::1 should be treated as equivalent.
492
+ Also normalizes ports (80/443 vs explicit ports).
493
+ """
494
+ if not origin:
495
+ return origin
496
+
497
+ import re
498
+
499
+ # Normalize localhost variants - replace all variants with localhost
500
+ # Handle IPv4: 0.0.0.0, 127.0.0.1
501
+ # Handle IPv6: ::1
502
+ # Handle hostname: localhost
503
+ normalized = re.sub(
504
+ r"://(0\.0\.0\.0|127\.0\.0\.1|localhost|::1)",
505
+ "://localhost",
506
+ origin.lower(),
507
+ flags=re.IGNORECASE,
508
+ )
509
+
510
+ # Normalize ports: remove :80 for http and :443 for https
511
+ normalized = re.sub(r":80$", "", normalized)
512
+ normalized = re.sub(r":443$", "", normalized)
513
+
514
+ return normalized.rstrip("/")
515
+
516
+ def _is_development_mode(self) -> bool:
517
+ """Check if running in development mode."""
518
+ import os
519
+
520
+ env = os.getenv("ENVIRONMENT", "").lower()
521
+ g_nome_env = os.getenv("G_NOME_ENV", "").lower()
522
+ return env in ["development", "dev"] or g_nome_env in ["development", "dev"]
523
+
370
524
  def _validate_websocket_origin(self, request: Request) -> bool:
371
525
  """
372
526
  Validate Origin header for WebSocket upgrade requests.
373
527
 
374
528
  Primary defense against Cross-Site WebSocket Hijacking (CSWSH).
375
529
  Returns True if Origin is valid, False otherwise.
530
+
531
+ In development mode, allows connections without Origin header (with warning).
376
532
  """
377
533
  origin = request.headers.get("origin")
378
534
  if not origin:
379
- logger.warning(f"WebSocket upgrade missing Origin header: {request.url.path}")
380
- return False
535
+ if self._is_development_mode():
536
+ logger.warning(
537
+ f"WebSocket upgrade missing Origin header in development mode: "
538
+ f"{request.url.path} - allowing connection"
539
+ )
540
+ return True
541
+ else:
542
+ logger.warning(f"WebSocket upgrade missing Origin header: {request.url.path}")
543
+ return False
381
544
 
382
545
  allowed_origins = self._get_allowed_origins(request)
546
+ normalized_origin = self._normalize_origin(origin)
547
+
548
+ logger.debug(
549
+ f"Validating WebSocket origin: {origin} (normalized: {normalized_origin}) "
550
+ f"against allowed: {allowed_origins}"
551
+ )
383
552
 
384
553
  for allowed in allowed_origins:
385
554
  if allowed == "*":
@@ -388,14 +557,23 @@ class CSRFMiddleware(BaseHTTPMiddleware):
388
557
  "not recommended for production"
389
558
  )
390
559
  return True
391
- if origin == allowed or origin.rstrip("/") == allowed.rstrip("/"):
560
+
561
+ normalized_allowed = self._normalize_origin(allowed)
562
+ if normalized_origin == normalized_allowed:
563
+ logger.debug(
564
+ f"✅ WebSocket origin validated: {origin} matches {allowed} "
565
+ f"(normalized: {normalized_origin} == {normalized_allowed})"
566
+ )
392
567
  return True
393
568
 
394
569
  cors_config = getattr(request.app.state, "cors_config", None)
395
570
  cors_enabled = cors_config.get("enabled", False) if cors_config else False
571
+ normalized_allowed_list = [self._normalize_origin(a) for a in allowed_origins]
396
572
  logger.warning(
397
- f"WebSocket upgrade rejected - invalid Origin: {origin} "
398
- f"(allowed: {allowed_origins}, app: {getattr(request.app, 'title', 'unknown')}, "
573
+ f"WebSocket upgrade rejected - invalid Origin: {origin} "
574
+ f"(normalized: {normalized_origin}, allowed: {allowed_origins}, "
575
+ f"normalized_allowed: {normalized_allowed_list}, "
576
+ f"app: {getattr(request.app, 'title', 'unknown')}, "
399
577
  f"path: {request.url.path}, CORS enabled: {cors_enabled}, "
400
578
  f"has_cors_config: {hasattr(request.app.state, 'cors_config')})"
401
579
  )
@@ -409,198 +587,369 @@ class CSRFMiddleware(BaseHTTPMiddleware):
409
587
  """
410
588
  Process request through CSRF middleware.
411
589
  """
590
+ # CRITICAL: Log EVERY request immediately to catch WebSocket upgrades
412
591
  path = request.url.path
413
592
  method = request.method
414
-
415
- # Debug: Log all requests to see what's happening
416
593
  upgrade_header = request.headers.get("upgrade", "").lower()
417
594
  connection_header = request.headers.get("connection", "").lower()
418
- if upgrade_header or "websocket" in path.lower():
419
- logger.info(
420
- f"🔍 CSRF middleware: {method} {path}, "
421
- f"upgrade={upgrade_header}, connection={connection_header}"
422
- )
595
+ origin_header = request.headers.get("origin")
596
+
597
+ # Log ALL WebSocket-related requests IMMEDIATELY (before any processing)
598
+ if upgrade_header or "websocket" in path.lower() or connection_header == "upgrade":
599
+ import sys
423
600
 
424
- # CRITICAL: Handle WebSocket upgrade requests BEFORE other CSRF checks
425
- # WebSocket upgrades use cookie-based authentication and require CSRF validation
426
- if self._is_websocket_upgrade(request):
601
+ print(
602
+ f"🚨 [CSRF MIDDLEWARE ENTRY] {method} {path}, "
603
+ f"upgrade={upgrade_header}, connection={connection_header}, "
604
+ f"origin={origin_header}",
605
+ file=sys.stderr,
606
+ flush=True,
607
+ )
427
608
  logger.info(
428
- f"🔌 CSRF middleware processing WebSocket upgrade: {path}, "
429
- f"origin: {request.headers.get('origin')}"
609
+ f"🚨 [CSRF MIDDLEWARE ENTRY] {method} {path}, "
610
+ f"upgrade={upgrade_header}, connection={connection_header}, "
611
+ f"origin={origin_header}"
430
612
  )
431
- # Always validate origin for WebSocket connections (CSWSH protection)
432
- if not self._validate_websocket_origin(request):
433
- logger.warning(
434
- f"WebSocket origin validation failed for {path}: "
613
+
614
+ try:
615
+ # CRITICAL: Log ALL requests to verify middleware is running
616
+ # Always log WebSocket-related requests
617
+ if upgrade_header or "websocket" in path.lower() or connection_header == "upgrade":
618
+ logger.info(
619
+ f"🔍 CSRF middleware INTERCEPTED: {method} {path}, "
620
+ f"upgrade={upgrade_header}, connection={connection_header}, "
621
+ f"origin={origin_header}"
622
+ )
623
+ import sys
624
+
625
+ print(
626
+ f"🔍 [CSRF MIDDLEWARE] {method} {path}, "
627
+ f"upgrade={upgrade_header}, origin={origin_header}",
628
+ file=sys.stderr,
629
+ flush=True,
630
+ )
631
+
632
+ # CRITICAL: Handle WebSocket upgrade requests BEFORE other CSRF checks
633
+ # WebSocket upgrades use cookie-based authentication and require CSRF validation
634
+ is_ws_upgrade = self._is_websocket_upgrade(request)
635
+ logger.info(f"🔍 WebSocket upgrade detection for {path}: is_websocket={is_ws_upgrade}")
636
+
637
+ if is_ws_upgrade:
638
+ logger.info(
639
+ f"🔌 CSRF middleware processing WebSocket upgrade: {path}, "
640
+ f"origin: {request.headers.get('origin')}"
641
+ )
642
+ # Always validate origin for WebSocket connections (CSWSH protection)
643
+ origin_valid = self._validate_websocket_origin(request)
644
+ logger.info(
645
+ f"🔍 WebSocket origin validation for {path}: "
435
646
  f"origin={request.headers.get('origin')}, "
436
- f"allowed={self._get_allowed_origins(request)}"
647
+ f"allowed={self._get_allowed_origins(request)}, "
648
+ f"valid={origin_valid}"
437
649
  )
438
- return JSONResponse(
439
- status_code=status.HTTP_403_FORBIDDEN,
440
- content={"detail": "Invalid origin for WebSocket connection"},
650
+ if not origin_valid:
651
+ logger.warning(
652
+ f" WebSocket origin validation failed for {path}: "
653
+ f"origin={request.headers.get('origin')}, "
654
+ f"allowed={self._get_allowed_origins(request)}"
655
+ )
656
+ return JSONResponse(
657
+ status_code=status.HTTP_403_FORBIDDEN,
658
+ content={"detail": "Invalid origin for WebSocket connection"},
659
+ )
660
+
661
+ # Cookie-based authentication requires CSRF protection
662
+ # Check if authentication token cookie is present
663
+ # Use same cookie name as SharedAuthMiddleware for consistency
664
+ from .shared_middleware import AUTH_COOKIE_NAME
665
+
666
+ auth_token_cookie = request.cookies.get(AUTH_COOKIE_NAME)
667
+ logger.info(
668
+ f"🔍 WebSocket auth check for {path}: "
669
+ f"auth_cookie={'present' if auth_token_cookie else 'missing'}"
441
670
  )
442
671
 
443
- # Cookie-based authentication requires CSRF protection
444
- # Check if authentication token cookie is present
445
- # Use same cookie name as SharedAuthMiddleware for consistency
446
- from .shared_middleware import AUTH_COOKIE_NAME
447
-
448
- auth_token_cookie = request.cookies.get(AUTH_COOKIE_NAME)
449
- if auth_token_cookie:
450
- # SECURITY BY DEFAULT: WebSocket CSRF protection uses encrypted session keys
451
- # stored in private collection via envelope encryption.
452
- #
453
- # Security Model:
454
- # 1. Origin validation (already done above) - primary defense
455
- # 2. Encrypted session key validation - CSRF protection via database
456
- # 3. SameSite cookies - prevents cross-site cookie sending
457
- #
458
- # Session keys are:
459
- # - Generated on authentication
460
- # - Encrypted using envelope encryption (same as app secrets)
461
- # - Stored in _mdb_engine_websocket_sessions private collection
462
- # - Validated during WebSocket upgrade
463
-
464
- # Check if this WebSocket endpoint requires CSRF validation
672
+ # Check if ticket/session key authentication is required
673
+ # (csrf_required flag controls whether WebSocket needs ticket/session key)
674
+ # If csrf_required=false, we skip ticket validation entirely
465
675
  csrf_required = self._websocket_requires_csrf(request, path)
676
+ logger.info(
677
+ f"🔍 WebSocket auth check for {path}: "
678
+ f"ticket/session_key_required={csrf_required}"
679
+ )
466
680
 
467
- if csrf_required:
468
- # Try to get WebSocket session manager from app state
469
- websocket_session_manager = getattr(
470
- request.app.state, "websocket_session_manager", None
681
+ # Only validate ticket/session key if:
682
+ # 1. Auth cookie is present (user is authenticated)
683
+ # 2. Ticket/session key is required for this endpoint
684
+ if auth_token_cookie and csrf_required:
685
+ # WebSocket Authentication (NOT CSRF):
686
+ # WebSockets use JWT → Ticket → WebSocket flow for authentication.
687
+ # CSRF protection comes from Origin validation + SameSite cookies.
688
+ #
689
+ # Authentication Methods (in order of preference):
690
+ # 1. Ticket (JWT → Ticket exchange) - preferred for single-app
691
+ # - Client: POST /auth/ticket (sends JWT cookie)
692
+ # - Server: Validates JWT, generates ticket (UUID)
693
+ # - Client: ws://host/app/ws?ticket=<uuid>
694
+ # - Server: Validates & consumes ticket (single-use)
695
+ # 2. Session key - preferred for multi-app SSO
696
+ # - Generated via /auth/websocket-session endpoint
697
+ # - Encrypted, database-backed, long TTL (24h)
698
+ # 3. CSRF cookie - backward compatibility only
699
+ #
700
+ # CSRF Protection (separate from authentication):
701
+ # - Origin validation (already done above) - primary CSRF defense
702
+ # - SameSite cookies - prevents cross-site cookie sending
703
+ #
704
+ # Ticket flow: JWT (httpOnly cookie) → POST /auth/ticket → Ticket (UUID)
705
+ # → WebSocket connection with ticket → Validated & consumed
706
+
707
+ # Check for session key first (preferred for multi-app setups)
708
+ session_key = request.query_params.get("session_key") or request.headers.get(
709
+ "X-WebSocket-Session-Key"
471
710
  )
472
711
 
473
- if websocket_session_manager:
474
- # Use encrypted session key validation (secure-by-default)
475
- session_key = request.query_params.get(
476
- "session_key"
477
- ) or request.headers.get("X-WebSocket-Session-Key")
712
+ # Check for ticket (preferred for single-app setups)
713
+ ticket = request.query_params.get("ticket") or request.headers.get(
714
+ "X-WebSocket-Ticket"
715
+ )
478
716
 
479
- if not session_key:
717
+ if session_key:
718
+ # Session key authentication (bypasses CSRF - encrypted and secure)
719
+ # For WebSocket upgrades, let the handler validate session keys
720
+ # This allows TestClient to catch WebSocketDisconnect exceptions properly
721
+ # The handler will validate and raise WebSocketDisconnect if invalid
722
+ websocket_session_manager = None
723
+ app = request.app
724
+ apps_checked = []
725
+ while app:
726
+ app_title = getattr(app, "title", "unknown")
727
+ apps_checked.append(app_title)
728
+ websocket_session_manager = getattr(
729
+ app.state, "websocket_session_manager", None
730
+ )
731
+ if websocket_session_manager:
732
+ logger.debug(
733
+ f"Found websocket_session_manager on app '{app_title}' "
734
+ f"for WebSocket path '{path}' (checked: {apps_checked})"
735
+ )
736
+ break
737
+ parent_app = getattr(app, "app", None)
738
+ if parent_app is app: # Prevent infinite loop
739
+ break
740
+ app = parent_app
741
+
742
+ if not websocket_session_manager:
480
743
  logger.error(
481
- f"❌ WebSocket upgrade missing session key for {path}. "
482
- f"Auth cookie present: {bool(auth_token_cookie)}. "
483
- f"Tip: Generate session key via /auth/websocket-session endpoint."
744
+ f"❌ WebSocket session key provided for {path} but "
745
+ "websocket_session_manager not found"
484
746
  )
485
747
  return JSONResponse(
486
- status_code=status.HTTP_403_FORBIDDEN,
748
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
749
+ content={
750
+ "detail": (
751
+ "WebSocket session manager not available. "
752
+ "Server configuration error."
753
+ )
754
+ },
755
+ )
756
+
757
+ # For WebSocket upgrades, let the handler validate the session key
758
+ # This ensures TestClient can catch WebSocketDisconnect exceptions
759
+ # The handler will validate and raise WebSocketDisconnect if invalid
760
+ logger.info(
761
+ f"✅ WebSocket session key provided for {path} - "
762
+ "CSRF validation bypassed (session key will be validated in handler)"
763
+ )
764
+ elif ticket:
765
+ # Ticket-based authentication (preferred)
766
+ # Get WebSocket ticket store
767
+ from ..routing.websockets import _global_websocket_ticket_store
768
+
769
+ websocket_ticket_store = _global_websocket_ticket_store
770
+
771
+ # Fallback: Try to get from app state (for backward compatibility)
772
+ if not websocket_ticket_store:
773
+ app = request.app
774
+ apps_checked = []
775
+ while app:
776
+ app_title = getattr(app, "title", "unknown")
777
+ apps_checked.append(app_title)
778
+
779
+ # Get ticket store
780
+ websocket_ticket_store = getattr(
781
+ app.state, "websocket_ticket_store", None
782
+ )
783
+ if websocket_ticket_store:
784
+ logger.debug(
785
+ f"Found websocket_ticket_store on app '{app_title}' "
786
+ f"for WebSocket path '{path}' (checked: {apps_checked})"
787
+ )
788
+ break
789
+
790
+ # Try to get parent app
791
+ parent_app = getattr(app, "app", None)
792
+ if parent_app is app: # Prevent infinite loop
793
+ break
794
+ app = parent_app
795
+
796
+ if not websocket_ticket_store:
797
+ logger.error(
798
+ f"❌ WebSocket ticket store not available for {path}. "
799
+ "Ticket authentication requires websocket_ticket_store "
800
+ "to be initialized."
801
+ )
802
+ return JSONResponse(
803
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
487
804
  content={
488
805
  "detail": (
489
- "WebSocket session key missing. "
490
- "Generate session key via /auth/websocket-session endpoint."
806
+ "WebSocket ticket store not available. "
807
+ "Server configuration error."
491
808
  )
492
809
  },
493
810
  )
494
811
 
495
- # Validate session key against encrypted storage
812
+ # Validate and consume ticket (atomic operation - single-use)
496
813
  try:
497
- session_data = await websocket_session_manager.validate_session(
498
- session_key
814
+ logger.info(
815
+ f"🔍 Validating WebSocket ticket for {path}: "
816
+ f"ticket={ticket[:16]}... (truncated)"
817
+ )
818
+ ticket_data = await websocket_ticket_store.validate_and_consume_ticket(
819
+ ticket
499
820
  )
500
- if not session_data:
821
+ if not ticket_data:
501
822
  logger.error(
502
- f"❌ WebSocket session key validation failed for {path}. "
503
- f"Session key: {session_key[:16]}..."
823
+ f"❌ WebSocket ticket validation failed for {path}. "
824
+ f"Ticket: {ticket[:16]}... "
825
+ "Ticket may be expired, invalid, or already used."
504
826
  )
505
827
  return JSONResponse(
506
828
  status_code=status.HTTP_403_FORBIDDEN,
507
829
  content={
508
830
  "detail": (
509
- "WebSocket session key expired or invalid. "
510
- "Generate a new session key."
831
+ "WebSocket ticket expired or invalid. "
832
+ "Generate a new ticket via /auth/ticket endpoint."
511
833
  )
512
834
  },
513
835
  )
514
836
 
515
- # Store session data in request state for WebSocket handler
516
- request.state.websocket_session = session_data
517
- logger.debug(
518
- f"✅ WebSocket session key validated for {path} "
519
- f"(user: {session_data.get('user_id')})"
837
+ # Store ticket data in request state for WebSocket handler
838
+ request.state.websocket_session = ticket_data
839
+ logger.info(
840
+ f"✅ WebSocket ticket validated for {path} "
841
+ f"(user_id: {ticket_data.get('user_id')}, "
842
+ f"user_email: {ticket_data.get('user_email')})"
520
843
  )
521
844
  except (
522
845
  ValueError,
523
846
  TypeError,
524
847
  AttributeError,
525
848
  RuntimeError,
526
- ):
527
- logger.exception("Error validating WebSocket session key")
528
- return JSONResponse(
529
- status_code=status.HTTP_403_FORBIDDEN,
530
- content={"detail": "WebSocket session validation error"},
531
- )
532
- else:
533
- # Fallback to cookie-based CSRF (backward compatibility)
534
- csrf_cookie_token = request.cookies.get(self.cookie_name)
535
- if not csrf_cookie_token:
849
+ ) as e:
536
850
  logger.error(
537
- f"❌ WebSocket upgrade missing CSRF cookie for {path}. "
538
- f"Auth cookie present: {bool(auth_token_cookie)}, "
539
- f"CSRF cookie name: {self.cookie_name}, "
540
- f"Available cookies: {list(request.cookies.keys())}. "
541
- f"Tip: Make a GET request first to receive CSRF cookie."
851
+ f"❌ Error validating WebSocket ticket for {path}: {e}",
852
+ exc_info=True,
542
853
  )
543
854
  return JSONResponse(
544
855
  status_code=status.HTTP_403_FORBIDDEN,
545
856
  content={
546
- "detail": (
547
- "CSRF token missing for WebSocket authentication. "
548
- "Make a GET request first to receive the CSRF cookie."
549
- )
857
+ "detail": "WebSocket ticket validation error. "
858
+ "Generate a new ticket."
550
859
  },
551
860
  )
861
+ else:
862
+ # Fallback to CSRF cookie validation (backward compatibility)
863
+ # For WebSocket, CSRF header is optional (JS can't set headers on upgrade)
864
+ # but if provided, it must match the cookie
865
+ cookie_token = request.cookies.get(self.cookie_name)
866
+ header_token = request.headers.get(self.header_name)
552
867
 
553
- # Validate CSRF token signature if secret is used
554
- if self.secret and not validate_csrf_token(
555
- csrf_cookie_token, self.secret, self.token_ttl
556
- ):
557
- logger.error(f" WebSocket CSRF token validation failed for {path}.")
868
+ if not cookie_token:
869
+ logger.error(
870
+ f"❌ WebSocket upgrade missing CSRF cookie for {path}. "
871
+ "CSRF protection is required. "
872
+ "Generate ticket via /auth/ticket endpoint or include CSRF cookie."
873
+ )
558
874
  return JSONResponse(
559
875
  status_code=status.HTTP_403_FORBIDDEN,
560
876
  content={
561
877
  "detail": (
562
- "CSRF token expired or invalid for WebSocket connection"
878
+ "CSRF token missing. "
879
+ "Generate ticket via /auth/ticket endpoint "
880
+ "or include CSRF cookie/token."
563
881
  )
564
882
  },
565
883
  )
566
884
 
567
- # If CSRF header is provided, validate it matches the cookie
568
- # (Header is optional for WebSocket, but if present, must match cookie)
569
- csrf_header_token = request.headers.get(self.header_name)
570
- if csrf_header_token:
571
- if not hmac.compare_digest(csrf_cookie_token, csrf_header_token):
885
+ # If header is provided, validate it matches the cookie
886
+ if header_token:
887
+ if not hmac.compare_digest(cookie_token, header_token):
572
888
  logger.error(
573
- f"❌ WebSocket CSRF header mismatch for {path}. "
574
- f"Cookie token and header token do not match."
889
+ f"❌ WebSocket CSRF token mismatch for {path}. "
890
+ "Header token does not match cookie token."
575
891
  )
576
892
  return JSONResponse(
577
893
  status_code=status.HTTP_403_FORBIDDEN,
578
- content={
579
- "detail": (
580
- "CSRF token mismatch: header token does not "
581
- "match cookie token"
582
- )
583
- },
894
+ content={"detail": "CSRF token invalid."},
584
895
  )
585
- logger.debug(
586
- f"✅ CSRF header validated and matches cookie for WebSocket {path}"
896
+
897
+ # Validate CSRF token (check signature if secret is used)
898
+ if not self._validate_csrf_token(cookie_token, request):
899
+ logger.error(f"❌ WebSocket CSRF token validation failed for {path}")
900
+ return JSONResponse(
901
+ status_code=status.HTTP_403_FORBIDDEN,
902
+ content={"detail": "CSRF token validation failed."},
587
903
  )
588
904
 
589
- logger.debug(f"✅ CSRF cookie validation passed for WebSocket {path}")
590
- else:
591
- logger.debug(
905
+ logger.info(
906
+ f"✅ WebSocket CSRF cookie validated for {path} "
907
+ "(backward compatibility mode)"
908
+ )
909
+ elif auth_token_cookie and not csrf_required:
910
+ logger.info(
592
911
  f"✅ WebSocket CSRF validation skipped for {path} "
593
- f"(csrf_required=false, Origin validation sufficient)"
912
+ f"(csrf_required=false) - only origin validation performed"
913
+ )
914
+ elif not auth_token_cookie:
915
+ logger.info(
916
+ f"✅ WebSocket connection allowed for {path} "
917
+ f"(no auth cookie - WebSocket handler will authenticate)"
594
918
  )
595
919
 
596
- logger.debug(
597
- f"WebSocket upgrade CSRF validation passed for {path} "
598
- f"(Origin validated, CSRF validated)"
920
+ validation_status = (
921
+ "CSRF/ticket validated"
922
+ if auth_token_cookie and csrf_required
923
+ else "CSRF skipped"
924
+ )
925
+ logger.info(
926
+ f"✅ WebSocket upgrade CSRF validation passed for {path} "
927
+ f"(Origin validated, {validation_status})"
599
928
  )
600
929
 
601
- # Origin validated (and CSRF validated if authenticated)
602
- # Allow WebSocket upgrade to proceed
603
- return await call_next(request)
930
+ # Origin validated (and CSRF/ticket validated if authenticated
931
+ # and csrf_required=true)
932
+ # Allow WebSocket upgrade to proceed to WebSocket handler
933
+ logger.debug(f"✅ WebSocket upgrade request allowed to proceed: {path}")
934
+ return await call_next(request)
935
+
936
+ except (
937
+ AttributeError,
938
+ KeyError,
939
+ RuntimeError,
940
+ ValueError,
941
+ TypeError,
942
+ ConnectionError,
943
+ ) as e:
944
+ # Catch exceptions in WebSocket handling to see what's failing
945
+ logger.error(
946
+ f"❌ CRITICAL: Exception in CSRF middleware WebSocket handling: {e}", exc_info=True
947
+ )
948
+ import sys
949
+
950
+ print(f"❌ [CSRF MIDDLEWARE EXCEPTION] {e}", file=sys.stderr, flush=True)
951
+ # Re-raise to see the full error
952
+ raise
604
953
 
605
954
  if self._is_exempt(path):
606
955
  return await call_next(request)