fastmcp 2.12.5__py3-none-any.whl → 2.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. fastmcp/__init__.py +2 -23
  2. fastmcp/cli/__init__.py +0 -3
  3. fastmcp/cli/__main__.py +5 -0
  4. fastmcp/cli/cli.py +19 -33
  5. fastmcp/cli/install/claude_code.py +6 -6
  6. fastmcp/cli/install/claude_desktop.py +3 -3
  7. fastmcp/cli/install/cursor.py +18 -12
  8. fastmcp/cli/install/gemini_cli.py +3 -3
  9. fastmcp/cli/install/mcp_json.py +3 -3
  10. fastmcp/cli/install/shared.py +0 -15
  11. fastmcp/cli/run.py +13 -8
  12. fastmcp/cli/tasks.py +110 -0
  13. fastmcp/client/__init__.py +9 -9
  14. fastmcp/client/auth/oauth.py +123 -225
  15. fastmcp/client/client.py +697 -95
  16. fastmcp/client/elicitation.py +11 -5
  17. fastmcp/client/logging.py +18 -14
  18. fastmcp/client/messages.py +7 -5
  19. fastmcp/client/oauth_callback.py +85 -171
  20. fastmcp/client/roots.py +2 -1
  21. fastmcp/client/sampling.py +1 -1
  22. fastmcp/client/tasks.py +614 -0
  23. fastmcp/client/transports.py +117 -30
  24. fastmcp/contrib/component_manager/__init__.py +1 -1
  25. fastmcp/contrib/component_manager/component_manager.py +2 -2
  26. fastmcp/contrib/component_manager/component_service.py +10 -26
  27. fastmcp/contrib/mcp_mixin/README.md +32 -1
  28. fastmcp/contrib/mcp_mixin/__init__.py +2 -2
  29. fastmcp/contrib/mcp_mixin/mcp_mixin.py +14 -2
  30. fastmcp/dependencies.py +25 -0
  31. fastmcp/experimental/sampling/handlers/openai.py +3 -3
  32. fastmcp/experimental/server/openapi/__init__.py +20 -21
  33. fastmcp/experimental/utilities/openapi/__init__.py +16 -47
  34. fastmcp/mcp_config.py +3 -4
  35. fastmcp/prompts/__init__.py +1 -1
  36. fastmcp/prompts/prompt.py +54 -51
  37. fastmcp/prompts/prompt_manager.py +16 -101
  38. fastmcp/resources/__init__.py +5 -5
  39. fastmcp/resources/resource.py +43 -21
  40. fastmcp/resources/resource_manager.py +9 -168
  41. fastmcp/resources/template.py +161 -61
  42. fastmcp/resources/types.py +30 -24
  43. fastmcp/server/__init__.py +1 -1
  44. fastmcp/server/auth/__init__.py +9 -14
  45. fastmcp/server/auth/auth.py +197 -46
  46. fastmcp/server/auth/handlers/authorize.py +326 -0
  47. fastmcp/server/auth/jwt_issuer.py +236 -0
  48. fastmcp/server/auth/middleware.py +96 -0
  49. fastmcp/server/auth/oauth_proxy.py +1469 -298
  50. fastmcp/server/auth/oidc_proxy.py +91 -20
  51. fastmcp/server/auth/providers/auth0.py +40 -21
  52. fastmcp/server/auth/providers/aws.py +29 -3
  53. fastmcp/server/auth/providers/azure.py +312 -131
  54. fastmcp/server/auth/providers/debug.py +114 -0
  55. fastmcp/server/auth/providers/descope.py +86 -29
  56. fastmcp/server/auth/providers/discord.py +308 -0
  57. fastmcp/server/auth/providers/github.py +29 -8
  58. fastmcp/server/auth/providers/google.py +48 -9
  59. fastmcp/server/auth/providers/in_memory.py +29 -5
  60. fastmcp/server/auth/providers/introspection.py +281 -0
  61. fastmcp/server/auth/providers/jwt.py +48 -31
  62. fastmcp/server/auth/providers/oci.py +233 -0
  63. fastmcp/server/auth/providers/scalekit.py +238 -0
  64. fastmcp/server/auth/providers/supabase.py +188 -0
  65. fastmcp/server/auth/providers/workos.py +35 -17
  66. fastmcp/server/context.py +236 -116
  67. fastmcp/server/dependencies.py +503 -18
  68. fastmcp/server/elicitation.py +286 -48
  69. fastmcp/server/event_store.py +177 -0
  70. fastmcp/server/http.py +71 -20
  71. fastmcp/server/low_level.py +165 -2
  72. fastmcp/server/middleware/__init__.py +1 -1
  73. fastmcp/server/middleware/caching.py +476 -0
  74. fastmcp/server/middleware/error_handling.py +14 -10
  75. fastmcp/server/middleware/logging.py +50 -39
  76. fastmcp/server/middleware/middleware.py +29 -16
  77. fastmcp/server/middleware/rate_limiting.py +3 -3
  78. fastmcp/server/middleware/tool_injection.py +116 -0
  79. fastmcp/server/openapi/__init__.py +35 -0
  80. fastmcp/{experimental/server → server}/openapi/components.py +15 -10
  81. fastmcp/{experimental/server → server}/openapi/routing.py +3 -3
  82. fastmcp/{experimental/server → server}/openapi/server.py +6 -5
  83. fastmcp/server/proxy.py +72 -48
  84. fastmcp/server/server.py +1415 -733
  85. fastmcp/server/tasks/__init__.py +21 -0
  86. fastmcp/server/tasks/capabilities.py +22 -0
  87. fastmcp/server/tasks/config.py +89 -0
  88. fastmcp/server/tasks/converters.py +205 -0
  89. fastmcp/server/tasks/handlers.py +356 -0
  90. fastmcp/server/tasks/keys.py +93 -0
  91. fastmcp/server/tasks/protocol.py +355 -0
  92. fastmcp/server/tasks/subscriptions.py +205 -0
  93. fastmcp/settings.py +125 -113
  94. fastmcp/tools/__init__.py +1 -1
  95. fastmcp/tools/tool.py +138 -55
  96. fastmcp/tools/tool_manager.py +30 -112
  97. fastmcp/tools/tool_transform.py +12 -21
  98. fastmcp/utilities/cli.py +67 -28
  99. fastmcp/utilities/components.py +10 -5
  100. fastmcp/utilities/inspect.py +79 -23
  101. fastmcp/utilities/json_schema.py +4 -4
  102. fastmcp/utilities/json_schema_type.py +8 -8
  103. fastmcp/utilities/logging.py +118 -8
  104. fastmcp/utilities/mcp_config.py +1 -2
  105. fastmcp/utilities/mcp_server_config/__init__.py +3 -3
  106. fastmcp/utilities/mcp_server_config/v1/environments/base.py +1 -2
  107. fastmcp/utilities/mcp_server_config/v1/environments/uv.py +6 -6
  108. fastmcp/utilities/mcp_server_config/v1/mcp_server_config.py +5 -5
  109. fastmcp/utilities/mcp_server_config/v1/schema.json +3 -0
  110. fastmcp/utilities/mcp_server_config/v1/sources/base.py +0 -1
  111. fastmcp/{experimental/utilities → utilities}/openapi/README.md +7 -35
  112. fastmcp/utilities/openapi/__init__.py +63 -0
  113. fastmcp/{experimental/utilities → utilities}/openapi/director.py +14 -15
  114. fastmcp/{experimental/utilities → utilities}/openapi/formatters.py +5 -5
  115. fastmcp/{experimental/utilities → utilities}/openapi/json_schema_converter.py +7 -3
  116. fastmcp/{experimental/utilities → utilities}/openapi/parser.py +37 -16
  117. fastmcp/utilities/tests.py +92 -5
  118. fastmcp/utilities/types.py +86 -16
  119. fastmcp/utilities/ui.py +626 -0
  120. {fastmcp-2.12.5.dist-info → fastmcp-2.14.0.dist-info}/METADATA +24 -15
  121. fastmcp-2.14.0.dist-info/RECORD +156 -0
  122. {fastmcp-2.12.5.dist-info → fastmcp-2.14.0.dist-info}/WHEEL +1 -1
  123. fastmcp/cli/claude.py +0 -135
  124. fastmcp/server/auth/providers/bearer.py +0 -25
  125. fastmcp/server/openapi.py +0 -1083
  126. fastmcp/utilities/openapi.py +0 -1568
  127. fastmcp/utilities/storage.py +0 -204
  128. fastmcp-2.12.5.dist-info/RECORD +0 -134
  129. fastmcp/{experimental/server → server}/openapi/README.md +0 -0
  130. fastmcp/{experimental/utilities → utilities}/openapi/models.py +3 -3
  131. fastmcp/{experimental/utilities → utilities}/openapi/schemas.py +2 -2
  132. {fastmcp-2.12.5.dist-info → fastmcp-2.14.0.dist-info}/entry_points.txt +0 -0
  133. {fastmcp-2.12.5.dist-info → fastmcp-2.14.0.dist-info}/licenses/LICENSE +0 -0
@@ -18,20 +18,29 @@ production use with enterprise identity providers.
18
18
 
19
19
  from __future__ import annotations
20
20
 
21
+ import base64
21
22
  import hashlib
23
+ import hmac
24
+ import json
22
25
  import secrets
23
26
  import time
24
27
  from base64 import urlsafe_b64encode
25
28
  from typing import TYPE_CHECKING, Any, Final
26
- from urllib.parse import urlencode
29
+ from urllib.parse import urlencode, urlparse
27
30
 
28
31
  import httpx
29
32
  from authlib.common.security import generate_token
30
33
  from authlib.integrations.httpx_client import AsyncOAuth2Client
34
+ from cryptography.fernet import Fernet
35
+ from key_value.aio.adapters.pydantic import PydanticAdapter
36
+ from key_value.aio.protocols import AsyncKeyValue
37
+ from key_value.aio.stores.disk import DiskStore
38
+ from key_value.aio.wrappers.encryption import FernetEncryptionWrapper
31
39
  from mcp.server.auth.provider import (
32
40
  AccessToken,
33
41
  AuthorizationCode,
34
42
  AuthorizationParams,
43
+ AuthorizeError,
35
44
  RefreshToken,
36
45
  TokenError,
37
46
  )
@@ -40,16 +49,34 @@ from mcp.server.auth.settings import (
40
49
  RevocationOptions,
41
50
  )
42
51
  from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
43
- from pydantic import AnyHttpUrl, AnyUrl, SecretStr
52
+ from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, SecretStr
44
53
  from starlette.requests import Request
45
- from starlette.responses import RedirectResponse
54
+ from starlette.responses import HTMLResponse, RedirectResponse
46
55
  from starlette.routing import Route
56
+ from typing_extensions import override
47
57
 
48
- import fastmcp
58
+ from fastmcp import settings
49
59
  from fastmcp.server.auth.auth import OAuthProvider, TokenVerifier
50
- from fastmcp.server.auth.redirect_validation import validate_redirect_uri
60
+ from fastmcp.server.auth.handlers.authorize import AuthorizationHandler
61
+ from fastmcp.server.auth.jwt_issuer import (
62
+ JWTIssuer,
63
+ derive_jwt_key,
64
+ )
65
+ from fastmcp.server.auth.redirect_validation import (
66
+ validate_redirect_uri,
67
+ )
51
68
  from fastmcp.utilities.logging import get_logger
52
- from fastmcp.utilities.storage import JSONFileStorage, KVStorage
69
+ from fastmcp.utilities.ui import (
70
+ BUTTON_STYLES,
71
+ DETAIL_BOX_STYLES,
72
+ DETAILS_STYLES,
73
+ INFO_BOX_STYLES,
74
+ REDIRECT_SECTION_STYLES,
75
+ TOOLTIP_STYLES,
76
+ create_logo,
77
+ create_page,
78
+ create_secure_html_response,
79
+ )
53
80
 
54
81
  if TYPE_CHECKING:
55
82
  pass
@@ -57,6 +84,121 @@ if TYPE_CHECKING:
57
84
  logger = get_logger(__name__)
58
85
 
59
86
 
87
+ # -------------------------------------------------------------------------
88
+ # Constants
89
+ # -------------------------------------------------------------------------
90
+
91
+ # Default token expiration times
92
+ DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS: Final[int] = 60 * 60 # 1 hour
93
+ DEFAULT_ACCESS_TOKEN_EXPIRY_NO_REFRESH_SECONDS: Final[int] = (
94
+ 60 * 60 * 24 * 365
95
+ ) # 1 year
96
+ DEFAULT_AUTH_CODE_EXPIRY_SECONDS: Final[int] = 5 * 60 # 5 minutes
97
+
98
+ # HTTP client timeout
99
+ HTTP_TIMEOUT_SECONDS: Final[int] = 30
100
+
101
+
102
+ # -------------------------------------------------------------------------
103
+ # Pydantic Models
104
+ # -------------------------------------------------------------------------
105
+
106
+
107
+ class OAuthTransaction(BaseModel):
108
+ """OAuth transaction state for consent flow.
109
+
110
+ Stored server-side to track active authorization flows with client context.
111
+ Includes CSRF tokens for consent protection per MCP security best practices.
112
+ """
113
+
114
+ txn_id: str
115
+ client_id: str
116
+ client_redirect_uri: str
117
+ client_state: str
118
+ code_challenge: str | None
119
+ code_challenge_method: str
120
+ scopes: list[str]
121
+ created_at: float
122
+ resource: str | None = None
123
+ proxy_code_verifier: str | None = None
124
+ csrf_token: str | None = None
125
+ csrf_expires_at: float | None = None
126
+
127
+
128
+ class ClientCode(BaseModel):
129
+ """Client authorization code with PKCE and upstream tokens.
130
+
131
+ Stored server-side after upstream IdP callback. Contains the upstream
132
+ tokens bound to the client's PKCE challenge for secure token exchange.
133
+ """
134
+
135
+ code: str
136
+ client_id: str
137
+ redirect_uri: str
138
+ code_challenge: str | None
139
+ code_challenge_method: str
140
+ scopes: list[str]
141
+ idp_tokens: dict[str, Any]
142
+ expires_at: float
143
+ created_at: float
144
+
145
+
146
+ class UpstreamTokenSet(BaseModel):
147
+ """Stored upstream OAuth tokens from identity provider.
148
+
149
+ These tokens are obtained from the upstream provider (Google, GitHub, etc.)
150
+ and stored in plaintext within this model. Encryption is handled transparently
151
+ at the storage layer via FernetEncryptionWrapper. Tokens are never exposed to MCP clients.
152
+ """
153
+
154
+ upstream_token_id: str # Unique ID for this token set
155
+ access_token: str # Upstream access token
156
+ refresh_token: str | None # Upstream refresh token
157
+ refresh_token_expires_at: (
158
+ float | None
159
+ ) # Unix timestamp when refresh token expires (if known)
160
+ expires_at: float # Unix timestamp when access token expires
161
+ token_type: str # Usually "Bearer"
162
+ scope: str # Space-separated scopes
163
+ client_id: str # MCP client this is bound to
164
+ created_at: float # Unix timestamp
165
+ raw_token_data: dict[str, Any] = Field(default_factory=dict) # Full token response
166
+
167
+
168
+ class JTIMapping(BaseModel):
169
+ """Maps FastMCP token JTI to upstream token ID.
170
+
171
+ This allows stateless JWT validation while still being able to look up
172
+ the corresponding upstream token when tools need to access upstream APIs.
173
+ """
174
+
175
+ jti: str # JWT ID from FastMCP-issued token
176
+ upstream_token_id: str # References UpstreamTokenSet
177
+ created_at: float # Unix timestamp
178
+
179
+
180
+ class RefreshTokenMetadata(BaseModel):
181
+ """Metadata for a refresh token, stored keyed by token hash.
182
+
183
+ We store only metadata (not the token itself) for security - if storage
184
+ is compromised, attackers get hashes they can't reverse into usable tokens.
185
+ """
186
+
187
+ client_id: str
188
+ scopes: list[str]
189
+ expires_at: int | None = None
190
+ created_at: float
191
+
192
+
193
+ def _hash_token(token: str) -> str:
194
+ """Hash a token for secure storage lookup.
195
+
196
+ Uses SHA-256 to create a one-way hash. The original token cannot be
197
+ recovered from the hash, providing defense in depth if storage is compromised.
198
+ """
199
+ return hashlib.sha256(token.encode()).hexdigest()
200
+
201
+
60
202
  class ProxyDCRClient(OAuthClientInformationFull):
61
203
  """Client for DCR proxy with configurable redirect URI validation.
62
204
 
@@ -83,18 +225,8 @@ class ProxyDCRClient(OAuthClientInformationFull):
83
225
  arise from accepting arbitrary redirect URIs.
84
226
  """
85
227
 
86
- def __init__(
87
- self, *args, allowed_redirect_uri_patterns: list[str] | None = None, **kwargs
88
- ):
89
- """Initialize with allowed redirect URI patterns.
90
-
91
- Args:
92
- allowed_redirect_uri_patterns: List of allowed redirect URI patterns with wildcard support.
93
- If None, defaults to localhost-only patterns.
94
- If empty list, allows all redirect URIs.
95
- """
96
- super().__init__(*args, **kwargs)
97
- self._allowed_redirect_uri_patterns = allowed_redirect_uri_patterns
228
+ allowed_redirect_uri_patterns: list[str] | None = Field(default=None)
229
+ client_name: str | None = Field(default=None)
98
230
 
99
231
  def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
100
232
  """Validate redirect URI against allowed patterns.
@@ -106,7 +238,10 @@ class ProxyDCRClient(OAuthClientInformationFull):
106
238
  """
107
239
  if redirect_uri is not None:
108
240
  # Validate against allowed patterns
109
- if validate_redirect_uri(redirect_uri, self._allowed_redirect_uri_patterns):
241
+ if validate_redirect_uri(
242
+ redirect_uri=redirect_uri,
243
+ allowed_patterns=self.allowed_redirect_uri_patterns,
244
+ ):
110
245
  return redirect_uri
111
246
  # Fall back to normal validation if not in allowed patterns
112
247
  return super().validate_redirect_uri(redirect_uri)
@@ -114,12 +249,267 @@ class ProxyDCRClient(OAuthClientInformationFull):
114
249
  return super().validate_redirect_uri(redirect_uri)
115
250
 
116
251
 
117
- # Default token expiration times
118
- DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS: Final[int] = 60 * 60 # 1 hour
119
- DEFAULT_AUTH_CODE_EXPIRY_SECONDS: Final[int] = 5 * 60 # 5 minutes
252
+ # -------------------------------------------------------------------------
253
+ # Helper Functions
254
+ # -------------------------------------------------------------------------
255
+
256
+
257
+ def create_consent_html(
258
+ client_id: str,
259
+ redirect_uri: str,
260
+ scopes: list[str],
261
+ txn_id: str,
262
+ csrf_token: str,
263
+ client_name: str | None = None,
264
+ title: str = "Application Access Request",
265
+ server_name: str | None = None,
266
+ server_icon_url: str | None = None,
267
+ server_website_url: str | None = None,
268
+ client_website_url: str | None = None,
269
+ csp_policy: str | None = None,
270
+ ) -> str:
271
+ """Create a styled HTML consent page for OAuth authorization requests.
272
+
273
+ Args:
274
+ csp_policy: Content Security Policy override.
275
+ If None, uses the built-in CSP policy with appropriate directives.
276
+ If empty string "", disables CSP entirely (no meta tag is rendered).
277
+ If a non-empty string, uses that as the CSP policy value.
278
+ """
279
+ import html as html_module
280
+
281
+ client_display = html_module.escape(client_name or client_id)
282
+ server_name_escaped = html_module.escape(server_name or "FastMCP")
283
+
284
+ # Make server name a hyperlink if website URL is available
285
+ if server_website_url:
286
+ website_url_escaped = html_module.escape(server_website_url)
287
+ server_display = f'<a href="{website_url_escaped}" target="_blank" rel="noopener noreferrer" class="server-name-link">{server_name_escaped}</a>'
288
+ else:
289
+ server_display = server_name_escaped
290
+
291
+ # Build intro box with call-to-action
292
+ intro_box = f"""
293
+ <div class="info-box">
294
+ <p>The application <strong>{client_display}</strong> wants to access the MCP server <strong>{server_display}</strong>. Please ensure you recognize the callback address below.</p>
295
+ </div>
296
+ """
120
297
 
121
- # HTTP client timeout
122
- HTTP_TIMEOUT_SECONDS: Final[int] = 30
298
+ # Build redirect URI section (yellow box, centered)
299
+ redirect_uri_escaped = html_module.escape(redirect_uri)
300
+ redirect_section = f"""
301
+ <div class="redirect-section">
302
+ <span class="label">Credentials will be sent to:</span>
303
+ <div class="value">{redirect_uri_escaped}</div>
304
+ </div>
305
+ """
306
+
307
+ # Build advanced details with collapsible section
308
+ detail_rows = [
309
+ ("Application Name", html_module.escape(client_name or client_id)),
310
+ ("Application Website", html_module.escape(client_website_url or "N/A")),
311
+ ("Application ID", client_id),
312
+ ("Redirect URI", redirect_uri_escaped),
313
+ (
314
+ "Requested Scopes",
315
+ ", ".join(html_module.escape(s) for s in scopes) if scopes else "None",
316
+ ),
317
+ ]
318
+
319
+ detail_rows_html = "\n".join(
320
+ [
321
+ f"""
322
+ <div class="detail-row">
323
+ <div class="detail-label">{label}:</div>
324
+ <div class="detail-value">{value}</div>
325
+ </div>
326
+ """
327
+ for label, value in detail_rows
328
+ ]
329
+ )
330
+
331
+ advanced_details = f"""
332
+ <details>
333
+ <summary>Advanced Details</summary>
334
+ <div class="detail-box">
335
+ {detail_rows_html}
336
+ </div>
337
+ </details>
338
+ """
339
+
340
+ # Build form with buttons
341
+ # Use empty action to submit to current URL (/consent or /mcp/consent)
342
+ # The POST handler is registered at the same path as GET
343
+ form = f"""
344
+ <form id="consentForm" method="POST" action="">
345
+ <input type="hidden" name="txn_id" value="{txn_id}" />
346
+ <input type="hidden" name="csrf_token" value="{csrf_token}" />
347
+ <input type="hidden" name="submit" value="true" />
348
+ <div class="button-group">
349
+ <button type="submit" name="action" value="approve" class="btn-approve">Allow Access</button>
350
+ <button type="submit" name="action" value="deny" class="btn-deny">Deny</button>
351
+ </div>
352
+ </form>
353
+ """
354
+
355
+ # Build help link with tooltip (identical to current implementation)
356
+ help_link = """
357
+ <div class="help-link-container">
358
+ <span class="help-link">
359
+ Why am I seeing this?
360
+ <span class="tooltip">
361
+ This FastMCP server requires your consent to allow a new client
362
+ to connect. This protects you from <a
363
+ href="https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#confused-deputy-problem"
364
+ target="_blank" class="tooltip-link">confused deputy
365
+ attacks</a>, where malicious clients could impersonate you
366
+ and steal access.<br><br>
367
+ <a
368
+ href="https://gofastmcp.com/servers/auth/oauth-proxy#confused-deputy-attacks"
369
+ target="_blank" class="tooltip-link">Learn more about
370
+ FastMCP security →</a>
371
+ </span>
372
+ </span>
373
+ </div>
374
+ """
375
+
376
+ # Build the page content
377
+ content = f"""
378
+ <div class="container">
379
+ {create_logo(icon_url=server_icon_url, alt_text=server_name or "FastMCP")}
380
+ <h1>Application Access Request</h1>
381
+ {intro_box}
382
+ {redirect_section}
383
+ {advanced_details}
384
+ {form}
385
+ </div>
386
+ {help_link}
387
+ """
388
+
389
+ # Additional styles needed for this page
390
+ additional_styles = (
391
+ INFO_BOX_STYLES
392
+ + REDIRECT_SECTION_STYLES
393
+ + DETAILS_STYLES
394
+ + DETAIL_BOX_STYLES
395
+ + BUTTON_STYLES
396
+ + TOOLTIP_STYLES
397
+ )
398
+
399
+ # Determine CSP policy to use
400
+ # If csp_policy is None, build the default CSP policy
401
+ # If csp_policy is empty string, CSP will be disabled entirely in create_page
402
+ # If csp_policy is a non-empty string, use it as-is
403
+ if csp_policy is None:
404
+ # Need to allow form-action for form submission
405
+ # Chrome requires explicit scheme declarations in CSP form-action when redirect chains
406
+ # end in custom protocol schemes (e.g., cursor://). Parse redirect_uri to include its scheme.
407
+ parsed_redirect = urlparse(redirect_uri)
408
+ redirect_scheme = parsed_redirect.scheme.lower()
409
+
410
+ # Build form-action directive with standard schemes plus custom protocol if present
411
+ form_action_schemes = ["https:", "http:"]
412
+ if redirect_scheme and redirect_scheme not in ("http", "https"):
413
+ # Custom protocol scheme (e.g., cursor:, vscode:, etc.)
414
+ form_action_schemes.append(f"{redirect_scheme}:")
415
+
416
+ form_action_directive = " ".join(form_action_schemes)
417
+ csp_policy = f"default-src 'none'; style-src 'unsafe-inline'; img-src https: data:; base-uri 'none'; form-action {form_action_directive}"
418
+
419
+ return create_page(
420
+ content=content,
421
+ title=title,
422
+ additional_styles=additional_styles,
423
+ csp_policy=csp_policy,
424
+ )
425
+
426
+
427
+ def create_error_html(
428
+ error_title: str,
429
+ error_message: str,
430
+ error_details: dict[str, str] | None = None,
431
+ server_name: str | None = None,
432
+ server_icon_url: str | None = None,
433
+ ) -> str:
434
+ """Create a styled HTML error page for OAuth errors.
435
+
436
+ Args:
437
+ error_title: The error title (e.g., "OAuth Error", "Authorization Failed")
438
+ error_message: The main error message to display
439
+ error_details: Optional dictionary of error details to show (e.g., `{"Error Code": "invalid_client"}`)
440
+ server_name: Optional server name to display
441
+ server_icon_url: Optional URL to server icon/logo
442
+
443
+ Returns:
444
+ Complete HTML page as a string
445
+ """
446
+ import html as html_module
447
+
448
+ error_message_escaped = html_module.escape(error_message)
449
+
450
+ # Build error message box
451
+ error_box = f"""
452
+ <div class="info-box error">
453
+ <p>{error_message_escaped}</p>
454
+ </div>
455
+ """
456
+
457
+ # Build error details section if provided
458
+ details_section = ""
459
+ if error_details:
460
+ detail_rows_html = "\n".join(
461
+ [
462
+ f"""
463
+ <div class="detail-row">
464
+ <div class="detail-label">{html_module.escape(label)}:</div>
465
+ <div class="detail-value">{html_module.escape(value)}</div>
466
+ </div>
467
+ """
468
+ for label, value in error_details.items()
469
+ ]
470
+ )
471
+
472
+ details_section = f"""
473
+ <details>
474
+ <summary>Error Details</summary>
475
+ <div class="detail-box">
476
+ {detail_rows_html}
477
+ </div>
478
+ </details>
479
+ """
480
+
481
+ # Build the page content
482
+ content = f"""
483
+ <div class="container">
484
+ {create_logo(icon_url=server_icon_url, alt_text=server_name or "FastMCP")}
485
+ <h1>{html_module.escape(error_title)}</h1>
486
+ {error_box}
487
+ {details_section}
488
+ </div>
489
+ """
490
+
491
+ # Additional styles needed for this page
492
+ # Override .info-box.error to use normal text color instead of red
493
+ additional_styles = (
494
+ INFO_BOX_STYLES
495
+ + DETAILS_STYLES
496
+ + DETAIL_BOX_STYLES
497
+ + """
498
+ .info-box.error {
499
+ color: #111827;
500
+ }
501
+ """
502
+ )
503
+
504
+ # Simple CSP policy for error pages (no forms needed)
505
+ csp_policy = "default-src 'none'; style-src 'unsafe-inline'; img-src https: data:; base-uri 'none'"
506
+
507
+ return create_page(
508
+ content=content,
509
+ title=error_title,
510
+ additional_styles=additional_styles,
511
+ csp_policy=csp_policy,
512
+ )
123
513
 
124
514
 
125
515
  class OAuthProxy(OAuthProvider):
@@ -200,15 +590,18 @@ class OAuthProxy(OAuthProvider):
200
590
 
201
591
  State Management
202
592
  ---------------
203
- The proxy maintains minimal but crucial state:
204
- - _clients: DCR registrations (all use ProxyDCRClient for flexibility)
593
+ The proxy maintains minimal but crucial state via pluggable storage (client_storage):
205
594
  - _oauth_transactions: Active authorization flows with client context
206
595
  - _client_codes: Authorization codes with PKCE challenges and upstream tokens
207
- - _access_tokens, _refresh_tokens: Token storage for revocation
208
- - Token relationship mappings for cleanup and rotation
596
+ - _jti_mapping_store: Maps FastMCP token JTIs to upstream token IDs
597
+ - _refresh_token_store: Refresh token metadata (keyed by token hash)
598
+
599
+ All state is stored in the configured client_storage backend (Redis, disk, etc.)
600
+ enabling horizontal scaling across multiple instances.
209
601
 
210
602
  Security Considerations
211
603
  ----------------------
604
+ - Refresh tokens stored by hash only (defense in depth if storage compromised)
212
605
  - PKCE enforced end-to-end (client to proxy, proxy to upstream)
213
606
  - Authorization codes are single-use with short expiry
214
607
  - Transaction IDs are cryptographically random
@@ -257,7 +650,14 @@ class OAuthProxy(OAuthProvider):
257
650
  # Extra parameters to forward to token endpoint
258
651
  extra_token_params: dict[str, str] | None = None,
259
652
  # Client storage
260
- client_storage: KVStorage | None = None,
653
+ client_storage: AsyncKeyValue | None = None,
654
+ # JWT signing key
655
+ jwt_signing_key: str | bytes | None = None,
656
+ # Consent screen configuration
657
+ require_authorization_consent: bool = True,
658
+ consent_csp_policy: str | None = None,
659
+ # Token expiry fallback
660
+ fallback_access_token_expiry_seconds: int | None = None,
261
661
  ):
262
662
  """Initialize the OAuth proxy provider.
263
663
 
@@ -291,10 +691,29 @@ class OAuthProxy(OAuthProvider):
291
691
  Example: {"audience": "https://api.example.com"}
292
692
  extra_token_params: Additional parameters to forward to the upstream token endpoint.
293
693
  Useful for provider-specific parameters during token exchange.
294
- client_storage: Storage implementation for OAuth client registrations.
295
- Defaults to file-based storage in ~/.fastmcp/oauth-proxy-clients/ if not specified.
296
- Pass any KVStorage implementation for custom storage backends.
694
+ client_storage: Storage backend for OAuth state (client registrations, tokens).
695
+ If None, an encrypted DiskStore will be created in the data directory.
696
+ jwt_signing_key: Secret for signing FastMCP JWT tokens (any string or bytes).
697
+ If bytes are provided, they will be used as-is.
698
+ If a string is provided, it will be derived into a 32-byte key using PBKDF2 (1.2M iterations).
699
+ If not provided, it will be derived from the upstream client secret using HKDF.
700
+ require_authorization_consent: Whether to require user consent before authorizing clients (default True).
701
+ When True, users see a consent screen before being redirected to the upstream IdP.
702
+ When False, authorization proceeds directly without user confirmation.
703
+ SECURITY WARNING: Only disable for local development or testing environments.
704
+ consent_csp_policy: Content Security Policy for the consent page.
705
+ If None (default), uses the built-in CSP policy with appropriate directives.
706
+ If empty string "", disables CSP entirely (no meta tag is rendered).
707
+ If a non-empty string, uses that as the CSP policy value.
708
+ This allows organizations with their own CSP policies to override or disable
709
+ the built-in CSP directives.
710
+ fallback_access_token_expiry_seconds: Expiry time to use when upstream provider
711
+ doesn't return `expires_in` in the token response. If not set, uses smart
712
+ defaults: 1 hour if a refresh token is available (since we can refresh),
713
+ or 1 year if no refresh token (for API-key-style tokens like GitHub OAuth Apps).
714
+ Set explicitly to override these defaults.
297
715
  """
716
+
298
717
  # Always enable DCR since we implement it locally for MCP clients
299
718
  client_registration_options = ClientRegistrationOptions(
300
719
  enabled=True,
@@ -316,12 +735,14 @@ class OAuthProxy(OAuthProvider):
316
735
  )
317
736
 
318
737
  # Store upstream configuration
319
- self._upstream_authorization_endpoint = upstream_authorization_endpoint
320
- self._upstream_token_endpoint = upstream_token_endpoint
321
- self._upstream_client_id = upstream_client_id
322
- self._upstream_client_secret = SecretStr(upstream_client_secret)
323
- self._upstream_revocation_endpoint = upstream_revocation_endpoint
324
- self._default_scope_str = " ".join(self.required_scopes or [])
738
+ self._upstream_authorization_endpoint: str = upstream_authorization_endpoint
739
+ self._upstream_token_endpoint: str = upstream_token_endpoint
740
+ self._upstream_client_id: str = upstream_client_id
741
+ self._upstream_client_secret: SecretStr = SecretStr(
742
+ secret_value=upstream_client_secret
743
+ )
744
+ self._upstream_revocation_endpoint: str | None = upstream_revocation_endpoint
745
+ self._default_scope_str: str = " ".join(self.required_scopes or [])
325
746
 
326
747
  # Store redirect configuration
327
748
  if not redirect_path:
@@ -330,40 +751,146 @@ class OAuthProxy(OAuthProvider):
330
751
  self._redirect_path = (
331
752
  redirect_path if redirect_path.startswith("/") else f"/{redirect_path}"
332
753
  )
333
- self._allowed_client_redirect_uris = allowed_client_redirect_uris
754
+
755
+ if (
756
+ isinstance(allowed_client_redirect_uris, list)
757
+ and not allowed_client_redirect_uris
758
+ ):
759
+ logger.warning(
760
+ "allowed_client_redirect_uris is empty list; no redirect URIs will be accepted. "
761
+ + "This will block all OAuth clients."
762
+ )
763
+ self._allowed_client_redirect_uris: list[str] | None = (
764
+ allowed_client_redirect_uris
765
+ )
334
766
 
335
767
  # PKCE configuration
336
- self._forward_pkce = forward_pkce
768
+ self._forward_pkce: bool = forward_pkce
337
769
 
338
770
  # Token endpoint authentication
339
- self._token_endpoint_auth_method = token_endpoint_auth_method
771
+ self._token_endpoint_auth_method: str | None = token_endpoint_auth_method
772
+
773
+ # Consent screen configuration
774
+ self._require_authorization_consent: bool = require_authorization_consent
775
+ self._consent_csp_policy: str | None = consent_csp_policy
776
+ if not require_authorization_consent:
777
+ logger.warning(
778
+ "Authorization consent screen disabled - only use for local development or testing. "
779
+ + "In production, this screen protects against confused deputy attacks."
780
+ )
340
781
 
341
782
  # Extra parameters for authorization and token endpoints
342
- self._extra_authorize_params = extra_authorize_params or {}
343
- self._extra_token_params = extra_token_params or {}
783
+ self._extra_authorize_params: dict[str, str] = extra_authorize_params or {}
784
+ self._extra_token_params: dict[str, str] = extra_token_params or {}
785
+
786
+ # Token expiry fallback (None means use smart default based on refresh token)
787
+ self._fallback_access_token_expiry_seconds: int | None = (
788
+ fallback_access_token_expiry_seconds
789
+ )
790
+
791
+ if jwt_signing_key is None:
792
+ jwt_signing_key = derive_jwt_key(
793
+ high_entropy_material=upstream_client_secret,
794
+ salt="fastmcp-jwt-signing-key",
795
+ )
796
+
797
+ if isinstance(jwt_signing_key, str):
798
+ if len(jwt_signing_key) < 12:
799
+ logger.warning(
800
+ "jwt_signing_key is less than 12 characters; it is recommended to use a longer. "
801
+ + "string for the key derivation."
802
+ )
803
+ jwt_signing_key = derive_jwt_key(
804
+ low_entropy_material=jwt_signing_key,
805
+ salt="fastmcp-jwt-signing-key",
806
+ )
344
807
 
345
- # Initialize client storage (default to file-based if not provided)
808
+ self._jwt_issuer: JWTIssuer = JWTIssuer(
809
+ issuer=str(self.base_url),
810
+ audience=f"{str(self.base_url).rstrip('/')}/mcp",
811
+ signing_key=jwt_signing_key,
812
+ )
813
+
814
+ # If the user does not provide a store, we will provide an encrypted disk store
346
815
  if client_storage is None:
347
- cache_dir = fastmcp.settings.home / "oauth-proxy-clients"
348
- client_storage = JSONFileStorage(cache_dir)
349
- self._client_storage = client_storage
816
+ storage_encryption_key = derive_jwt_key(
817
+ high_entropy_material=jwt_signing_key.decode(),
818
+ salt="fastmcp-storage-encryption-key",
819
+ )
820
+ client_storage = FernetEncryptionWrapper(
821
+ key_value=DiskStore(directory=settings.home / "oauth-proxy"),
822
+ fernet=Fernet(key=storage_encryption_key),
823
+ )
824
+
825
+ self._client_storage: AsyncKeyValue = client_storage
350
826
 
351
- # Local state for token bookkeeping only (no client caching)
352
- self._access_tokens: dict[str, AccessToken] = {}
353
- self._refresh_tokens: dict[str, RefreshToken] = {}
827
+ # Cache HTTPS check to avoid repeated logging
828
+ self._is_https: bool = str(self.base_url).startswith("https://")
829
+ if not self._is_https:
830
+ logger.warning(
831
+ "Using non-secure cookies for development; deploy with HTTPS for production."
832
+ )
354
833
 
355
- # Token relation mappings for cleanup
356
- self._access_to_refresh: dict[str, str] = {}
357
- self._refresh_to_access: dict[str, str] = {}
834
+ self._upstream_token_store: PydanticAdapter[UpstreamTokenSet] = PydanticAdapter[
835
+ UpstreamTokenSet
836
+ ](
837
+ key_value=self._client_storage,
838
+ pydantic_model=UpstreamTokenSet,
839
+ default_collection="mcp-upstream-tokens",
840
+ raise_on_validation_error=True,
841
+ )
842
+
843
+ self._client_store: PydanticAdapter[ProxyDCRClient] = PydanticAdapter[
844
+ ProxyDCRClient
845
+ ](
846
+ key_value=self._client_storage,
847
+ pydantic_model=ProxyDCRClient,
848
+ default_collection="mcp-oauth-proxy-clients",
849
+ raise_on_validation_error=True,
850
+ )
358
851
 
359
852
  # OAuth transaction storage for IdP callback forwarding
360
- self._oauth_transactions: dict[
361
- str, dict[str, Any]
362
- ] = {} # txn_id -> transaction_data
363
- self._client_codes: dict[str, dict[str, Any]] = {} # client_code -> code_data
853
+ # Reuse client_storage with different collections for state management
854
+ self._transaction_store: PydanticAdapter[OAuthTransaction] = PydanticAdapter[
855
+ OAuthTransaction
856
+ ](
857
+ key_value=self._client_storage,
858
+ pydantic_model=OAuthTransaction,
859
+ default_collection="mcp-oauth-transactions",
860
+ raise_on_validation_error=True,
861
+ )
862
+
863
+ self._code_store: PydanticAdapter[ClientCode] = PydanticAdapter[ClientCode](
864
+ key_value=self._client_storage,
865
+ pydantic_model=ClientCode,
866
+ default_collection="mcp-authorization-codes",
867
+ raise_on_validation_error=True,
868
+ )
869
+
870
+ # Storage for JTI mappings (FastMCP token -> upstream token)
871
+ self._jti_mapping_store: PydanticAdapter[JTIMapping] = PydanticAdapter[
872
+ JTIMapping
873
+ ](
874
+ key_value=self._client_storage,
875
+ pydantic_model=JTIMapping,
876
+ default_collection="mcp-jti-mappings",
877
+ raise_on_validation_error=True,
878
+ )
879
+
880
+ # Refresh token metadata storage, keyed by token hash for security.
881
+ # We only store metadata (not the token itself) - if storage is compromised,
882
+ # attackers get hashes they can't reverse into usable tokens.
883
+ self._refresh_token_store: PydanticAdapter[RefreshTokenMetadata] = (
884
+ PydanticAdapter[RefreshTokenMetadata](
885
+ key_value=self._client_storage,
886
+ pydantic_model=RefreshTokenMetadata,
887
+ default_collection="mcp-refresh-tokens",
888
+ raise_on_validation_error=True,
889
+ )
890
+ )
364
891
 
365
892
  # Use the provided token validator
366
- self._token_validator = token_verifier
893
+ self._token_validator: TokenVerifier = token_verifier
367
894
 
368
895
  logger.debug(
369
896
  "Initialized OAuth proxy provider with upstream server %s",
@@ -393,6 +920,7 @@ class OAuthProxy(OAuthProvider):
393
920
  # Client Registration (Local Implementation)
394
921
  # -------------------------------------------------------------------------
395
922
 
923
+ @override
396
924
  async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
397
925
  """Get client information by ID. This is generally the random ID
398
926
  provided to the DCR client during registration, not the upstream client ID.
@@ -400,20 +928,15 @@ class OAuthProxy(OAuthProvider):
400
928
  For unregistered clients, returns None (which will raise an error in the SDK).
401
929
  """
402
930
  # Load from storage
403
- data = await self._client_storage.get(client_id)
404
- if not data:
931
+ if not (client := await self._client_store.get(key=client_id)):
405
932
  return None
406
933
 
407
- if client_data := data.get("client", None):
408
- return ProxyDCRClient(
409
- allowed_redirect_uri_patterns=data.get(
410
- "allowed_redirect_uri_patterns", self._allowed_client_redirect_uris
411
- ),
412
- **client_data,
413
- )
934
+ if client.allowed_redirect_uri_patterns is None:
935
+ client.allowed_redirect_uri_patterns = self._allowed_client_redirect_uris
414
936
 
415
- return None
937
+ return client
416
938
 
939
+ @override
417
940
  async def register_client(self, client_info: OAuthClientInformationFull) -> None:
418
941
  """Register a client locally
419
942
 
@@ -424,23 +947,28 @@ class OAuthProxy(OAuthProvider):
424
947
  """
425
948
 
426
949
  # Create a ProxyDCRClient with configured redirect URI validation
427
- proxy_client = ProxyDCRClient(
950
+ if client_info.client_id is None:
951
+ raise ValueError("client_id is required for client registration")
952
+ # We use token_endpoint_auth_method="none" because the proxy handles
953
+ # all upstream authentication. The client_secret must also be None
954
+ # because the SDK requires secrets to be provided if they're set,
955
+ # regardless of auth method.
956
+ proxy_client: ProxyDCRClient = ProxyDCRClient(
428
957
  client_id=client_info.client_id,
429
- client_secret=client_info.client_secret,
958
+ client_secret=None,
430
959
  redirect_uris=client_info.redirect_uris or [AnyUrl("http://localhost")],
431
960
  grant_types=client_info.grant_types
432
961
  or ["authorization_code", "refresh_token"],
433
962
  scope=client_info.scope or self._default_scope_str,
434
963
  token_endpoint_auth_method="none",
435
964
  allowed_redirect_uri_patterns=self._allowed_client_redirect_uris,
965
+ client_name=getattr(client_info, "client_name", None),
436
966
  )
437
967
 
438
- # Store as structured dict with all needed metadata
439
- storage_data = {
440
- "client": proxy_client.model_dump(mode="json"),
441
- "allowed_redirect_uri_patterns": self._allowed_client_redirect_uris,
442
- }
443
- await self._client_storage.set(client_info.client_id, storage_data)
968
+ await self._client_store.put(
969
+ key=client_info.client_id,
970
+ value=proxy_client,
971
+ )
444
972
 
445
973
  # Log redirect URIs to help users discover what patterns they might need
446
974
  if client_info.redirect_uris:
@@ -454,25 +982,28 @@ class OAuthProxy(OAuthProvider):
454
982
  logger.debug(
455
983
  "Registered client %s with %d redirect URIs",
456
984
  client_info.client_id,
457
- len(proxy_client.redirect_uris),
985
+ len(proxy_client.redirect_uris) if proxy_client.redirect_uris else 0,
458
986
  )
459
987
 
460
988
  # -------------------------------------------------------------------------
461
989
  # Authorization Flow (Proxy to Upstream)
462
990
  # -------------------------------------------------------------------------
463
991
 
992
+ @override
464
993
  async def authorize(
465
994
  self,
466
995
  client: OAuthClientInformationFull,
467
996
  params: AuthorizationParams,
468
997
  ) -> str:
469
- """Start OAuth transaction and redirect to upstream IdP.
998
+ """Start OAuth transaction and route through consent interstitial.
999
+
1000
+ Flow:
1001
+ 1. Store transaction with client details and PKCE (if forwarding)
1002
+ 2. Return local /consent URL; browser visits consent first
1003
+ 3. Consent handler redirects to upstream IdP if approved/already approved
470
1004
 
471
- This implements the DCR-compliant proxy pattern:
472
- 1. Store transaction with client details and PKCE challenge
473
- 2. Generate proxy's own PKCE parameters if forwarding is enabled
474
- 3. Use transaction ID as state for IdP
475
- 4. Redirect to IdP with our fixed callback URL and proxy's PKCE
1005
+ If consent is disabled (require_authorization_consent=False), skip the consent screen
1006
+ and redirect directly to the upstream IdP.
476
1007
  """
477
1008
  # Generate transaction ID for this authorization request
478
1009
  txn_id = secrets.token_urlsafe(32)
@@ -488,80 +1019,57 @@ class OAuthProxy(OAuthProvider):
488
1019
  )
489
1020
 
490
1021
  # Store transaction data for IdP callback processing
491
- transaction_data = {
492
- "client_id": client.client_id,
493
- "client_redirect_uri": str(params.redirect_uri),
494
- "client_state": params.state,
495
- "code_challenge": params.code_challenge,
496
- "code_challenge_method": getattr(params, "code_challenge_method", "S256"),
497
- "scopes": params.scopes or [],
498
- "created_at": time.time(),
499
- }
500
-
501
- # Store proxy's PKCE verifier if we're forwarding
502
- if proxy_code_verifier:
503
- transaction_data["proxy_code_verifier"] = proxy_code_verifier
504
-
505
- self._oauth_transactions[txn_id] = transaction_data
506
-
507
- # Build query parameters for upstream IdP authorization request
508
- # Use our fixed IdP callback and transaction ID as state
509
- query_params: dict[str, Any] = {
510
- "response_type": "code",
511
- "client_id": self._upstream_client_id,
512
- "redirect_uri": f"{str(self.base_url).rstrip('/')}{self._redirect_path}",
513
- "state": txn_id, # Use txn_id as IdP state
514
- }
515
-
516
- # Add scopes - use client scopes or fallback to required scopes
517
- scopes_to_use = params.scopes or self.required_scopes or []
518
-
519
- if scopes_to_use:
520
- query_params["scope"] = " ".join(scopes_to_use)
521
-
522
- # Forward proxy's PKCE challenge to upstream if enabled
523
- if proxy_code_challenge:
524
- query_params["code_challenge"] = proxy_code_challenge
525
- query_params["code_challenge_method"] = "S256"
526
- logger.debug(
527
- "Forwarding proxy PKCE challenge to upstream for transaction %s",
528
- txn_id,
1022
+ if client.client_id is None:
1023
+ raise AuthorizeError(
1024
+ error="invalid_client", # type: ignore[arg-type]
1025
+ error_description="Client ID is required",
529
1026
  )
1027
+ transaction = OAuthTransaction(
1028
+ txn_id=txn_id,
1029
+ client_id=client.client_id,
1030
+ client_redirect_uri=str(params.redirect_uri),
1031
+ client_state=params.state or "",
1032
+ code_challenge=params.code_challenge,
1033
+ code_challenge_method=getattr(params, "code_challenge_method", "S256"),
1034
+ scopes=params.scopes or [],
1035
+ created_at=time.time(),
1036
+ resource=getattr(params, "resource", None),
1037
+ proxy_code_verifier=proxy_code_verifier,
1038
+ )
1039
+ await self._transaction_store.put(
1040
+ key=txn_id,
1041
+ value=transaction,
1042
+ ttl=15 * 60, # Auto-expire after 15 minutes
1043
+ )
530
1044
 
531
- # Forward resource parameter if provided (RFC 8707)
532
- if params.resource:
533
- query_params["resource"] = params.resource
534
- logger.debug(
535
- "Forwarding resource indicator '%s' to upstream for transaction %s",
536
- params.resource,
537
- txn_id,
1045
+ # If consent is disabled, skip consent screen and go directly to upstream IdP
1046
+ if not self._require_authorization_consent:
1047
+ upstream_url = self._build_upstream_authorize_url(
1048
+ txn_id, transaction.model_dump()
538
1049
  )
539
-
540
- # Add any extra authorization parameters configured for this proxy
541
- if self._extra_authorize_params:
542
- query_params.update(self._extra_authorize_params)
543
1050
  logger.debug(
544
- "Adding extra authorization parameters for transaction %s: %s",
1051
+ "Starting OAuth transaction %s for client %s, redirecting directly to upstream IdP (consent disabled, PKCE forwarding: %s)",
545
1052
  txn_id,
546
- list(self._extra_authorize_params.keys()),
1053
+ client.client_id,
1054
+ "enabled" if proxy_code_challenge else "disabled",
547
1055
  )
1056
+ return upstream_url
548
1057
 
549
- # Build the upstream authorization URL
550
- separator = "&" if "?" in self._upstream_authorization_endpoint else "?"
551
- upstream_url = f"{self._upstream_authorization_endpoint}{separator}{urlencode(query_params)}"
1058
+ consent_url = f"{str(self.base_url).rstrip('/')}/consent?txn_id={txn_id}"
552
1059
 
553
1060
  logger.debug(
554
- "Starting OAuth transaction %s for client %s, redirecting to IdP (PKCE forwarding: %s)",
1061
+ "Starting OAuth transaction %s for client %s, redirecting to consent page (PKCE forwarding: %s)",
555
1062
  txn_id,
556
1063
  client.client_id,
557
1064
  "enabled" if proxy_code_challenge else "disabled",
558
1065
  )
559
- return upstream_url
1066
+ return consent_url
560
1067
 
561
1068
  # -------------------------------------------------------------------------
562
1069
  # Authorization Code Handling
563
1070
  # -------------------------------------------------------------------------
564
1071
 
1072
+ @override
565
1073
  async def load_authorization_code(
566
1074
  self,
567
1075
  client: OAuthClientInformationFull,
@@ -573,111 +1081,258 @@ class OAuthProxy(OAuthProvider):
573
1081
  with PKCE challenge for validation.
574
1082
  """
575
1083
  # Look up client code data
576
- code_data = self._client_codes.get(authorization_code)
577
- if not code_data:
1084
+ code_model = await self._code_store.get(key=authorization_code)
1085
+ if not code_model:
578
1086
  logger.debug("Authorization code not found: %s", authorization_code)
579
1087
  return None
580
1088
 
581
1089
  # Check if code expired
582
- if time.time() > code_data["expires_at"]:
1090
+ if time.time() > code_model.expires_at:
583
1091
  logger.debug("Authorization code expired: %s", authorization_code)
584
- self._client_codes.pop(authorization_code, None)
1092
+ _ = await self._code_store.delete(key=authorization_code)
585
1093
  return None
586
1094
 
587
1095
  # Verify client ID matches
588
- if code_data["client_id"] != client.client_id:
1096
+ if code_model.client_id != client.client_id:
589
1097
  logger.debug(
590
1098
  "Authorization code client ID mismatch: %s vs %s",
591
- code_data["client_id"],
1099
+ code_model.client_id,
592
1100
  client.client_id,
593
1101
  )
594
1102
  return None
595
1103
 
596
1104
  # Create authorization code object with PKCE challenge
1105
+ if client.client_id is None:
1106
+ raise AuthorizeError(
1107
+ error="invalid_client", # type: ignore[arg-type]
1108
+ error_description="Client ID is required",
1109
+ )
597
1110
  return AuthorizationCode(
598
1111
  code=authorization_code,
599
1112
  client_id=client.client_id,
600
- redirect_uri=code_data["redirect_uri"],
1113
+ redirect_uri=AnyUrl(url=code_model.redirect_uri),
601
1114
  redirect_uri_provided_explicitly=True,
602
- scopes=code_data["scopes"],
603
- expires_at=code_data["expires_at"],
604
- code_challenge=code_data.get("code_challenge", ""),
1115
+ scopes=code_model.scopes,
1116
+ expires_at=code_model.expires_at,
1117
+ code_challenge=code_model.code_challenge or "",
605
1118
  )
606
1119
 
1120
+ @override
607
1121
  async def exchange_authorization_code(
608
1122
  self,
609
1123
  client: OAuthClientInformationFull,
610
1124
  authorization_code: AuthorizationCode,
611
1125
  ) -> OAuthToken:
612
- """Exchange authorization code for stored IdP tokens.
1126
+ """Exchange authorization code for FastMCP-issued tokens.
1127
+
1128
+ Implements the token factory pattern:
1129
+ 1. Retrieves upstream tokens from stored authorization code
1130
+ 2. Extracts user identity from upstream token
1131
+ 3. Encrypts and stores upstream tokens
1132
+ 4. Issues FastMCP-signed JWT tokens
1133
+ 5. Returns FastMCP tokens (NOT upstream tokens)
613
1134
 
614
- For the DCR-compliant proxy flow, we return the IdP tokens that were obtained
615
- during the IdP callback exchange. PKCE validation is handled by the MCP framework.
1135
+ PKCE validation is handled by the MCP framework before this method is called.
616
1136
  """
617
1137
  # Look up stored code data
618
- code_data = self._client_codes.get(authorization_code.code)
619
- if not code_data:
1138
+ code_model = await self._code_store.get(key=authorization_code.code)
1139
+ if not code_model:
620
1140
  logger.error(
621
1141
  "Authorization code not found in client codes: %s",
622
1142
  authorization_code.code,
623
1143
  )
624
1144
  raise TokenError("invalid_grant", "Authorization code not found")
625
1145
 
626
- # Get stored IdP tokens
627
- idp_tokens = code_data["idp_tokens"]
1146
+ # Get stored upstream tokens
1147
+ idp_tokens = code_model.idp_tokens
628
1148
 
629
1149
  # Clean up client code (one-time use)
630
- self._client_codes.pop(authorization_code.code, None)
1150
+ await self._code_store.delete(key=authorization_code.code)
1151
+
1152
+ # Generate IDs for token storage
1153
+ upstream_token_id = secrets.token_urlsafe(32)
1154
+ access_jti = secrets.token_urlsafe(32)
1155
+ refresh_jti = (
1156
+ secrets.token_urlsafe(32) if idp_tokens.get("refresh_token") else None
1157
+ )
1158
+
1159
+ # Calculate token expiry times
1160
+ # If upstream provides expires_in, use it. Otherwise use fallback based on:
1161
+ # - User-provided fallback if set
1162
+ # - 1 hour if refresh token available (can refresh when expired)
1163
+ # - 1 year if no refresh token (likely API-key-style token like GitHub OAuth Apps)
1164
+ if "expires_in" in idp_tokens:
1165
+ expires_in = int(idp_tokens["expires_in"])
1166
+ elif self._fallback_access_token_expiry_seconds is not None:
1167
+ expires_in = self._fallback_access_token_expiry_seconds
1168
+ elif idp_tokens.get("refresh_token"):
1169
+ expires_in = DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS
1170
+ else:
1171
+ expires_in = DEFAULT_ACCESS_TOKEN_EXPIRY_NO_REFRESH_SECONDS
1172
+
1173
+ # Calculate refresh token expiry if provided by upstream
1174
+ # Some providers include refresh_expires_in, some don't
1175
+ refresh_expires_in = None
1176
+ refresh_token_expires_at = None
1177
+ if idp_tokens.get("refresh_token"):
1178
+ if "refresh_expires_in" in idp_tokens:
1179
+ refresh_expires_in = int(idp_tokens["refresh_expires_in"])
1180
+ refresh_token_expires_at = time.time() + refresh_expires_in
1181
+ logger.debug(
1182
+ "Upstream refresh token expires in %d seconds", refresh_expires_in
1183
+ )
1184
+ else:
1185
+ # Default to 30 days if upstream doesn't specify
1186
+ # This is conservative - most providers use longer expiry
1187
+ refresh_expires_in = 60 * 60 * 24 * 30 # 30 days
1188
+ refresh_token_expires_at = time.time() + refresh_expires_in
1189
+ logger.debug(
1190
+ "Upstream refresh token expiry unknown, using 30-day default"
1191
+ )
631
1192
 
632
- # Extract token information for local tracking
633
- access_token_value = idp_tokens["access_token"]
634
- refresh_token_value = idp_tokens.get("refresh_token")
635
- expires_in = int(
636
- idp_tokens.get("expires_in", DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS)
1193
+ # Encrypt and store upstream tokens
1194
+ upstream_token_set = UpstreamTokenSet(
1195
+ upstream_token_id=upstream_token_id,
1196
+ access_token=idp_tokens["access_token"],
1197
+ refresh_token=idp_tokens["refresh_token"]
1198
+ if idp_tokens.get("refresh_token")
1199
+ else None,
1200
+ refresh_token_expires_at=refresh_token_expires_at,
1201
+ expires_at=time.time() + expires_in,
1202
+ token_type=idp_tokens.get("token_type", "Bearer"),
1203
+ scope=" ".join(authorization_code.scopes),
1204
+ client_id=client.client_id or "",
1205
+ created_at=time.time(),
1206
+ raw_token_data=idp_tokens,
1207
+ )
1208
+ await self._upstream_token_store.put(
1209
+ key=upstream_token_id,
1210
+ value=upstream_token_set,
1211
+ ttl=refresh_expires_in
1212
+ or expires_in, # Auto-expire when refresh token, or access token expires
637
1213
  )
638
- expires_at = int(time.time() + expires_in)
1214
+ logger.debug("Stored encrypted upstream tokens (jti=%s)", access_jti[:8])
639
1215
 
640
- # Store access token locally for tracking
641
- access_token = AccessToken(
642
- token=access_token_value,
1216
+ # Issue minimal FastMCP access token (just a reference via JTI)
1217
+ if client.client_id is None:
1218
+ raise TokenError("invalid_client", "Client ID is required")
1219
+ fastmcp_access_token = self._jwt_issuer.issue_access_token(
643
1220
  client_id=client.client_id,
644
1221
  scopes=authorization_code.scopes,
645
- expires_at=expires_at,
1222
+ jti=access_jti,
1223
+ expires_in=expires_in,
646
1224
  )
647
- self._access_tokens[access_token_value] = access_token
648
1225
 
649
- # Store refresh token if provided
650
- if refresh_token_value:
651
- refresh_token = RefreshToken(
652
- token=refresh_token_value,
1226
+ # Issue minimal FastMCP refresh token if upstream provided one
1227
+ # Use upstream refresh token expiry to align lifetimes
1228
+ fastmcp_refresh_token = None
1229
+ if refresh_jti and refresh_expires_in:
1230
+ fastmcp_refresh_token = self._jwt_issuer.issue_refresh_token(
653
1231
  client_id=client.client_id,
654
1232
  scopes=authorization_code.scopes,
655
- expires_at=None, # Refresh tokens typically don't expire
1233
+ jti=refresh_jti,
1234
+ expires_in=refresh_expires_in,
1235
+ )
1236
+
1237
+ # Store JTI mappings
1238
+ await self._jti_mapping_store.put(
1239
+ key=access_jti,
1240
+ value=JTIMapping(
1241
+ jti=access_jti,
1242
+ upstream_token_id=upstream_token_id,
1243
+ created_at=time.time(),
1244
+ ),
1245
+ ttl=expires_in, # Auto-expire with access token
1246
+ )
1247
+ if refresh_jti:
1248
+ await self._jti_mapping_store.put(
1249
+ key=refresh_jti,
1250
+ value=JTIMapping(
1251
+ jti=refresh_jti,
1252
+ upstream_token_id=upstream_token_id,
1253
+ created_at=time.time(),
1254
+ ),
1255
+ ttl=60 * 60 * 24 * 30, # Auto-expire with refresh token (30 days)
656
1256
  )
657
- self._refresh_tokens[refresh_token_value] = refresh_token
658
1257
 
659
- # Maintain token relationships for cleanup
660
- self._access_to_refresh[access_token_value] = refresh_token_value
661
- self._refresh_to_access[refresh_token_value] = access_token_value
1258
+ # Store refresh token metadata (keyed by hash for security)
1259
+ if fastmcp_refresh_token and refresh_expires_in:
1260
+ await self._refresh_token_store.put(
1261
+ key=_hash_token(fastmcp_refresh_token),
1262
+ value=RefreshTokenMetadata(
1263
+ client_id=client.client_id,
1264
+ scopes=authorization_code.scopes,
1265
+ expires_at=int(time.time()) + refresh_expires_in,
1266
+ created_at=time.time(),
1267
+ ),
1268
+ ttl=refresh_expires_in,
1269
+ )
662
1270
 
663
1271
  logger.debug(
664
- "Successfully exchanged client code for stored IdP tokens (client: %s)",
1272
+ "Issued FastMCP tokens for client=%s (access_jti=%s, refresh_jti=%s)",
665
1273
  client.client_id,
1274
+ access_jti[:8],
1275
+ refresh_jti[:8] if refresh_jti else "none",
666
1276
  )
667
1277
 
668
- return OAuthToken(**idp_tokens) # type: ignore[arg-type]
1278
+ # Return FastMCP-issued tokens (NOT upstream tokens!)
1279
+ return OAuthToken(
1280
+ access_token=fastmcp_access_token,
1281
+ token_type="Bearer",
1282
+ expires_in=expires_in,
1283
+ refresh_token=fastmcp_refresh_token,
1284
+ scope=" ".join(authorization_code.scopes),
1285
+ )
669
1286
 
670
1287
  # -------------------------------------------------------------------------
671
1288
  # Refresh Token Flow
672
1289
  # -------------------------------------------------------------------------
673
1290
 
1291
+ def _prepare_scopes_for_upstream_refresh(self, scopes: list[str]) -> list[str]:
1292
+ """Prepare scopes for upstream token refresh request.
1293
+
1294
+ Override this method to transform scopes before sending to upstream provider.
1295
+ For example, Azure needs to prefix scopes and add additional Graph scopes.
1296
+
1297
+ The scopes parameter represents what should be stored in the RefreshToken.
1298
+ This method returns what should be sent to the upstream provider.
1299
+
1300
+ Args:
1301
+ scopes: Base scopes that will be stored in RefreshToken
1302
+
1303
+ Returns:
1304
+ Scopes to send to upstream provider (may be transformed/augmented)
1305
+ """
1306
+ return scopes
1307
+
674
1308
  async def load_refresh_token(
675
1309
  self,
676
1310
  client: OAuthClientInformationFull,
677
1311
  refresh_token: str,
678
1312
  ) -> RefreshToken | None:
679
- """Load refresh token from local storage."""
680
- return self._refresh_tokens.get(refresh_token)
1313
+ """Load refresh token metadata from distributed storage.
1314
+
1315
+ Looks up by token hash and reconstructs the RefreshToken object.
1316
+ Validates that the token belongs to the requesting client.
1317
+ """
1318
+ token_hash = _hash_token(refresh_token)
1319
+ metadata = await self._refresh_token_store.get(key=token_hash)
1320
+ if not metadata:
1321
+ return None
1322
+ # Verify token belongs to this client (prevents cross-client token usage)
1323
+ if metadata.client_id != client.client_id:
1324
+ logger.warning(
1325
+ "Refresh token client_id mismatch: expected %s, got %s",
1326
+ client.client_id,
1327
+ metadata.client_id,
1328
+ )
1329
+ return None
1330
+ return RefreshToken(
1331
+ token=refresh_token,
1332
+ client_id=metadata.client_id,
1333
+ scopes=metadata.scopes,
1334
+ expires_at=metadata.expires_at,
1335
+ )
681
1336
 
682
1337
  async def exchange_refresh_token(
683
1338
  self,
@@ -685,9 +1340,45 @@ class OAuthProxy(OAuthProvider):
685
1340
  refresh_token: RefreshToken,
686
1341
  scopes: list[str],
687
1342
  ) -> OAuthToken:
688
- """Exchange refresh token for new access token using authlib."""
1343
+ """Exchange FastMCP refresh token for new FastMCP access token.
1344
+
1345
+ Implements two-tier refresh:
1346
+ 1. Verify FastMCP refresh token
1347
+ 2. Look up upstream token via JTI mapping
1348
+ 3. Refresh upstream token with upstream provider
1349
+ 4. Update stored upstream token
1350
+ 5. Issue new FastMCP access token
1351
+ 6. Keep same FastMCP refresh token (unless upstream rotates)
1352
+ """
1353
+ # Verify FastMCP refresh token
1354
+ try:
1355
+ refresh_payload = self._jwt_issuer.verify_token(refresh_token.token)
1356
+ refresh_jti = refresh_payload["jti"]
1357
+ except Exception as e:
1358
+ logger.debug("FastMCP refresh token validation failed: %s", e)
1359
+ raise TokenError("invalid_grant", "Invalid refresh token") from e
1360
+
1361
+ # Look up upstream token via JTI mapping
1362
+ jti_mapping = await self._jti_mapping_store.get(key=refresh_jti)
1363
+ if not jti_mapping:
1364
+ logger.error("JTI mapping not found for refresh token: %s", refresh_jti[:8])
1365
+ raise TokenError("invalid_grant", "Refresh token mapping not found")
1366
+
1367
+ upstream_token_set = await self._upstream_token_store.get(
1368
+ key=jti_mapping.upstream_token_id
1369
+ )
1370
+ if not upstream_token_set:
1371
+ logger.error(
1372
+ "Upstream token set not found: %s", jti_mapping.upstream_token_id[:8]
1373
+ )
1374
+ raise TokenError("invalid_grant", "Upstream token not found")
689
1375
 
690
- # Use authlib's AsyncOAuth2Client for refresh token exchange
1376
+ # Decrypt upstream refresh token
1377
+ if not upstream_token_set.refresh_token:
1378
+ logger.error("No upstream refresh token available")
1379
+ raise TokenError("invalid_grant", "Refresh not supported for this token")
1380
+
1381
+ # Refresh upstream token using authlib
691
1382
  oauth_client = AsyncOAuth2Client(
692
1383
  client_id=self._upstream_client_id,
693
1384
  client_secret=self._upstream_client_secret.get_secret_value(),
@@ -695,78 +1386,214 @@ class OAuthProxy(OAuthProvider):
695
1386
  timeout=HTTP_TIMEOUT_SECONDS,
696
1387
  )
697
1388
 
698
- try:
699
- logger.debug("Using authlib to refresh token from upstream")
1389
+ # Allow child classes to transform scopes before sending to upstream
1390
+ # This enables provider-specific scope formatting (e.g., Azure prefixing)
1391
+ # while keeping original scopes in storage
1392
+ upstream_scopes = self._prepare_scopes_for_upstream_refresh(scopes)
700
1393
 
701
- # Let authlib handle the refresh token exchange
1394
+ try:
1395
+ logger.debug("Refreshing upstream token (jti=%s)", refresh_jti[:8])
702
1396
  token_response: dict[str, Any] = await oauth_client.refresh_token( # type: ignore[misc]
703
1397
  url=self._upstream_token_endpoint,
704
- refresh_token=refresh_token.token,
705
- scope=" ".join(scopes) if scopes else None,
1398
+ refresh_token=upstream_token_set.refresh_token,
1399
+ scope=" ".join(upstream_scopes) if upstream_scopes else None,
1400
+ **self._extra_token_params,
706
1401
  )
1402
+ logger.debug("Successfully refreshed upstream token")
1403
+ except Exception as e:
1404
+ logger.error("Upstream token refresh failed: %s", e)
1405
+ raise TokenError("invalid_grant", f"Upstream refresh failed: {e}") from e
1406
+
1407
+ # Update stored upstream token
1408
+ # In refresh flow, we know there's a refresh token, so default to 1 hour
1409
+ # (user override still applies if set)
1410
+ if "expires_in" in token_response:
1411
+ new_expires_in = int(token_response["expires_in"])
1412
+ elif self._fallback_access_token_expiry_seconds is not None:
1413
+ new_expires_in = self._fallback_access_token_expiry_seconds
1414
+ else:
1415
+ new_expires_in = DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS
1416
+ upstream_token_set.access_token = token_response["access_token"]
1417
+ upstream_token_set.expires_at = time.time() + new_expires_in
1418
+
1419
+ # Handle upstream refresh token rotation and expiry
1420
+ new_refresh_expires_in = None
1421
+ if new_upstream_refresh := token_response.get("refresh_token"):
1422
+ if new_upstream_refresh != upstream_token_set.refresh_token:
1423
+ upstream_token_set.refresh_token = new_upstream_refresh
1424
+ logger.debug("Upstream refresh token rotated")
1425
+
1426
+ # Update refresh token expiry if provided
1427
+ if "refresh_expires_in" in token_response:
1428
+ new_refresh_expires_in = int(token_response["refresh_expires_in"])
1429
+ upstream_token_set.refresh_token_expires_at = (
1430
+ time.time() + new_refresh_expires_in
1431
+ )
1432
+ logger.debug(
1433
+ "Upstream refresh token expires in %d seconds",
1434
+ new_refresh_expires_in,
1435
+ )
1436
+ elif upstream_token_set.refresh_token_expires_at:
1437
+ # Keep existing expiry if upstream doesn't provide new one
1438
+ new_refresh_expires_in = int(
1439
+ upstream_token_set.refresh_token_expires_at - time.time()
1440
+ )
1441
+ else:
1442
+ # Default to 30 days if unknown
1443
+ new_refresh_expires_in = 60 * 60 * 24 * 30
1444
+ upstream_token_set.refresh_token_expires_at = (
1445
+ time.time() + new_refresh_expires_in
1446
+ )
707
1447
 
708
- logger.debug(
709
- "Successfully refreshed access token via authlib (client: %s)",
710
- client.client_id,
711
- )
1448
+ upstream_token_set.raw_token_data = token_response
1449
+ await self._upstream_token_store.put(
1450
+ key=upstream_token_set.upstream_token_id,
1451
+ value=upstream_token_set,
1452
+ ttl=new_refresh_expires_in
1453
+ or (
1454
+ int(upstream_token_set.refresh_token_expires_at - time.time())
1455
+ if upstream_token_set.refresh_token_expires_at
1456
+ else 60 * 60 * 24 * 30 # Default to 30 days if unknown
1457
+ ), # Auto-expire when refresh token expires
1458
+ )
712
1459
 
713
- except Exception as e:
714
- logger.error("Authlib refresh token exchange failed: %s", e)
715
- raise TokenError(
716
- "invalid_grant", f"Upstream refresh token exchange failed: {e}"
717
- ) from e
1460
+ # Issue new minimal FastMCP access token (just a reference via JTI)
1461
+ if client.client_id is None:
1462
+ raise TokenError("invalid_client", "Client ID is required")
1463
+ new_access_jti = secrets.token_urlsafe(32)
1464
+ new_fastmcp_access = self._jwt_issuer.issue_access_token(
1465
+ client_id=client.client_id,
1466
+ scopes=scopes,
1467
+ jti=new_access_jti,
1468
+ expires_in=new_expires_in,
1469
+ )
718
1470
 
719
- # Update local token storage
720
- new_access_token = token_response["access_token"]
721
- expires_in = int(
722
- token_response.get("expires_in", DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS)
1471
+ # Store new access token JTI mapping
1472
+ await self._jti_mapping_store.put(
1473
+ key=new_access_jti,
1474
+ value=JTIMapping(
1475
+ jti=new_access_jti,
1476
+ upstream_token_id=upstream_token_set.upstream_token_id,
1477
+ created_at=time.time(),
1478
+ ),
1479
+ ttl=new_expires_in, # Auto-expire with refreshed access token
723
1480
  )
724
1481
 
725
- self._access_tokens[new_access_token] = AccessToken(
726
- token=new_access_token,
1482
+ # Issue NEW minimal FastMCP refresh token (rotation for security)
1483
+ # Use upstream refresh token expiry to align lifetimes
1484
+ new_refresh_jti = secrets.token_urlsafe(32)
1485
+ new_fastmcp_refresh = self._jwt_issuer.issue_refresh_token(
727
1486
  client_id=client.client_id,
728
1487
  scopes=scopes,
729
- expires_at=int(time.time() + expires_in),
730
- )
731
-
732
- # Handle refresh token rotation if new one provided
733
- if "refresh_token" in token_response:
734
- new_refresh_token = token_response["refresh_token"]
735
- if new_refresh_token != refresh_token.token:
736
- # Remove old refresh token
737
- self._refresh_tokens.pop(refresh_token.token, None)
738
- old_access = self._refresh_to_access.pop(refresh_token.token, None)
739
- if old_access:
740
- self._access_to_refresh.pop(old_access, None)
741
-
742
- # Store new refresh token
743
- self._refresh_tokens[new_refresh_token] = RefreshToken(
744
- token=new_refresh_token,
745
- client_id=client.client_id,
746
- scopes=scopes,
747
- expires_at=None,
748
- )
749
- self._access_to_refresh[new_access_token] = new_refresh_token
750
- self._refresh_to_access[new_refresh_token] = new_access_token
1488
+ jti=new_refresh_jti,
1489
+ expires_in=new_refresh_expires_in
1490
+ or 60 * 60 * 24 * 30, # Fallback to 30 days
1491
+ )
1492
+
1493
+ # Store new refresh token JTI mapping with aligned expiry
1494
+ refresh_ttl = new_refresh_expires_in or 60 * 60 * 24 * 30
1495
+ await self._jti_mapping_store.put(
1496
+ key=new_refresh_jti,
1497
+ value=JTIMapping(
1498
+ jti=new_refresh_jti,
1499
+ upstream_token_id=upstream_token_set.upstream_token_id,
1500
+ created_at=time.time(),
1501
+ ),
1502
+ ttl=refresh_ttl, # Align with upstream refresh token expiry
1503
+ )
1504
+
1505
+ # Invalidate old refresh token (refresh token rotation - enforces one-time use)
1506
+ await self._jti_mapping_store.delete(key=refresh_jti)
1507
+ logger.debug(
1508
+ "Rotated refresh token (old JTI invalidated - one-time use enforced)"
1509
+ )
1510
+
1511
+ # Store new refresh token metadata (keyed by hash)
1512
+ await self._refresh_token_store.put(
1513
+ key=_hash_token(new_fastmcp_refresh),
1514
+ value=RefreshTokenMetadata(
1515
+ client_id=client.client_id,
1516
+ scopes=scopes,
1517
+ expires_at=int(time.time()) + refresh_ttl,
1518
+ created_at=time.time(),
1519
+ ),
1520
+ ttl=refresh_ttl,
1521
+ )
1522
+
1523
+ # Delete old refresh token (by hash)
1524
+ await self._refresh_token_store.delete(key=_hash_token(refresh_token.token))
1525
+
1526
+ logger.info(
1527
+ "Issued new FastMCP tokens (rotated refresh) for client=%s (access_jti=%s, refresh_jti=%s)",
1528
+ client.client_id,
1529
+ new_access_jti[:8],
1530
+ new_refresh_jti[:8],
1531
+ )
751
1532
 
752
- return OAuthToken(**token_response) # type: ignore[arg-type]
1533
+ # Return new FastMCP tokens (both access AND refresh are new)
1534
+ return OAuthToken(
1535
+ access_token=new_fastmcp_access,
1536
+ token_type="Bearer",
1537
+ expires_in=new_expires_in,
1538
+ refresh_token=new_fastmcp_refresh, # NEW refresh token (rotated)
1539
+ scope=" ".join(scopes),
1540
+ )
753
1541
 
754
1542
  # -------------------------------------------------------------------------
755
1543
  # Token Validation
756
1544
  # -------------------------------------------------------------------------
757
1545
 
758
- async def load_access_token(self, token: str) -> AccessToken | None:
759
- """Validate access token using upstream JWKS.
1546
+ async def load_access_token(self, token: str) -> AccessToken | None: # type: ignore[override]
1547
+ """Validate FastMCP JWT by swapping for upstream token.
1548
+
1549
+ This implements the token swap pattern:
1550
+ 1. Verify FastMCP JWT signature (proves it's our token)
1551
+ 2. Look up upstream token via JTI mapping
1552
+ 3. Decrypt upstream token
1553
+ 4. Validate upstream token with provider (GitHub API, JWT validation, etc.)
1554
+ 5. Return upstream validation result
760
1555
 
761
- Delegates to the JWT verifier which handles signature validation,
762
- expiration checking, and claims validation using the upstream JWKS.
1556
+ The FastMCP JWT is a reference token - all authorization data comes
1557
+ from validating the upstream token via the TokenVerifier.
763
1558
  """
764
- result = await self._token_validator.verify_token(token)
765
- if result:
766
- logger.debug("Token validated successfully")
767
- else:
768
- logger.debug("Token validation failed")
769
- return result
1559
+ try:
1560
+ # 1. Verify FastMCP JWT signature and claims
1561
+ payload = self._jwt_issuer.verify_token(token)
1562
+ jti = payload["jti"]
1563
+
1564
+ # 2. Look up upstream token via JTI mapping
1565
+ jti_mapping = await self._jti_mapping_store.get(key=jti)
1566
+ if not jti_mapping:
1567
+ logger.debug("JTI mapping not found: %s", jti)
1568
+ return None
1569
+
1570
+ upstream_token_set = await self._upstream_token_store.get(
1571
+ key=jti_mapping.upstream_token_id
1572
+ )
1573
+ if not upstream_token_set:
1574
+ logger.debug(
1575
+ "Upstream token not found: %s", jti_mapping.upstream_token_id
1576
+ )
1577
+ return None
1578
+
1579
+ # 3. Validate with upstream provider (delegated to TokenVerifier)
1580
+ # This calls the real token validator (GitHub API, JWKS, etc.)
1581
+ validated = await self._token_validator.verify_token(
1582
+ upstream_token_set.access_token
1583
+ )
1584
+
1585
+ if not validated:
1586
+ logger.debug("Upstream token validation failed")
1587
+ return None
1588
+
1589
+ logger.debug(
1590
+ "Token swap successful for JTI=%s (upstream validated)", jti[:8]
1591
+ )
1592
+ return validated
1593
+
1594
+ except Exception as e:
1595
+ logger.debug("Token swap validation failed: %s", e)
1596
+ return None
770
1597
 
771
1598
  # -------------------------------------------------------------------------
772
1599
  # Token Revocation
@@ -775,24 +1602,13 @@ class OAuthProxy(OAuthProvider):
775
1602
  async def revoke_token(self, token: AccessToken | RefreshToken) -> None:
776
1603
  """Revoke token locally and with upstream server if supported.
777
1604
 
778
- Removes tokens from local storage and attempts to revoke them with
779
- the upstream server if a revocation endpoint is configured.
1605
+ For refresh tokens, removes from local storage by hash.
1606
+ For all tokens, attempts upstream revocation if endpoint is configured.
1607
+ Access token JTI mappings expire via TTL.
780
1608
  """
781
- # Clean up local token storage
782
- if isinstance(token, AccessToken):
783
- self._access_tokens.pop(token.token, None)
784
- # Also remove associated refresh token
785
- paired_refresh = self._access_to_refresh.pop(token.token, None)
786
- if paired_refresh:
787
- self._refresh_tokens.pop(paired_refresh, None)
788
- self._refresh_to_access.pop(paired_refresh, None)
789
- else: # RefreshToken
790
- self._refresh_tokens.pop(token.token, None)
791
- # Also remove associated access token
792
- paired_access = self._refresh_to_access.pop(token.token, None)
793
- if paired_access:
794
- self._access_tokens.pop(paired_access, None)
795
- self._access_to_refresh.pop(paired_access, None)
1609
+ # For refresh tokens, delete from local storage by hash
1610
+ if isinstance(token, RefreshToken):
1611
+ await self._refresh_token_store.delete(key=_hash_token(token.token))
796
1612
 
797
1613
  # Attempt upstream revocation if endpoint is configured
798
1614
  if self._upstream_revocation_endpoint:
@@ -819,21 +1635,21 @@ class OAuthProxy(OAuthProvider):
819
1635
  def get_routes(
820
1636
  self,
821
1637
  mcp_path: str | None = None,
822
- mcp_endpoint: Any | None = None,
823
1638
  ) -> list[Route]:
824
- """Get OAuth routes with custom proxy token handler.
1639
+ """Get OAuth routes with custom handlers for better error UX.
825
1640
 
826
- This method creates standard OAuth routes and replaces the token endpoint
827
- with our proxy handler that forwards requests to the upstream OAuth server.
1641
+ This method creates standard OAuth routes and replaces:
1642
+ - /authorize endpoint: Enhanced error responses for unregistered clients
1643
+ - /token endpoint: OAuth 2.1 compliant error codes
828
1644
 
829
1645
  Args:
830
1646
  mcp_path: The path where the MCP endpoint is mounted (e.g., "/mcp")
831
- mcp_endpoint: The MCP endpoint handler to protect with auth
1647
+ This is used to advertise the resource URL in metadata.
832
1648
  """
833
1649
  # Get standard OAuth routes from parent class
834
- routes = super().get_routes(mcp_path, mcp_endpoint)
1650
+ # Note: parent already replaces /token with TokenHandler for proper error codes
1651
+ routes = super().get_routes(mcp_path)
835
1652
  custom_routes = []
836
- token_route_found = False
837
1653
 
838
1654
  logger.debug(
839
1655
  f"get_routes called - configuring OAuth routes in {len(routes)} routes"
@@ -844,16 +1660,31 @@ class OAuthProxy(OAuthProvider):
844
1660
  f"Route {i}: {route} - path: {getattr(route, 'path', 'N/A')}, methods: {getattr(route, 'methods', 'N/A')}"
845
1661
  )
846
1662
 
847
- # Keep all standard OAuth routes unchanged - our DCR-compliant flow handles everything
848
- custom_routes.append(route)
849
-
1663
+ # Replace the authorize endpoint with our enhanced handler for better error UX
850
1664
  if (
851
1665
  isinstance(route, Route)
852
- and route.path == "/token"
1666
+ and route.path == "/authorize"
853
1667
  and route.methods is not None
854
- and "POST" in route.methods
1668
+ and ("GET" in route.methods or "POST" in route.methods)
855
1669
  ):
856
- token_route_found = True
1670
+ # Replace with our enhanced authorization handler
1671
+ # Note: self.base_url is guaranteed to be set in parent __init__
1672
+ authorize_handler = AuthorizationHandler(
1673
+ provider=self,
1674
+ base_url=self.base_url, # ty: ignore[invalid-argument-type]
1675
+ server_name=None, # Could be extended to pass server metadata
1676
+ server_icon_url=None,
1677
+ )
1678
+ custom_routes.append(
1679
+ Route(
1680
+ path="/authorize",
1681
+ endpoint=authorize_handler.handle,
1682
+ methods=["GET", "POST"],
1683
+ )
1684
+ )
1685
+ else:
1686
+ # Keep all other standard OAuth routes unchanged
1687
+ custom_routes.append(route)
857
1688
 
858
1689
  # Add OAuth callback endpoint for forwarding to client callbacks
859
1690
  custom_routes.append(
@@ -864,16 +1695,23 @@ class OAuthProxy(OAuthProvider):
864
1695
  )
865
1696
  )
866
1697
 
867
- logger.debug(
868
- f"✅ OAuth routes configured: token_endpoint={token_route_found}, total routes={len(custom_routes)} (includes OAuth callback)"
1698
+ # Add consent endpoints
1699
+ # Handle both GET (show page) and POST (submit) at /consent
1700
+ custom_routes.append(
1701
+ Route(
1702
+ path="/consent", endpoint=self._handle_consent, methods=["GET", "POST"]
1703
+ )
869
1704
  )
1705
+
870
1706
  return custom_routes
871
1707
 
872
1708
  # -------------------------------------------------------------------------
873
1709
  # IdP Callback Forwarding
874
1710
  # -------------------------------------------------------------------------
875
1711
 
876
- async def _handle_idp_callback(self, request: Request) -> RedirectResponse:
1712
+ async def _handle_idp_callback(
1713
+ self, request: Request
1714
+ ) -> HTMLResponse | RedirectResponse:
877
1715
  """Handle callback from upstream IdP and forward to client.
878
1716
 
879
1717
  This implements the DCR-compliant callback forwarding:
@@ -888,32 +1726,38 @@ class OAuthProxy(OAuthProvider):
888
1726
  error = request.query_params.get("error")
889
1727
 
890
1728
  if error:
1729
+ error_description = request.query_params.get("error_description")
891
1730
  logger.error(
892
1731
  "IdP callback error: %s - %s",
893
1732
  error,
894
- request.query_params.get("error_description"),
1733
+ error_description,
895
1734
  )
896
- # TODO: Forward error to client callback
897
- return RedirectResponse(
898
- url=f"data:text/html,<h1>OAuth Error</h1><p>{error}: {request.query_params.get('error_description', 'Unknown error')}</p>",
899
- status_code=302,
1735
+ # Show error page to user
1736
+ html_content = create_error_html(
1737
+ error_title="OAuth Error",
1738
+ error_message=f"Authentication failed: {error_description or 'Unknown error'}",
1739
+ error_details={"Error Code": error} if error else None,
900
1740
  )
1741
+ return HTMLResponse(content=html_content, status_code=400)
901
1742
 
902
1743
  if not idp_code or not txn_id:
903
1744
  logger.error("IdP callback missing code or transaction ID")
904
- return RedirectResponse(
905
- url="data:text/html,<h1>OAuth Error</h1><p>Missing authorization code or transaction ID</p>",
906
- status_code=302,
1745
+ html_content = create_error_html(
1746
+ error_title="OAuth Error",
1747
+ error_message="Missing authorization code or transaction ID from the identity provider.",
907
1748
  )
1749
+ return HTMLResponse(content=html_content, status_code=400)
908
1750
 
909
1751
  # Look up transaction data
910
- transaction = self._oauth_transactions.get(txn_id)
911
- if not transaction:
1752
+ transaction_model = await self._transaction_store.get(key=txn_id)
1753
+ if not transaction_model:
912
1754
  logger.error("IdP callback with invalid transaction ID: %s", txn_id)
913
- return RedirectResponse(
914
- url="data:text/html,<h1>OAuth Error</h1><p>Invalid or expired transaction</p>",
915
- status_code=302,
1755
+ html_content = create_error_html(
1756
+ error_title="OAuth Error",
1757
+ error_message="Invalid or expired authorization transaction. Please try authenticating again.",
916
1758
  )
1759
+ return HTMLResponse(content=html_content, status_code=400)
1760
+ transaction = transaction_model.model_dump()
917
1761
 
918
1762
  # Exchange IdP code for tokens (server-side)
919
1763
  oauth_client = AsyncOAuth2Client(
@@ -966,30 +1810,35 @@ class OAuthProxy(OAuthProvider):
966
1810
 
967
1811
  except Exception as e:
968
1812
  logger.error("IdP token exchange failed: %s", e)
969
- # TODO: Forward error to client callback
970
- return RedirectResponse(
971
- url=f"data:text/html,<h1>OAuth Error</h1><p>Token exchange failed: {e}</p>",
972
- status_code=302,
1813
+ html_content = create_error_html(
1814
+ error_title="OAuth Error",
1815
+ error_message=f"Token exchange with identity provider failed: {e}",
973
1816
  )
1817
+ return HTMLResponse(content=html_content, status_code=500)
974
1818
 
975
1819
  # Generate our own authorization code for the client
976
1820
  client_code = secrets.token_urlsafe(32)
977
1821
  code_expires_at = int(time.time() + DEFAULT_AUTH_CODE_EXPIRY_SECONDS)
978
1822
 
979
1823
  # Store client code with PKCE challenge and IdP tokens
980
- self._client_codes[client_code] = {
981
- "client_id": transaction["client_id"],
982
- "redirect_uri": transaction["client_redirect_uri"],
983
- "code_challenge": transaction["code_challenge"],
984
- "code_challenge_method": transaction["code_challenge_method"],
985
- "scopes": transaction["scopes"],
986
- "idp_tokens": idp_tokens,
987
- "expires_at": code_expires_at,
988
- "created_at": time.time(),
989
- }
1824
+ await self._code_store.put(
1825
+ key=client_code,
1826
+ value=ClientCode(
1827
+ code=client_code,
1828
+ client_id=transaction["client_id"],
1829
+ redirect_uri=transaction["client_redirect_uri"],
1830
+ code_challenge=transaction["code_challenge"],
1831
+ code_challenge_method=transaction["code_challenge_method"],
1832
+ scopes=transaction["scopes"],
1833
+ idp_tokens=idp_tokens,
1834
+ expires_at=code_expires_at,
1835
+ created_at=time.time(),
1836
+ ),
1837
+ ttl=DEFAULT_AUTH_CODE_EXPIRY_SECONDS, # Auto-expire after 5 minutes
1838
+ )
990
1839
 
991
1840
  # Clean up transaction
992
- self._oauth_transactions.pop(txn_id, None)
1841
+ await self._transaction_store.delete(key=txn_id)
993
1842
 
994
1843
  # Build client callback URL with our code and original state
995
1844
  client_redirect_uri = transaction["client_redirect_uri"]
@@ -1012,7 +1861,329 @@ class OAuthProxy(OAuthProvider):
1012
1861
 
1013
1862
  except Exception as e:
1014
1863
  logger.error("Error in IdP callback handler: %s", e, exc_info=True)
1864
+ html_content = create_error_html(
1865
+ error_title="OAuth Error",
1866
+ error_message="Internal server error during OAuth callback processing. Please try again.",
1867
+ )
1868
+ return HTMLResponse(content=html_content, status_code=500)
1869
+
1870
+ # -------------------------------------------------------------------------
1871
+ # Consent Interstitial
1872
+ # -------------------------------------------------------------------------
1873
+
1874
+ def _normalize_uri(self, uri: str) -> str:
1875
+ """Normalize a URI to a canonical form for consent tracking."""
1876
+ parsed = urlparse(uri)
1877
+ path = parsed.path or ""
1878
+ normalized = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}{path}"
1879
+ if normalized.endswith("/") and len(path) > 1:
1880
+ normalized = normalized[:-1]
1881
+ return normalized
1882
+
1883
+ def _make_client_key(self, client_id: str, redirect_uri: str | AnyUrl) -> str:
1884
+ """Create a stable key for consent tracking from client_id and redirect_uri."""
1885
+ normalized = self._normalize_uri(str(redirect_uri))
1886
+ return f"{client_id}:{normalized}"
1887
+
1888
+ def _cookie_name(self, base_name: str) -> str:
1889
+ """Return secure cookie name for HTTPS, fallback for HTTP development."""
1890
+ if self._is_https:
1891
+ return f"__Host-{base_name}"
1892
+ return f"__{base_name}"
1893
+
1894
+ def _sign_cookie(self, payload: str) -> str:
1895
+ """Sign a cookie payload with HMAC-SHA256.
1896
+
1897
+ Returns: base64(payload).base64(signature)
1898
+ """
1899
+ # Use upstream client secret as signing key
1900
+ key = self._upstream_client_secret.get_secret_value().encode()
1901
+ signature = hmac.new(key, payload.encode(), hashlib.sha256).digest()
1902
+ signature_b64 = base64.b64encode(signature).decode()
1903
+ return f"{payload}.{signature_b64}"
1904
+
1905
+ def _verify_cookie(self, signed_value: str) -> str | None:
1906
+ """Verify and extract payload from signed cookie.
1907
+
1908
+ Returns: payload if signature valid, None otherwise
1909
+ """
1910
+ try:
1911
+ if "." not in signed_value:
1912
+ return None
1913
+ payload, signature_b64 = signed_value.rsplit(".", 1)
1914
+
1915
+ # Verify signature
1916
+ key = self._upstream_client_secret.get_secret_value().encode()
1917
+ expected_sig = hmac.new(key, payload.encode(), hashlib.sha256).digest()
1918
+ provided_sig = base64.b64decode(signature_b64.encode())
1919
+
1920
+ # Constant-time comparison
1921
+ if not hmac.compare_digest(expected_sig, provided_sig):
1922
+ return None
1923
+
1924
+ return payload
1925
+ except Exception:
1926
+ return None
1927
+
1928
+ def _decode_list_cookie(self, request: Request, base_name: str) -> list[str]:
1929
+ """Decode and verify a signed base64-encoded JSON list from cookie. Returns [] if missing/invalid."""
1930
+ # Prefer secure name, but also check non-secure variant for dev
1931
+ secure_name = self._cookie_name(base_name)
1932
+ raw = request.cookies.get(secure_name) or request.cookies.get(f"__{base_name}")
1933
+ if not raw:
1934
+ return []
1935
+ try:
1936
+ # Verify signature
1937
+ payload = self._verify_cookie(raw)
1938
+ if not payload:
1939
+ logger.debug("Cookie signature verification failed for %s", secure_name)
1940
+ return []
1941
+
1942
+ # Decode payload
1943
+ data = base64.b64decode(payload.encode())
1944
+ value = json.loads(data.decode())
1945
+ if isinstance(value, list):
1946
+ return [str(x) for x in value]
1947
+ except Exception:
1948
+ logger.debug("Failed to decode cookie %s; treating as empty", secure_name)
1949
+ return []
1950
+
1951
+ def _encode_list_cookie(self, values: list[str]) -> str:
1952
+ """Encode values to base64 and sign with HMAC.
1953
+
1954
+ Returns: signed cookie value (payload.signature)
1955
+ """
1956
+ payload = json.dumps(values, separators=(",", ":")).encode()
1957
+ payload_b64 = base64.b64encode(payload).decode()
1958
+ return self._sign_cookie(payload_b64)
1959
+
1960
+ def _set_list_cookie(
1961
+ self,
1962
+ response: HTMLResponse | RedirectResponse,
1963
+ base_name: str,
1964
+ value_b64: str,
1965
+ max_age: int,
1966
+ ) -> None:
1967
+ name = self._cookie_name(base_name)
1968
+ response.set_cookie(
1969
+ name,
1970
+ value_b64,
1971
+ max_age=max_age,
1972
+ secure=self._is_https,
1973
+ httponly=True,
1974
+ samesite="lax",
1975
+ path="/",
1976
+ )
1977
+
1978
+ def _build_upstream_authorize_url(
1979
+ self, txn_id: str, transaction: dict[str, Any]
1980
+ ) -> str:
1981
+ """Construct the upstream IdP authorization URL using stored transaction data."""
1982
+ query_params: dict[str, Any] = {
1983
+ "response_type": "code",
1984
+ "client_id": self._upstream_client_id,
1985
+ "redirect_uri": f"{str(self.base_url).rstrip('/')}{self._redirect_path}",
1986
+ "state": txn_id,
1987
+ }
1988
+
1989
+ scopes_to_use = transaction.get("scopes") or self.required_scopes or []
1990
+ if scopes_to_use:
1991
+ query_params["scope"] = " ".join(scopes_to_use)
1992
+
1993
+ # If PKCE forwarding was enabled, include the proxy challenge
1994
+ proxy_code_verifier = transaction.get("proxy_code_verifier")
1995
+ if proxy_code_verifier:
1996
+ challenge_bytes = hashlib.sha256(proxy_code_verifier.encode()).digest()
1997
+ proxy_code_challenge = (
1998
+ urlsafe_b64encode(challenge_bytes).decode().rstrip("=")
1999
+ )
2000
+ query_params["code_challenge"] = proxy_code_challenge
2001
+ query_params["code_challenge_method"] = "S256"
2002
+
2003
+ # Forward resource indicator if present in transaction
2004
+ if resource := transaction.get("resource"):
2005
+ query_params["resource"] = resource
2006
+
2007
+ # Extra configured parameters
2008
+ if self._extra_authorize_params:
2009
+ query_params.update(self._extra_authorize_params)
2010
+
2011
+ separator = "&" if "?" in self._upstream_authorization_endpoint else "?"
2012
+ return f"{self._upstream_authorization_endpoint}{separator}{urlencode(query_params)}"
2013
+
2014
+ async def _handle_consent(
2015
+ self, request: Request
2016
+ ) -> HTMLResponse | RedirectResponse:
2017
+ """Handle consent page - dispatch to GET or POST handler based on method."""
2018
+ if request.method == "POST":
2019
+ return await self._submit_consent(request)
2020
+ return await self._show_consent_page(request)
2021
+
2022
+ async def _show_consent_page(
2023
+ self, request: Request
2024
+ ) -> HTMLResponse | RedirectResponse:
2025
+ """Display consent page or auto-approve/deny based on cookies."""
2026
+ from fastmcp.server.server import FastMCP
2027
+
2028
+ txn_id = request.query_params.get("txn_id")
2029
+ if not txn_id:
2030
+ return create_secure_html_response(
2031
+ "<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
2032
+ )
2033
+
2034
+ txn_model = await self._transaction_store.get(key=txn_id)
2035
+ if not txn_model:
2036
+ return create_secure_html_response(
2037
+ "<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
2038
+ )
2039
+
2040
+ txn = txn_model.model_dump()
2041
+ client_key = self._make_client_key(txn["client_id"], txn["client_redirect_uri"])
2042
+
2043
+ approved = set(self._decode_list_cookie(request, "MCP_APPROVED_CLIENTS"))
2044
+ denied = set(self._decode_list_cookie(request, "MCP_DENIED_CLIENTS"))
2045
+
2046
+ if client_key in approved:
2047
+ upstream_url = self._build_upstream_authorize_url(txn_id, txn)
2048
+ return RedirectResponse(url=upstream_url, status_code=302)
2049
+
2050
+ if client_key in denied:
2051
+ callback_params = {
2052
+ "error": "access_denied",
2053
+ "state": txn.get("client_state") or "",
2054
+ }
2055
+ sep = "&" if "?" in txn["client_redirect_uri"] else "?"
1015
2056
  return RedirectResponse(
1016
- url="data:text/html,<h1>OAuth Error</h1><p>Internal server error during IdP callback</p>",
2057
+ url=f"{txn['client_redirect_uri']}{sep}{urlencode(callback_params)}",
1017
2058
  status_code=302,
1018
2059
  )
2060
+
2061
+ # Need consent: issue CSRF token and show HTML
2062
+ csrf_token = secrets.token_urlsafe(32)
2063
+ csrf_expires_at = time.time() + 15 * 60
2064
+
2065
+ # Update transaction with CSRF token
2066
+ txn_model.csrf_token = csrf_token
2067
+ txn_model.csrf_expires_at = csrf_expires_at
2068
+ await self._transaction_store.put(
2069
+ key=txn_id, value=txn_model, ttl=15 * 60
2070
+ ) # Auto-expire after 15 minutes
2071
+
2072
+ # Update dict for use in HTML generation
2073
+ txn["csrf_token"] = csrf_token
2074
+ txn["csrf_expires_at"] = csrf_expires_at
2075
+
2076
+ # Load client to get client_name if available
2077
+ client = await self.get_client(txn["client_id"])
2078
+ client_name = getattr(client, "client_name", None) if client else None
2079
+
2080
+ # Extract server metadata from app state
2081
+ fastmcp = getattr(request.app.state, "fastmcp_server", None)
2082
+
2083
+ if isinstance(fastmcp, FastMCP):
2084
+ server_name = fastmcp.name
2085
+ icons = fastmcp.icons
2086
+ server_icon_url = icons[0].src if icons else None
2087
+ server_website_url = fastmcp.website_url
2088
+ else:
2089
+ server_name = None
2090
+ server_icon_url = None
2091
+ server_website_url = None
2092
+
2093
+ html = create_consent_html(
2094
+ client_id=txn["client_id"],
2095
+ redirect_uri=txn["client_redirect_uri"],
2096
+ scopes=txn.get("scopes") or [],
2097
+ txn_id=txn_id,
2098
+ csrf_token=csrf_token,
2099
+ client_name=client_name,
2100
+ server_name=server_name,
2101
+ server_icon_url=server_icon_url,
2102
+ server_website_url=server_website_url,
2103
+ csp_policy=self._consent_csp_policy,
2104
+ )
2105
+ response = create_secure_html_response(html)
2106
+ # Store CSRF in cookie with short lifetime
2107
+ self._set_list_cookie(
2108
+ response,
2109
+ "MCP_CONSENT_STATE",
2110
+ self._encode_list_cookie([csrf_token]),
2111
+ max_age=15 * 60,
2112
+ )
2113
+ return response
2114
+
2115
+ async def _submit_consent(
2116
+ self, request: Request
2117
+ ) -> RedirectResponse | HTMLResponse:
2118
+ """Handle consent approval/denial, set cookies, and redirect appropriately."""
2119
+ form = await request.form()
2120
+ txn_id = str(form.get("txn_id", ""))
2121
+ action = str(form.get("action", ""))
2122
+ csrf_token = str(form.get("csrf_token", ""))
2123
+
2124
+ if not txn_id:
2125
+ return create_secure_html_response(
2126
+ "<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
2127
+ )
2128
+
2129
+ txn_model = await self._transaction_store.get(key=txn_id)
2130
+ if not txn_model:
2131
+ return create_secure_html_response(
2132
+ "<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
2133
+ )
2134
+
2135
+ txn = txn_model.model_dump()
2136
+ expected_csrf = txn.get("csrf_token")
2137
+ expires_at = float(txn.get("csrf_expires_at") or 0)
2138
+
2139
+ if not expected_csrf or csrf_token != expected_csrf or time.time() > expires_at:
2140
+ return create_secure_html_response(
2141
+ "<h1>Error</h1><p>Invalid or expired consent token</p>", status_code=400
2142
+ )
2143
+
2144
+ client_key = self._make_client_key(txn["client_id"], txn["client_redirect_uri"])
2145
+
2146
+ if action == "approve":
2147
+ approved = set(self._decode_list_cookie(request, "MCP_APPROVED_CLIENTS"))
2148
+ if client_key not in approved:
2149
+ approved.add(client_key)
2150
+ approved_b64 = self._encode_list_cookie(sorted(approved))
2151
+
2152
+ upstream_url = self._build_upstream_authorize_url(txn_id, txn)
2153
+ response = RedirectResponse(url=upstream_url, status_code=302)
2154
+ self._set_list_cookie(
2155
+ response, "MCP_APPROVED_CLIENTS", approved_b64, max_age=365 * 24 * 3600
2156
+ )
2157
+ # Clear CSRF cookie by setting empty short-lived value
2158
+ self._set_list_cookie(
2159
+ response, "MCP_CONSENT_STATE", self._encode_list_cookie([]), max_age=60
2160
+ )
2161
+ return response
2162
+
2163
+ elif action == "deny":
2164
+ denied = set(self._decode_list_cookie(request, "MCP_DENIED_CLIENTS"))
2165
+ if client_key not in denied:
2166
+ denied.add(client_key)
2167
+ denied_b64 = self._encode_list_cookie(sorted(denied))
2168
+
2169
+ callback_params = {
2170
+ "error": "access_denied",
2171
+ "state": txn.get("client_state") or "",
2172
+ }
2173
+ sep = "&" if "?" in txn["client_redirect_uri"] else "?"
2174
+ client_callback_url = (
2175
+ f"{txn['client_redirect_uri']}{sep}{urlencode(callback_params)}"
2176
+ )
2177
+ response = RedirectResponse(url=client_callback_url, status_code=302)
2178
+ self._set_list_cookie(
2179
+ response, "MCP_DENIED_CLIENTS", denied_b64, max_age=365 * 24 * 3600
2180
+ )
2181
+ self._set_list_cookie(
2182
+ response, "MCP_CONSENT_STATE", self._encode_list_cookie([]), max_age=60
2183
+ )
2184
+ return response
2185
+
2186
+ else:
2187
+ return create_secure_html_response(
2188
+ "<h1>Error</h1><p>Invalid action</p>", status_code=400
2189
+ )