mdb-engine 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mdb_engine/auth/csrf.py CHANGED
@@ -102,14 +102,16 @@ def validate_csrf_token(
102
102
  try:
103
103
  parts = token.split(":")
104
104
  if len(parts) != 3:
105
+ logger.debug("CSRF token has wrong format (expected 3 parts)")
105
106
  return False
106
107
 
107
108
  raw_token, timestamp_str, signature = parts
108
109
  timestamp = int(timestamp_str)
109
110
 
110
111
  # Check age
111
- if time.time() - timestamp > max_age:
112
- logger.debug("CSRF token expired")
112
+ age = time.time() - timestamp
113
+ if age > max_age:
114
+ logger.debug(f"CSRF token expired (age: {age:.0f}s, max: {max_age}s)")
113
115
  return False
114
116
 
115
117
  # Verify signature
@@ -119,7 +121,10 @@ def validate_csrf_token(
119
121
  ]
120
122
 
121
123
  if not hmac.compare_digest(signature, expected_sig):
122
- logger.warning("CSRF token signature mismatch")
124
+ logger.warning(
125
+ f"CSRF token signature mismatch. "
126
+ f"Token format: signed, Has secret: {bool(secret)}"
127
+ )
123
128
  return False
124
129
 
125
130
  return True
@@ -128,7 +133,10 @@ def validate_csrf_token(
128
133
  return False
129
134
 
130
135
  # Simple token validation (just check it exists and has reasonable length)
131
- return len(token) >= CSRF_TOKEN_LENGTH
136
+ is_valid = len(token) >= CSRF_TOKEN_LENGTH
137
+ if not is_valid:
138
+ logger.debug(f"CSRF token too short (length: {len(token)}, required: {CSRF_TOKEN_LENGTH})")
139
+ return is_valid
132
140
 
133
141
 
134
142
  class CSRFMiddleware(BaseHTTPMiddleware):
@@ -197,10 +205,218 @@ class CSRFMiddleware(BaseHTTPMiddleware):
197
205
  return True
198
206
  return False
199
207
 
208
+ def _validate_csrf_token(self, token: str, request: Request) -> bool:
209
+ """
210
+ Validate a CSRF token using the middleware's secret and TTL.
211
+
212
+ Args:
213
+ token: CSRF token to validate
214
+ request: FastAPI request (unused, kept for API consistency)
215
+
216
+ Returns:
217
+ True if token is valid, False otherwise
218
+ """
219
+ return validate_csrf_token(token, self.secret, self.token_ttl)
220
+
221
+ def _websocket_requires_csrf(self, request: Request, path: str) -> bool:
222
+ """
223
+ Check if WebSocket endpoint requires CSRF validation.
224
+
225
+ Defaults to True (security by default). Can be disabled per-endpoint via manifest.json:
226
+ websockets.{endpoint}.auth.csrf_required = false
227
+
228
+ Args:
229
+ request: FastAPI request
230
+ path: WebSocket path (e.g., "/chat-app/ws")
231
+
232
+ Returns:
233
+ True if CSRF validation is required, False otherwise
234
+ """
235
+ # Try parent app first (where websocket_configs should be stored)
236
+ websocket_configs = getattr(request.app.state, "websocket_configs", None)
237
+
238
+ # If not found, try to traverse up to find parent app
239
+ if not websocket_configs:
240
+ logger.debug(f"No websocket_configs found on request.app.state for path '{path}'")
241
+ app = request.app
242
+ apps_checked = []
243
+ while app:
244
+ app_title = getattr(app, "title", "unknown")
245
+ apps_checked.append(app_title)
246
+ websocket_configs = getattr(app.state, "websocket_configs", None)
247
+ if websocket_configs:
248
+ logger.debug(
249
+ f"Found websocket_configs on app '{app_title}' "
250
+ f"(checked: {apps_checked})"
251
+ )
252
+ break
253
+ parent_app = getattr(app, "app", None)
254
+ if parent_app is app: # Prevent infinite loop
255
+ break
256
+ app = parent_app
257
+
258
+ if not websocket_configs:
259
+ # No WebSocket configs found - use default (CSRF required for security by default)
260
+ logger.debug(
261
+ f"No websocket_configs found anywhere for path '{path}' - "
262
+ f"using default csrf_required=true"
263
+ )
264
+ return True
265
+
266
+ # Normalize path for matching (handle trailing slashes)
267
+ normalized_path = path.rstrip("/")
268
+ logger.debug(
269
+ f"Checking CSRF requirement for path '{normalized_path}' "
270
+ f"against {len(websocket_configs)} app config(s)"
271
+ )
272
+
273
+ # Try to find matching app config
274
+ # WebSocket paths are registered as /app-slug/endpoint-path
275
+ # e.g., /chat-app/ws where app_slug="chat-app" and endpoint_path="/ws"
276
+ for app_slug, config in websocket_configs.items():
277
+ logger.debug(f"Checking app '{app_slug}' config with {len(config)} endpoint(s)")
278
+ # Check each endpoint in this app's config
279
+ for endpoint_name, endpoint_config in config.items():
280
+ endpoint_path = endpoint_config.get("path", "")
281
+ # Normalize endpoint path
282
+ normalized_endpoint = endpoint_path.rstrip("/")
283
+
284
+ # Build expected full path: /app-slug/endpoint-path
285
+ if normalized_endpoint.startswith("/"):
286
+ expected_full_path = f"/{app_slug}{normalized_endpoint}"
287
+ else:
288
+ expected_full_path = f"/{app_slug}/{normalized_endpoint}"
289
+
290
+ # Match patterns:
291
+ # 1. Full path match: /chat-app/ws == /chat-app/ws
292
+ # 2. Endpoint-only match: /ws (if path ends with endpoint)
293
+ # 3. Path contains endpoint: /chat-app/ws contains /ws
294
+ matches = (
295
+ normalized_path == expected_full_path
296
+ or normalized_path == normalized_endpoint
297
+ or normalized_path.endswith(normalized_endpoint)
298
+ or normalized_path.endswith(f"/{app_slug}{normalized_endpoint}")
299
+ )
300
+
301
+ if matches:
302
+ auth_config = endpoint_config.get("auth", {})
303
+ if isinstance(auth_config, dict):
304
+ # Return csrf_required setting (defaults to True - security by default)
305
+ csrf_required = auth_config.get("csrf_required", True)
306
+ logger.info(
307
+ f"✅ WebSocket '{normalized_path}' csrf_required={csrf_required} "
308
+ f"(from app='{app_slug}', endpoint='{endpoint_name}', "
309
+ f"endpoint_path='{normalized_endpoint}')"
310
+ )
311
+ return csrf_required
312
+ else:
313
+ logger.debug(
314
+ f"WebSocket '{normalized_path}' auth_config is not a dict: "
315
+ f"{type(auth_config)}"
316
+ )
317
+
318
+ # No matching config found - use default (CSRF required for security by default)
319
+ logger.debug(
320
+ f"❌ No WebSocket config match for '{normalized_path}' "
321
+ f"(checked {len(websocket_configs)} app(s)) - using default csrf_required=true"
322
+ )
323
+ return True
324
+
200
325
  def _is_websocket_upgrade(self, request: Request) -> bool:
201
326
  """Check if request is a WebSocket upgrade request."""
202
327
  upgrade_header = request.headers.get("upgrade", "").lower()
203
- return upgrade_header == "websocket"
328
+ connection_header = request.headers.get("connection", "").lower()
329
+ path = request.url.path
330
+
331
+ # CRITICAL: Enhanced logging for WebSocket detection
332
+ import sys
333
+
334
+ print(
335
+ f"🔍 [_is_websocket_upgrade] Path: {path}, "
336
+ f"upgrade='{upgrade_header}', connection='{connection_header}'",
337
+ file=sys.stderr,
338
+ flush=True,
339
+ )
340
+ logger.info(
341
+ f"_is_websocket_upgrade check: upgrade='{upgrade_header}', "
342
+ f"connection='{connection_header}', path='{path}'"
343
+ )
344
+
345
+ # Primary check: WebSocket upgrade requires both Upgrade: websocket
346
+ # and Connection: Upgrade headers
347
+ has_upgrade_header = upgrade_header == "websocket"
348
+ has_connection_upgrade = "upgrade" in connection_header or "websocket" in connection_header
349
+
350
+ print(
351
+ f"🔍 [_is_websocket_upgrade] has_upgrade={has_upgrade_header}, "
352
+ f"has_connection={has_connection_upgrade}",
353
+ file=sys.stderr,
354
+ flush=True,
355
+ )
356
+
357
+ # Secondary check: If upgrade header is present but connection is
358
+ # overridden (e.g., by TestClient), check if path matches a known
359
+ # WebSocket route pattern
360
+ path_matches_websocket_route = False
361
+ if has_upgrade_header and not has_connection_upgrade:
362
+ # Check if path matches any configured WebSocket route
363
+ websocket_configs = getattr(request.app.state, "websocket_configs", None)
364
+ if websocket_configs:
365
+ path = request.url.path.rstrip("/") or "/"
366
+ for app_slug, config in websocket_configs.items():
367
+ for _endpoint_name, endpoint_config in config.items():
368
+ endpoint_path = endpoint_config.get("path", "").rstrip("/") or "/"
369
+ # Try various path matching patterns
370
+ expected_full_path = (
371
+ f"/{app_slug}{endpoint_path}"
372
+ if endpoint_path != "/"
373
+ else f"/{app_slug}"
374
+ )
375
+ # Match patterns:
376
+ # 1. Exact match with app prefix: /app-slug/endpoint-path
377
+ # 2. Endpoint-only match: /endpoint-path (if path ends with endpoint)
378
+ # 3. Root match: / matches / or /app-slug
379
+ if (
380
+ path == expected_full_path
381
+ or path.endswith(endpoint_path)
382
+ or path == endpoint_path
383
+ or (path == "/" and endpoint_path == "/")
384
+ or (path == f"/{app_slug}" and endpoint_path == "/")
385
+ ):
386
+ path_matches_websocket_route = True
387
+ break
388
+ if path_matches_websocket_route:
389
+ break
390
+
391
+ is_websocket = has_upgrade_header and (
392
+ has_connection_upgrade or path_matches_websocket_route
393
+ )
394
+
395
+ # CRITICAL: Enhanced logging
396
+ import sys
397
+
398
+ if is_websocket:
399
+ print(
400
+ f"✅ [_is_websocket_upgrade] WebSocket detected: path={path}, "
401
+ f"upgrade={upgrade_header}, connection={connection_header}, "
402
+ f"path_match={path_matches_websocket_route}, result={is_websocket}",
403
+ file=sys.stderr,
404
+ flush=True,
405
+ )
406
+ logger.info(
407
+ f"WebSocket upgrade detected: path={path}, "
408
+ f"upgrade={upgrade_header}, connection={connection_header}, "
409
+ f"path_match={path_matches_websocket_route}"
410
+ )
411
+ else:
412
+ print(
413
+ f"❌ [_is_websocket_upgrade] NOT a WebSocket: path={path}, "
414
+ f"upgrade={upgrade_header}, connection={connection_header}, "
415
+ f"has_upgrade={has_upgrade_header}, has_connection={has_connection_upgrade}",
416
+ file=sys.stderr,
417
+ flush=True,
418
+ )
419
+ return is_websocket
204
420
 
205
421
  def _get_allowed_origins(self, request: Request) -> list[str]:
206
422
  """
@@ -239,33 +455,100 @@ class CSRFMiddleware(BaseHTTPMiddleware):
239
455
  except (AttributeError, TypeError, KeyError):
240
456
  pass
241
457
 
242
- # Final fallback: Use request host
458
+ # Final fallback: Use request host (normalize localhost variants)
243
459
  try:
244
460
  host = request.url.hostname
245
461
  scheme = request.url.scheme
246
462
  port = request.url.port
463
+
464
+ # Normalize localhost variants - return all common variants for development
465
+ # This handles cases where server binds to 0.0.0.0 but browser sends localhost
466
+ if host in ["localhost", "0.0.0.0", "127.0.0.1", "::1"]:
467
+ origins = []
468
+ for localhost_variant in ["localhost", "127.0.0.1"]:
469
+ if port and port not in [80, 443]:
470
+ origins.append(f"{scheme}://{localhost_variant}:{port}")
471
+ else:
472
+ origins.append(f"{scheme}://{localhost_variant}")
473
+ logger.debug(f"Generated localhost variant origins for host '{host}': {origins}")
474
+ return origins
475
+
476
+ # For other hosts, use the actual hostname
247
477
  if port and port not in [80, 443]:
248
478
  origin = f"{scheme}://{host}:{port}"
249
479
  else:
250
480
  origin = f"{scheme}://{host}"
251
481
  return [origin]
252
- except (AttributeError, TypeError):
482
+ except (AttributeError, TypeError) as e:
483
+ logger.debug(f"Could not determine origin from request: {e}")
253
484
  # Return empty list if we can't determine origin (will reject)
254
485
  return []
255
486
 
487
+ def _normalize_origin(self, origin: str) -> str:
488
+ """
489
+ Normalize origin for comparison (handles localhost/0.0.0.0/127.0.0.1/::1 equivalency).
490
+
491
+ In development, localhost, 0.0.0.0, 127.0.0.1, and ::1 should be treated as equivalent.
492
+ Also normalizes ports (80/443 vs explicit ports).
493
+ """
494
+ if not origin:
495
+ return origin
496
+
497
+ import re
498
+
499
+ # Normalize localhost variants - replace all variants with localhost
500
+ # Handle IPv4: 0.0.0.0, 127.0.0.1
501
+ # Handle IPv6: ::1
502
+ # Handle hostname: localhost
503
+ normalized = re.sub(
504
+ r"://(0\.0\.0\.0|127\.0\.0\.1|localhost|::1)",
505
+ "://localhost",
506
+ origin.lower(),
507
+ flags=re.IGNORECASE,
508
+ )
509
+
510
+ # Normalize ports: remove :80 for http and :443 for https
511
+ normalized = re.sub(r":80$", "", normalized)
512
+ normalized = re.sub(r":443$", "", normalized)
513
+
514
+ return normalized.rstrip("/")
515
+
516
+ def _is_development_mode(self) -> bool:
517
+ """Check if running in development mode."""
518
+ import os
519
+
520
+ env = os.getenv("ENVIRONMENT", "").lower()
521
+ g_nome_env = os.getenv("G_NOME_ENV", "").lower()
522
+ return env in ["development", "dev"] or g_nome_env in ["development", "dev"]
523
+
256
524
  def _validate_websocket_origin(self, request: Request) -> bool:
257
525
  """
258
526
  Validate Origin header for WebSocket upgrade requests.
259
527
 
260
528
  Primary defense against Cross-Site WebSocket Hijacking (CSWSH).
261
529
  Returns True if Origin is valid, False otherwise.
530
+
531
+ In development mode, allows connections without Origin header (with warning).
262
532
  """
263
533
  origin = request.headers.get("origin")
264
534
  if not origin:
265
- logger.warning(f"WebSocket upgrade missing Origin header: {request.url.path}")
266
- return False
535
+ if self._is_development_mode():
536
+ logger.warning(
537
+ f"WebSocket upgrade missing Origin header in development mode: "
538
+ f"{request.url.path} - allowing connection"
539
+ )
540
+ return True
541
+ else:
542
+ logger.warning(f"WebSocket upgrade missing Origin header: {request.url.path}")
543
+ return False
267
544
 
268
545
  allowed_origins = self._get_allowed_origins(request)
546
+ normalized_origin = self._normalize_origin(origin)
547
+
548
+ logger.debug(
549
+ f"Validating WebSocket origin: {origin} (normalized: {normalized_origin}) "
550
+ f"against allowed: {allowed_origins}"
551
+ )
269
552
 
270
553
  for allowed in allowed_origins:
271
554
  if allowed == "*":
@@ -274,14 +557,23 @@ class CSRFMiddleware(BaseHTTPMiddleware):
274
557
  "not recommended for production"
275
558
  )
276
559
  return True
277
- if origin == allowed or origin.rstrip("/") == allowed.rstrip("/"):
560
+
561
+ normalized_allowed = self._normalize_origin(allowed)
562
+ if normalized_origin == normalized_allowed:
563
+ logger.debug(
564
+ f"✅ WebSocket origin validated: {origin} matches {allowed} "
565
+ f"(normalized: {normalized_origin} == {normalized_allowed})"
566
+ )
278
567
  return True
279
568
 
280
569
  cors_config = getattr(request.app.state, "cors_config", None)
281
570
  cors_enabled = cors_config.get("enabled", False) if cors_config else False
571
+ normalized_allowed_list = [self._normalize_origin(a) for a in allowed_origins]
282
572
  logger.warning(
283
- f"WebSocket upgrade rejected - invalid Origin: {origin} "
284
- f"(allowed: {allowed_origins}, app: {getattr(request.app, 'title', 'unknown')}, "
573
+ f"WebSocket upgrade rejected - invalid Origin: {origin} "
574
+ f"(normalized: {normalized_origin}, allowed: {allowed_origins}, "
575
+ f"normalized_allowed: {normalized_allowed_list}, "
576
+ f"app: {getattr(request.app, 'title', 'unknown')}, "
285
577
  f"path: {request.url.path}, CORS enabled: {cors_enabled}, "
286
578
  f"has_cors_config: {hasattr(request.app.state, 'cors_config')})"
287
579
  )
@@ -295,79 +587,369 @@ class CSRFMiddleware(BaseHTTPMiddleware):
295
587
  """
296
588
  Process request through CSRF middleware.
297
589
  """
590
+ # CRITICAL: Log EVERY request immediately to catch WebSocket upgrades
298
591
  path = request.url.path
299
592
  method = request.method
593
+ upgrade_header = request.headers.get("upgrade", "").lower()
594
+ connection_header = request.headers.get("connection", "").lower()
595
+ origin_header = request.headers.get("origin")
596
+
597
+ # Log ALL WebSocket-related requests IMMEDIATELY (before any processing)
598
+ if upgrade_header or "websocket" in path.lower() or connection_header == "upgrade":
599
+ import sys
600
+
601
+ print(
602
+ f"🚨 [CSRF MIDDLEWARE ENTRY] {method} {path}, "
603
+ f"upgrade={upgrade_header}, connection={connection_header}, "
604
+ f"origin={origin_header}",
605
+ file=sys.stderr,
606
+ flush=True,
607
+ )
608
+ logger.info(
609
+ f"🚨 [CSRF MIDDLEWARE ENTRY] {method} {path}, "
610
+ f"upgrade={upgrade_header}, connection={connection_header}, "
611
+ f"origin={origin_header}"
612
+ )
300
613
 
301
- # CRITICAL: Handle WebSocket upgrade requests BEFORE other CSRF checks
302
- # WebSocket upgrades use cookie-based authentication and require CSRF validation
303
- if self._is_websocket_upgrade(request):
304
- # Always validate origin for WebSocket connections (CSWSH protection)
305
- if not self._validate_websocket_origin(request):
306
- logger.warning(
307
- f"WebSocket origin validation failed for {path}: "
308
- f"origin={request.headers.get('origin')}, "
309
- f"allowed={self._get_allowed_origins(request)}"
614
+ try:
615
+ # CRITICAL: Log ALL requests to verify middleware is running
616
+ # Always log WebSocket-related requests
617
+ if upgrade_header or "websocket" in path.lower() or connection_header == "upgrade":
618
+ logger.info(
619
+ f"🔍 CSRF middleware INTERCEPTED: {method} {path}, "
620
+ f"upgrade={upgrade_header}, connection={connection_header}, "
621
+ f"origin={origin_header}"
310
622
  )
311
- return JSONResponse(
312
- status_code=status.HTTP_403_FORBIDDEN,
313
- content={"detail": "Invalid origin for WebSocket connection"},
623
+ import sys
624
+
625
+ print(
626
+ f"🔍 [CSRF MIDDLEWARE] {method} {path}, "
627
+ f"upgrade={upgrade_header}, origin={origin_header}",
628
+ file=sys.stderr,
629
+ flush=True,
314
630
  )
315
631
 
316
- # Cookie-based authentication requires CSRF protection
317
- # Check if authentication token cookie is present
318
- # Use same cookie name as SharedAuthMiddleware for consistency
319
- from .shared_middleware import AUTH_COOKIE_NAME
320
-
321
- auth_token_cookie = request.cookies.get(AUTH_COOKIE_NAME)
322
- if auth_token_cookie:
323
- # For WebSocket upgrades, CSRF protection relies on:
324
- # 1. Origin validation (already done above) - primary defense
325
- # 2. SameSite cookies - prevents cross-site cookie sending
326
- # 3. CSRF cookie presence - ensures session is established
327
- #
328
- # Note: JavaScript WebSocket API cannot set custom headers,
329
- # so we cannot use double-submit cookie pattern (cookie + header).
330
- # Instead, we rely on Origin validation + SameSite cookies for CSRF protection.
331
- csrf_cookie_token = request.cookies.get(self.cookie_name)
332
- if not csrf_cookie_token:
333
- logger.warning(f"WebSocket upgrade missing CSRF cookie for {path}")
632
+ # CRITICAL: Handle WebSocket upgrade requests BEFORE other CSRF checks
633
+ # WebSocket upgrades use cookie-based authentication and require CSRF validation
634
+ is_ws_upgrade = self._is_websocket_upgrade(request)
635
+ logger.info(f"🔍 WebSocket upgrade detection for {path}: is_websocket={is_ws_upgrade}")
636
+
637
+ if is_ws_upgrade:
638
+ logger.info(
639
+ f"🔌 CSRF middleware processing WebSocket upgrade: {path}, "
640
+ f"origin: {request.headers.get('origin')}"
641
+ )
642
+ # Always validate origin for WebSocket connections (CSWSH protection)
643
+ origin_valid = self._validate_websocket_origin(request)
644
+ logger.info(
645
+ f"🔍 WebSocket origin validation for {path}: "
646
+ f"origin={request.headers.get('origin')}, "
647
+ f"allowed={self._get_allowed_origins(request)}, "
648
+ f"valid={origin_valid}"
649
+ )
650
+ if not origin_valid:
651
+ logger.warning(
652
+ f"❌ WebSocket origin validation failed for {path}: "
653
+ f"origin={request.headers.get('origin')}, "
654
+ f"allowed={self._get_allowed_origins(request)}"
655
+ )
334
656
  return JSONResponse(
335
657
  status_code=status.HTTP_403_FORBIDDEN,
336
- content={"detail": "CSRF token missing for WebSocket authentication"},
658
+ content={"detail": "Invalid origin for WebSocket connection"},
337
659
  )
338
660
 
339
- # Validate CSRF token signature if secret is used
340
- if self.secret and not validate_csrf_token(
341
- csrf_cookie_token, self.secret, self.token_ttl
342
- ):
343
- logger.warning(f"WebSocket CSRF token validation failed for {path}")
344
- return JSONResponse(
345
- status_code=status.HTTP_403_FORBIDDEN,
346
- content={
347
- "detail": "CSRF token expired or invalid for WebSocket connection"
348
- },
661
+ # Cookie-based authentication requires CSRF protection
662
+ # Check if authentication token cookie is present
663
+ # Use same cookie name as SharedAuthMiddleware for consistency
664
+ from .shared_middleware import AUTH_COOKIE_NAME
665
+
666
+ auth_token_cookie = request.cookies.get(AUTH_COOKIE_NAME)
667
+ logger.info(
668
+ f"🔍 WebSocket auth check for {path}: "
669
+ f"auth_cookie={'present' if auth_token_cookie else 'missing'}"
670
+ )
671
+
672
+ # Check if ticket/session key authentication is required
673
+ # (csrf_required flag controls whether WebSocket needs ticket/session key)
674
+ # If csrf_required=false, we skip ticket validation entirely
675
+ csrf_required = self._websocket_requires_csrf(request, path)
676
+ logger.info(
677
+ f"🔍 WebSocket auth check for {path}: "
678
+ f"ticket/session_key_required={csrf_required}"
679
+ )
680
+
681
+ # Only validate ticket/session key if:
682
+ # 1. Auth cookie is present (user is authenticated)
683
+ # 2. Ticket/session key is required for this endpoint
684
+ if auth_token_cookie and csrf_required:
685
+ # WebSocket Authentication (NOT CSRF):
686
+ # WebSockets use JWT → Ticket → WebSocket flow for authentication.
687
+ # CSRF protection comes from Origin validation + SameSite cookies.
688
+ #
689
+ # Authentication Methods (in order of preference):
690
+ # 1. Ticket (JWT → Ticket exchange) - preferred for single-app
691
+ # - Client: POST /auth/ticket (sends JWT cookie)
692
+ # - Server: Validates JWT, generates ticket (UUID)
693
+ # - Client: ws://host/app/ws?ticket=<uuid>
694
+ # - Server: Validates & consumes ticket (single-use)
695
+ # 2. Session key - preferred for multi-app SSO
696
+ # - Generated via /auth/websocket-session endpoint
697
+ # - Encrypted, database-backed, long TTL (24h)
698
+ # 3. CSRF cookie - backward compatibility only
699
+ #
700
+ # CSRF Protection (separate from authentication):
701
+ # - Origin validation (already done above) - primary CSRF defense
702
+ # - SameSite cookies - prevents cross-site cookie sending
703
+ #
704
+ # Ticket flow: JWT (httpOnly cookie) → POST /auth/ticket → Ticket (UUID)
705
+ # → WebSocket connection with ticket → Validated & consumed
706
+
707
+ # Check for session key first (preferred for multi-app setups)
708
+ session_key = request.query_params.get("session_key") or request.headers.get(
709
+ "X-WebSocket-Session-Key"
710
+ )
711
+
712
+ # Check for ticket (preferred for single-app setups)
713
+ ticket = request.query_params.get("ticket") or request.headers.get(
714
+ "X-WebSocket-Ticket"
349
715
  )
350
716
 
351
- # Optional: If CSRF header is provided, validate it matches cookie
352
- # (Some clients may send it, but it's not required for WebSocket upgrades)
353
- header_token = request.headers.get(self.header_name)
354
- if header_token:
355
- # If header is provided, validate it matches cookie (double-submit pattern)
356
- if not hmac.compare_digest(csrf_cookie_token, header_token):
357
- logger.warning(f"WebSocket CSRF token mismatch for {path}")
358
- return JSONResponse(
359
- status_code=status.HTTP_403_FORBIDDEN,
360
- content={"detail": "CSRF token invalid for WebSocket connection"},
717
+ if session_key:
718
+ # Session key authentication (bypasses CSRF - encrypted and secure)
719
+ # For WebSocket upgrades, let the handler validate session keys
720
+ # This allows TestClient to catch WebSocketDisconnect exceptions properly
721
+ # The handler will validate and raise WebSocketDisconnect if invalid
722
+ websocket_session_manager = None
723
+ app = request.app
724
+ apps_checked = []
725
+ while app:
726
+ app_title = getattr(app, "title", "unknown")
727
+ apps_checked.append(app_title)
728
+ websocket_session_manager = getattr(
729
+ app.state, "websocket_session_manager", None
730
+ )
731
+ if websocket_session_manager:
732
+ logger.debug(
733
+ f"Found websocket_session_manager on app '{app_title}' "
734
+ f"for WebSocket path '{path}' (checked: {apps_checked})"
735
+ )
736
+ break
737
+ parent_app = getattr(app, "app", None)
738
+ if parent_app is app: # Prevent infinite loop
739
+ break
740
+ app = parent_app
741
+
742
+ if not websocket_session_manager:
743
+ logger.error(
744
+ f"❌ WebSocket session key provided for {path} but "
745
+ "websocket_session_manager not found"
746
+ )
747
+ return JSONResponse(
748
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
749
+ content={
750
+ "detail": (
751
+ "WebSocket session manager not available. "
752
+ "Server configuration error."
753
+ )
754
+ },
755
+ )
756
+
757
+ # For WebSocket upgrades, let the handler validate the session key
758
+ # This ensures TestClient can catch WebSocketDisconnect exceptions
759
+ # The handler will validate and raise WebSocketDisconnect if invalid
760
+ logger.info(
761
+ f"✅ WebSocket session key provided for {path} - "
762
+ "CSRF validation bypassed (session key will be validated in handler)"
763
+ )
764
+ elif ticket:
765
+ # Ticket-based authentication (preferred)
766
+ # Get WebSocket ticket store
767
+ from ..routing.websockets import _global_websocket_ticket_store
768
+
769
+ websocket_ticket_store = _global_websocket_ticket_store
770
+
771
+ # Fallback: Try to get from app state (for backward compatibility)
772
+ if not websocket_ticket_store:
773
+ app = request.app
774
+ apps_checked = []
775
+ while app:
776
+ app_title = getattr(app, "title", "unknown")
777
+ apps_checked.append(app_title)
778
+
779
+ # Get ticket store
780
+ websocket_ticket_store = getattr(
781
+ app.state, "websocket_ticket_store", None
782
+ )
783
+ if websocket_ticket_store:
784
+ logger.debug(
785
+ f"Found websocket_ticket_store on app '{app_title}' "
786
+ f"for WebSocket path '{path}' (checked: {apps_checked})"
787
+ )
788
+ break
789
+
790
+ # Try to get parent app
791
+ parent_app = getattr(app, "app", None)
792
+ if parent_app is app: # Prevent infinite loop
793
+ break
794
+ app = parent_app
795
+
796
+ if not websocket_ticket_store:
797
+ logger.error(
798
+ f"❌ WebSocket ticket store not available for {path}. "
799
+ "Ticket authentication requires websocket_ticket_store "
800
+ "to be initialized."
801
+ )
802
+ return JSONResponse(
803
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
804
+ content={
805
+ "detail": (
806
+ "WebSocket ticket store not available. "
807
+ "Server configuration error."
808
+ )
809
+ },
810
+ )
811
+
812
+ # Validate and consume ticket (atomic operation - single-use)
813
+ try:
814
+ logger.info(
815
+ f"🔍 Validating WebSocket ticket for {path}: "
816
+ f"ticket={ticket[:16]}... (truncated)"
817
+ )
818
+ ticket_data = await websocket_ticket_store.validate_and_consume_ticket(
819
+ ticket
820
+ )
821
+ if not ticket_data:
822
+ logger.error(
823
+ f"❌ WebSocket ticket validation failed for {path}. "
824
+ f"Ticket: {ticket[:16]}... "
825
+ "Ticket may be expired, invalid, or already used."
826
+ )
827
+ return JSONResponse(
828
+ status_code=status.HTTP_403_FORBIDDEN,
829
+ content={
830
+ "detail": (
831
+ "WebSocket ticket expired or invalid. "
832
+ "Generate a new ticket via /auth/ticket endpoint."
833
+ )
834
+ },
835
+ )
836
+
837
+ # Store ticket data in request state for WebSocket handler
838
+ request.state.websocket_session = ticket_data
839
+ logger.info(
840
+ f"✅ WebSocket ticket validated for {path} "
841
+ f"(user_id: {ticket_data.get('user_id')}, "
842
+ f"user_email: {ticket_data.get('user_email')})"
843
+ )
844
+ except (
845
+ ValueError,
846
+ TypeError,
847
+ AttributeError,
848
+ RuntimeError,
849
+ ) as e:
850
+ logger.error(
851
+ f"❌ Error validating WebSocket ticket for {path}: {e}",
852
+ exc_info=True,
853
+ )
854
+ return JSONResponse(
855
+ status_code=status.HTTP_403_FORBIDDEN,
856
+ content={
857
+ "detail": "WebSocket ticket validation error. "
858
+ "Generate a new ticket."
859
+ },
860
+ )
861
+ else:
862
+ # Fallback to CSRF cookie validation (backward compatibility)
863
+ # For WebSocket, CSRF header is optional (JS can't set headers on upgrade)
864
+ # but if provided, it must match the cookie
865
+ cookie_token = request.cookies.get(self.cookie_name)
866
+ header_token = request.headers.get(self.header_name)
867
+
868
+ if not cookie_token:
869
+ logger.error(
870
+ f"❌ WebSocket upgrade missing CSRF cookie for {path}. "
871
+ "CSRF protection is required. "
872
+ "Generate ticket via /auth/ticket endpoint or include CSRF cookie."
873
+ )
874
+ return JSONResponse(
875
+ status_code=status.HTTP_403_FORBIDDEN,
876
+ content={
877
+ "detail": (
878
+ "CSRF token missing. "
879
+ "Generate ticket via /auth/ticket endpoint "
880
+ "or include CSRF cookie/token."
881
+ )
882
+ },
883
+ )
884
+
885
+ # If header is provided, validate it matches the cookie
886
+ if header_token:
887
+ if not hmac.compare_digest(cookie_token, header_token):
888
+ logger.error(
889
+ f"❌ WebSocket CSRF token mismatch for {path}. "
890
+ "Header token does not match cookie token."
891
+ )
892
+ return JSONResponse(
893
+ status_code=status.HTTP_403_FORBIDDEN,
894
+ content={"detail": "CSRF token invalid."},
895
+ )
896
+
897
+ # Validate CSRF token (check signature if secret is used)
898
+ if not self._validate_csrf_token(cookie_token, request):
899
+ logger.error(f"❌ WebSocket CSRF token validation failed for {path}")
900
+ return JSONResponse(
901
+ status_code=status.HTTP_403_FORBIDDEN,
902
+ content={"detail": "CSRF token validation failed."},
903
+ )
904
+
905
+ logger.info(
906
+ f"✅ WebSocket CSRF cookie validated for {path} "
907
+ "(backward compatibility mode)"
361
908
  )
909
+ elif auth_token_cookie and not csrf_required:
910
+ logger.info(
911
+ f"✅ WebSocket CSRF validation skipped for {path} "
912
+ f"(csrf_required=false) - only origin validation performed"
913
+ )
914
+ elif not auth_token_cookie:
915
+ logger.info(
916
+ f"✅ WebSocket connection allowed for {path} "
917
+ f"(no auth cookie - WebSocket handler will authenticate)"
918
+ )
362
919
 
363
- logger.debug(
364
- f"WebSocket upgrade CSRF validation passed for {path} "
365
- f"(Origin validated, CSRF cookie present)"
920
+ validation_status = (
921
+ "CSRF/ticket validated"
922
+ if auth_token_cookie and csrf_required
923
+ else "CSRF skipped"
924
+ )
925
+ logger.info(
926
+ f"✅ WebSocket upgrade CSRF validation passed for {path} "
927
+ f"(Origin validated, {validation_status})"
366
928
  )
367
929
 
368
- # Origin validated (and CSRF validated if authenticated)
369
- # Allow WebSocket upgrade to proceed
370
- return await call_next(request)
930
+ # Origin validated (and CSRF/ticket validated if authenticated
931
+ # and csrf_required=true)
932
+ # Allow WebSocket upgrade to proceed to WebSocket handler
933
+ logger.debug(f"✅ WebSocket upgrade request allowed to proceed: {path}")
934
+ return await call_next(request)
935
+
936
+ except (
937
+ AttributeError,
938
+ KeyError,
939
+ RuntimeError,
940
+ ValueError,
941
+ TypeError,
942
+ ConnectionError,
943
+ ) as e:
944
+ # Catch exceptions in WebSocket handling to see what's failing
945
+ logger.error(
946
+ f"❌ CRITICAL: Exception in CSRF middleware WebSocket handling: {e}", exc_info=True
947
+ )
948
+ import sys
949
+
950
+ print(f"❌ [CSRF MIDDLEWARE EXCEPTION] {e}", file=sys.stderr, flush=True)
951
+ # Re-raise to see the full error
952
+ raise
371
953
 
372
954
  if self._is_exempt(path):
373
955
  return await call_next(request)