workspace-mcp 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,426 @@
1
+ """
2
+ Authentication Middleware
3
+
4
+ Middleware to bind requests to sessions and handle OAuth 2.1 authentication.
5
+ Integrates token validation, session management, and request context.
6
+ """
7
+
8
+ import logging
9
+ from dataclasses import dataclass
10
+ from typing import Dict, Any, Optional, List, Callable
11
+
12
+ from fastapi import Request, HTTPException
13
+ from fastapi.responses import JSONResponse
14
+ from starlette.middleware.base import BaseHTTPMiddleware
15
+
16
+ from .discovery import AuthorizationServerDiscovery
17
+ from .tokens import TokenValidator, TokenValidationError
18
+ from .sessions import SessionStore, Session
19
+ from .http import HTTPAuthHandler
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class AuthContext:
26
+ """Authentication context attached to requests."""
27
+
28
+ authenticated: bool = False
29
+ user_id: Optional[str] = None
30
+ session_id: Optional[str] = None
31
+ session: Optional[Session] = None
32
+ token_info: Optional[Dict[str, Any]] = None
33
+ scopes: List[str] = None
34
+ error: Optional[str] = None
35
+ error_description: Optional[str] = None
36
+
37
+ def __post_init__(self):
38
+ if self.scopes is None:
39
+ self.scopes = []
40
+
41
+
42
+ class AuthenticationMiddleware(BaseHTTPMiddleware):
43
+ """Middleware to bind requests to sessions and handle authentication."""
44
+
45
+ def __init__(
46
+ self,
47
+ app,
48
+ session_store: SessionStore,
49
+ token_validator: Optional[TokenValidator] = None,
50
+ discovery_service: Optional[AuthorizationServerDiscovery] = None,
51
+ http_auth_handler: Optional[HTTPAuthHandler] = None,
52
+ required_scopes: Optional[List[str]] = None,
53
+ exempt_paths: Optional[List[str]] = None,
54
+ authorization_server_url: Optional[str] = None,
55
+ expected_audience: Optional[str] = None,
56
+ enable_bearer_passthrough: bool = True,
57
+ ):
58
+ """
59
+ Initialize authentication middleware.
60
+
61
+ Args:
62
+ app: FastAPI application
63
+ session_store: Session store instance
64
+ token_validator: Token validator instance
65
+ discovery_service: Authorization server discovery service
66
+ http_auth_handler: HTTP authentication handler
67
+ required_scopes: Default required scopes
68
+ exempt_paths: Paths exempt from authentication
69
+ authorization_server_url: Default authorization server URL
70
+ expected_audience: Expected token audience
71
+ enable_bearer_passthrough: Enable Bearer token passthrough mode
72
+ """
73
+ super().__init__(app)
74
+
75
+ self.session_store = session_store
76
+ self.token_validator = token_validator or TokenValidator(discovery_service)
77
+ self.discovery = discovery_service or AuthorizationServerDiscovery()
78
+ self.http_auth = http_auth_handler or HTTPAuthHandler()
79
+ self.required_scopes = required_scopes or []
80
+ self.exempt_paths = set(exempt_paths or ["/health", "/oauth2callback", "/.well-known/"])
81
+ self.authorization_server_url = authorization_server_url
82
+ self.expected_audience = expected_audience
83
+ self.enable_bearer_passthrough = enable_bearer_passthrough
84
+
85
+ async def dispatch(self, request: Request, call_next: Callable) -> Any:
86
+ """Process request through authentication middleware."""
87
+
88
+ # Check if path is exempt from authentication
89
+ if self._is_exempt_path(request.url.path):
90
+ request.state.auth = AuthContext()
91
+ return await call_next(request)
92
+
93
+ # Perform authentication
94
+ auth_context = await self.authenticate_request(request)
95
+
96
+ # Attach authentication context to request
97
+ request.state.auth = auth_context
98
+
99
+ # Handle authentication failures
100
+ if not auth_context.authenticated and self._requires_authentication(request):
101
+ return self._create_auth_error_response(auth_context)
102
+
103
+ # Continue with request processing
104
+ return await call_next(request)
105
+
106
+ async def authenticate_request(self, request: Request) -> AuthContext:
107
+ """
108
+ Validate token and resolve session for request.
109
+
110
+ Args:
111
+ request: HTTP request
112
+
113
+ Returns:
114
+ Authentication context
115
+ """
116
+ auth_context = AuthContext()
117
+
118
+ try:
119
+ # Extract token information
120
+ token_info = self.http_auth.get_token_info_from_headers(dict(request.headers))
121
+
122
+ if not token_info["has_bearer_token"]:
123
+ auth_context.error = "missing_token"
124
+ auth_context.error_description = "No Bearer token provided"
125
+ return auth_context
126
+
127
+ if not token_info["valid_format"]:
128
+ auth_context.error = "invalid_token"
129
+ auth_context.error_description = "Invalid Bearer token format"
130
+ return auth_context
131
+
132
+ token = token_info["token"]
133
+
134
+ # Try session-based authentication first
135
+ session_id = self._extract_session_id(request)
136
+ if session_id:
137
+ session = self.session_store.get_session(session_id)
138
+ if session:
139
+ # Validate that token matches session
140
+ if await self._validate_session_token(session, token):
141
+ auth_context = self._create_session_auth_context(session)
142
+ return auth_context
143
+ else:
144
+ logger.warning(f"Token mismatch for session {session_id}")
145
+
146
+ # Fall back to direct token validation
147
+ if self.enable_bearer_passthrough:
148
+ auth_context = await self._validate_bearer_token(token, request)
149
+
150
+ # Create session if token is valid and no session exists
151
+ if auth_context.authenticated and not session_id:
152
+ session_id = self._create_session_from_token(auth_context)
153
+ auth_context.session_id = session_id
154
+
155
+ return auth_context
156
+ else:
157
+ auth_context.error = "invalid_session"
158
+ auth_context.error_description = "Valid session required"
159
+ return auth_context
160
+
161
+ except Exception as e:
162
+ logger.error(f"Authentication error: {e}")
163
+ auth_context.error = "server_error"
164
+ auth_context.error_description = "Internal authentication error"
165
+ return auth_context
166
+
167
+ async def _validate_bearer_token(self, token: str, request: Request) -> AuthContext:
168
+ """Validate Bearer token directly."""
169
+ auth_context = AuthContext()
170
+
171
+ try:
172
+ # Validate token
173
+ token_result = await self.token_validator.validate_token(
174
+ token=token,
175
+ expected_audience=self.expected_audience,
176
+ required_scopes=self.required_scopes,
177
+ authorization_server_url=self.authorization_server_url,
178
+ )
179
+
180
+ if token_result["valid"]:
181
+ auth_context.authenticated = True
182
+ auth_context.user_id = token_result["user_identity"]
183
+ auth_context.token_info = token_result
184
+ auth_context.scopes = token_result.get("scopes", [])
185
+
186
+ logger.debug(f"Successfully validated Bearer token for user {auth_context.user_id}")
187
+ else:
188
+ auth_context.error = "invalid_token"
189
+ auth_context.error_description = "Token validation failed"
190
+
191
+ except TokenValidationError as e:
192
+ auth_context.error = e.error_code
193
+ auth_context.error_description = str(e)
194
+ logger.warning(f"Token validation failed: {e}")
195
+ except Exception as e:
196
+ auth_context.error = "server_error"
197
+ auth_context.error_description = "Token validation error"
198
+ logger.error(f"Token validation error: {e}")
199
+
200
+ return auth_context
201
+
202
+ async def _validate_session_token(self, session: Session, token: str) -> bool:
203
+ """Validate that token matches session."""
204
+ try:
205
+ # Compare token with session token info
206
+ session_token = session.token_info.get("access_token")
207
+ if not session_token:
208
+ return False
209
+
210
+ # Direct token comparison
211
+ if session_token == token:
212
+ return True
213
+
214
+ # For JWT tokens, compare claims
215
+ if self.token_validator._is_jwt_format(token):
216
+ try:
217
+ token_payload = self.token_validator.decode_jwt_payload(token)
218
+ session_payload = session.token_info.get("claims", {})
219
+
220
+ # Compare key claims
221
+ key_claims = ["sub", "email", "aud", "iss"]
222
+ for claim in key_claims:
223
+ if claim in token_payload and claim in session_payload:
224
+ if token_payload[claim] != session_payload[claim]:
225
+ return False
226
+
227
+ return True
228
+ except Exception:
229
+ return False
230
+
231
+ return False
232
+
233
+ except Exception as e:
234
+ logger.error(f"Session token validation error: {e}")
235
+ return False
236
+
237
+ def _create_session_auth_context(self, session: Session) -> AuthContext:
238
+ """Create authentication context from session."""
239
+ return AuthContext(
240
+ authenticated=True,
241
+ user_id=session.user_id,
242
+ session_id=session.session_id,
243
+ session=session,
244
+ token_info=session.token_info,
245
+ scopes=session.scopes,
246
+ )
247
+
248
+ def _create_session_from_token(self, auth_context: AuthContext) -> Optional[str]:
249
+ """Create new session from validated token."""
250
+ if not auth_context.authenticated or not auth_context.user_id:
251
+ return None
252
+
253
+ try:
254
+ session_id = self.session_store.create_session(
255
+ user_id=auth_context.user_id,
256
+ token_info=auth_context.token_info,
257
+ scopes=auth_context.scopes,
258
+ authorization_server=auth_context.token_info.get("issuer"),
259
+ metadata={
260
+ "created_via": "bearer_token",
261
+ "token_type": auth_context.token_info.get("token_type"),
262
+ }
263
+ )
264
+
265
+ logger.info(f"Created session {session_id} for user {auth_context.user_id}")
266
+ return session_id
267
+
268
+ except Exception as e:
269
+ logger.error(f"Failed to create session: {e}")
270
+ return None
271
+
272
+ def _extract_session_id(self, request: Request) -> Optional[str]:
273
+ """Extract session ID from request."""
274
+ # Try different sources for session ID
275
+
276
+ # 1. MCP-Session-Id header (primary)
277
+ session_id = request.headers.get("mcp-session-id") or request.headers.get("Mcp-Session-Id")
278
+ if session_id:
279
+ return session_id
280
+
281
+ # 2. X-Session-ID header (alternative)
282
+ session_id = request.headers.get("x-session-id") or request.headers.get("X-Session-ID")
283
+ if session_id:
284
+ return session_id
285
+
286
+ # 3. Query parameter
287
+ session_id = request.query_params.get("session_id")
288
+ if session_id:
289
+ return session_id
290
+
291
+ return None
292
+
293
+ def _is_exempt_path(self, path: str) -> bool:
294
+ """Check if path is exempt from authentication."""
295
+ for exempt_path in self.exempt_paths:
296
+ if path.startswith(exempt_path):
297
+ return True
298
+ return False
299
+
300
+ def _requires_authentication(self, request: Request) -> bool:
301
+ """Check if request requires authentication."""
302
+ # For now, all non-exempt paths require authentication
303
+ # This could be extended with more sophisticated rules
304
+ return True
305
+
306
+ def _create_auth_error_response(self, auth_context: AuthContext) -> JSONResponse:
307
+ """Create authentication error response."""
308
+
309
+ # Determine status code
310
+ if auth_context.error == "missing_token":
311
+ status_code = 401
312
+ elif auth_context.error in ["invalid_token", "invalid_session"]:
313
+ status_code = 401
314
+ elif auth_context.error == "insufficient_scope":
315
+ status_code = 403
316
+ else:
317
+ status_code = 401
318
+
319
+ # Build error response
320
+ error_data = {
321
+ "error": auth_context.error or "unauthorized",
322
+ "error_description": auth_context.error_description or "Authentication required",
323
+ }
324
+
325
+ # Build WWW-Authenticate header
326
+ www_auth_header = self.http_auth.build_www_authenticate_header(
327
+ realm="mcp-server",
328
+ error=auth_context.error,
329
+ error_description=auth_context.error_description,
330
+ scope=" ".join(self.required_scopes) if self.required_scopes else None,
331
+ )
332
+
333
+ headers = {
334
+ "WWW-Authenticate": www_auth_header,
335
+ "Cache-Control": "no-store",
336
+ "Pragma": "no-cache",
337
+ }
338
+
339
+ return JSONResponse(
340
+ status_code=status_code,
341
+ content=error_data,
342
+ headers=headers,
343
+ )
344
+
345
+ def attach_session_to_request(self, request: Request, session: Session) -> None:
346
+ """
347
+ Attach session context to request.
348
+
349
+ Args:
350
+ request: HTTP request
351
+ session: Session to attach
352
+ """
353
+ auth_context = self._create_session_auth_context(session)
354
+ request.state.auth = auth_context
355
+
356
+ async def close(self):
357
+ """Clean up middleware resources."""
358
+ await self.token_validator.close()
359
+ await self.discovery.close()
360
+
361
+
362
+ def get_auth_context(request: Request) -> AuthContext:
363
+ """
364
+ Get authentication context from request.
365
+
366
+ Args:
367
+ request: HTTP request
368
+
369
+ Returns:
370
+ Authentication context
371
+ """
372
+ return getattr(request.state, "auth", AuthContext())
373
+
374
+
375
+ def require_auth(request: Request) -> AuthContext:
376
+ """
377
+ Require authentication and return context.
378
+
379
+ Args:
380
+ request: HTTP request
381
+
382
+ Returns:
383
+ Authentication context
384
+
385
+ Raises:
386
+ HTTPException: If not authenticated
387
+ """
388
+ auth_context = get_auth_context(request)
389
+
390
+ if not auth_context.authenticated:
391
+ raise HTTPException(
392
+ status_code=401,
393
+ detail="Authentication required",
394
+ headers={"WWW-Authenticate": "Bearer"},
395
+ )
396
+
397
+ return auth_context
398
+
399
+
400
+ def require_scopes(request: Request, required_scopes: List[str]) -> AuthContext:
401
+ """
402
+ Require specific scopes and return context.
403
+
404
+ Args:
405
+ request: HTTP request
406
+ required_scopes: Required OAuth scopes
407
+
408
+ Returns:
409
+ Authentication context
410
+
411
+ Raises:
412
+ HTTPException: If scopes insufficient
413
+ """
414
+ auth_context = require_auth(request)
415
+
416
+ missing_scopes = set(required_scopes) - set(auth_context.scopes)
417
+ if missing_scopes:
418
+ raise HTTPException(
419
+ status_code=403,
420
+ detail=f"Insufficient scope. Missing: {', '.join(missing_scopes)}",
421
+ headers={
422
+ "WWW-Authenticate": f'Bearer scope="{" ".join(required_scopes)}", error="insufficient_scope"'
423
+ },
424
+ )
425
+
426
+ return auth_context