workspace-mcp 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,440 @@
1
+ """
2
+ OAuth 2.1 Handler
3
+
4
+ Main OAuth 2.1 authentication handler that integrates all components.
5
+ Provides a unified interface for OAuth 2.1 functionality.
6
+ """
7
+
8
+ import logging
9
+ from typing import Dict, Any, Optional, List, Tuple
10
+
11
+ from .config import OAuth2Config
12
+ from .discovery import AuthorizationServerDiscovery
13
+ from .oauth2 import OAuth2AuthorizationFlow
14
+ from .tokens import TokenValidator, TokenValidationError
15
+ from .sessions import SessionStore, Session
16
+ from .middleware import AuthenticationMiddleware, AuthContext
17
+ from .http import HTTPAuthHandler
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class OAuth2Handler:
23
+ """Main OAuth 2.1 authentication handler."""
24
+
25
+ def __init__(self, config: OAuth2Config):
26
+ """
27
+ Initialize OAuth 2.1 handler.
28
+
29
+ Args:
30
+ config: OAuth 2.1 configuration
31
+ """
32
+ self.config = config
33
+
34
+ # Initialize components
35
+ self.discovery = AuthorizationServerDiscovery(
36
+ resource_url=config.resource_url,
37
+ cache_ttl=config.discovery_cache_ttl,
38
+ )
39
+
40
+ self.flow_handler = OAuth2AuthorizationFlow(
41
+ client_id=config.client_id,
42
+ client_secret=config.client_secret,
43
+ discovery_service=self.discovery,
44
+ )
45
+
46
+ self.token_validator = TokenValidator(
47
+ discovery_service=self.discovery,
48
+ cache_ttl=config.jwks_cache_ttl,
49
+ )
50
+
51
+ self.session_store = SessionStore(
52
+ default_session_timeout=config.session_timeout,
53
+ max_sessions_per_user=config.max_sessions_per_user,
54
+ cleanup_interval=config.session_cleanup_interval,
55
+ enable_persistence=config.enable_session_persistence,
56
+ persistence_file=str(config.get_session_persistence_path()) if config.get_session_persistence_path() else None,
57
+ )
58
+
59
+ self.http_auth = HTTPAuthHandler()
60
+
61
+ # Setup debug logging if enabled
62
+ if config.enable_debug_logging:
63
+ logging.getLogger("auth.oauth21").setLevel(logging.DEBUG)
64
+
65
+ async def start(self):
66
+ """Start the OAuth 2.1 handler and background tasks."""
67
+ await self.session_store.start_cleanup_task()
68
+ logger.info("OAuth 2.1 handler started")
69
+
70
+ async def stop(self):
71
+ """Stop the OAuth 2.1 handler and clean up resources."""
72
+ await self.session_store.stop_cleanup_task()
73
+ await self.flow_handler.close()
74
+ await self.token_validator.close()
75
+ logger.info("OAuth 2.1 handler stopped")
76
+
77
+ async def create_authorization_url(
78
+ self,
79
+ redirect_uri: str,
80
+ scopes: List[str],
81
+ state: Optional[str] = None,
82
+ session_id: Optional[str] = None,
83
+ additional_params: Optional[Dict[str, str]] = None,
84
+ ) -> Tuple[str, str, str]:
85
+ """
86
+ Create OAuth 2.1 authorization URL.
87
+
88
+ Args:
89
+ redirect_uri: OAuth redirect URI
90
+ scopes: Requested scopes
91
+ state: State parameter (generated if not provided)
92
+ session_id: Optional session ID to associate
93
+ additional_params: Additional authorization parameters
94
+
95
+ Returns:
96
+ Tuple of (authorization_url, state, code_verifier)
97
+
98
+ Raises:
99
+ ValueError: If configuration is invalid
100
+ """
101
+ if not self.config.authorization_server_url:
102
+ raise ValueError("Authorization server URL not configured")
103
+
104
+ # Build authorization URL
105
+ auth_url, final_state, code_verifier = await self.flow_handler.build_authorization_url(
106
+ authorization_server_url=self.config.authorization_server_url,
107
+ redirect_uri=redirect_uri,
108
+ scopes=scopes,
109
+ state=state,
110
+ resource=self.config.resource_url,
111
+ additional_params=additional_params,
112
+ )
113
+
114
+ # Store session association if provided
115
+ if session_id:
116
+ self._store_authorization_state(final_state, session_id, code_verifier)
117
+
118
+ logger.info(f"Created authorization URL for scopes: {scopes}")
119
+ return auth_url, final_state, code_verifier
120
+
121
+ async def exchange_code_for_session(
122
+ self,
123
+ authorization_code: str,
124
+ code_verifier: str,
125
+ redirect_uri: str,
126
+ state: Optional[str] = None,
127
+ ) -> Tuple[str, Session]:
128
+ """
129
+ Exchange authorization code for session.
130
+
131
+ Args:
132
+ authorization_code: Authorization code from callback
133
+ code_verifier: PKCE code verifier
134
+ redirect_uri: OAuth redirect URI
135
+ state: State parameter from authorization
136
+
137
+ Returns:
138
+ Tuple of (session_id, session)
139
+
140
+ Raises:
141
+ ValueError: If code exchange fails
142
+ TokenValidationError: If token validation fails
143
+ """
144
+ # Exchange code for tokens
145
+ token_response = await self.flow_handler.exchange_code_for_token(
146
+ authorization_server_url=self.config.authorization_server_url,
147
+ authorization_code=authorization_code,
148
+ code_verifier=code_verifier,
149
+ redirect_uri=redirect_uri,
150
+ resource=self.config.resource_url,
151
+ )
152
+
153
+ # Validate the received token
154
+ access_token = token_response["access_token"]
155
+ token_info = await self.token_validator.validate_token(
156
+ token=access_token,
157
+ expected_audience=self.config.expected_audience,
158
+ required_scopes=self.config.required_scopes,
159
+ authorization_server_url=self.config.authorization_server_url,
160
+ )
161
+
162
+ if not token_info["valid"]:
163
+ raise TokenValidationError("Received token is invalid")
164
+
165
+ # Extract user identity
166
+ user_id = token_info["user_identity"]
167
+
168
+ # Create session
169
+ session_id = self.session_store.create_session(
170
+ user_id=user_id,
171
+ token_info={
172
+ **token_response,
173
+ "validation_info": token_info,
174
+ "claims": token_info.get("claims", {}),
175
+ },
176
+ scopes=token_info.get("scopes", []),
177
+ authorization_server=self.config.authorization_server_url,
178
+ client_id=self.config.client_id,
179
+ metadata={
180
+ "auth_method": "oauth2_authorization_code",
181
+ "token_type": token_info.get("token_type"),
182
+ "created_via": "code_exchange",
183
+ }
184
+ )
185
+
186
+ session = self.session_store.get_session(session_id)
187
+ logger.info(f"Created session {session_id} for user {user_id}")
188
+
189
+ return session_id, session
190
+
191
+ async def authenticate_bearer_token(
192
+ self,
193
+ token: str,
194
+ required_scopes: Optional[List[str]] = None,
195
+ create_session: bool = True,
196
+ ) -> AuthContext:
197
+ """
198
+ Authenticate Bearer token and optionally create session.
199
+
200
+ Args:
201
+ token: Bearer token to authenticate
202
+ required_scopes: Required scopes (uses config default if not provided)
203
+ create_session: Whether to create a session for valid tokens
204
+
205
+ Returns:
206
+ Authentication context
207
+
208
+ Raises:
209
+ TokenValidationError: If token validation fails
210
+ """
211
+ auth_context = AuthContext()
212
+
213
+ try:
214
+ # Validate token
215
+ scopes_to_check = required_scopes or self.config.required_scopes
216
+ token_info = await self.token_validator.validate_token(
217
+ token=token,
218
+ expected_audience=self.config.expected_audience,
219
+ required_scopes=scopes_to_check,
220
+ authorization_server_url=self.config.authorization_server_url,
221
+ )
222
+
223
+ if token_info["valid"]:
224
+ auth_context.authenticated = True
225
+ auth_context.user_id = token_info["user_identity"]
226
+ auth_context.token_info = token_info
227
+ auth_context.scopes = token_info.get("scopes", [])
228
+
229
+ # Create session if requested
230
+ if create_session:
231
+ session_id = self.session_store.create_session(
232
+ user_id=auth_context.user_id,
233
+ token_info=token_info,
234
+ scopes=auth_context.scopes,
235
+ authorization_server=self.config.authorization_server_url,
236
+ client_id=self.config.client_id,
237
+ metadata={
238
+ "auth_method": "bearer_token",
239
+ "created_via": "token_passthrough",
240
+ }
241
+ )
242
+
243
+ auth_context.session_id = session_id
244
+ auth_context.session = self.session_store.get_session(session_id)
245
+
246
+ logger.debug(f"Authenticated Bearer token for user {auth_context.user_id}")
247
+ else:
248
+ auth_context.error = "invalid_token"
249
+ auth_context.error_description = "Token validation failed"
250
+
251
+ except TokenValidationError as e:
252
+ auth_context.error = e.error_code
253
+ auth_context.error_description = str(e)
254
+ logger.warning(f"Bearer token validation failed: {e}")
255
+
256
+ return auth_context
257
+
258
+ async def refresh_session_token(self, session_id: str) -> bool:
259
+ """
260
+ Refresh tokens for a session.
261
+
262
+ Args:
263
+ session_id: Session identifier
264
+
265
+ Returns:
266
+ True if refresh was successful
267
+
268
+ Raises:
269
+ ValueError: If session not found or refresh fails
270
+ """
271
+ session = self.session_store.get_session(session_id)
272
+ if not session:
273
+ raise ValueError(f"Session {session_id} not found")
274
+
275
+ refresh_token = session.token_info.get("refresh_token")
276
+ if not refresh_token:
277
+ raise ValueError("Session has no refresh token")
278
+
279
+ try:
280
+ # Refresh the token
281
+ token_response = await self.flow_handler.refresh_access_token(
282
+ authorization_server_url=self.config.authorization_server_url,
283
+ refresh_token=refresh_token,
284
+ scopes=session.scopes,
285
+ resource=self.config.resource_url,
286
+ )
287
+
288
+ # Update session with new tokens
289
+ updated_token_info = {**session.token_info, **token_response}
290
+ success = self.session_store.update_session(
291
+ session_id=session_id,
292
+ token_info=updated_token_info,
293
+ extend_expiration=True,
294
+ )
295
+
296
+ if success:
297
+ logger.info(f"Refreshed tokens for session {session_id}")
298
+
299
+ return success
300
+
301
+ except Exception as e:
302
+ logger.error(f"Failed to refresh token for session {session_id}: {e}")
303
+ raise ValueError(f"Token refresh failed: {str(e)}")
304
+
305
+ def create_middleware(self) -> AuthenticationMiddleware:
306
+ """
307
+ Create authentication middleware.
308
+
309
+ Returns:
310
+ Configured authentication middleware
311
+ """
312
+ return AuthenticationMiddleware(
313
+ app=None, # Will be set when middleware is added to app
314
+ session_store=self.session_store,
315
+ token_validator=self.token_validator,
316
+ discovery_service=self.discovery,
317
+ http_auth_handler=self.http_auth,
318
+ required_scopes=self.config.required_scopes,
319
+ exempt_paths=self.config.exempt_paths,
320
+ authorization_server_url=self.config.authorization_server_url,
321
+ expected_audience=self.config.expected_audience,
322
+ enable_bearer_passthrough=self.config.enable_bearer_passthrough,
323
+ )
324
+
325
+ def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
326
+ """
327
+ Get session information.
328
+
329
+ Args:
330
+ session_id: Session identifier
331
+
332
+ Returns:
333
+ Session information dictionary or None
334
+ """
335
+ session = self.session_store.get_session(session_id)
336
+ if not session:
337
+ return None
338
+
339
+ return {
340
+ "session_id": session.session_id,
341
+ "user_id": session.user_id,
342
+ "scopes": session.scopes,
343
+ "created_at": session.created_at.isoformat(),
344
+ "last_accessed": session.last_accessed.isoformat(),
345
+ "expires_at": session.expires_at.isoformat() if session.expires_at else None,
346
+ "authorization_server": session.authorization_server,
347
+ "metadata": session.metadata,
348
+ "has_refresh_token": "refresh_token" in session.token_info,
349
+ }
350
+
351
+ def get_user_sessions(self, user_id: str) -> List[Dict[str, Any]]:
352
+ """
353
+ Get all sessions for a user.
354
+
355
+ Args:
356
+ user_id: User identifier
357
+
358
+ Returns:
359
+ List of session information
360
+ """
361
+ sessions = self.session_store.get_user_sessions(user_id)
362
+ return [
363
+ {
364
+ "session_id": session.session_id,
365
+ "created_at": session.created_at.isoformat(),
366
+ "last_accessed": session.last_accessed.isoformat(),
367
+ "expires_at": session.expires_at.isoformat() if session.expires_at else None,
368
+ "scopes": session.scopes,
369
+ "metadata": session.metadata,
370
+ }
371
+ for session in sessions
372
+ ]
373
+
374
+ def revoke_session(self, session_id: str) -> bool:
375
+ """
376
+ Revoke a session.
377
+
378
+ Args:
379
+ session_id: Session identifier
380
+
381
+ Returns:
382
+ True if session was revoked
383
+ """
384
+ success = self.session_store.remove_session(session_id)
385
+ if success:
386
+ logger.info(f"Revoked session {session_id}")
387
+ return success
388
+
389
+ def revoke_user_sessions(self, user_id: str) -> int:
390
+ """
391
+ Revoke all sessions for a user.
392
+
393
+ Args:
394
+ user_id: User identifier
395
+
396
+ Returns:
397
+ Number of sessions revoked
398
+ """
399
+ count = self.session_store.remove_user_sessions(user_id)
400
+ logger.info(f"Revoked {count} sessions for user {user_id}")
401
+ return count
402
+
403
+ def get_handler_stats(self) -> Dict[str, Any]:
404
+ """
405
+ Get OAuth 2.1 handler statistics.
406
+
407
+ Returns:
408
+ Handler statistics
409
+ """
410
+ session_stats = self.session_store.get_session_stats()
411
+
412
+ return {
413
+ "config": {
414
+ "authorization_server": self.config.authorization_server_url,
415
+ "client_id": self.config.client_id,
416
+ "session_timeout": self.config.session_timeout,
417
+ "bearer_passthrough": self.config.enable_bearer_passthrough,
418
+ },
419
+ "sessions": session_stats,
420
+ "components": {
421
+ "discovery_cache_size": len(self.discovery.cache),
422
+ "token_validation_cache_size": len(self.token_validator.validation_cache),
423
+ "jwks_cache_size": len(self.token_validator.jwks_cache),
424
+ }
425
+ }
426
+
427
+ def _store_authorization_state(self, state: str, session_id: str, code_verifier: str):
428
+ """Store authorization state for later retrieval."""
429
+ # This could be enhanced with a proper state store
430
+ # For now, we'll use session metadata
431
+ pass
432
+
433
+ async def __aenter__(self):
434
+ """Async context manager entry."""
435
+ await self.start()
436
+ return self
437
+
438
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
439
+ """Async context manager exit."""
440
+ await self.stop()
auth/oauth21/http.py ADDED
@@ -0,0 +1,270 @@
1
+ """
2
+ HTTP Authentication Handler
3
+
4
+ Handles HTTP authentication headers and responses per RFC6750 (Bearer Token Usage)
5
+ and OAuth 2.1 specifications.
6
+ """
7
+
8
+ import logging
9
+ import re
10
+ from typing import Optional, Dict, Any
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class HTTPAuthHandler:
16
+ """Handles HTTP authentication headers and responses."""
17
+
18
+ def __init__(self, resource_metadata_url: Optional[str] = None):
19
+ """
20
+ Initialize the HTTP authentication handler.
21
+
22
+ Args:
23
+ resource_metadata_url: URL for protected resource metadata discovery
24
+ """
25
+ self.resource_metadata_url = resource_metadata_url or "/.well-known/oauth-authorization-server"
26
+
27
+ def parse_authorization_header(self, header: str) -> Optional[str]:
28
+ """
29
+ Extract Bearer token from Authorization header per RFC6750.
30
+
31
+ Args:
32
+ header: Authorization header value
33
+
34
+ Returns:
35
+ Bearer token string or None if not found/invalid
36
+
37
+ Examples:
38
+ >>> handler = HTTPAuthHandler()
39
+ >>> handler.parse_authorization_header("Bearer abc123")
40
+ 'abc123'
41
+ >>> handler.parse_authorization_header("Basic abc123")
42
+ None
43
+ """
44
+ if not header:
45
+ return None
46
+
47
+ # RFC6750 Section 2.1: Authorization Request Header Field
48
+ # Authorization: Bearer <token>
49
+ bearer_pattern = re.compile(r'^Bearer\s+([^\s]+)$', re.IGNORECASE)
50
+ match = bearer_pattern.match(header.strip())
51
+
52
+ if match:
53
+ token = match.group(1)
54
+ # Basic validation - token should not be empty
55
+ if token:
56
+ logger.debug("Successfully extracted Bearer token from Authorization header")
57
+ return token
58
+ else:
59
+ logger.warning("Empty Bearer token in Authorization header")
60
+ return None
61
+ else:
62
+ logger.debug(f"Authorization header does not contain valid Bearer token: {header[:20]}...")
63
+ return None
64
+
65
+ def build_www_authenticate_header(
66
+ self,
67
+ realm: Optional[str] = None,
68
+ scope: Optional[str] = None,
69
+ error: Optional[str] = None,
70
+ error_description: Optional[str] = None,
71
+ error_uri: Optional[str] = None,
72
+ ) -> str:
73
+ """
74
+ Build WWW-Authenticate header for 401 responses per RFC6750.
75
+
76
+ Args:
77
+ realm: Authentication realm
78
+ scope: Required scope(s)
79
+ error: Error code (invalid_request, invalid_token, insufficient_scope)
80
+ error_description: Human-readable error description
81
+ error_uri: URI with error information
82
+
83
+ Returns:
84
+ WWW-Authenticate header value
85
+
86
+ Examples:
87
+ >>> handler = HTTPAuthHandler()
88
+ >>> handler.build_www_authenticate_header(realm="api")
89
+ 'Bearer realm="api"'
90
+ >>> handler.build_www_authenticate_header(error="invalid_token")
91
+ 'Bearer error="invalid_token"'
92
+ """
93
+ # Start with Bearer scheme
94
+ parts = ["Bearer"]
95
+
96
+ # Add realm if provided
97
+ if realm:
98
+ parts.append(f'realm="{self._quote_attribute_value(realm)}"')
99
+
100
+ # Add scope if provided
101
+ if scope:
102
+ parts.append(f'scope="{self._quote_attribute_value(scope)}"')
103
+
104
+ # Add error information if provided
105
+ if error:
106
+ parts.append(f'error="{self._quote_attribute_value(error)}"')
107
+
108
+ if error_description:
109
+ parts.append(f'error_description="{self._quote_attribute_value(error_description)}"')
110
+
111
+ if error_uri:
112
+ parts.append(f'error_uri="{self._quote_attribute_value(error_uri)}"')
113
+
114
+ return " ".join(parts)
115
+
116
+ def build_resource_metadata_header(self) -> str:
117
+ """
118
+ Build WWW-Authenticate header with resource metadata URL for discovery.
119
+
120
+ Returns:
121
+ WWW-Authenticate header with AS_metadata_url parameter
122
+ """
123
+ return f'Bearer AS_metadata_url="{self.resource_metadata_url}"'
124
+
125
+ def _quote_attribute_value(self, value: str) -> str:
126
+ """
127
+ Quote attribute value for use in HTTP header per RFC7235.
128
+
129
+ Args:
130
+ value: Attribute value to quote
131
+
132
+ Returns:
133
+ Properly quoted value
134
+ """
135
+ # Escape quotes and backslashes
136
+ escaped = value.replace('\\', '\\\\').replace('"', '\\"')
137
+ return escaped
138
+
139
+ def extract_bearer_token_from_request(self, headers: Dict[str, str]) -> Optional[str]:
140
+ """
141
+ Extract Bearer token from HTTP request headers.
142
+
143
+ Args:
144
+ headers: HTTP request headers (case-insensitive dict)
145
+
146
+ Returns:
147
+ Bearer token or None
148
+ """
149
+ # Look for Authorization header (case-insensitive)
150
+ authorization = None
151
+ for key, value in headers.items():
152
+ if key.lower() == "authorization":
153
+ authorization = value
154
+ break
155
+
156
+ if authorization:
157
+ return self.parse_authorization_header(authorization)
158
+
159
+ return None
160
+
161
+ def is_bearer_token_request(self, headers: Dict[str, str]) -> bool:
162
+ """
163
+ Check if request contains Bearer token authentication.
164
+
165
+ Args:
166
+ headers: HTTP request headers
167
+
168
+ Returns:
169
+ True if request has Bearer token
170
+ """
171
+ token = self.extract_bearer_token_from_request(headers)
172
+ return token is not None
173
+
174
+ def build_error_response_headers(
175
+ self,
176
+ error: str,
177
+ error_description: Optional[str] = None,
178
+ realm: Optional[str] = None,
179
+ scope: Optional[str] = None,
180
+ ) -> Dict[str, str]:
181
+ """
182
+ Build complete error response headers for 401/403 responses.
183
+
184
+ Args:
185
+ error: OAuth error code
186
+ error_description: Human-readable error description
187
+ realm: Authentication realm
188
+ scope: Required scope
189
+
190
+ Returns:
191
+ Dictionary of response headers
192
+ """
193
+ headers = {
194
+ "WWW-Authenticate": self.build_www_authenticate_header(
195
+ realm=realm,
196
+ scope=scope,
197
+ error=error,
198
+ error_description=error_description,
199
+ ),
200
+ "Cache-Control": "no-store",
201
+ "Pragma": "no-cache",
202
+ }
203
+
204
+ return headers
205
+
206
+ def validate_token_format(self, token: str) -> bool:
207
+ """
208
+ Validate Bearer token format per RFC6750.
209
+
210
+ Args:
211
+ token: Bearer token to validate
212
+
213
+ Returns:
214
+ True if token format is valid
215
+ """
216
+ if not token:
217
+ return False
218
+
219
+ # RFC6750 - token should be ASCII and not contain certain characters
220
+ try:
221
+ # Check if token contains only valid characters
222
+ # Avoid control characters and certain special characters
223
+ invalid_chars = set('\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f'
224
+ '\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f'
225
+ '\x20\x7f"\\')
226
+
227
+ if any(char in invalid_chars for char in token):
228
+ logger.warning("Bearer token contains invalid characters")
229
+ return False
230
+
231
+ # Token should be ASCII
232
+ token.encode('ascii')
233
+
234
+ return True
235
+
236
+ except UnicodeEncodeError:
237
+ logger.warning("Bearer token contains non-ASCII characters")
238
+ return False
239
+
240
+ def get_token_info_from_headers(self, headers: Dict[str, str]) -> Dict[str, Any]:
241
+ """
242
+ Extract and validate token information from request headers.
243
+
244
+ Args:
245
+ headers: HTTP request headers
246
+
247
+ Returns:
248
+ Dictionary with token information
249
+ """
250
+ result = {
251
+ "has_bearer_token": False,
252
+ "token": None,
253
+ "valid_format": False,
254
+ "error": None,
255
+ }
256
+
257
+ # Extract token
258
+ token = self.extract_bearer_token_from_request(headers)
259
+
260
+ if token:
261
+ result["has_bearer_token"] = True
262
+ result["token"] = token
263
+ result["valid_format"] = self.validate_token_format(token)
264
+
265
+ if not result["valid_format"]:
266
+ result["error"] = "invalid_token"
267
+ else:
268
+ result["error"] = "missing_token"
269
+
270
+ return result