fastmcp 2.12.1__py3-none-any.whl → 2.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. fastmcp/__init__.py +2 -2
  2. fastmcp/cli/cli.py +56 -36
  3. fastmcp/cli/install/__init__.py +2 -0
  4. fastmcp/cli/install/claude_code.py +7 -16
  5. fastmcp/cli/install/claude_desktop.py +4 -12
  6. fastmcp/cli/install/cursor.py +20 -30
  7. fastmcp/cli/install/gemini_cli.py +241 -0
  8. fastmcp/cli/install/mcp_json.py +4 -12
  9. fastmcp/cli/run.py +15 -94
  10. fastmcp/client/__init__.py +9 -9
  11. fastmcp/client/auth/oauth.py +117 -206
  12. fastmcp/client/client.py +123 -47
  13. fastmcp/client/elicitation.py +6 -1
  14. fastmcp/client/logging.py +18 -14
  15. fastmcp/client/oauth_callback.py +85 -171
  16. fastmcp/client/sampling.py +1 -1
  17. fastmcp/client/transports.py +81 -26
  18. fastmcp/contrib/component_manager/__init__.py +1 -1
  19. fastmcp/contrib/component_manager/component_manager.py +2 -2
  20. fastmcp/contrib/component_manager/component_service.py +7 -7
  21. fastmcp/contrib/mcp_mixin/README.md +35 -4
  22. fastmcp/contrib/mcp_mixin/__init__.py +2 -2
  23. fastmcp/contrib/mcp_mixin/mcp_mixin.py +54 -7
  24. fastmcp/experimental/sampling/handlers/openai.py +2 -2
  25. fastmcp/experimental/server/openapi/__init__.py +5 -8
  26. fastmcp/experimental/server/openapi/components.py +11 -7
  27. fastmcp/experimental/server/openapi/routing.py +2 -2
  28. fastmcp/experimental/utilities/openapi/__init__.py +10 -15
  29. fastmcp/experimental/utilities/openapi/director.py +16 -10
  30. fastmcp/experimental/utilities/openapi/json_schema_converter.py +6 -2
  31. fastmcp/experimental/utilities/openapi/models.py +3 -3
  32. fastmcp/experimental/utilities/openapi/parser.py +37 -16
  33. fastmcp/experimental/utilities/openapi/schemas.py +33 -7
  34. fastmcp/mcp_config.py +3 -4
  35. fastmcp/prompts/__init__.py +1 -1
  36. fastmcp/prompts/prompt.py +32 -27
  37. fastmcp/prompts/prompt_manager.py +16 -101
  38. fastmcp/resources/__init__.py +5 -5
  39. fastmcp/resources/resource.py +28 -20
  40. fastmcp/resources/resource_manager.py +9 -168
  41. fastmcp/resources/template.py +119 -27
  42. fastmcp/resources/types.py +30 -24
  43. fastmcp/server/__init__.py +1 -1
  44. fastmcp/server/auth/__init__.py +9 -5
  45. fastmcp/server/auth/auth.py +80 -47
  46. fastmcp/server/auth/handlers/authorize.py +326 -0
  47. fastmcp/server/auth/jwt_issuer.py +236 -0
  48. fastmcp/server/auth/middleware.py +96 -0
  49. fastmcp/server/auth/oauth_proxy.py +1556 -265
  50. fastmcp/server/auth/oidc_proxy.py +412 -0
  51. fastmcp/server/auth/providers/auth0.py +193 -0
  52. fastmcp/server/auth/providers/aws.py +263 -0
  53. fastmcp/server/auth/providers/azure.py +314 -129
  54. fastmcp/server/auth/providers/bearer.py +1 -1
  55. fastmcp/server/auth/providers/debug.py +114 -0
  56. fastmcp/server/auth/providers/descope.py +229 -0
  57. fastmcp/server/auth/providers/discord.py +308 -0
  58. fastmcp/server/auth/providers/github.py +31 -6
  59. fastmcp/server/auth/providers/google.py +50 -7
  60. fastmcp/server/auth/providers/in_memory.py +27 -3
  61. fastmcp/server/auth/providers/introspection.py +281 -0
  62. fastmcp/server/auth/providers/jwt.py +48 -31
  63. fastmcp/server/auth/providers/oci.py +233 -0
  64. fastmcp/server/auth/providers/scalekit.py +238 -0
  65. fastmcp/server/auth/providers/supabase.py +188 -0
  66. fastmcp/server/auth/providers/workos.py +37 -15
  67. fastmcp/server/context.py +194 -67
  68. fastmcp/server/dependencies.py +56 -16
  69. fastmcp/server/elicitation.py +1 -1
  70. fastmcp/server/http.py +57 -18
  71. fastmcp/server/low_level.py +121 -2
  72. fastmcp/server/middleware/__init__.py +1 -1
  73. fastmcp/server/middleware/caching.py +476 -0
  74. fastmcp/server/middleware/error_handling.py +14 -10
  75. fastmcp/server/middleware/logging.py +158 -116
  76. fastmcp/server/middleware/middleware.py +30 -16
  77. fastmcp/server/middleware/rate_limiting.py +3 -3
  78. fastmcp/server/middleware/tool_injection.py +116 -0
  79. fastmcp/server/openapi.py +15 -7
  80. fastmcp/server/proxy.py +22 -11
  81. fastmcp/server/server.py +744 -254
  82. fastmcp/settings.py +65 -15
  83. fastmcp/tools/__init__.py +1 -1
  84. fastmcp/tools/tool.py +173 -108
  85. fastmcp/tools/tool_manager.py +30 -112
  86. fastmcp/tools/tool_transform.py +13 -11
  87. fastmcp/utilities/cli.py +67 -28
  88. fastmcp/utilities/components.py +7 -2
  89. fastmcp/utilities/inspect.py +79 -23
  90. fastmcp/utilities/json_schema.py +21 -4
  91. fastmcp/utilities/json_schema_type.py +4 -4
  92. fastmcp/utilities/logging.py +182 -10
  93. fastmcp/utilities/mcp_server_config/__init__.py +3 -3
  94. fastmcp/utilities/mcp_server_config/v1/environments/base.py +1 -2
  95. fastmcp/utilities/mcp_server_config/v1/environments/uv.py +10 -45
  96. fastmcp/utilities/mcp_server_config/v1/mcp_server_config.py +8 -7
  97. fastmcp/utilities/mcp_server_config/v1/schema.json +5 -1
  98. fastmcp/utilities/mcp_server_config/v1/sources/base.py +0 -1
  99. fastmcp/utilities/openapi.py +11 -11
  100. fastmcp/utilities/tests.py +93 -10
  101. fastmcp/utilities/types.py +87 -21
  102. fastmcp/utilities/ui.py +626 -0
  103. {fastmcp-2.12.1.dist-info → fastmcp-2.13.2.dist-info}/METADATA +141 -60
  104. fastmcp-2.13.2.dist-info/RECORD +144 -0
  105. {fastmcp-2.12.1.dist-info → fastmcp-2.13.2.dist-info}/WHEEL +1 -1
  106. fastmcp/cli/claude.py +0 -144
  107. fastmcp-2.12.1.dist-info/RECORD +0 -128
  108. {fastmcp-2.12.1.dist-info → fastmcp-2.13.2.dist-info}/entry_points.txt +0 -0
  109. {fastmcp-2.12.1.dist-info → fastmcp-2.13.2.dist-info}/licenses/LICENSE +0 -0
@@ -18,36 +18,70 @@ production use with enterprise identity providers.
18
18
 
19
19
  from __future__ import annotations
20
20
 
21
+ import base64
21
22
  import hashlib
23
+ import hmac
24
+ import json
22
25
  import secrets
23
26
  import time
24
27
  from base64 import urlsafe_b64encode
25
28
  from typing import TYPE_CHECKING, Any, Final
26
- from urllib.parse import urlencode
29
+ from urllib.parse import urlencode, urlparse
27
30
 
28
31
  import httpx
29
32
  from authlib.common.security import generate_token
30
33
  from authlib.integrations.httpx_client import AsyncOAuth2Client
34
+ from cryptography.fernet import Fernet
35
+ from key_value.aio.adapters.pydantic import PydanticAdapter
36
+ from key_value.aio.protocols import AsyncKeyValue
37
+ from key_value.aio.stores.disk import DiskStore
38
+ from key_value.aio.wrappers.encryption import FernetEncryptionWrapper
39
+ from mcp.server.auth.handlers.token import TokenErrorResponse, TokenSuccessResponse
40
+ from mcp.server.auth.handlers.token import TokenHandler as _SDKTokenHandler
41
+ from mcp.server.auth.json_response import PydanticJSONResponse
42
+ from mcp.server.auth.middleware.client_auth import ClientAuthenticator
31
43
  from mcp.server.auth.provider import (
32
44
  AccessToken,
33
45
  AuthorizationCode,
34
46
  AuthorizationParams,
47
+ AuthorizeError,
35
48
  RefreshToken,
36
49
  TokenError,
37
50
  )
51
+ from mcp.server.auth.routes import cors_middleware
38
52
  from mcp.server.auth.settings import (
39
53
  ClientRegistrationOptions,
40
54
  RevocationOptions,
41
55
  )
42
56
  from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
43
- from pydantic import AnyHttpUrl, AnyUrl, SecretStr
57
+ from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, SecretStr
44
58
  from starlette.requests import Request
45
- from starlette.responses import RedirectResponse
59
+ from starlette.responses import HTMLResponse, RedirectResponse
46
60
  from starlette.routing import Route
61
+ from typing_extensions import override
47
62
 
63
+ from fastmcp import settings
48
64
  from fastmcp.server.auth.auth import OAuthProvider, TokenVerifier
49
- from fastmcp.server.auth.redirect_validation import validate_redirect_uri
65
+ from fastmcp.server.auth.handlers.authorize import AuthorizationHandler
66
+ from fastmcp.server.auth.jwt_issuer import (
67
+ JWTIssuer,
68
+ derive_jwt_key,
69
+ )
70
+ from fastmcp.server.auth.redirect_validation import (
71
+ validate_redirect_uri,
72
+ )
50
73
  from fastmcp.utilities.logging import get_logger
74
+ from fastmcp.utilities.ui import (
75
+ BUTTON_STYLES,
76
+ DETAIL_BOX_STYLES,
77
+ DETAILS_STYLES,
78
+ INFO_BOX_STYLES,
79
+ REDIRECT_SECTION_STYLES,
80
+ TOOLTIP_STYLES,
81
+ create_logo,
82
+ create_page,
83
+ create_secure_html_response,
84
+ )
51
85
 
52
86
  if TYPE_CHECKING:
53
87
  pass
@@ -55,6 +89,118 @@ if TYPE_CHECKING:
55
89
  logger = get_logger(__name__)
56
90
 
57
91
 
92
+ # -------------------------------------------------------------------------
93
+ # Constants
94
+ # -------------------------------------------------------------------------
95
+
96
+ # Default token expiration times
97
+ DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS: Final[int] = 60 * 60 # 1 hour
98
+ DEFAULT_AUTH_CODE_EXPIRY_SECONDS: Final[int] = 5 * 60 # 5 minutes
99
+
100
+ # HTTP client timeout
101
+ HTTP_TIMEOUT_SECONDS: Final[int] = 30
102
+
103
+
104
+ # -------------------------------------------------------------------------
105
+ # Pydantic Models
106
+ # -------------------------------------------------------------------------
107
+
108
+
109
+ class OAuthTransaction(BaseModel):
110
+ """OAuth transaction state for consent flow.
111
+
112
+ Stored server-side to track active authorization flows with client context.
113
+ Includes CSRF tokens for consent protection per MCP security best practices.
114
+ """
115
+
116
+ txn_id: str
117
+ client_id: str
118
+ client_redirect_uri: str
119
+ client_state: str
120
+ code_challenge: str | None
121
+ code_challenge_method: str
122
+ scopes: list[str]
123
+ created_at: float
124
+ resource: str | None = None
125
+ proxy_code_verifier: str | None = None
126
+ csrf_token: str | None = None
127
+ csrf_expires_at: float | None = None
128
+
129
+
130
+ class ClientCode(BaseModel):
131
+ """Client authorization code with PKCE and upstream tokens.
132
+
133
+ Stored server-side after upstream IdP callback. Contains the upstream
134
+ tokens bound to the client's PKCE challenge for secure token exchange.
135
+ """
136
+
137
+ code: str
138
+ client_id: str
139
+ redirect_uri: str
140
+ code_challenge: str | None
141
+ code_challenge_method: str
142
+ scopes: list[str]
143
+ idp_tokens: dict[str, Any]
144
+ expires_at: float
145
+ created_at: float
146
+
147
+
148
+ class UpstreamTokenSet(BaseModel):
149
+ """Stored upstream OAuth tokens from identity provider.
150
+
151
+ These tokens are obtained from the upstream provider (Google, GitHub, etc.)
152
+ and stored in plaintext within this model. Encryption is handled transparently
153
+ at the storage layer via FernetEncryptionWrapper. Tokens are never exposed to MCP clients.
154
+ """
155
+
156
+ upstream_token_id: str # Unique ID for this token set
157
+ access_token: str # Upstream access token
158
+ refresh_token: str | None # Upstream refresh token
159
+ refresh_token_expires_at: (
160
+ float | None
161
+ ) # Unix timestamp when refresh token expires (if known)
162
+ expires_at: float # Unix timestamp when access token expires
163
+ token_type: str # Usually "Bearer"
164
+ scope: str # Space-separated scopes
165
+ client_id: str # MCP client this is bound to
166
+ created_at: float # Unix timestamp
167
+ raw_token_data: dict[str, Any] = Field(default_factory=dict) # Full token response
168
+
169
+
170
+ class JTIMapping(BaseModel):
171
+ """Maps FastMCP token JTI to upstream token ID.
172
+
173
+ This allows stateless JWT validation while still being able to look up
174
+ the corresponding upstream token when tools need to access upstream APIs.
175
+ """
176
+
177
+ jti: str # JWT ID from FastMCP-issued token
178
+ upstream_token_id: str # References UpstreamTokenSet
179
+ created_at: float # Unix timestamp
180
+
181
+
182
+ class RefreshTokenMetadata(BaseModel):
183
+ """Metadata for a refresh token, stored keyed by token hash.
184
+
185
+ We store only metadata (not the token itself) for security - if storage
186
+ is compromised, attackers get hashes they can't reverse into usable tokens.
187
+ """
188
+
189
+ client_id: str
190
+ scopes: list[str]
191
+ expires_at: int | None = None
192
+ created_at: float
193
+
194
+
195
+ def _hash_token(token: str) -> str:
196
+ """Hash a token for secure storage lookup.
197
+
198
+ Uses SHA-256 to create a one-way hash. The original token cannot be
199
+ recovered from the hash, providing defense in depth if storage is compromised.
200
+ """
201
+ return hashlib.sha256(token.encode()).hexdigest()
202
+
203
+
58
204
  class ProxyDCRClient(OAuthClientInformationFull):
59
205
  """Client for DCR proxy with configurable redirect URI validation.
60
206
 
@@ -81,18 +227,8 @@ class ProxyDCRClient(OAuthClientInformationFull):
81
227
  arise from accepting arbitrary redirect URIs.
82
228
  """
83
229
 
84
- def __init__(
85
- self, *args, allowed_redirect_uri_patterns: list[str] | None = None, **kwargs
86
- ):
87
- """Initialize with allowed redirect URI patterns.
88
-
89
- Args:
90
- allowed_redirect_uri_patterns: List of allowed redirect URI patterns with wildcard support.
91
- If None, defaults to localhost-only patterns.
92
- If empty list, allows all redirect URIs.
93
- """
94
- super().__init__(*args, **kwargs)
95
- self._allowed_redirect_uri_patterns = allowed_redirect_uri_patterns
230
+ allowed_redirect_uri_patterns: list[str] | None = Field(default=None)
231
+ client_name: str | None = Field(default=None)
96
232
 
97
233
  def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
98
234
  """Validate redirect URI against allowed patterns.
@@ -104,7 +240,10 @@ class ProxyDCRClient(OAuthClientInformationFull):
104
240
  """
105
241
  if redirect_uri is not None:
106
242
  # Validate against allowed patterns
107
- if validate_redirect_uri(redirect_uri, self._allowed_redirect_uri_patterns):
243
+ if validate_redirect_uri(
244
+ redirect_uri=redirect_uri,
245
+ allowed_patterns=self.allowed_redirect_uri_patterns,
246
+ ):
108
247
  return redirect_uri
109
248
  # Fall back to normal validation if not in allowed patterns
110
249
  return super().validate_redirect_uri(redirect_uri)
@@ -112,12 +251,321 @@ class ProxyDCRClient(OAuthClientInformationFull):
112
251
  return super().validate_redirect_uri(redirect_uri)
113
252
 
114
253
 
115
- # Default token expiration times
116
- DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS: Final[int] = 60 * 60 # 1 hour
117
- DEFAULT_AUTH_CODE_EXPIRY_SECONDS: Final[int] = 5 * 60 # 5 minutes
254
+ # -------------------------------------------------------------------------
255
+ # Helper Functions
256
+ # -------------------------------------------------------------------------
257
+
258
+
259
+ def create_consent_html(
260
+ client_id: str,
261
+ redirect_uri: str,
262
+ scopes: list[str],
263
+ txn_id: str,
264
+ csrf_token: str,
265
+ client_name: str | None = None,
266
+ title: str = "Application Access Request",
267
+ server_name: str | None = None,
268
+ server_icon_url: str | None = None,
269
+ server_website_url: str | None = None,
270
+ client_website_url: str | None = None,
271
+ csp_policy: str | None = None,
272
+ ) -> str:
273
+ """Create a styled HTML consent page for OAuth authorization requests.
274
+
275
+ Args:
276
+ csp_policy: Content Security Policy override.
277
+ If None, uses the built-in CSP policy with appropriate directives.
278
+ If empty string "", disables CSP entirely (no meta tag is rendered).
279
+ If a non-empty string, uses that as the CSP policy value.
280
+ """
281
+ import html as html_module
282
+
283
+ client_display = html_module.escape(client_name or client_id)
284
+ server_name_escaped = html_module.escape(server_name or "FastMCP")
285
+
286
+ # Make server name a hyperlink if website URL is available
287
+ if server_website_url:
288
+ website_url_escaped = html_module.escape(server_website_url)
289
+ server_display = f'<a href="{website_url_escaped}" target="_blank" rel="noopener noreferrer" class="server-name-link">{server_name_escaped}</a>'
290
+ else:
291
+ server_display = server_name_escaped
292
+
293
+ # Build intro box with call-to-action
294
+ intro_box = f"""
295
+ <div class="info-box">
296
+ <p>The application <strong>{client_display}</strong> wants to access the MCP server <strong>{server_display}</strong>. Please ensure you recognize the callback address below.</p>
297
+ </div>
298
+ """
118
299
 
119
- # HTTP client timeout
120
- HTTP_TIMEOUT_SECONDS: Final[int] = 30
300
+ # Build redirect URI section (yellow box, centered)
301
+ redirect_uri_escaped = html_module.escape(redirect_uri)
302
+ redirect_section = f"""
303
+ <div class="redirect-section">
304
+ <span class="label">Credentials will be sent to:</span>
305
+ <div class="value">{redirect_uri_escaped}</div>
306
+ </div>
307
+ """
308
+
309
+ # Build advanced details with collapsible section
310
+ detail_rows = [
311
+ ("Application Name", html_module.escape(client_name or client_id)),
312
+ ("Application Website", html_module.escape(client_website_url or "N/A")),
313
+ ("Application ID", client_id),
314
+ ("Redirect URI", redirect_uri_escaped),
315
+ (
316
+ "Requested Scopes",
317
+ ", ".join(html_module.escape(s) for s in scopes) if scopes else "None",
318
+ ),
319
+ ]
320
+
321
+ detail_rows_html = "\n".join(
322
+ [
323
+ f"""
324
+ <div class="detail-row">
325
+ <div class="detail-label">{label}:</div>
326
+ <div class="detail-value">{value}</div>
327
+ </div>
328
+ """
329
+ for label, value in detail_rows
330
+ ]
331
+ )
332
+
333
+ advanced_details = f"""
334
+ <details>
335
+ <summary>Advanced Details</summary>
336
+ <div class="detail-box">
337
+ {detail_rows_html}
338
+ </div>
339
+ </details>
340
+ """
341
+
342
+ # Build form with buttons
343
+ # Use empty action to submit to current URL (/consent or /mcp/consent)
344
+ # The POST handler is registered at the same path as GET
345
+ form = f"""
346
+ <form id="consentForm" method="POST" action="">
347
+ <input type="hidden" name="txn_id" value="{txn_id}" />
348
+ <input type="hidden" name="csrf_token" value="{csrf_token}" />
349
+ <input type="hidden" name="submit" value="true" />
350
+ <div class="button-group">
351
+ <button type="submit" name="action" value="approve" class="btn-approve">Allow Access</button>
352
+ <button type="submit" name="action" value="deny" class="btn-deny">Deny</button>
353
+ </div>
354
+ </form>
355
+ """
356
+
357
+ # Build help link with tooltip (identical to current implementation)
358
+ help_link = """
359
+ <div class="help-link-container">
360
+ <span class="help-link">
361
+ Why am I seeing this?
362
+ <span class="tooltip">
363
+ This FastMCP server requires your consent to allow a new client
364
+ to connect. This protects you from <a
365
+ href="https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#confused-deputy-problem"
366
+ target="_blank" class="tooltip-link">confused deputy
367
+ attacks</a>, where malicious clients could impersonate you
368
+ and steal access.<br><br>
369
+ <a
370
+ href="https://gofastmcp.com/servers/auth/oauth-proxy#confused-deputy-attacks"
371
+ target="_blank" class="tooltip-link">Learn more about
372
+ FastMCP security →</a>
373
+ </span>
374
+ </span>
375
+ </div>
376
+ """
377
+
378
+ # Build the page content
379
+ content = f"""
380
+ <div class="container">
381
+ {create_logo(icon_url=server_icon_url, alt_text=server_name or "FastMCP")}
382
+ <h1>Application Access Request</h1>
383
+ {intro_box}
384
+ {redirect_section}
385
+ {advanced_details}
386
+ {form}
387
+ </div>
388
+ {help_link}
389
+ """
390
+
391
+ # Additional styles needed for this page
392
+ additional_styles = (
393
+ INFO_BOX_STYLES
394
+ + REDIRECT_SECTION_STYLES
395
+ + DETAILS_STYLES
396
+ + DETAIL_BOX_STYLES
397
+ + BUTTON_STYLES
398
+ + TOOLTIP_STYLES
399
+ )
400
+
401
+ # Determine CSP policy to use
402
+ # If csp_policy is None, build the default CSP policy
403
+ # If csp_policy is empty string, CSP will be disabled entirely in create_page
404
+ # If csp_policy is a non-empty string, use it as-is
405
+ if csp_policy is None:
406
+ # Need to allow form-action for form submission
407
+ # Chrome requires explicit scheme declarations in CSP form-action when redirect chains
408
+ # end in custom protocol schemes (e.g., cursor://). Parse redirect_uri to include its scheme.
409
+ parsed_redirect = urlparse(redirect_uri)
410
+ redirect_scheme = parsed_redirect.scheme.lower()
411
+
412
+ # Build form-action directive with standard schemes plus custom protocol if present
413
+ form_action_schemes = ["https:", "http:"]
414
+ if redirect_scheme and redirect_scheme not in ("http", "https"):
415
+ # Custom protocol scheme (e.g., cursor:, vscode:, etc.)
416
+ form_action_schemes.append(f"{redirect_scheme}:")
417
+
418
+ form_action_directive = " ".join(form_action_schemes)
419
+ csp_policy = f"default-src 'none'; style-src 'unsafe-inline'; img-src https: data:; base-uri 'none'; form-action {form_action_directive}"
420
+
421
+ return create_page(
422
+ content=content,
423
+ title=title,
424
+ additional_styles=additional_styles,
425
+ csp_policy=csp_policy,
426
+ )
427
+
428
+
429
+ def create_error_html(
430
+ error_title: str,
431
+ error_message: str,
432
+ error_details: dict[str, str] | None = None,
433
+ server_name: str | None = None,
434
+ server_icon_url: str | None = None,
435
+ ) -> str:
436
+ """Create a styled HTML error page for OAuth errors.
437
+
438
+ Args:
439
+ error_title: The error title (e.g., "OAuth Error", "Authorization Failed")
440
+ error_message: The main error message to display
441
+ error_details: Optional dictionary of error details to show (e.g., {"Error Code": "invalid_client"})
442
+ server_name: Optional server name to display
443
+ server_icon_url: Optional URL to server icon/logo
444
+
445
+ Returns:
446
+ Complete HTML page as a string
447
+ """
448
+ import html as html_module
449
+
450
+ error_message_escaped = html_module.escape(error_message)
451
+
452
+ # Build error message box
453
+ error_box = f"""
454
+ <div class="info-box error">
455
+ <p>{error_message_escaped}</p>
456
+ </div>
457
+ """
458
+
459
+ # Build error details section if provided
460
+ details_section = ""
461
+ if error_details:
462
+ detail_rows_html = "\n".join(
463
+ [
464
+ f"""
465
+ <div class="detail-row">
466
+ <div class="detail-label">{html_module.escape(label)}:</div>
467
+ <div class="detail-value">{html_module.escape(value)}</div>
468
+ </div>
469
+ """
470
+ for label, value in error_details.items()
471
+ ]
472
+ )
473
+
474
+ details_section = f"""
475
+ <details>
476
+ <summary>Error Details</summary>
477
+ <div class="detail-box">
478
+ {detail_rows_html}
479
+ </div>
480
+ </details>
481
+ """
482
+
483
+ # Build the page content
484
+ content = f"""
485
+ <div class="container">
486
+ {create_logo(icon_url=server_icon_url, alt_text=server_name or "FastMCP")}
487
+ <h1>{html_module.escape(error_title)}</h1>
488
+ {error_box}
489
+ {details_section}
490
+ </div>
491
+ """
492
+
493
+ # Additional styles needed for this page
494
+ # Override .info-box.error to use normal text color instead of red
495
+ additional_styles = (
496
+ INFO_BOX_STYLES
497
+ + DETAILS_STYLES
498
+ + DETAIL_BOX_STYLES
499
+ + """
500
+ .info-box.error {
501
+ color: #111827;
502
+ }
503
+ """
504
+ )
505
+
506
+ # Simple CSP policy for error pages (no forms needed)
507
+ csp_policy = "default-src 'none'; style-src 'unsafe-inline'; img-src https: data:; base-uri 'none'"
508
+
509
+ return create_page(
510
+ content=content,
511
+ title=error_title,
512
+ additional_styles=additional_styles,
513
+ csp_policy=csp_policy,
514
+ )
515
+
516
+
517
+ # -------------------------------------------------------------------------
518
+ # Handler Classes
519
+ # -------------------------------------------------------------------------
520
+
521
+
522
+ class TokenHandler(_SDKTokenHandler):
523
+ """TokenHandler that returns OAuth 2.1 compliant error responses.
524
+
525
+ The MCP SDK always returns HTTP 400 for all client authentication issues.
526
+ However, OAuth 2.1 Section 5.3 and the MCP specification require that
527
+ invalid or expired tokens MUST receive a HTTP 401 response.
528
+
529
+ This handler extends the base MCP SDK TokenHandler to transform client
530
+ authentication failures into OAuth 2.1 compliant responses:
531
+ - Changes 'unauthorized_client' to 'invalid_client' error code
532
+ - Returns HTTP 401 status code instead of 400 for client auth failures
533
+
534
+ Per OAuth 2.1 Section 5.3: "The authorization server MAY return an HTTP 401
535
+ (Unauthorized) status code to indicate which HTTP authentication schemes
536
+ are supported."
537
+
538
+ Per MCP spec: "Invalid or expired tokens MUST receive a HTTP 401 response."
539
+ """
540
+
541
+ def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
542
+ """Override response method to provide OAuth 2.1 compliant error handling."""
543
+ # Check if this is a client authentication failure (not just unauthorized for grant type)
544
+ # unauthorized_client can mean two things:
545
+ # 1. Client authentication failed (client_id not found or wrong credentials) -> invalid_client 401
546
+ # 2. Client not authorized for this grant type -> unauthorized_client 400 (correct per spec)
547
+ if (
548
+ isinstance(obj, TokenErrorResponse)
549
+ and obj.error == "unauthorized_client"
550
+ and obj.error_description
551
+ and "Invalid client_id" in obj.error_description
552
+ ):
553
+ # Transform client auth failure to OAuth 2.1 compliant response
554
+ return PydanticJSONResponse(
555
+ content=TokenErrorResponse(
556
+ error="invalid_client",
557
+ error_description=obj.error_description,
558
+ error_uri=obj.error_uri,
559
+ ),
560
+ status_code=401,
561
+ headers={
562
+ "Cache-Control": "no-store",
563
+ "Pragma": "no-cache",
564
+ },
565
+ )
566
+
567
+ # Otherwise use default behavior from parent class
568
+ return super().response(obj)
121
569
 
122
570
 
123
571
  class OAuthProxy(OAuthProvider):
@@ -198,15 +646,18 @@ class OAuthProxy(OAuthProvider):
198
646
 
199
647
  State Management
200
648
  ---------------
201
- The proxy maintains minimal but crucial state:
202
- - _clients: DCR registrations (all use ProxyDCRClient for flexibility)
649
+ The proxy maintains minimal but crucial state via pluggable storage (client_storage):
203
650
  - _oauth_transactions: Active authorization flows with client context
204
651
  - _client_codes: Authorization codes with PKCE challenges and upstream tokens
205
- - _access_tokens, _refresh_tokens: Token storage for revocation
206
- - Token relationship mappings for cleanup and rotation
652
+ - _jti_mapping_store: Maps FastMCP token JTIs to upstream token IDs
653
+ - _refresh_token_store: Refresh token metadata (keyed by token hash)
654
+
655
+ All state is stored in the configured client_storage backend (Redis, disk, etc.)
656
+ enabling horizontal scaling across multiple instances.
207
657
 
208
658
  Security Considerations
209
659
  ----------------------
660
+ - Refresh tokens stored by hash only (defense in depth if storage compromised)
210
661
  - PKCE enforced end-to-end (client to proxy, proxy to upstream)
211
662
  - Authorization codes are single-use with short expiry
212
663
  - Transaction IDs are cryptographically random
@@ -240,7 +691,7 @@ class OAuthProxy(OAuthProvider):
240
691
  token_verifier: TokenVerifier,
241
692
  # FastMCP server configuration
242
693
  base_url: AnyHttpUrl | str,
243
- redirect_path: str = "/auth/callback",
694
+ redirect_path: str | None = None,
244
695
  issuer_url: AnyHttpUrl | str | None = None,
245
696
  service_documentation_url: AnyHttpUrl | str | None = None,
246
697
  # Client redirect URI validation
@@ -250,6 +701,17 @@ class OAuthProxy(OAuthProvider):
250
701
  forward_pkce: bool = True,
251
702
  # Token endpoint authentication
252
703
  token_endpoint_auth_method: str | None = None,
704
+ # Extra parameters to forward to authorization endpoint
705
+ extra_authorize_params: dict[str, str] | None = None,
706
+ # Extra parameters to forward to token endpoint
707
+ extra_token_params: dict[str, str] | None = None,
708
+ # Client storage
709
+ client_storage: AsyncKeyValue | None = None,
710
+ # JWT signing key
711
+ jwt_signing_key: str | bytes | None = None,
712
+ # Consent screen configuration
713
+ require_authorization_consent: bool = True,
714
+ consent_csp_policy: str | None = None,
253
715
  ):
254
716
  """Initialize the OAuth proxy provider.
255
717
 
@@ -273,12 +735,34 @@ class OAuthProxy(OAuthProvider):
273
735
  valid_scopes: List of all the possible valid scopes for a client.
274
736
  These are advertised to clients through the `/.well-known` endpoints. Defaults to `required_scopes` if not provided.
275
737
  forward_pkce: Whether to forward PKCE to upstream server (default True).
276
- Enable for providers that support/require PKCE (Google, Azure, etc.).
738
+ Enable for providers that support/require PKCE (Google, Azure, AWS, etc.).
277
739
  Disable only if upstream provider doesn't support PKCE.
278
740
  token_endpoint_auth_method: Token endpoint authentication method for upstream server.
279
741
  Common values: "client_secret_basic", "client_secret_post", "none".
280
742
  If None, authlib will use its default (typically "client_secret_basic").
743
+ extra_authorize_params: Additional parameters to forward to the upstream authorization endpoint.
744
+ Useful for provider-specific parameters like Auth0's "audience".
745
+ Example: {"audience": "https://api.example.com"}
746
+ extra_token_params: Additional parameters to forward to the upstream token endpoint.
747
+ Useful for provider-specific parameters during token exchange.
748
+ client_storage: Storage backend for OAuth state (client registrations, tokens).
749
+ If None, an encrypted DiskStore will be created in the data directory.
750
+ jwt_signing_key: Secret for signing FastMCP JWT tokens (any string or bytes).
751
+ If bytes are provided, they will be used as-is.
752
+ If a string is provided, it will be derived into a 32-byte key using PBKDF2 (1.2M iterations).
753
+ If not provided, it will be derived from the upstream client secret using HKDF.
754
+ require_authorization_consent: Whether to require user consent before authorizing clients (default True).
755
+ When True, users see a consent screen before being redirected to the upstream IdP.
756
+ When False, authorization proceeds directly without user confirmation.
757
+ SECURITY WARNING: Only disable for local development or testing environments.
758
+ consent_csp_policy: Content Security Policy for the consent page.
759
+ If None (default), uses the built-in CSP policy with appropriate directives.
760
+ If empty string "", disables CSP entirely (no meta tag is rendered).
761
+ If a non-empty string, uses that as the CSP policy value.
762
+ This allows organizations with their own CSP policies to override or disable
763
+ the built-in CSP directives.
281
764
  """
765
+
282
766
  # Always enable DCR since we implement it locally for MCP clients
283
767
  client_registration_options = ClientRegistrationOptions(
284
768
  enabled=True,
@@ -300,42 +784,157 @@ class OAuthProxy(OAuthProvider):
300
784
  )
301
785
 
302
786
  # Store upstream configuration
303
- self._upstream_authorization_endpoint = upstream_authorization_endpoint
304
- self._upstream_token_endpoint = upstream_token_endpoint
305
- self._upstream_client_id = upstream_client_id
306
- self._upstream_client_secret = SecretStr(upstream_client_secret)
307
- self._upstream_revocation_endpoint = upstream_revocation_endpoint
308
- self._default_scope_str = " ".join(self.required_scopes or [])
787
+ self._upstream_authorization_endpoint: str = upstream_authorization_endpoint
788
+ self._upstream_token_endpoint: str = upstream_token_endpoint
789
+ self._upstream_client_id: str = upstream_client_id
790
+ self._upstream_client_secret: SecretStr = SecretStr(
791
+ secret_value=upstream_client_secret
792
+ )
793
+ self._upstream_revocation_endpoint: str | None = upstream_revocation_endpoint
794
+ self._default_scope_str: str = " ".join(self.required_scopes or [])
309
795
 
310
796
  # Store redirect configuration
311
- self._redirect_path = (
312
- redirect_path if redirect_path.startswith("/") else f"/{redirect_path}"
797
+ if not redirect_path:
798
+ self._redirect_path = "/auth/callback"
799
+ else:
800
+ self._redirect_path = (
801
+ redirect_path if redirect_path.startswith("/") else f"/{redirect_path}"
802
+ )
803
+
804
+ if (
805
+ isinstance(allowed_client_redirect_uris, list)
806
+ and not allowed_client_redirect_uris
807
+ ):
808
+ logger.warning(
809
+ "allowed_client_redirect_uris is empty list; no redirect URIs will be accepted. "
810
+ + "This will block all OAuth clients."
811
+ )
812
+ self._allowed_client_redirect_uris: list[str] | None = (
813
+ allowed_client_redirect_uris
313
814
  )
314
- self._allowed_client_redirect_uris = allowed_client_redirect_uris
315
815
 
316
816
  # PKCE configuration
317
- self._forward_pkce = forward_pkce
817
+ self._forward_pkce: bool = forward_pkce
318
818
 
319
819
  # Token endpoint authentication
320
- self._token_endpoint_auth_method = token_endpoint_auth_method
820
+ self._token_endpoint_auth_method: str | None = token_endpoint_auth_method
821
+
822
+ # Consent screen configuration
823
+ self._require_authorization_consent: bool = require_authorization_consent
824
+ self._consent_csp_policy: str | None = consent_csp_policy
825
+ if not require_authorization_consent:
826
+ logger.warning(
827
+ "Authorization consent screen disabled - only use for local development or testing. "
828
+ + "In production, this screen protects against confused deputy attacks."
829
+ )
321
830
 
322
- # Local state for DCR and token bookkeeping
323
- self._clients: dict[str, OAuthClientInformationFull] = {}
324
- self._access_tokens: dict[str, AccessToken] = {}
325
- self._refresh_tokens: dict[str, RefreshToken] = {}
831
+ # Extra parameters for authorization and token endpoints
832
+ self._extra_authorize_params: dict[str, str] = extra_authorize_params or {}
833
+ self._extra_token_params: dict[str, str] = extra_token_params or {}
326
834
 
327
- # Token relation mappings for cleanup
328
- self._access_to_refresh: dict[str, str] = {}
329
- self._refresh_to_access: dict[str, str] = {}
835
+ if jwt_signing_key is None:
836
+ jwt_signing_key = derive_jwt_key(
837
+ high_entropy_material=upstream_client_secret,
838
+ salt="fastmcp-jwt-signing-key",
839
+ )
840
+
841
+ if isinstance(jwt_signing_key, str):
842
+ if len(jwt_signing_key) < 12:
843
+ logger.warning(
844
+ "jwt_signing_key is less than 12 characters; it is recommended to use a longer. "
845
+ + "string for the key derivation."
846
+ )
847
+ jwt_signing_key = derive_jwt_key(
848
+ low_entropy_material=jwt_signing_key,
849
+ salt="fastmcp-jwt-signing-key",
850
+ )
851
+
852
+ self._jwt_issuer: JWTIssuer = JWTIssuer(
853
+ issuer=str(self.base_url),
854
+ audience=f"{str(self.base_url).rstrip('/')}/mcp",
855
+ signing_key=jwt_signing_key,
856
+ )
857
+
858
+ # If the user does not provide a store, we will provide an encrypted disk store
859
+ if client_storage is None:
860
+ storage_encryption_key = derive_jwt_key(
861
+ high_entropy_material=jwt_signing_key.decode(),
862
+ salt="fastmcp-storage-encryption-key",
863
+ )
864
+ client_storage = FernetEncryptionWrapper(
865
+ key_value=DiskStore(directory=settings.home / "oauth-proxy"),
866
+ fernet=Fernet(key=storage_encryption_key),
867
+ )
868
+
869
+ self._client_storage: AsyncKeyValue = client_storage
870
+
871
+ # Cache HTTPS check to avoid repeated logging
872
+ self._is_https: bool = str(self.base_url).startswith("https://")
873
+ if not self._is_https:
874
+ logger.warning(
875
+ "Using non-secure cookies for development; deploy with HTTPS for production."
876
+ )
877
+
878
+ self._upstream_token_store: PydanticAdapter[UpstreamTokenSet] = PydanticAdapter[
879
+ UpstreamTokenSet
880
+ ](
881
+ key_value=self._client_storage,
882
+ pydantic_model=UpstreamTokenSet,
883
+ default_collection="mcp-upstream-tokens",
884
+ raise_on_validation_error=True,
885
+ )
886
+
887
+ self._client_store: PydanticAdapter[ProxyDCRClient] = PydanticAdapter[
888
+ ProxyDCRClient
889
+ ](
890
+ key_value=self._client_storage,
891
+ pydantic_model=ProxyDCRClient,
892
+ default_collection="mcp-oauth-proxy-clients",
893
+ raise_on_validation_error=True,
894
+ )
330
895
 
331
896
  # OAuth transaction storage for IdP callback forwarding
332
- self._oauth_transactions: dict[
333
- str, dict[str, Any]
334
- ] = {} # txn_id -> transaction_data
335
- self._client_codes: dict[str, dict[str, Any]] = {} # client_code -> code_data
897
+ # Reuse client_storage with different collections for state management
898
+ self._transaction_store: PydanticAdapter[OAuthTransaction] = PydanticAdapter[
899
+ OAuthTransaction
900
+ ](
901
+ key_value=self._client_storage,
902
+ pydantic_model=OAuthTransaction,
903
+ default_collection="mcp-oauth-transactions",
904
+ raise_on_validation_error=True,
905
+ )
906
+
907
+ self._code_store: PydanticAdapter[ClientCode] = PydanticAdapter[ClientCode](
908
+ key_value=self._client_storage,
909
+ pydantic_model=ClientCode,
910
+ default_collection="mcp-authorization-codes",
911
+ raise_on_validation_error=True,
912
+ )
913
+
914
+ # Storage for JTI mappings (FastMCP token -> upstream token)
915
+ self._jti_mapping_store: PydanticAdapter[JTIMapping] = PydanticAdapter[
916
+ JTIMapping
917
+ ](
918
+ key_value=self._client_storage,
919
+ pydantic_model=JTIMapping,
920
+ default_collection="mcp-jti-mappings",
921
+ raise_on_validation_error=True,
922
+ )
923
+
924
+ # Refresh token metadata storage, keyed by token hash for security.
925
+ # We only store metadata (not the token itself) - if storage is compromised,
926
+ # attackers get hashes they can't reverse into usable tokens.
927
+ self._refresh_token_store: PydanticAdapter[RefreshTokenMetadata] = (
928
+ PydanticAdapter[RefreshTokenMetadata](
929
+ key_value=self._client_storage,
930
+ pydantic_model=RefreshTokenMetadata,
931
+ default_collection="mcp-refresh-tokens",
932
+ raise_on_validation_error=True,
933
+ )
934
+ )
336
935
 
337
936
  # Use the provided token validator
338
- self._token_validator = token_verifier
937
+ self._token_validator: TokenVerifier = token_verifier
339
938
 
340
939
  logger.debug(
341
940
  "Initialized OAuth proxy provider with upstream server %s",
@@ -365,16 +964,23 @@ class OAuthProxy(OAuthProvider):
365
964
  # Client Registration (Local Implementation)
366
965
  # -------------------------------------------------------------------------
367
966
 
967
+ @override
368
968
  async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
369
969
  """Get client information by ID. This is generally the random ID
370
970
  provided to the DCR client during registration, not the upstream client ID.
371
971
 
372
972
  For unregistered clients, returns None (which will raise an error in the SDK).
373
973
  """
374
- client = self._clients.get(client_id)
974
+ # Load from storage
975
+ if not (client := await self._client_store.get(key=client_id)):
976
+ return None
977
+
978
+ if client.allowed_redirect_uri_patterns is None:
979
+ client.allowed_redirect_uri_patterns = self._allowed_client_redirect_uris
375
980
 
376
981
  return client
377
982
 
983
+ @override
378
984
  async def register_client(self, client_info: OAuthClientInformationFull) -> None:
379
985
  """Register a client locally
380
986
 
@@ -385,19 +991,24 @@ class OAuthProxy(OAuthProvider):
385
991
  """
386
992
 
387
993
  # Create a ProxyDCRClient with configured redirect URI validation
388
- proxy_client = ProxyDCRClient(
994
+ if client_info.client_id is None:
995
+ raise ValueError("client_id is required for client registration")
996
+ proxy_client: ProxyDCRClient = ProxyDCRClient(
389
997
  client_id=client_info.client_id,
390
998
  client_secret=client_info.client_secret,
391
999
  redirect_uris=client_info.redirect_uris or [AnyUrl("http://localhost")],
392
1000
  grant_types=client_info.grant_types
393
1001
  or ["authorization_code", "refresh_token"],
394
- scope=self._default_scope_str,
1002
+ scope=client_info.scope or self._default_scope_str,
395
1003
  token_endpoint_auth_method="none",
396
1004
  allowed_redirect_uri_patterns=self._allowed_client_redirect_uris,
1005
+ client_name=getattr(client_info, "client_name", None),
397
1006
  )
398
1007
 
399
- # Store the ProxyDCRClient
400
- self._clients[client_info.client_id] = proxy_client
1008
+ await self._client_store.put(
1009
+ key=client_info.client_id,
1010
+ value=proxy_client,
1011
+ )
401
1012
 
402
1013
  # Log redirect URIs to help users discover what patterns they might need
403
1014
  if client_info.redirect_uris:
@@ -411,25 +1022,28 @@ class OAuthProxy(OAuthProvider):
411
1022
  logger.debug(
412
1023
  "Registered client %s with %d redirect URIs",
413
1024
  client_info.client_id,
414
- len(proxy_client.redirect_uris),
1025
+ len(proxy_client.redirect_uris) if proxy_client.redirect_uris else 0,
415
1026
  )
416
1027
 
417
1028
  # -------------------------------------------------------------------------
418
1029
  # Authorization Flow (Proxy to Upstream)
419
1030
  # -------------------------------------------------------------------------
420
1031
 
1032
+ @override
421
1033
  async def authorize(
422
1034
  self,
423
1035
  client: OAuthClientInformationFull,
424
1036
  params: AuthorizationParams,
425
1037
  ) -> str:
426
- """Start OAuth transaction and redirect to upstream IdP.
1038
+ """Start OAuth transaction and route through consent interstitial.
427
1039
 
428
- This implements the DCR-compliant proxy pattern:
429
- 1. Store transaction with client details and PKCE challenge
430
- 2. Generate proxy's own PKCE parameters if forwarding is enabled
431
- 3. Use transaction ID as state for IdP
432
- 4. Redirect to IdP with our fixed callback URL and proxy's PKCE
1040
+ Flow:
1041
+ 1. Store transaction with client details and PKCE (if forwarding)
1042
+ 2. Return local /consent URL; browser visits consent first
1043
+ 3. Consent handler redirects to upstream IdP if approved/already approved
1044
+
1045
+ If consent is disabled (require_authorization_consent=False), skip the consent screen
1046
+ and redirect directly to the upstream IdP.
433
1047
  """
434
1048
  # Generate transaction ID for this authorization request
435
1049
  txn_id = secrets.token_urlsafe(32)
@@ -445,62 +1059,56 @@ class OAuthProxy(OAuthProvider):
445
1059
  )
446
1060
 
447
1061
  # Store transaction data for IdP callback processing
448
- transaction_data = {
449
- "client_id": client.client_id,
450
- "client_redirect_uri": str(params.redirect_uri),
451
- "client_state": params.state,
452
- "code_challenge": params.code_challenge,
453
- "code_challenge_method": getattr(params, "code_challenge_method", "S256"),
454
- "scopes": params.scopes or [],
455
- "created_at": time.time(),
456
- }
457
-
458
- # Store proxy's PKCE verifier if we're forwarding
459
- if proxy_code_verifier:
460
- transaction_data["proxy_code_verifier"] = proxy_code_verifier
461
-
462
- self._oauth_transactions[txn_id] = transaction_data
463
-
464
- # Build query parameters for upstream IdP authorization request
465
- # Use our fixed IdP callback and transaction ID as state
466
- query_params: dict[str, Any] = {
467
- "response_type": "code",
468
- "client_id": self._upstream_client_id,
469
- "redirect_uri": f"{str(self.base_url).rstrip('/')}{self._redirect_path}",
470
- "state": txn_id, # Use txn_id as IdP state
471
- }
472
-
473
- # Add scopes - use client scopes or fallback to required scopes
474
- scopes_to_use = params.scopes or self.required_scopes or []
475
-
476
- if scopes_to_use:
477
- query_params["scope"] = " ".join(scopes_to_use)
1062
+ if client.client_id is None:
1063
+ raise AuthorizeError(
1064
+ error="invalid_client", error_description="Client ID is required"
1065
+ )
1066
+ transaction = OAuthTransaction(
1067
+ txn_id=txn_id,
1068
+ client_id=client.client_id,
1069
+ client_redirect_uri=str(params.redirect_uri),
1070
+ client_state=params.state or "",
1071
+ code_challenge=params.code_challenge,
1072
+ code_challenge_method=getattr(params, "code_challenge_method", "S256"),
1073
+ scopes=params.scopes or [],
1074
+ created_at=time.time(),
1075
+ resource=getattr(params, "resource", None),
1076
+ proxy_code_verifier=proxy_code_verifier,
1077
+ )
1078
+ await self._transaction_store.put(
1079
+ key=txn_id,
1080
+ value=transaction,
1081
+ ttl=15 * 60, # Auto-expire after 15 minutes
1082
+ )
478
1083
 
479
- # Forward proxy's PKCE challenge to upstream if enabled
480
- if proxy_code_challenge:
481
- query_params["code_challenge"] = proxy_code_challenge
482
- query_params["code_challenge_method"] = "S256"
1084
+ # If consent is disabled, skip consent screen and go directly to upstream IdP
1085
+ if not self._require_authorization_consent:
1086
+ upstream_url = self._build_upstream_authorize_url(
1087
+ txn_id, transaction.model_dump()
1088
+ )
483
1089
  logger.debug(
484
- "Forwarding proxy PKCE challenge to upstream for transaction %s",
1090
+ "Starting OAuth transaction %s for client %s, redirecting directly to upstream IdP (consent disabled, PKCE forwarding: %s)",
485
1091
  txn_id,
1092
+ client.client_id,
1093
+ "enabled" if proxy_code_challenge else "disabled",
486
1094
  )
1095
+ return upstream_url
487
1096
 
488
- # Build the upstream authorization URL
489
- separator = "&" if "?" in self._upstream_authorization_endpoint else "?"
490
- upstream_url = f"{self._upstream_authorization_endpoint}{separator}{urlencode(query_params)}"
1097
+ consent_url = f"{str(self.base_url).rstrip('/')}/consent?txn_id={txn_id}"
491
1098
 
492
1099
  logger.debug(
493
- "Starting OAuth transaction %s for client %s, redirecting to IdP (PKCE forwarding: %s)",
1100
+ "Starting OAuth transaction %s for client %s, redirecting to consent page (PKCE forwarding: %s)",
494
1101
  txn_id,
495
1102
  client.client_id,
496
1103
  "enabled" if proxy_code_challenge else "disabled",
497
1104
  )
498
- return upstream_url
1105
+ return consent_url
499
1106
 
500
1107
  # -------------------------------------------------------------------------
501
1108
  # Authorization Code Handling
502
1109
  # -------------------------------------------------------------------------
503
1110
 
1111
+ @override
504
1112
  async def load_authorization_code(
505
1113
  self,
506
1114
  client: OAuthClientInformationFull,
@@ -512,111 +1120,248 @@ class OAuthProxy(OAuthProvider):
512
1120
  with PKCE challenge for validation.
513
1121
  """
514
1122
  # Look up client code data
515
- code_data = self._client_codes.get(authorization_code)
516
- if not code_data:
1123
+ code_model = await self._code_store.get(key=authorization_code)
1124
+ if not code_model:
517
1125
  logger.debug("Authorization code not found: %s", authorization_code)
518
1126
  return None
519
1127
 
520
1128
  # Check if code expired
521
- if time.time() > code_data["expires_at"]:
1129
+ if time.time() > code_model.expires_at:
522
1130
  logger.debug("Authorization code expired: %s", authorization_code)
523
- self._client_codes.pop(authorization_code, None)
1131
+ _ = await self._code_store.delete(key=authorization_code)
524
1132
  return None
525
1133
 
526
1134
  # Verify client ID matches
527
- if code_data["client_id"] != client.client_id:
1135
+ if code_model.client_id != client.client_id:
528
1136
  logger.debug(
529
1137
  "Authorization code client ID mismatch: %s vs %s",
530
- code_data["client_id"],
1138
+ code_model.client_id,
531
1139
  client.client_id,
532
1140
  )
533
1141
  return None
534
1142
 
535
1143
  # Create authorization code object with PKCE challenge
1144
+ if client.client_id is None:
1145
+ raise AuthorizeError(
1146
+ error="invalid_client", error_description="Client ID is required"
1147
+ )
536
1148
  return AuthorizationCode(
537
1149
  code=authorization_code,
538
1150
  client_id=client.client_id,
539
- redirect_uri=code_data["redirect_uri"],
1151
+ redirect_uri=AnyUrl(url=code_model.redirect_uri),
540
1152
  redirect_uri_provided_explicitly=True,
541
- scopes=code_data["scopes"],
542
- expires_at=code_data["expires_at"],
543
- code_challenge=code_data.get("code_challenge", ""),
1153
+ scopes=code_model.scopes,
1154
+ expires_at=code_model.expires_at,
1155
+ code_challenge=code_model.code_challenge or "",
544
1156
  )
545
1157
 
1158
+ @override
546
1159
  async def exchange_authorization_code(
547
1160
  self,
548
1161
  client: OAuthClientInformationFull,
549
1162
  authorization_code: AuthorizationCode,
550
1163
  ) -> OAuthToken:
551
- """Exchange authorization code for stored IdP tokens.
1164
+ """Exchange authorization code for FastMCP-issued tokens.
1165
+
1166
+ Implements the token factory pattern:
1167
+ 1. Retrieves upstream tokens from stored authorization code
1168
+ 2. Extracts user identity from upstream token
1169
+ 3. Encrypts and stores upstream tokens
1170
+ 4. Issues FastMCP-signed JWT tokens
1171
+ 5. Returns FastMCP tokens (NOT upstream tokens)
552
1172
 
553
- For the DCR-compliant proxy flow, we return the IdP tokens that were obtained
554
- during the IdP callback exchange. PKCE validation is handled by the MCP framework.
1173
+ PKCE validation is handled by the MCP framework before this method is called.
555
1174
  """
556
1175
  # Look up stored code data
557
- code_data = self._client_codes.get(authorization_code.code)
558
- if not code_data:
1176
+ code_model = await self._code_store.get(key=authorization_code.code)
1177
+ if not code_model:
559
1178
  logger.error(
560
1179
  "Authorization code not found in client codes: %s",
561
1180
  authorization_code.code,
562
1181
  )
563
1182
  raise TokenError("invalid_grant", "Authorization code not found")
564
1183
 
565
- # Get stored IdP tokens
566
- idp_tokens = code_data["idp_tokens"]
1184
+ # Get stored upstream tokens
1185
+ idp_tokens = code_model.idp_tokens
567
1186
 
568
1187
  # Clean up client code (one-time use)
569
- self._client_codes.pop(authorization_code.code, None)
1188
+ await self._code_store.delete(key=authorization_code.code)
570
1189
 
571
- # Extract token information for local tracking
572
- access_token_value = idp_tokens["access_token"]
573
- refresh_token_value = idp_tokens.get("refresh_token")
1190
+ # Generate IDs for token storage
1191
+ upstream_token_id = secrets.token_urlsafe(32)
1192
+ access_jti = secrets.token_urlsafe(32)
1193
+ refresh_jti = (
1194
+ secrets.token_urlsafe(32) if idp_tokens.get("refresh_token") else None
1195
+ )
1196
+
1197
+ # Calculate token expiry times
574
1198
  expires_in = int(
575
1199
  idp_tokens.get("expires_in", DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS)
576
1200
  )
577
- expires_at = int(time.time() + expires_in)
578
1201
 
579
- # Store access token locally for tracking
580
- access_token = AccessToken(
581
- token=access_token_value,
1202
+ # Calculate refresh token expiry if provided by upstream
1203
+ # Some providers include refresh_expires_in, some don't
1204
+ refresh_expires_in = None
1205
+ refresh_token_expires_at = None
1206
+ if idp_tokens.get("refresh_token"):
1207
+ if "refresh_expires_in" in idp_tokens:
1208
+ refresh_expires_in = int(idp_tokens["refresh_expires_in"])
1209
+ refresh_token_expires_at = time.time() + refresh_expires_in
1210
+ logger.debug(
1211
+ "Upstream refresh token expires in %d seconds", refresh_expires_in
1212
+ )
1213
+ else:
1214
+ # Default to 30 days if upstream doesn't specify
1215
+ # This is conservative - most providers use longer expiry
1216
+ refresh_expires_in = 60 * 60 * 24 * 30 # 30 days
1217
+ refresh_token_expires_at = time.time() + refresh_expires_in
1218
+ logger.debug(
1219
+ "Upstream refresh token expiry unknown, using 30-day default"
1220
+ )
1221
+
1222
+ # Encrypt and store upstream tokens
1223
+ upstream_token_set = UpstreamTokenSet(
1224
+ upstream_token_id=upstream_token_id,
1225
+ access_token=idp_tokens["access_token"],
1226
+ refresh_token=idp_tokens["refresh_token"]
1227
+ if idp_tokens.get("refresh_token")
1228
+ else None,
1229
+ refresh_token_expires_at=refresh_token_expires_at,
1230
+ expires_at=time.time() + expires_in,
1231
+ token_type=idp_tokens.get("token_type", "Bearer"),
1232
+ scope=" ".join(authorization_code.scopes),
1233
+ client_id=client.client_id or "",
1234
+ created_at=time.time(),
1235
+ raw_token_data=idp_tokens,
1236
+ )
1237
+ await self._upstream_token_store.put(
1238
+ key=upstream_token_id,
1239
+ value=upstream_token_set,
1240
+ ttl=refresh_expires_in
1241
+ or expires_in, # Auto-expire when refresh token, or access token expires
1242
+ )
1243
+ logger.debug("Stored encrypted upstream tokens (jti=%s)", access_jti[:8])
1244
+
1245
+ # Issue minimal FastMCP access token (just a reference via JTI)
1246
+ if client.client_id is None:
1247
+ raise TokenError("invalid_client", "Client ID is required")
1248
+ fastmcp_access_token = self._jwt_issuer.issue_access_token(
582
1249
  client_id=client.client_id,
583
1250
  scopes=authorization_code.scopes,
584
- expires_at=expires_at,
1251
+ jti=access_jti,
1252
+ expires_in=expires_in,
585
1253
  )
586
- self._access_tokens[access_token_value] = access_token
587
1254
 
588
- # Store refresh token if provided
589
- if refresh_token_value:
590
- refresh_token = RefreshToken(
591
- token=refresh_token_value,
1255
+ # Issue minimal FastMCP refresh token if upstream provided one
1256
+ # Use upstream refresh token expiry to align lifetimes
1257
+ fastmcp_refresh_token = None
1258
+ if refresh_jti and refresh_expires_in:
1259
+ fastmcp_refresh_token = self._jwt_issuer.issue_refresh_token(
592
1260
  client_id=client.client_id,
593
1261
  scopes=authorization_code.scopes,
594
- expires_at=None, # Refresh tokens typically don't expire
1262
+ jti=refresh_jti,
1263
+ expires_in=refresh_expires_in,
595
1264
  )
596
- self._refresh_tokens[refresh_token_value] = refresh_token
597
1265
 
598
- # Maintain token relationships for cleanup
599
- self._access_to_refresh[access_token_value] = refresh_token_value
600
- self._refresh_to_access[refresh_token_value] = access_token_value
1266
+ # Store JTI mappings
1267
+ await self._jti_mapping_store.put(
1268
+ key=access_jti,
1269
+ value=JTIMapping(
1270
+ jti=access_jti,
1271
+ upstream_token_id=upstream_token_id,
1272
+ created_at=time.time(),
1273
+ ),
1274
+ ttl=expires_in, # Auto-expire with access token
1275
+ )
1276
+ if refresh_jti:
1277
+ await self._jti_mapping_store.put(
1278
+ key=refresh_jti,
1279
+ value=JTIMapping(
1280
+ jti=refresh_jti,
1281
+ upstream_token_id=upstream_token_id,
1282
+ created_at=time.time(),
1283
+ ),
1284
+ ttl=60 * 60 * 24 * 30, # Auto-expire with refresh token (30 days)
1285
+ )
1286
+
1287
+ # Store refresh token metadata (keyed by hash for security)
1288
+ if fastmcp_refresh_token and refresh_expires_in:
1289
+ await self._refresh_token_store.put(
1290
+ key=_hash_token(fastmcp_refresh_token),
1291
+ value=RefreshTokenMetadata(
1292
+ client_id=client.client_id,
1293
+ scopes=authorization_code.scopes,
1294
+ expires_at=int(time.time()) + refresh_expires_in,
1295
+ created_at=time.time(),
1296
+ ),
1297
+ ttl=refresh_expires_in,
1298
+ )
601
1299
 
602
1300
  logger.debug(
603
- "Successfully exchanged client code for stored IdP tokens (client: %s)",
1301
+ "Issued FastMCP tokens for client=%s (access_jti=%s, refresh_jti=%s)",
604
1302
  client.client_id,
1303
+ access_jti[:8],
1304
+ refresh_jti[:8] if refresh_jti else "none",
605
1305
  )
606
1306
 
607
- return OAuthToken(**idp_tokens) # type: ignore[arg-type]
1307
+ # Return FastMCP-issued tokens (NOT upstream tokens!)
1308
+ return OAuthToken(
1309
+ access_token=fastmcp_access_token,
1310
+ token_type="Bearer",
1311
+ expires_in=expires_in,
1312
+ refresh_token=fastmcp_refresh_token,
1313
+ scope=" ".join(authorization_code.scopes),
1314
+ )
608
1315
 
609
1316
  # -------------------------------------------------------------------------
610
1317
  # Refresh Token Flow
611
1318
  # -------------------------------------------------------------------------
612
1319
 
1320
+ def _prepare_scopes_for_upstream_refresh(self, scopes: list[str]) -> list[str]:
1321
+ """Prepare scopes for upstream token refresh request.
1322
+
1323
+ Override this method to transform scopes before sending to upstream provider.
1324
+ For example, Azure needs to prefix scopes and add additional Graph scopes.
1325
+
1326
+ The scopes parameter represents what should be stored in the RefreshToken.
1327
+ This method returns what should be sent to the upstream provider.
1328
+
1329
+ Args:
1330
+ scopes: Base scopes that will be stored in RefreshToken
1331
+
1332
+ Returns:
1333
+ Scopes to send to upstream provider (may be transformed/augmented)
1334
+ """
1335
+ return scopes
1336
+
613
1337
  async def load_refresh_token(
614
1338
  self,
615
1339
  client: OAuthClientInformationFull,
616
1340
  refresh_token: str,
617
1341
  ) -> RefreshToken | None:
618
- """Load refresh token from local storage."""
619
- return self._refresh_tokens.get(refresh_token)
1342
+ """Load refresh token metadata from distributed storage.
1343
+
1344
+ Looks up by token hash and reconstructs the RefreshToken object.
1345
+ Validates that the token belongs to the requesting client.
1346
+ """
1347
+ token_hash = _hash_token(refresh_token)
1348
+ metadata = await self._refresh_token_store.get(key=token_hash)
1349
+ if not metadata:
1350
+ return None
1351
+ # Verify token belongs to this client (prevents cross-client token usage)
1352
+ if metadata.client_id != client.client_id:
1353
+ logger.warning(
1354
+ "Refresh token client_id mismatch: expected %s, got %s",
1355
+ client.client_id,
1356
+ metadata.client_id,
1357
+ )
1358
+ return None
1359
+ return RefreshToken(
1360
+ token=refresh_token,
1361
+ client_id=metadata.client_id,
1362
+ scopes=metadata.scopes,
1363
+ expires_at=metadata.expires_at,
1364
+ )
620
1365
 
621
1366
  async def exchange_refresh_token(
622
1367
  self,
@@ -624,9 +1369,45 @@ class OAuthProxy(OAuthProvider):
624
1369
  refresh_token: RefreshToken,
625
1370
  scopes: list[str],
626
1371
  ) -> OAuthToken:
627
- """Exchange refresh token for new access token using authlib."""
1372
+ """Exchange FastMCP refresh token for new FastMCP access token.
1373
+
1374
+ Implements two-tier refresh:
1375
+ 1. Verify FastMCP refresh token
1376
+ 2. Look up upstream token via JTI mapping
1377
+ 3. Refresh upstream token with upstream provider
1378
+ 4. Update stored upstream token
1379
+ 5. Issue new FastMCP access token
1380
+ 6. Keep same FastMCP refresh token (unless upstream rotates)
1381
+ """
1382
+ # Verify FastMCP refresh token
1383
+ try:
1384
+ refresh_payload = self._jwt_issuer.verify_token(refresh_token.token)
1385
+ refresh_jti = refresh_payload["jti"]
1386
+ except Exception as e:
1387
+ logger.debug("FastMCP refresh token validation failed: %s", e)
1388
+ raise TokenError("invalid_grant", "Invalid refresh token") from e
1389
+
1390
+ # Look up upstream token via JTI mapping
1391
+ jti_mapping = await self._jti_mapping_store.get(key=refresh_jti)
1392
+ if not jti_mapping:
1393
+ logger.error("JTI mapping not found for refresh token: %s", refresh_jti[:8])
1394
+ raise TokenError("invalid_grant", "Refresh token mapping not found")
628
1395
 
629
- # Use authlib's AsyncOAuth2Client for refresh token exchange
1396
+ upstream_token_set = await self._upstream_token_store.get(
1397
+ key=jti_mapping.upstream_token_id
1398
+ )
1399
+ if not upstream_token_set:
1400
+ logger.error(
1401
+ "Upstream token set not found: %s", jti_mapping.upstream_token_id[:8]
1402
+ )
1403
+ raise TokenError("invalid_grant", "Upstream token not found")
1404
+
1405
+ # Decrypt upstream refresh token
1406
+ if not upstream_token_set.refresh_token:
1407
+ logger.error("No upstream refresh token available")
1408
+ raise TokenError("invalid_grant", "Refresh not supported for this token")
1409
+
1410
+ # Refresh upstream token using authlib
630
1411
  oauth_client = AsyncOAuth2Client(
631
1412
  client_id=self._upstream_client_id,
632
1413
  client_secret=self._upstream_client_secret.get_secret_value(),
@@ -634,78 +1415,209 @@ class OAuthProxy(OAuthProvider):
634
1415
  timeout=HTTP_TIMEOUT_SECONDS,
635
1416
  )
636
1417
 
637
- try:
638
- logger.debug("Using authlib to refresh token from upstream")
1418
+ # Allow child classes to transform scopes before sending to upstream
1419
+ # This enables provider-specific scope formatting (e.g., Azure prefixing)
1420
+ # while keeping original scopes in storage
1421
+ upstream_scopes = self._prepare_scopes_for_upstream_refresh(scopes)
639
1422
 
640
- # Let authlib handle the refresh token exchange
1423
+ try:
1424
+ logger.debug("Refreshing upstream token (jti=%s)", refresh_jti[:8])
641
1425
  token_response: dict[str, Any] = await oauth_client.refresh_token( # type: ignore[misc]
642
1426
  url=self._upstream_token_endpoint,
643
- refresh_token=refresh_token.token,
644
- scope=" ".join(scopes) if scopes else None,
1427
+ refresh_token=upstream_token_set.refresh_token,
1428
+ scope=" ".join(upstream_scopes) if upstream_scopes else None,
1429
+ **self._extra_token_params,
645
1430
  )
646
-
647
- logger.debug(
648
- "Successfully refreshed access token via authlib (client: %s)",
649
- client.client_id,
650
- )
651
-
1431
+ logger.debug("Successfully refreshed upstream token")
652
1432
  except Exception as e:
653
- logger.error("Authlib refresh token exchange failed: %s", e)
654
- raise TokenError(
655
- "invalid_grant", f"Upstream refresh token exchange failed: {e}"
656
- ) from e
1433
+ logger.error("Upstream token refresh failed: %s", e)
1434
+ raise TokenError("invalid_grant", f"Upstream refresh failed: {e}") from e
657
1435
 
658
- # Update local token storage
659
- new_access_token = token_response["access_token"]
660
- expires_in = int(
1436
+ # Update stored upstream token
1437
+ new_expires_in = int(
661
1438
  token_response.get("expires_in", DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS)
662
1439
  )
1440
+ upstream_token_set.access_token = token_response["access_token"]
1441
+ upstream_token_set.expires_at = time.time() + new_expires_in
1442
+
1443
+ # Handle upstream refresh token rotation and expiry
1444
+ new_refresh_expires_in = None
1445
+ if new_upstream_refresh := token_response.get("refresh_token"):
1446
+ if new_upstream_refresh != upstream_token_set.refresh_token:
1447
+ upstream_token_set.refresh_token = new_upstream_refresh
1448
+ logger.debug("Upstream refresh token rotated")
1449
+
1450
+ # Update refresh token expiry if provided
1451
+ if "refresh_expires_in" in token_response:
1452
+ new_refresh_expires_in = int(token_response["refresh_expires_in"])
1453
+ upstream_token_set.refresh_token_expires_at = (
1454
+ time.time() + new_refresh_expires_in
1455
+ )
1456
+ logger.debug(
1457
+ "Upstream refresh token expires in %d seconds",
1458
+ new_refresh_expires_in,
1459
+ )
1460
+ elif upstream_token_set.refresh_token_expires_at:
1461
+ # Keep existing expiry if upstream doesn't provide new one
1462
+ new_refresh_expires_in = int(
1463
+ upstream_token_set.refresh_token_expires_at - time.time()
1464
+ )
1465
+ else:
1466
+ # Default to 30 days if unknown
1467
+ new_refresh_expires_in = 60 * 60 * 24 * 30
1468
+ upstream_token_set.refresh_token_expires_at = (
1469
+ time.time() + new_refresh_expires_in
1470
+ )
663
1471
 
664
- self._access_tokens[new_access_token] = AccessToken(
665
- token=new_access_token,
1472
+ upstream_token_set.raw_token_data = token_response
1473
+ await self._upstream_token_store.put(
1474
+ key=upstream_token_set.upstream_token_id,
1475
+ value=upstream_token_set,
1476
+ ttl=new_refresh_expires_in
1477
+ or (
1478
+ int(upstream_token_set.refresh_token_expires_at - time.time())
1479
+ if upstream_token_set.refresh_token_expires_at
1480
+ else 60 * 60 * 24 * 30 # Default to 30 days if unknown
1481
+ ), # Auto-expire when refresh token expires
1482
+ )
1483
+
1484
+ # Issue new minimal FastMCP access token (just a reference via JTI)
1485
+ if client.client_id is None:
1486
+ raise TokenError("invalid_client", "Client ID is required")
1487
+ new_access_jti = secrets.token_urlsafe(32)
1488
+ new_fastmcp_access = self._jwt_issuer.issue_access_token(
666
1489
  client_id=client.client_id,
667
1490
  scopes=scopes,
668
- expires_at=int(time.time() + expires_in),
669
- )
670
-
671
- # Handle refresh token rotation if new one provided
672
- if "refresh_token" in token_response:
673
- new_refresh_token = token_response["refresh_token"]
674
- if new_refresh_token != refresh_token.token:
675
- # Remove old refresh token
676
- self._refresh_tokens.pop(refresh_token.token, None)
677
- old_access = self._refresh_to_access.pop(refresh_token.token, None)
678
- if old_access:
679
- self._access_to_refresh.pop(old_access, None)
680
-
681
- # Store new refresh token
682
- self._refresh_tokens[new_refresh_token] = RefreshToken(
683
- token=new_refresh_token,
684
- client_id=client.client_id,
685
- scopes=scopes,
686
- expires_at=None,
687
- )
688
- self._access_to_refresh[new_access_token] = new_refresh_token
689
- self._refresh_to_access[new_refresh_token] = new_access_token
1491
+ jti=new_access_jti,
1492
+ expires_in=new_expires_in,
1493
+ )
1494
+
1495
+ # Store new access token JTI mapping
1496
+ await self._jti_mapping_store.put(
1497
+ key=new_access_jti,
1498
+ value=JTIMapping(
1499
+ jti=new_access_jti,
1500
+ upstream_token_id=upstream_token_set.upstream_token_id,
1501
+ created_at=time.time(),
1502
+ ),
1503
+ ttl=new_expires_in, # Auto-expire with refreshed access token
1504
+ )
690
1505
 
691
- return OAuthToken(**token_response) # type: ignore[arg-type]
1506
+ # Issue NEW minimal FastMCP refresh token (rotation for security)
1507
+ # Use upstream refresh token expiry to align lifetimes
1508
+ new_refresh_jti = secrets.token_urlsafe(32)
1509
+ new_fastmcp_refresh = self._jwt_issuer.issue_refresh_token(
1510
+ client_id=client.client_id,
1511
+ scopes=scopes,
1512
+ jti=new_refresh_jti,
1513
+ expires_in=new_refresh_expires_in
1514
+ or 60 * 60 * 24 * 30, # Fallback to 30 days
1515
+ )
1516
+
1517
+ # Store new refresh token JTI mapping with aligned expiry
1518
+ refresh_ttl = new_refresh_expires_in or 60 * 60 * 24 * 30
1519
+ await self._jti_mapping_store.put(
1520
+ key=new_refresh_jti,
1521
+ value=JTIMapping(
1522
+ jti=new_refresh_jti,
1523
+ upstream_token_id=upstream_token_set.upstream_token_id,
1524
+ created_at=time.time(),
1525
+ ),
1526
+ ttl=refresh_ttl, # Align with upstream refresh token expiry
1527
+ )
1528
+
1529
+ # Invalidate old refresh token (refresh token rotation - enforces one-time use)
1530
+ await self._jti_mapping_store.delete(key=refresh_jti)
1531
+ logger.debug(
1532
+ "Rotated refresh token (old JTI invalidated - one-time use enforced)"
1533
+ )
1534
+
1535
+ # Store new refresh token metadata (keyed by hash)
1536
+ await self._refresh_token_store.put(
1537
+ key=_hash_token(new_fastmcp_refresh),
1538
+ value=RefreshTokenMetadata(
1539
+ client_id=client.client_id,
1540
+ scopes=scopes,
1541
+ expires_at=int(time.time()) + refresh_ttl,
1542
+ created_at=time.time(),
1543
+ ),
1544
+ ttl=refresh_ttl,
1545
+ )
1546
+
1547
+ # Delete old refresh token (by hash)
1548
+ await self._refresh_token_store.delete(key=_hash_token(refresh_token.token))
1549
+
1550
+ logger.info(
1551
+ "Issued new FastMCP tokens (rotated refresh) for client=%s (access_jti=%s, refresh_jti=%s)",
1552
+ client.client_id,
1553
+ new_access_jti[:8],
1554
+ new_refresh_jti[:8],
1555
+ )
1556
+
1557
+ # Return new FastMCP tokens (both access AND refresh are new)
1558
+ return OAuthToken(
1559
+ access_token=new_fastmcp_access,
1560
+ token_type="Bearer",
1561
+ expires_in=new_expires_in,
1562
+ refresh_token=new_fastmcp_refresh, # NEW refresh token (rotated)
1563
+ scope=" ".join(scopes),
1564
+ )
692
1565
 
693
1566
  # -------------------------------------------------------------------------
694
1567
  # Token Validation
695
1568
  # -------------------------------------------------------------------------
696
1569
 
697
1570
  async def load_access_token(self, token: str) -> AccessToken | None:
698
- """Validate access token using upstream JWKS.
1571
+ """Validate FastMCP JWT by swapping for upstream token.
1572
+
1573
+ This implements the token swap pattern:
1574
+ 1. Verify FastMCP JWT signature (proves it's our token)
1575
+ 2. Look up upstream token via JTI mapping
1576
+ 3. Decrypt upstream token
1577
+ 4. Validate upstream token with provider (GitHub API, JWT validation, etc.)
1578
+ 5. Return upstream validation result
699
1579
 
700
- Delegates to the JWT verifier which handles signature validation,
701
- expiration checking, and claims validation using the upstream JWKS.
1580
+ The FastMCP JWT is a reference token - all authorization data comes
1581
+ from validating the upstream token via the TokenVerifier.
702
1582
  """
703
- result = await self._token_validator.verify_token(token)
704
- if result:
705
- logger.debug("Token validated successfully")
706
- else:
707
- logger.debug("Token validation failed")
708
- return result
1583
+ try:
1584
+ # 1. Verify FastMCP JWT signature and claims
1585
+ payload = self._jwt_issuer.verify_token(token)
1586
+ jti = payload["jti"]
1587
+
1588
+ # 2. Look up upstream token via JTI mapping
1589
+ jti_mapping = await self._jti_mapping_store.get(key=jti)
1590
+ if not jti_mapping:
1591
+ logger.debug("JTI mapping not found: %s", jti)
1592
+ return None
1593
+
1594
+ upstream_token_set = await self._upstream_token_store.get(
1595
+ key=jti_mapping.upstream_token_id
1596
+ )
1597
+ if not upstream_token_set:
1598
+ logger.debug(
1599
+ "Upstream token not found: %s", jti_mapping.upstream_token_id
1600
+ )
1601
+ return None
1602
+
1603
+ # 3. Validate with upstream provider (delegated to TokenVerifier)
1604
+ # This calls the real token validator (GitHub API, JWKS, etc.)
1605
+ validated = await self._token_validator.verify_token(
1606
+ upstream_token_set.access_token
1607
+ )
1608
+
1609
+ if not validated:
1610
+ logger.debug("Upstream token validation failed")
1611
+ return None
1612
+
1613
+ logger.debug(
1614
+ "Token swap successful for JTI=%s (upstream validated)", jti[:8]
1615
+ )
1616
+ return validated
1617
+
1618
+ except Exception as e:
1619
+ logger.debug("Token swap validation failed: %s", e)
1620
+ return None
709
1621
 
710
1622
  # -------------------------------------------------------------------------
711
1623
  # Token Revocation
@@ -714,24 +1626,13 @@ class OAuthProxy(OAuthProvider):
714
1626
  async def revoke_token(self, token: AccessToken | RefreshToken) -> None:
715
1627
  """Revoke token locally and with upstream server if supported.
716
1628
 
717
- Removes tokens from local storage and attempts to revoke them with
718
- the upstream server if a revocation endpoint is configured.
1629
+ For refresh tokens, removes from local storage by hash.
1630
+ For all tokens, attempts upstream revocation if endpoint is configured.
1631
+ Access token JTI mappings expire via TTL.
719
1632
  """
720
- # Clean up local token storage
721
- if isinstance(token, AccessToken):
722
- self._access_tokens.pop(token.token, None)
723
- # Also remove associated refresh token
724
- paired_refresh = self._access_to_refresh.pop(token.token, None)
725
- if paired_refresh:
726
- self._refresh_tokens.pop(paired_refresh, None)
727
- self._refresh_to_access.pop(paired_refresh, None)
728
- else: # RefreshToken
729
- self._refresh_tokens.pop(token.token, None)
730
- # Also remove associated access token
731
- paired_access = self._refresh_to_access.pop(token.token, None)
732
- if paired_access:
733
- self._access_tokens.pop(paired_access, None)
734
- self._access_to_refresh.pop(paired_access, None)
1633
+ # For refresh tokens, delete from local storage by hash
1634
+ if isinstance(token, RefreshToken):
1635
+ await self._refresh_token_store.delete(key=_hash_token(token.token))
735
1636
 
736
1637
  # Attempt upstream revocation if endpoint is configured
737
1638
  if self._upstream_revocation_endpoint:
@@ -758,21 +1659,22 @@ class OAuthProxy(OAuthProvider):
758
1659
  def get_routes(
759
1660
  self,
760
1661
  mcp_path: str | None = None,
761
- mcp_endpoint: Any | None = None,
762
1662
  ) -> list[Route]:
763
- """Get OAuth routes with custom proxy token handler.
1663
+ """Get OAuth routes with custom handlers for better error UX.
764
1664
 
765
- This method creates standard OAuth routes and replaces the token endpoint
766
- with our proxy handler that forwards requests to the upstream OAuth server.
1665
+ This method creates standard OAuth routes and replaces:
1666
+ - /authorize endpoint: Enhanced error responses for unregistered clients
1667
+ - /token endpoint: OAuth 2.1 compliant error codes
767
1668
 
768
1669
  Args:
769
1670
  mcp_path: The path where the MCP endpoint is mounted (e.g., "/mcp")
770
- mcp_endpoint: The MCP endpoint handler to protect with auth
1671
+ This is used to advertise the resource URL in metadata.
771
1672
  """
772
1673
  # Get standard OAuth routes from parent class
773
- routes = super().get_routes(mcp_path, mcp_endpoint)
1674
+ routes = super().get_routes(mcp_path)
774
1675
  custom_routes = []
775
1676
  token_route_found = False
1677
+ authorize_route_found = False
776
1678
 
777
1679
  logger.debug(
778
1680
  f"get_routes called - configuring OAuth routes in {len(routes)} routes"
@@ -783,16 +1685,53 @@ class OAuthProxy(OAuthProvider):
783
1685
  f"Route {i}: {route} - path: {getattr(route, 'path', 'N/A')}, methods: {getattr(route, 'methods', 'N/A')}"
784
1686
  )
785
1687
 
786
- # Keep all standard OAuth routes unchanged - our DCR-compliant flow handles everything
787
- custom_routes.append(route)
788
-
1688
+ # Replace the authorize endpoint with our enhanced handler for better error UX
789
1689
  if (
1690
+ isinstance(route, Route)
1691
+ and route.path == "/authorize"
1692
+ and route.methods is not None
1693
+ and ("GET" in route.methods or "POST" in route.methods)
1694
+ ):
1695
+ authorize_route_found = True
1696
+ # Replace with our enhanced authorization handler
1697
+ # Note: self.base_url is guaranteed to be set in parent __init__
1698
+ authorize_handler = AuthorizationHandler(
1699
+ provider=self,
1700
+ base_url=self.base_url, # ty: ignore[invalid-argument-type]
1701
+ server_name=None, # Could be extended to pass server metadata
1702
+ server_icon_url=None,
1703
+ )
1704
+ custom_routes.append(
1705
+ Route(
1706
+ path="/authorize",
1707
+ endpoint=authorize_handler.handle,
1708
+ methods=["GET", "POST"],
1709
+ )
1710
+ )
1711
+ # Replace the token endpoint with our custom handler that returns proper OAuth 2.1 error codes
1712
+ elif (
790
1713
  isinstance(route, Route)
791
1714
  and route.path == "/token"
792
1715
  and route.methods is not None
793
1716
  and "POST" in route.methods
794
1717
  ):
795
1718
  token_route_found = True
1719
+ # Replace with our OAuth 2.1 compliant token handler
1720
+ token_handler = TokenHandler(
1721
+ provider=self, client_authenticator=ClientAuthenticator(self)
1722
+ )
1723
+ custom_routes.append(
1724
+ Route(
1725
+ path="/token",
1726
+ endpoint=cors_middleware(
1727
+ token_handler.handle, ["POST", "OPTIONS"]
1728
+ ),
1729
+ methods=["POST", "OPTIONS"],
1730
+ )
1731
+ )
1732
+ else:
1733
+ # Keep all other standard OAuth routes unchanged
1734
+ custom_routes.append(route)
796
1735
 
797
1736
  # Add OAuth callback endpoint for forwarding to client callbacks
798
1737
  custom_routes.append(
@@ -803,8 +1742,16 @@ class OAuthProxy(OAuthProvider):
803
1742
  )
804
1743
  )
805
1744
 
1745
+ # Add consent endpoints
1746
+ # Handle both GET (show page) and POST (submit) at /consent
1747
+ custom_routes.append(
1748
+ Route(
1749
+ path="/consent", endpoint=self._handle_consent, methods=["GET", "POST"]
1750
+ )
1751
+ )
1752
+
806
1753
  logger.debug(
807
- f"✅ OAuth routes configured: token_endpoint={token_route_found}, total routes={len(custom_routes)} (includes OAuth callback)"
1754
+ f"✅ OAuth routes configured: authorize_endpoint={authorize_route_found}, token_endpoint={token_route_found}, total routes={len(custom_routes)} (includes OAuth callback + consent)"
808
1755
  )
809
1756
  return custom_routes
810
1757
 
@@ -812,7 +1759,9 @@ class OAuthProxy(OAuthProvider):
812
1759
  # IdP Callback Forwarding
813
1760
  # -------------------------------------------------------------------------
814
1761
 
815
- async def _handle_idp_callback(self, request: Request) -> RedirectResponse:
1762
+ async def _handle_idp_callback(
1763
+ self, request: Request
1764
+ ) -> HTMLResponse | RedirectResponse:
816
1765
  """Handle callback from upstream IdP and forward to client.
817
1766
 
818
1767
  This implements the DCR-compliant callback forwarding:
@@ -827,32 +1776,38 @@ class OAuthProxy(OAuthProvider):
827
1776
  error = request.query_params.get("error")
828
1777
 
829
1778
  if error:
1779
+ error_description = request.query_params.get("error_description")
830
1780
  logger.error(
831
1781
  "IdP callback error: %s - %s",
832
1782
  error,
833
- request.query_params.get("error_description"),
1783
+ error_description,
834
1784
  )
835
- # TODO: Forward error to client callback
836
- return RedirectResponse(
837
- url=f"data:text/html,<h1>OAuth Error</h1><p>{error}: {request.query_params.get('error_description', 'Unknown error')}</p>",
838
- status_code=302,
1785
+ # Show error page to user
1786
+ html_content = create_error_html(
1787
+ error_title="OAuth Error",
1788
+ error_message=f"Authentication failed: {error_description or 'Unknown error'}",
1789
+ error_details={"Error Code": error} if error else None,
839
1790
  )
1791
+ return HTMLResponse(content=html_content, status_code=400)
840
1792
 
841
1793
  if not idp_code or not txn_id:
842
1794
  logger.error("IdP callback missing code or transaction ID")
843
- return RedirectResponse(
844
- url="data:text/html,<h1>OAuth Error</h1><p>Missing authorization code or transaction ID</p>",
845
- status_code=302,
1795
+ html_content = create_error_html(
1796
+ error_title="OAuth Error",
1797
+ error_message="Missing authorization code or transaction ID from the identity provider.",
846
1798
  )
1799
+ return HTMLResponse(content=html_content, status_code=400)
847
1800
 
848
1801
  # Look up transaction data
849
- transaction = self._oauth_transactions.get(txn_id)
850
- if not transaction:
1802
+ transaction_model = await self._transaction_store.get(key=txn_id)
1803
+ if not transaction_model:
851
1804
  logger.error("IdP callback with invalid transaction ID: %s", txn_id)
852
- return RedirectResponse(
853
- url="data:text/html,<h1>OAuth Error</h1><p>Invalid or expired transaction</p>",
854
- status_code=302,
1805
+ html_content = create_error_html(
1806
+ error_title="OAuth Error",
1807
+ error_message="Invalid or expired authorization transaction. Please try authenticating again.",
855
1808
  )
1809
+ return HTMLResponse(content=html_content, status_code=400)
1810
+ transaction = transaction_model.model_dump()
856
1811
 
857
1812
  # Exchange IdP code for tokens (server-side)
858
1813
  oauth_client = AsyncOAuth2Client(
@@ -870,56 +1825,70 @@ class OAuthProxy(OAuthProvider):
870
1825
  f"Exchanging IdP code for tokens with redirect_uri: {idp_redirect_uri}"
871
1826
  )
872
1827
 
1828
+ # Build token exchange parameters
1829
+ token_params = {
1830
+ "url": self._upstream_token_endpoint,
1831
+ "code": idp_code,
1832
+ "redirect_uri": idp_redirect_uri,
1833
+ }
1834
+
873
1835
  # Include proxy's code_verifier if we forwarded PKCE
874
1836
  proxy_code_verifier = transaction.get("proxy_code_verifier")
875
1837
  if proxy_code_verifier:
1838
+ token_params["code_verifier"] = proxy_code_verifier
876
1839
  logger.debug(
877
1840
  "Including proxy code_verifier in token exchange for transaction %s",
878
1841
  txn_id,
879
1842
  )
880
- idp_tokens: dict[str, Any] = await oauth_client.fetch_token( # type: ignore[misc]
881
- url=self._upstream_token_endpoint,
882
- code=idp_code,
883
- redirect_uri=idp_redirect_uri,
884
- code_verifier=proxy_code_verifier,
885
- )
886
- else:
887
- idp_tokens: dict[str, Any] = await oauth_client.fetch_token( # type: ignore[misc]
888
- url=self._upstream_token_endpoint,
889
- code=idp_code,
890
- redirect_uri=idp_redirect_uri,
1843
+
1844
+ # Add any extra token parameters configured for this proxy
1845
+ if self._extra_token_params:
1846
+ token_params.update(self._extra_token_params)
1847
+ logger.debug(
1848
+ "Adding extra token parameters for transaction %s: %s",
1849
+ txn_id,
1850
+ list(self._extra_token_params.keys()),
891
1851
  )
892
1852
 
1853
+ idp_tokens: dict[str, Any] = await oauth_client.fetch_token(
1854
+ **token_params
1855
+ ) # type: ignore[misc]
1856
+
893
1857
  logger.debug(
894
1858
  f"Successfully exchanged IdP code for tokens (transaction: {txn_id}, PKCE: {bool(proxy_code_verifier)})"
895
1859
  )
896
1860
 
897
1861
  except Exception as e:
898
1862
  logger.error("IdP token exchange failed: %s", e)
899
- # TODO: Forward error to client callback
900
- return RedirectResponse(
901
- url=f"data:text/html,<h1>OAuth Error</h1><p>Token exchange failed: {e}</p>",
902
- status_code=302,
1863
+ html_content = create_error_html(
1864
+ error_title="OAuth Error",
1865
+ error_message=f"Token exchange with identity provider failed: {e}",
903
1866
  )
1867
+ return HTMLResponse(content=html_content, status_code=500)
904
1868
 
905
1869
  # Generate our own authorization code for the client
906
1870
  client_code = secrets.token_urlsafe(32)
907
1871
  code_expires_at = int(time.time() + DEFAULT_AUTH_CODE_EXPIRY_SECONDS)
908
1872
 
909
1873
  # Store client code with PKCE challenge and IdP tokens
910
- self._client_codes[client_code] = {
911
- "client_id": transaction["client_id"],
912
- "redirect_uri": transaction["client_redirect_uri"],
913
- "code_challenge": transaction["code_challenge"],
914
- "code_challenge_method": transaction["code_challenge_method"],
915
- "scopes": transaction["scopes"],
916
- "idp_tokens": idp_tokens,
917
- "expires_at": code_expires_at,
918
- "created_at": time.time(),
919
- }
1874
+ await self._code_store.put(
1875
+ key=client_code,
1876
+ value=ClientCode(
1877
+ code=client_code,
1878
+ client_id=transaction["client_id"],
1879
+ redirect_uri=transaction["client_redirect_uri"],
1880
+ code_challenge=transaction["code_challenge"],
1881
+ code_challenge_method=transaction["code_challenge_method"],
1882
+ scopes=transaction["scopes"],
1883
+ idp_tokens=idp_tokens,
1884
+ expires_at=code_expires_at,
1885
+ created_at=time.time(),
1886
+ ),
1887
+ ttl=DEFAULT_AUTH_CODE_EXPIRY_SECONDS, # Auto-expire after 5 minutes
1888
+ )
920
1889
 
921
1890
  # Clean up transaction
922
- self._oauth_transactions.pop(txn_id, None)
1891
+ await self._transaction_store.delete(key=txn_id)
923
1892
 
924
1893
  # Build client callback URL with our code and original state
925
1894
  client_redirect_uri = transaction["client_redirect_uri"]
@@ -942,7 +1911,329 @@ class OAuthProxy(OAuthProvider):
942
1911
 
943
1912
  except Exception as e:
944
1913
  logger.error("Error in IdP callback handler: %s", e, exc_info=True)
1914
+ html_content = create_error_html(
1915
+ error_title="OAuth Error",
1916
+ error_message="Internal server error during OAuth callback processing. Please try again.",
1917
+ )
1918
+ return HTMLResponse(content=html_content, status_code=500)
1919
+
1920
+ # -------------------------------------------------------------------------
1921
+ # Consent Interstitial
1922
+ # -------------------------------------------------------------------------
1923
+
1924
+ def _normalize_uri(self, uri: str) -> str:
1925
+ """Normalize a URI to a canonical form for consent tracking."""
1926
+ parsed = urlparse(uri)
1927
+ path = parsed.path or ""
1928
+ normalized = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}{path}"
1929
+ if normalized.endswith("/") and len(path) > 1:
1930
+ normalized = normalized[:-1]
1931
+ return normalized
1932
+
1933
+ def _make_client_key(self, client_id: str, redirect_uri: str | AnyUrl) -> str:
1934
+ """Create a stable key for consent tracking from client_id and redirect_uri."""
1935
+ normalized = self._normalize_uri(str(redirect_uri))
1936
+ return f"{client_id}:{normalized}"
1937
+
1938
+ def _cookie_name(self, base_name: str) -> str:
1939
+ """Return secure cookie name for HTTPS, fallback for HTTP development."""
1940
+ if self._is_https:
1941
+ return f"__Host-{base_name}"
1942
+ return f"__{base_name}"
1943
+
1944
+ def _sign_cookie(self, payload: str) -> str:
1945
+ """Sign a cookie payload with HMAC-SHA256.
1946
+
1947
+ Returns: base64(payload).base64(signature)
1948
+ """
1949
+ # Use upstream client secret as signing key
1950
+ key = self._upstream_client_secret.get_secret_value().encode()
1951
+ signature = hmac.new(key, payload.encode(), hashlib.sha256).digest()
1952
+ signature_b64 = base64.b64encode(signature).decode()
1953
+ return f"{payload}.{signature_b64}"
1954
+
1955
+ def _verify_cookie(self, signed_value: str) -> str | None:
1956
+ """Verify and extract payload from signed cookie.
1957
+
1958
+ Returns: payload if signature valid, None otherwise
1959
+ """
1960
+ try:
1961
+ if "." not in signed_value:
1962
+ return None
1963
+ payload, signature_b64 = signed_value.rsplit(".", 1)
1964
+
1965
+ # Verify signature
1966
+ key = self._upstream_client_secret.get_secret_value().encode()
1967
+ expected_sig = hmac.new(key, payload.encode(), hashlib.sha256).digest()
1968
+ provided_sig = base64.b64decode(signature_b64.encode())
1969
+
1970
+ # Constant-time comparison
1971
+ if not hmac.compare_digest(expected_sig, provided_sig):
1972
+ return None
1973
+
1974
+ return payload
1975
+ except Exception:
1976
+ return None
1977
+
1978
+ def _decode_list_cookie(self, request: Request, base_name: str) -> list[str]:
1979
+ """Decode and verify a signed base64-encoded JSON list from cookie. Returns [] if missing/invalid."""
1980
+ # Prefer secure name, but also check non-secure variant for dev
1981
+ secure_name = self._cookie_name(base_name)
1982
+ raw = request.cookies.get(secure_name) or request.cookies.get(f"__{base_name}")
1983
+ if not raw:
1984
+ return []
1985
+ try:
1986
+ # Verify signature
1987
+ payload = self._verify_cookie(raw)
1988
+ if not payload:
1989
+ logger.debug("Cookie signature verification failed for %s", secure_name)
1990
+ return []
1991
+
1992
+ # Decode payload
1993
+ data = base64.b64decode(payload.encode())
1994
+ value = json.loads(data.decode())
1995
+ if isinstance(value, list):
1996
+ return [str(x) for x in value]
1997
+ except Exception:
1998
+ logger.debug("Failed to decode cookie %s; treating as empty", secure_name)
1999
+ return []
2000
+
2001
+ def _encode_list_cookie(self, values: list[str]) -> str:
2002
+ """Encode values to base64 and sign with HMAC.
2003
+
2004
+ Returns: signed cookie value (payload.signature)
2005
+ """
2006
+ payload = json.dumps(values, separators=(",", ":")).encode()
2007
+ payload_b64 = base64.b64encode(payload).decode()
2008
+ return self._sign_cookie(payload_b64)
2009
+
2010
+ def _set_list_cookie(
2011
+ self,
2012
+ response: HTMLResponse | RedirectResponse,
2013
+ base_name: str,
2014
+ value_b64: str,
2015
+ max_age: int,
2016
+ ) -> None:
2017
+ name = self._cookie_name(base_name)
2018
+ response.set_cookie(
2019
+ name,
2020
+ value_b64,
2021
+ max_age=max_age,
2022
+ secure=self._is_https,
2023
+ httponly=True,
2024
+ samesite="lax",
2025
+ path="/",
2026
+ )
2027
+
2028
+ def _build_upstream_authorize_url(
2029
+ self, txn_id: str, transaction: dict[str, Any]
2030
+ ) -> str:
2031
+ """Construct the upstream IdP authorization URL using stored transaction data."""
2032
+ query_params: dict[str, Any] = {
2033
+ "response_type": "code",
2034
+ "client_id": self._upstream_client_id,
2035
+ "redirect_uri": f"{str(self.base_url).rstrip('/')}{self._redirect_path}",
2036
+ "state": txn_id,
2037
+ }
2038
+
2039
+ scopes_to_use = transaction.get("scopes") or self.required_scopes or []
2040
+ if scopes_to_use:
2041
+ query_params["scope"] = " ".join(scopes_to_use)
2042
+
2043
+ # If PKCE forwarding was enabled, include the proxy challenge
2044
+ proxy_code_verifier = transaction.get("proxy_code_verifier")
2045
+ if proxy_code_verifier:
2046
+ challenge_bytes = hashlib.sha256(proxy_code_verifier.encode()).digest()
2047
+ proxy_code_challenge = (
2048
+ urlsafe_b64encode(challenge_bytes).decode().rstrip("=")
2049
+ )
2050
+ query_params["code_challenge"] = proxy_code_challenge
2051
+ query_params["code_challenge_method"] = "S256"
2052
+
2053
+ # Forward resource indicator if present in transaction
2054
+ if resource := transaction.get("resource"):
2055
+ query_params["resource"] = resource
2056
+
2057
+ # Extra configured parameters
2058
+ if self._extra_authorize_params:
2059
+ query_params.update(self._extra_authorize_params)
2060
+
2061
+ separator = "&" if "?" in self._upstream_authorization_endpoint else "?"
2062
+ return f"{self._upstream_authorization_endpoint}{separator}{urlencode(query_params)}"
2063
+
2064
+ async def _handle_consent(
2065
+ self, request: Request
2066
+ ) -> HTMLResponse | RedirectResponse:
2067
+ """Handle consent page - dispatch to GET or POST handler based on method."""
2068
+ if request.method == "POST":
2069
+ return await self._submit_consent(request)
2070
+ return await self._show_consent_page(request)
2071
+
2072
+ async def _show_consent_page(
2073
+ self, request: Request
2074
+ ) -> HTMLResponse | RedirectResponse:
2075
+ """Display consent page or auto-approve/deny based on cookies."""
2076
+ from fastmcp.server.server import FastMCP
2077
+
2078
+ txn_id = request.query_params.get("txn_id")
2079
+ if not txn_id:
2080
+ return create_secure_html_response(
2081
+ "<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
2082
+ )
2083
+
2084
+ txn_model = await self._transaction_store.get(key=txn_id)
2085
+ if not txn_model:
2086
+ return create_secure_html_response(
2087
+ "<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
2088
+ )
2089
+
2090
+ txn = txn_model.model_dump()
2091
+ client_key = self._make_client_key(txn["client_id"], txn["client_redirect_uri"])
2092
+
2093
+ approved = set(self._decode_list_cookie(request, "MCP_APPROVED_CLIENTS"))
2094
+ denied = set(self._decode_list_cookie(request, "MCP_DENIED_CLIENTS"))
2095
+
2096
+ if client_key in approved:
2097
+ upstream_url = self._build_upstream_authorize_url(txn_id, txn)
2098
+ return RedirectResponse(url=upstream_url, status_code=302)
2099
+
2100
+ if client_key in denied:
2101
+ callback_params = {
2102
+ "error": "access_denied",
2103
+ "state": txn.get("client_state") or "",
2104
+ }
2105
+ sep = "&" if "?" in txn["client_redirect_uri"] else "?"
945
2106
  return RedirectResponse(
946
- url="data:text/html,<h1>OAuth Error</h1><p>Internal server error during IdP callback</p>",
2107
+ url=f"{txn['client_redirect_uri']}{sep}{urlencode(callback_params)}",
947
2108
  status_code=302,
948
2109
  )
2110
+
2111
+ # Need consent: issue CSRF token and show HTML
2112
+ csrf_token = secrets.token_urlsafe(32)
2113
+ csrf_expires_at = time.time() + 15 * 60
2114
+
2115
+ # Update transaction with CSRF token
2116
+ txn_model.csrf_token = csrf_token
2117
+ txn_model.csrf_expires_at = csrf_expires_at
2118
+ await self._transaction_store.put(
2119
+ key=txn_id, value=txn_model, ttl=15 * 60
2120
+ ) # Auto-expire after 15 minutes
2121
+
2122
+ # Update dict for use in HTML generation
2123
+ txn["csrf_token"] = csrf_token
2124
+ txn["csrf_expires_at"] = csrf_expires_at
2125
+
2126
+ # Load client to get client_name if available
2127
+ client = await self.get_client(txn["client_id"])
2128
+ client_name = getattr(client, "client_name", None) if client else None
2129
+
2130
+ # Extract server metadata from app state
2131
+ fastmcp = getattr(request.app.state, "fastmcp_server", None)
2132
+
2133
+ if isinstance(fastmcp, FastMCP):
2134
+ server_name = fastmcp.name
2135
+ icons = fastmcp.icons
2136
+ server_icon_url = icons[0].src if icons else None
2137
+ server_website_url = fastmcp.website_url
2138
+ else:
2139
+ server_name = None
2140
+ server_icon_url = None
2141
+ server_website_url = None
2142
+
2143
+ html = create_consent_html(
2144
+ client_id=txn["client_id"],
2145
+ redirect_uri=txn["client_redirect_uri"],
2146
+ scopes=txn.get("scopes") or [],
2147
+ txn_id=txn_id,
2148
+ csrf_token=csrf_token,
2149
+ client_name=client_name,
2150
+ server_name=server_name,
2151
+ server_icon_url=server_icon_url,
2152
+ server_website_url=server_website_url,
2153
+ csp_policy=self._consent_csp_policy,
2154
+ )
2155
+ response = create_secure_html_response(html)
2156
+ # Store CSRF in cookie with short lifetime
2157
+ self._set_list_cookie(
2158
+ response,
2159
+ "MCP_CONSENT_STATE",
2160
+ self._encode_list_cookie([csrf_token]),
2161
+ max_age=15 * 60,
2162
+ )
2163
+ return response
2164
+
2165
+ async def _submit_consent(
2166
+ self, request: Request
2167
+ ) -> RedirectResponse | HTMLResponse:
2168
+ """Handle consent approval/denial, set cookies, and redirect appropriately."""
2169
+ form = await request.form()
2170
+ txn_id = str(form.get("txn_id", ""))
2171
+ action = str(form.get("action", ""))
2172
+ csrf_token = str(form.get("csrf_token", ""))
2173
+
2174
+ if not txn_id:
2175
+ return create_secure_html_response(
2176
+ "<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
2177
+ )
2178
+
2179
+ txn_model = await self._transaction_store.get(key=txn_id)
2180
+ if not txn_model:
2181
+ return create_secure_html_response(
2182
+ "<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
2183
+ )
2184
+
2185
+ txn = txn_model.model_dump()
2186
+ expected_csrf = txn.get("csrf_token")
2187
+ expires_at = float(txn.get("csrf_expires_at") or 0)
2188
+
2189
+ if not expected_csrf or csrf_token != expected_csrf or time.time() > expires_at:
2190
+ return create_secure_html_response(
2191
+ "<h1>Error</h1><p>Invalid or expired consent token</p>", status_code=400
2192
+ )
2193
+
2194
+ client_key = self._make_client_key(txn["client_id"], txn["client_redirect_uri"])
2195
+
2196
+ if action == "approve":
2197
+ approved = set(self._decode_list_cookie(request, "MCP_APPROVED_CLIENTS"))
2198
+ if client_key not in approved:
2199
+ approved.add(client_key)
2200
+ approved_b64 = self._encode_list_cookie(sorted(approved))
2201
+
2202
+ upstream_url = self._build_upstream_authorize_url(txn_id, txn)
2203
+ response = RedirectResponse(url=upstream_url, status_code=302)
2204
+ self._set_list_cookie(
2205
+ response, "MCP_APPROVED_CLIENTS", approved_b64, max_age=365 * 24 * 3600
2206
+ )
2207
+ # Clear CSRF cookie by setting empty short-lived value
2208
+ self._set_list_cookie(
2209
+ response, "MCP_CONSENT_STATE", self._encode_list_cookie([]), max_age=60
2210
+ )
2211
+ return response
2212
+
2213
+ elif action == "deny":
2214
+ denied = set(self._decode_list_cookie(request, "MCP_DENIED_CLIENTS"))
2215
+ if client_key not in denied:
2216
+ denied.add(client_key)
2217
+ denied_b64 = self._encode_list_cookie(sorted(denied))
2218
+
2219
+ callback_params = {
2220
+ "error": "access_denied",
2221
+ "state": txn.get("client_state") or "",
2222
+ }
2223
+ sep = "&" if "?" in txn["client_redirect_uri"] else "?"
2224
+ client_callback_url = (
2225
+ f"{txn['client_redirect_uri']}{sep}{urlencode(callback_params)}"
2226
+ )
2227
+ response = RedirectResponse(url=client_callback_url, status_code=302)
2228
+ self._set_list_cookie(
2229
+ response, "MCP_DENIED_CLIENTS", denied_b64, max_age=365 * 24 * 3600
2230
+ )
2231
+ self._set_list_cookie(
2232
+ response, "MCP_CONSENT_STATE", self._encode_list_cookie([]), max_age=60
2233
+ )
2234
+ return response
2235
+
2236
+ else:
2237
+ return create_secure_html_response(
2238
+ "<h1>Error</h1><p>Invalid action</p>", status_code=400
2239
+ )