mcp-proxy-oauth-dcr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,908 @@
1
+ """OAuth DCR Authentication Manager implementation.
2
+
3
+ This module implements RFC 7591 Dynamic Client Registration and manages
4
+ the OAuth client credentials lifecycle and token management.
5
+ """
6
+
7
+ import asyncio
8
+ import logging
9
+ from datetime import datetime, timedelta
10
+ from typing import Optional, Dict, Any, List
11
+ from urllib.parse import urljoin
12
+
13
+ import aiohttp
14
+
15
+ from ..exceptions import (
16
+ DcrError,
17
+ TokenError,
18
+ TokenRefreshError,
19
+ InvalidCredentialsError,
20
+ )
21
+ from ..models import (
22
+ AuthenticationState,
23
+ ClientCredentials,
24
+ OAuthTokenResponse,
25
+ ProxyConfig,
26
+ )
27
+ from ..logging_config import get_logger
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ class AuthenticationManagerImpl:
33
+ """Implementation of OAuth DCR and token management.
34
+
35
+ This class handles:
36
+ - OAuth Dynamic Client Registration (RFC 7591)
37
+ - Client credentials storage and lifecycle management
38
+ - Access token acquisition and refresh
39
+ - Token validation and expiration handling
40
+ - Failure tracking and diagnostic logging
41
+ """
42
+
43
+ # Thresholds for diagnostic warnings
44
+ MAX_DCR_RETRIES_WARNING = 3
45
+ MAX_TOKEN_FAILURES_WARNING = 5
46
+ FAILURE_WINDOW_SECONDS = 300 # 5 minutes
47
+
48
+ def __init__(self, config: ProxyConfig):
49
+ """Initialize the authentication manager.
50
+
51
+ Args:
52
+ config: Proxy configuration containing OAuth provider details
53
+ """
54
+ self.config = config
55
+ self._state = AuthenticationState()
56
+ self._session: Optional[aiohttp.ClientSession] = None
57
+ self._lock = asyncio.Lock()
58
+ self._refresh_task: Optional[asyncio.Task] = None
59
+ self._shutdown_event = asyncio.Event()
60
+
61
+ # OAuth endpoints
62
+ self._dcr_endpoint = urljoin(str(config.oauth_provider_url), "/register")
63
+ self._token_endpoint = urljoin(str(config.oauth_provider_url), "/token")
64
+
65
+ # Failure tracking for diagnostics (Requirement 6.4)
66
+ self._dcr_failures: List[Dict[str, Any]] = []
67
+ self._token_failures: List[Dict[str, Any]] = []
68
+ self._last_diagnostic_log: Optional[datetime] = None
69
+
70
+ logger.info(
71
+ "Authentication manager initialized",
72
+ oauth_provider=str(config.oauth_provider_url),
73
+ dcr_endpoint=self._dcr_endpoint,
74
+ token_endpoint=self._token_endpoint,
75
+ client_name=config.client_name,
76
+ scopes=config.scopes,
77
+ )
78
+
79
+ async def initialize(self) -> None:
80
+ """Initialize authentication manager and perform DCR if needed.
81
+
82
+ This method:
83
+ 1. Creates an HTTP session for OAuth requests
84
+ 2. Performs Dynamic Client Registration if no credentials exist
85
+ 3. Obtains an initial access token
86
+ 4. Starts automatic token refresh background task
87
+
88
+ Raises:
89
+ DcrError: If client registration fails
90
+ TokenError: If initial token acquisition fails
91
+ """
92
+ async with self._lock:
93
+ if self._session is None:
94
+ self._session = aiohttp.ClientSession(
95
+ timeout=aiohttp.ClientTimeout(total=self.config.connection_timeout)
96
+ )
97
+
98
+ # Perform DCR if we don't have credentials
99
+ if self._state.client_credentials is None:
100
+ logger.info("No client credentials found, performing DCR")
101
+ await self._perform_dcr_internal()
102
+
103
+ # Get initial access token
104
+ if not self._state.is_token_valid():
105
+ logger.info("Obtaining initial access token")
106
+ await self._obtain_token()
107
+
108
+ # Start automatic token refresh background task
109
+ if self._refresh_task is None or self._refresh_task.done():
110
+ self._shutdown_event.clear()
111
+ self._refresh_task = asyncio.create_task(
112
+ self._auto_refresh_loop(),
113
+ name="token-auto-refresh"
114
+ )
115
+ logger.info("Started automatic token refresh background task")
116
+
117
+ async def get_access_token(self) -> str:
118
+ """Get a valid access token, refreshing if necessary.
119
+
120
+ This method ensures a valid token is always returned by:
121
+ 1. Checking if current token is valid
122
+ 2. Refreshing token if expired or near expiration
123
+ 3. Re-performing DCR if refresh fails
124
+
125
+ Returns:
126
+ Valid access token
127
+
128
+ Raises:
129
+ TokenError: If unable to obtain a valid token
130
+ """
131
+ async with self._lock:
132
+ # Check if current token is valid
133
+ if self._state.is_token_valid():
134
+ return self._state.access_token # type: ignore
135
+
136
+ # Token needs refresh
137
+ logger.info("Access token expired or invalid, refreshing")
138
+ try:
139
+ await self._obtain_token()
140
+ return self._state.access_token # type: ignore
141
+ except TokenError as e:
142
+ # If token refresh fails, try re-performing DCR
143
+ logger.warning(f"Token refresh failed: {e}, attempting DCR")
144
+ await self._perform_dcr_internal()
145
+ await self._obtain_token()
146
+ return self._state.access_token # type: ignore
147
+
148
+ async def refresh_token(self) -> str:
149
+ """Refresh the access token.
150
+
151
+ Returns:
152
+ New access token
153
+
154
+ Raises:
155
+ TokenRefreshError: If token refresh fails
156
+ """
157
+ async with self._lock:
158
+ await self._obtain_token()
159
+ return self._state.access_token # type: ignore
160
+
161
+ async def perform_dcr(self) -> ClientCredentials:
162
+ """Perform OAuth Dynamic Client Registration.
163
+
164
+ Implements RFC 7591 Dynamic Client Registration Protocol.
165
+ Sends a registration request to the OAuth provider and stores
166
+ the returned client credentials.
167
+
168
+ Returns:
169
+ Client credentials from successful registration
170
+
171
+ Raises:
172
+ DcrError: If registration fails
173
+ """
174
+ async with self._lock:
175
+ return await self._perform_dcr_internal()
176
+
177
+ def is_token_valid(self) -> bool:
178
+ """Check if current token is valid.
179
+
180
+ Returns:
181
+ True if token exists and is not expired
182
+ """
183
+ return self._state.is_token_valid()
184
+
185
+ def get_state(self) -> AuthenticationState:
186
+ """Get current authentication state.
187
+
188
+ Returns:
189
+ Current authentication state
190
+ """
191
+ return self._state.model_copy(deep=True)
192
+
193
+ async def close(self) -> None:
194
+ """Close the authentication manager and cleanup resources."""
195
+ # Signal shutdown to background task
196
+ self._shutdown_event.set()
197
+
198
+ # Cancel and wait for refresh task
199
+ if self._refresh_task and not self._refresh_task.done():
200
+ self._refresh_task.cancel()
201
+ try:
202
+ await self._refresh_task
203
+ except asyncio.CancelledError:
204
+ logger.debug("Token refresh task cancelled")
205
+
206
+ # Close HTTP session
207
+ if self._session:
208
+ await self._session.close()
209
+ self._session = None
210
+
211
+ logger.info(
212
+ "Authentication manager closed",
213
+ total_dcr_failures=len(self._dcr_failures),
214
+ total_token_failures=len(self._token_failures),
215
+ )
216
+
217
+ # ========================================================================
218
+ # Diagnostic Methods (Requirement 6.3, 6.4)
219
+ # ========================================================================
220
+
221
+ def _record_dcr_failure(self, error: Exception, details: Optional[Dict[str, Any]] = None) -> None:
222
+ """Record a DCR failure for diagnostic tracking.
223
+
224
+ Args:
225
+ error: The exception that occurred
226
+ details: Additional failure details
227
+ """
228
+ failure_record = {
229
+ "timestamp": datetime.now(),
230
+ "error_type": type(error).__name__,
231
+ "error_message": str(error),
232
+ "details": details or {},
233
+ "retry_count": self._state.dcr_retry_count,
234
+ }
235
+ self._dcr_failures.append(failure_record)
236
+
237
+ # Keep only recent failures (within window)
238
+ cutoff = datetime.now() - timedelta(seconds=self.FAILURE_WINDOW_SECONDS)
239
+ self._dcr_failures = [
240
+ f for f in self._dcr_failures if f["timestamp"] > cutoff
241
+ ]
242
+
243
+ # Log diagnostic information if threshold exceeded
244
+ if len(self._dcr_failures) >= self.MAX_DCR_RETRIES_WARNING:
245
+ self._log_dcr_diagnostics()
246
+
247
+ def _record_token_failure(self, error: Exception, details: Optional[Dict[str, Any]] = None) -> None:
248
+ """Record a token acquisition failure for diagnostic tracking.
249
+
250
+ Args:
251
+ error: The exception that occurred
252
+ details: Additional failure details
253
+ """
254
+ failure_record = {
255
+ "timestamp": datetime.now(),
256
+ "error_type": type(error).__name__,
257
+ "error_message": str(error),
258
+ "details": details or {},
259
+ }
260
+ self._token_failures.append(failure_record)
261
+
262
+ # Keep only recent failures (within window)
263
+ cutoff = datetime.now() - timedelta(seconds=self.FAILURE_WINDOW_SECONDS)
264
+ self._token_failures = [
265
+ f for f in self._token_failures if f["timestamp"] > cutoff
266
+ ]
267
+
268
+ # Log diagnostic information if threshold exceeded
269
+ if len(self._token_failures) >= self.MAX_TOKEN_FAILURES_WARNING:
270
+ self._log_token_diagnostics()
271
+
272
+ def _log_dcr_diagnostics(self) -> None:
273
+ """Log comprehensive diagnostic information for repeated DCR failures.
274
+
275
+ Implements Requirement 6.4: Provide clear diagnostic information for repeated failures.
276
+ """
277
+ # Avoid spamming logs - only log once per minute
278
+ if self._last_diagnostic_log:
279
+ time_since_last = (datetime.now() - self._last_diagnostic_log).total_seconds()
280
+ if time_since_last < 60:
281
+ return
282
+
283
+ self._last_diagnostic_log = datetime.now()
284
+
285
+ # Analyze failure patterns
286
+ error_types = {}
287
+ for failure in self._dcr_failures:
288
+ error_type = failure["error_type"]
289
+ error_types[error_type] = error_types.get(error_type, 0) + 1
290
+
291
+ # Get most recent failure details
292
+ recent_failure = self._dcr_failures[-1] if self._dcr_failures else None
293
+
294
+ logger.error(
295
+ "DIAGNOSTIC: Repeated DCR failures detected",
296
+ failure_count=len(self._dcr_failures),
297
+ failure_window_seconds=self.FAILURE_WINDOW_SECONDS,
298
+ retry_count=self._state.dcr_retry_count,
299
+ error_types=error_types,
300
+ most_recent_error=recent_failure["error_message"] if recent_failure else None,
301
+ most_recent_details=recent_failure["details"] if recent_failure else None,
302
+ dcr_endpoint=self._dcr_endpoint,
303
+ oauth_provider=str(self.config.oauth_provider_url),
304
+ client_name=self.config.client_name,
305
+ scopes=self.config.scopes,
306
+ last_dcr_attempt=self._state.last_dcr_attempt.isoformat() if self._state.last_dcr_attempt else None,
307
+ has_credentials=self._state.client_credentials is not None,
308
+ diagnostic_suggestions=[
309
+ "Check OAuth provider availability and endpoint URLs",
310
+ "Verify client_name and scopes are valid for the OAuth provider",
311
+ "Check network connectivity to OAuth provider",
312
+ "Review OAuth provider logs for registration errors",
313
+ "Verify OAuth provider supports RFC 7591 Dynamic Client Registration",
314
+ ],
315
+ )
316
+
317
+ def _log_token_diagnostics(self) -> None:
318
+ """Log comprehensive diagnostic information for repeated token failures.
319
+
320
+ Implements Requirement 6.4: Provide clear diagnostic information for repeated failures.
321
+ """
322
+ # Avoid spamming logs - only log once per minute
323
+ if self._last_diagnostic_log:
324
+ time_since_last = (datetime.now() - self._last_diagnostic_log).total_seconds()
325
+ if time_since_last < 60:
326
+ return
327
+
328
+ self._last_diagnostic_log = datetime.now()
329
+
330
+ # Analyze failure patterns
331
+ error_types = {}
332
+ for failure in self._token_failures:
333
+ error_type = failure["error_type"]
334
+ error_types[error_type] = error_types.get(error_type, 0) + 1
335
+
336
+ # Get most recent failure details
337
+ recent_failure = self._token_failures[-1] if self._token_failures else None
338
+
339
+ # Check credential status
340
+ credentials = self._state.client_credentials
341
+ credentials_expired = credentials.is_expired() if credentials else None
342
+
343
+ logger.error(
344
+ "DIAGNOSTIC: Repeated token acquisition failures detected",
345
+ failure_count=len(self._token_failures),
346
+ failure_window_seconds=self.FAILURE_WINDOW_SECONDS,
347
+ error_types=error_types,
348
+ most_recent_error=recent_failure["error_message"] if recent_failure else None,
349
+ most_recent_details=recent_failure["details"] if recent_failure else None,
350
+ token_endpoint=self._token_endpoint,
351
+ oauth_provider=str(self.config.oauth_provider_url),
352
+ has_credentials=credentials is not None,
353
+ credentials_expired=credentials_expired,
354
+ client_id=credentials.client_id if credentials else None,
355
+ scopes=self.config.scopes,
356
+ token_expires_at=self._state.token_expires_at.isoformat() if self._state.token_expires_at else None,
357
+ is_authenticated=self._state.is_authenticated,
358
+ diagnostic_suggestions=[
359
+ "Check if client credentials are valid and not expired",
360
+ "Verify token endpoint URL is correct",
361
+ "Check network connectivity to OAuth provider",
362
+ "Review OAuth provider logs for token errors",
363
+ "Verify requested scopes are authorized for the client",
364
+ "Consider re-performing DCR to obtain fresh credentials",
365
+ ],
366
+ )
367
+
368
+ def get_diagnostic_info(self) -> Dict[str, Any]:
369
+ """Get diagnostic information about authentication state.
370
+
371
+ Returns:
372
+ Dictionary with diagnostic information
373
+ """
374
+ credentials = self._state.client_credentials
375
+
376
+ return {
377
+ "is_authenticated": self._state.is_authenticated,
378
+ "is_token_valid": self._state.is_token_valid(),
379
+ "token_valid": self._state.is_token_valid(), # Backward compatibility
380
+ "has_credentials": credentials is not None,
381
+ "client_id": credentials.client_id if credentials else None,
382
+ "credentials_expired": credentials.is_expired() if credentials else None,
383
+ "dcr_retry_count": self._state.dcr_retry_count,
384
+ "last_dcr_attempt": self._state.last_dcr_attempt.isoformat() if self._state.last_dcr_attempt else None,
385
+ "token_expires_at": self._state.token_expires_at.isoformat() if self._state.token_expires_at else None,
386
+ "recent_dcr_failures": len(self._dcr_failures),
387
+ "recent_token_failures": len(self._token_failures),
388
+ "dcr_endpoint": self._dcr_endpoint,
389
+ "token_endpoint": self._token_endpoint,
390
+ "oauth_provider": str(self.config.oauth_provider_url),
391
+ }
392
+
393
+ # ========================================================================
394
+ # Internal Methods
395
+ # ========================================================================
396
+
397
+ async def _auto_refresh_loop(self) -> None:
398
+ """Background task that automatically refreshes tokens before expiration.
399
+
400
+ This task runs continuously and:
401
+ 1. Checks token expiration status periodically
402
+ 2. Refreshes tokens proactively before they expire
403
+ 3. Handles refresh failures gracefully
404
+ 4. Respects shutdown signals
405
+ """
406
+ logger.info("Token auto-refresh loop started")
407
+
408
+ while not self._shutdown_event.is_set():
409
+ try:
410
+ # Calculate sleep duration based on token expiration
411
+ sleep_duration = self._calculate_refresh_interval()
412
+
413
+ # Wait for either shutdown signal or refresh interval
414
+ try:
415
+ await asyncio.wait_for(
416
+ self._shutdown_event.wait(),
417
+ timeout=sleep_duration
418
+ )
419
+ # Shutdown event was set
420
+ break
421
+ except asyncio.TimeoutError:
422
+ # Time to check/refresh token
423
+ pass
424
+
425
+ # Check if token needs refresh
426
+ if self._state.needs_token_refresh():
427
+ logger.info("Token needs refresh, refreshing proactively")
428
+ try:
429
+ async with self._lock:
430
+ await self._obtain_token()
431
+ logger.info("Token refreshed successfully")
432
+ except Exception as e:
433
+ logger.error(f"Failed to refresh token in background: {e}")
434
+ # Continue loop - will retry on next iteration
435
+
436
+ except asyncio.CancelledError:
437
+ logger.info("Token auto-refresh loop cancelled")
438
+ break
439
+ except Exception as e:
440
+ logger.error(f"Unexpected error in token auto-refresh loop: {e}")
441
+ # Sleep briefly before retrying to avoid tight loop on persistent errors
442
+ await asyncio.sleep(5)
443
+
444
+ logger.info("Token auto-refresh loop stopped")
445
+
446
+ def _calculate_refresh_interval(self) -> float:
447
+ """Calculate how long to wait before next token refresh check.
448
+
449
+ Returns:
450
+ Sleep duration in seconds
451
+ """
452
+ if not self._state.token_expires_at:
453
+ # No token yet, check again soon
454
+ return 10.0
455
+
456
+ # Calculate time until token expires
457
+ time_until_expiry = (
458
+ self._state.token_expires_at - datetime.now()
459
+ ).total_seconds()
460
+
461
+ # Refresh when 60 seconds remain (or sooner if token expires soon)
462
+ # Check at least every 30 seconds to catch any issues
463
+ if time_until_expiry <= 60:
464
+ # Token expires soon or already expired, refresh immediately
465
+ return 0.1
466
+ elif time_until_expiry <= 120:
467
+ # Less than 2 minutes, check frequently
468
+ return 10.0
469
+ else:
470
+ # Refresh when 60 seconds remain
471
+ # Sleep until that time, but check at least every 30 seconds
472
+ sleep_time = time_until_expiry - 60
473
+ return min(sleep_time, 30.0)
474
+
475
+ async def _perform_dcr_internal(self) -> ClientCredentials:
476
+ """Internal method to perform DCR without lock (lock must be held).
477
+
478
+ Implements RFC 7591 Dynamic Client Registration:
479
+ 1. Sends POST request to registration endpoint
480
+ 2. Includes client metadata (name, scopes, grant types)
481
+ 3. Parses response and stores credentials
482
+
483
+ Returns:
484
+ Client credentials from registration
485
+
486
+ Raises:
487
+ DcrError: If registration fails
488
+ """
489
+ if self._session is None:
490
+ raise DcrError("Session not initialized")
491
+
492
+ # Build DCR request according to RFC 7591
493
+ # Support both client_credentials and authorization_code flows
494
+ dcr_request = {
495
+ "client_name": self.config.client_name,
496
+ "grant_types": ["authorization_code", "refresh_token"],
497
+ "token_endpoint_auth_method": "client_secret_post",
498
+ "scope": " ".join(self.config.scopes),
499
+ "redirect_uris": ["http://localhost:8080/callback"],
500
+ }
501
+
502
+ logger.info(
503
+ "Performing DCR",
504
+ endpoint=self._dcr_endpoint,
505
+ client_name=self.config.client_name,
506
+ scopes=self.config.scopes,
507
+ retry_count=self._state.dcr_retry_count,
508
+ )
509
+
510
+ try:
511
+ async with self._session.post(
512
+ self._dcr_endpoint,
513
+ json=dcr_request,
514
+ headers={"Content-Type": "application/json"},
515
+ ) as response:
516
+ response_data = await response.json()
517
+
518
+ if response.status != 201:
519
+ error_msg = response_data.get("error_description", "DCR failed")
520
+ error_code = response_data.get("error", "unknown_error")
521
+
522
+ logger.error(
523
+ "DCR failed",
524
+ status=response.status,
525
+ error_code=error_code,
526
+ error_description=error_msg,
527
+ response_data=response_data,
528
+ endpoint=self._dcr_endpoint,
529
+ retry_count=self._state.dcr_retry_count,
530
+ )
531
+
532
+ self._state.dcr_retry_count += 1
533
+ self._state.last_dcr_attempt = datetime.now()
534
+
535
+ error = DcrError(
536
+ f"DCR failed with status {response.status}: {error_msg}",
537
+ details=response_data
538
+ )
539
+ self._record_dcr_failure(error, {
540
+ "status": response.status,
541
+ "error_code": error_code,
542
+ "endpoint": self._dcr_endpoint,
543
+ })
544
+ raise error
545
+
546
+ # Parse DCR response according to RFC 7591
547
+ client_id = response_data.get("client_id")
548
+ client_secret = response_data.get("client_secret")
549
+
550
+ if not client_id or not client_secret:
551
+ error = DcrError(
552
+ "Invalid DCR response: missing client_id or client_secret",
553
+ details=response_data
554
+ )
555
+ self._record_dcr_failure(error, {
556
+ "status": response.status,
557
+ "response_data": response_data,
558
+ })
559
+ raise error
560
+
561
+ # Handle optional client_secret_expires_at
562
+ expires_at = None
563
+ if "client_secret_expires_at" in response_data:
564
+ expires_timestamp = response_data["client_secret_expires_at"]
565
+ if expires_timestamp > 0: # 0 means never expires
566
+ expires_at = datetime.fromtimestamp(expires_timestamp)
567
+
568
+ credentials = ClientCredentials(
569
+ client_id=client_id,
570
+ client_secret=client_secret,
571
+ expires_at=expires_at,
572
+ )
573
+
574
+ # Update state
575
+ self._state.client_credentials = credentials
576
+ self._state.last_dcr_attempt = datetime.now()
577
+ self._state.dcr_retry_count = 0
578
+
579
+ logger.info(
580
+ "DCR successful",
581
+ client_id=client_id,
582
+ expires_at=expires_at.isoformat() if expires_at else "never",
583
+ endpoint=self._dcr_endpoint,
584
+ )
585
+ return credentials
586
+
587
+ except aiohttp.ClientError as e:
588
+ logger.error(
589
+ "Network error during DCR",
590
+ error=str(e),
591
+ error_type=type(e).__name__,
592
+ endpoint=self._dcr_endpoint,
593
+ retry_count=self._state.dcr_retry_count,
594
+ )
595
+ self._state.dcr_retry_count += 1
596
+ self._state.last_dcr_attempt = datetime.now()
597
+
598
+ error = DcrError(f"Network error during DCR: {e}")
599
+ self._record_dcr_failure(error, {
600
+ "error_type": type(e).__name__,
601
+ "endpoint": self._dcr_endpoint,
602
+ })
603
+ raise error
604
+ except Exception as e:
605
+ if isinstance(e, DcrError):
606
+ raise
607
+ logger.error(
608
+ "Unexpected error during DCR",
609
+ error=str(e),
610
+ error_type=type(e).__name__,
611
+ endpoint=self._dcr_endpoint,
612
+ retry_count=self._state.dcr_retry_count,
613
+ exc_info=True,
614
+ )
615
+ self._state.dcr_retry_count += 1
616
+ self._state.last_dcr_attempt = datetime.now()
617
+
618
+ error = DcrError(f"Unexpected error during DCR: {e}")
619
+ self._record_dcr_failure(error, {
620
+ "error_type": type(e).__name__,
621
+ "endpoint": self._dcr_endpoint,
622
+ })
623
+ raise error
624
+
625
+ async def _perform_authorization_code_flow(self) -> None:
626
+ """Perform OAuth authorization code flow with browser-based login.
627
+
628
+ This method:
629
+ 1. Generates a code verifier and challenge for PKCE
630
+ 2. Opens a browser to the authorization URL
631
+ 3. Starts a local HTTP server to receive the callback
632
+ 4. Exchanges the authorization code for tokens
633
+
634
+ Raises:
635
+ TokenError: If authorization fails
636
+ """
637
+ import secrets
638
+ import hashlib
639
+ import base64
640
+ import webbrowser
641
+ from http.server import HTTPServer, BaseHTTPRequestHandler
642
+ from urllib.parse import urlencode, parse_qs, urlparse
643
+
644
+ if self._state.client_credentials is None:
645
+ raise TokenError("No client credentials available for authorization")
646
+
647
+ credentials = self._state.client_credentials
648
+
649
+ # Generate PKCE code verifier and challenge
650
+ code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
651
+ code_challenge = base64.urlsafe_b64encode(
652
+ hashlib.sha256(code_verifier.encode('utf-8')).digest()
653
+ ).decode('utf-8').rstrip('=')
654
+
655
+ # Build authorization URL with resource parameter per RFC 8707
656
+ auth_params = {
657
+ "client_id": credentials.client_id,
658
+ "response_type": "code",
659
+ "redirect_uri": "http://localhost:8080/callback",
660
+ "scope": " ".join(self.config.scopes),
661
+ "code_challenge": code_challenge,
662
+ "code_challenge_method": "S256",
663
+ "resource": str(self.config.mcp_server_url), # RFC 8707 resource indicator
664
+ }
665
+
666
+ auth_url = f"{self.config.oauth_provider_url}authorize?{urlencode(auth_params)}"
667
+
668
+ logger.info(
669
+ "Starting authorization code flow",
670
+ auth_url=auth_url,
671
+ redirect_uri="http://localhost:8080/callback",
672
+ )
673
+
674
+ # Container to store the authorization code
675
+ auth_code_container = {"code": None, "error": None}
676
+
677
+ # Create callback handler
678
+ class CallbackHandler(BaseHTTPRequestHandler):
679
+ def do_GET(self):
680
+ # Parse query parameters
681
+ query = parse_qs(urlparse(self.path).query)
682
+
683
+ if "code" in query:
684
+ auth_code_container["code"] = query["code"][0]
685
+ self.send_response(200)
686
+ self.send_header("Content-type", "text/html")
687
+ self.end_headers()
688
+ self.wfile.write(b"<html><body><h1>Authorization successful!</h1><p>You can close this window and return to the application.</p></body></html>")
689
+ elif "error" in query:
690
+ auth_code_container["error"] = query.get("error_description", ["Unknown error"])[0]
691
+ self.send_response(400)
692
+ self.send_header("Content-type", "text/html")
693
+ self.end_headers()
694
+ self.wfile.write(f"<html><body><h1>Authorization failed</h1><p>{auth_code_container['error']}</p></body></html>".encode())
695
+ else:
696
+ self.send_response(400)
697
+ self.send_header("Content-type", "text/html")
698
+ self.end_headers()
699
+ self.wfile.write(b"<html><body><h1>Invalid callback</h1></body></html>")
700
+
701
+ def log_message(self, format, *args):
702
+ # Suppress HTTP server logs
703
+ pass
704
+
705
+ # Start local HTTP server for callback
706
+ server = HTTPServer(("localhost", 8080), CallbackHandler)
707
+
708
+ # Open browser
709
+ print(f"\n{'='*60}")
710
+ print("Opening browser for authentication...")
711
+ print(f"If the browser doesn't open automatically, visit:")
712
+ print(f" {auth_url}")
713
+ print(f"{'='*60}\n")
714
+
715
+ webbrowser.open(auth_url)
716
+
717
+ # Wait for callback (with timeout)
718
+ import time
719
+ timeout = 300 # 5 minutes
720
+ start_time = time.time()
721
+
722
+ while auth_code_container["code"] is None and auth_code_container["error"] is None:
723
+ server.handle_request()
724
+ if time.time() - start_time > timeout:
725
+ raise TokenError("Authorization timeout - no response received within 5 minutes")
726
+
727
+ server.server_close()
728
+
729
+ # Check for errors
730
+ if auth_code_container["error"]:
731
+ raise TokenError(f"Authorization failed: {auth_code_container['error']}")
732
+
733
+ if not auth_code_container["code"]:
734
+ raise TokenError("No authorization code received")
735
+
736
+ auth_code = auth_code_container["code"]
737
+ logger.info("Authorization code received, exchanging for tokens")
738
+
739
+ # Exchange authorization code for tokens
740
+ # Include resource parameter per RFC 8707 and MCP spec
741
+ # Use client_secret_post authentication (credentials in body, not Basic auth)
742
+ token_request = {
743
+ "grant_type": "authorization_code",
744
+ "code": auth_code,
745
+ "redirect_uri": "http://localhost:8080/callback",
746
+ "code_verifier": code_verifier,
747
+ "client_id": credentials.client_id,
748
+ "client_secret": credentials.client_secret, # client_secret_post method
749
+ "resource": str(self.config.mcp_server_url), # RFC 8707 - bind token to MCP server
750
+ }
751
+
752
+ try:
753
+ async with self._session.post(
754
+ self._token_endpoint,
755
+ data=token_request,
756
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
757
+ ) as response:
758
+ response_data = await response.json()
759
+
760
+ if response.status != 200:
761
+ error_msg = response_data.get("error_description", "Token exchange failed")
762
+ raise TokenError(f"Token exchange failed: {error_msg}")
763
+
764
+ # Parse token response
765
+ access_token = response_data.get("access_token")
766
+ if not access_token:
767
+ raise TokenError("No access token in response")
768
+
769
+ # Update state with new token
770
+ expires_in = response_data.get("expires_in", 3600)
771
+ expires_at = datetime.now() + timedelta(seconds=expires_in)
772
+
773
+ self._state.access_token = access_token
774
+ self._state.token_expires_at = expires_at
775
+ self._state.refresh_token = response_data.get("refresh_token")
776
+ self._state.is_authenticated = True
777
+
778
+ logger.info(
779
+ "Access token obtained successfully",
780
+ expires_in=expires_in,
781
+ has_refresh_token=self._state.refresh_token is not None,
782
+ )
783
+
784
+ except Exception as e:
785
+ error = TokenError(f"Token exchange failed: {e}")
786
+ self._record_token_failure(error)
787
+ raise error
788
+
789
+ async def _obtain_token(self) -> None:
790
+ """Internal method to obtain access token.
791
+
792
+ First tries to use refresh token if available, otherwise falls back to
793
+ authorization code flow with PKCE for user authentication.
794
+
795
+ Raises:
796
+ TokenError: If token acquisition fails
797
+ InvalidCredentialsError: If client credentials are missing or invalid
798
+ """
799
+ if self._session is None:
800
+ raise TokenError("Session not initialized")
801
+
802
+ if self._state.client_credentials is None:
803
+ raise InvalidCredentialsError("No client credentials available")
804
+
805
+ credentials = self._state.client_credentials
806
+
807
+ # Check if credentials are expired
808
+ if credentials.is_expired():
809
+ logger.warning(
810
+ "Client credentials expired, performing DCR",
811
+ expires_at=credentials.expires_at.isoformat() if credentials.expires_at else None,
812
+ )
813
+ await self._perform_dcr_internal()
814
+ credentials = self._state.client_credentials
815
+ if credentials is None:
816
+ raise InvalidCredentialsError("Failed to obtain valid credentials")
817
+
818
+ # Try to use refresh token first if available
819
+ if self._state.refresh_token:
820
+ logger.info("Attempting to use refresh token")
821
+ try:
822
+ await self._use_refresh_token()
823
+ return
824
+ except TokenError as e:
825
+ logger.warning(f"Refresh token failed: {e}, falling back to authorization code flow")
826
+ # Clear the invalid refresh token
827
+ self._state.refresh_token = None
828
+
829
+ # Fall back to authorization code flow with browser-based login
830
+ await self._perform_authorization_code_flow()
831
+
832
+ async def _use_refresh_token(self) -> None:
833
+ """Use refresh token to obtain a new access token.
834
+
835
+ Raises:
836
+ TokenError: If refresh token exchange fails
837
+ """
838
+ if not self._state.refresh_token:
839
+ raise TokenError("No refresh token available")
840
+
841
+ if not self._state.client_credentials:
842
+ raise TokenError("No client credentials available")
843
+
844
+ credentials = self._state.client_credentials
845
+
846
+ logger.info("Exchanging refresh token for new access token")
847
+
848
+ # Use client_secret_post authentication (credentials in body, not Basic auth)
849
+ token_request = {
850
+ "grant_type": "refresh_token",
851
+ "refresh_token": self._state.refresh_token,
852
+ "client_id": credentials.client_id,
853
+ "client_secret": credentials.client_secret, # client_secret_post method
854
+ "resource": str(self.config.mcp_server_url), # RFC 8707 - bind token to MCP server
855
+ }
856
+
857
+ try:
858
+ async with self._session.post(
859
+ self._token_endpoint,
860
+ data=token_request,
861
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
862
+ ) as response:
863
+ response_data = await response.json()
864
+
865
+ if response.status != 200:
866
+ error_msg = response_data.get("error_description", "Refresh token exchange failed")
867
+ error = TokenError(f"Refresh token exchange failed: {error_msg}")
868
+ self._record_token_failure(error)
869
+ raise error
870
+
871
+ # Parse token response
872
+ access_token = response_data.get("access_token")
873
+ if not access_token:
874
+ error = TokenError("No access token in refresh response")
875
+ self._record_token_failure(error)
876
+ raise error
877
+
878
+ # Update state with new token
879
+ expires_in = response_data.get("expires_in", 3600)
880
+ expires_at = datetime.now() + timedelta(seconds=expires_in)
881
+
882
+ self._state.access_token = access_token
883
+ self._state.token_expires_at = expires_at
884
+ self._state.is_authenticated = True
885
+
886
+ # Update refresh token if a new one was provided
887
+ new_refresh_token = response_data.get("refresh_token")
888
+ if new_refresh_token:
889
+ self._state.refresh_token = new_refresh_token
890
+
891
+ logger.info(
892
+ "Access token refreshed successfully",
893
+ expires_in=expires_in,
894
+ new_refresh_token_provided=new_refresh_token is not None,
895
+ )
896
+
897
+ except aiohttp.ClientError as e:
898
+ error = TokenError(f"Refresh token exchange failed: {e}")
899
+ self._record_token_failure(error)
900
+ raise error
901
+
902
+ def __repr__(self) -> str:
903
+ """String representation of the authentication manager."""
904
+ return (
905
+ f"AuthenticationManagerImpl("
906
+ f"authenticated={self._state.is_authenticated}, "
907
+ f"token_valid={self._state.is_token_valid()})"
908
+ )