fastmcp 2.12.5__py3-none-any.whl → 2.13.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastmcp/cli/cli.py +6 -6
- fastmcp/cli/install/claude_code.py +3 -3
- fastmcp/cli/install/claude_desktop.py +3 -3
- fastmcp/cli/install/cursor.py +7 -7
- fastmcp/cli/install/gemini_cli.py +3 -3
- fastmcp/cli/install/mcp_json.py +3 -3
- fastmcp/cli/run.py +13 -8
- fastmcp/client/auth/oauth.py +100 -208
- fastmcp/client/client.py +11 -11
- fastmcp/client/logging.py +18 -14
- fastmcp/client/oauth_callback.py +81 -171
- fastmcp/client/transports.py +76 -22
- fastmcp/contrib/component_manager/component_service.py +6 -6
- fastmcp/contrib/mcp_mixin/README.md +32 -1
- fastmcp/contrib/mcp_mixin/mcp_mixin.py +14 -2
- fastmcp/experimental/utilities/openapi/json_schema_converter.py +4 -0
- fastmcp/experimental/utilities/openapi/parser.py +23 -3
- fastmcp/prompts/prompt.py +13 -6
- fastmcp/prompts/prompt_manager.py +16 -101
- fastmcp/resources/resource.py +13 -6
- fastmcp/resources/resource_manager.py +5 -164
- fastmcp/resources/template.py +107 -17
- fastmcp/server/auth/auth.py +40 -32
- fastmcp/server/auth/jwt_issuer.py +289 -0
- fastmcp/server/auth/oauth_proxy.py +1228 -233
- fastmcp/server/auth/oidc_proxy.py +8 -6
- fastmcp/server/auth/providers/auth0.py +13 -7
- fastmcp/server/auth/providers/aws.py +14 -3
- fastmcp/server/auth/providers/azure.py +137 -124
- fastmcp/server/auth/providers/descope.py +4 -6
- fastmcp/server/auth/providers/github.py +14 -8
- fastmcp/server/auth/providers/google.py +15 -9
- fastmcp/server/auth/providers/introspection.py +281 -0
- fastmcp/server/auth/providers/jwt.py +8 -2
- fastmcp/server/auth/providers/scalekit.py +179 -0
- fastmcp/server/auth/providers/supabase.py +172 -0
- fastmcp/server/auth/providers/workos.py +17 -14
- fastmcp/server/context.py +89 -34
- fastmcp/server/http.py +57 -17
- fastmcp/server/low_level.py +121 -2
- fastmcp/server/middleware/caching.py +469 -0
- fastmcp/server/middleware/error_handling.py +6 -2
- fastmcp/server/middleware/logging.py +48 -37
- fastmcp/server/middleware/middleware.py +28 -15
- fastmcp/server/middleware/rate_limiting.py +3 -3
- fastmcp/server/proxy.py +6 -6
- fastmcp/server/server.py +638 -183
- fastmcp/settings.py +22 -9
- fastmcp/tools/tool.py +7 -3
- fastmcp/tools/tool_manager.py +22 -108
- fastmcp/tools/tool_transform.py +3 -3
- fastmcp/utilities/cli.py +32 -22
- fastmcp/utilities/components.py +5 -0
- fastmcp/utilities/inspect.py +77 -21
- fastmcp/utilities/logging.py +118 -8
- fastmcp/utilities/mcp_server_config/v1/environments/uv.py +6 -6
- fastmcp/utilities/mcp_server_config/v1/mcp_server_config.py +3 -3
- fastmcp/utilities/mcp_server_config/v1/schema.json +3 -0
- fastmcp/utilities/tests.py +87 -4
- fastmcp/utilities/types.py +1 -1
- fastmcp/utilities/ui.py +497 -0
- {fastmcp-2.12.5.dist-info → fastmcp-2.13.0rc2.dist-info}/METADATA +8 -4
- {fastmcp-2.12.5.dist-info → fastmcp-2.13.0rc2.dist-info}/RECORD +66 -62
- fastmcp/cli/claude.py +0 -135
- fastmcp/utilities/storage.py +0 -204
- {fastmcp-2.12.5.dist-info → fastmcp-2.13.0rc2.dist-info}/WHEEL +0 -0
- {fastmcp-2.12.5.dist-info → fastmcp-2.13.0rc2.dist-info}/entry_points.txt +0 -0
- {fastmcp-2.12.5.dist-info → fastmcp-2.13.0rc2.dist-info}/licenses/LICENSE +0 -0
|
@@ -18,16 +18,26 @@ production use with enterprise identity providers.
|
|
|
18
18
|
|
|
19
19
|
from __future__ import annotations
|
|
20
20
|
|
|
21
|
+
import base64
|
|
21
22
|
import hashlib
|
|
23
|
+
import hmac
|
|
24
|
+
import json
|
|
22
25
|
import secrets
|
|
23
26
|
import time
|
|
24
27
|
from base64 import urlsafe_b64encode
|
|
25
28
|
from typing import TYPE_CHECKING, Any, Final
|
|
26
|
-
from urllib.parse import urlencode
|
|
29
|
+
from urllib.parse import urlencode, urlparse
|
|
27
30
|
|
|
28
31
|
import httpx
|
|
29
32
|
from authlib.common.security import generate_token
|
|
30
33
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
34
|
+
from key_value.aio.adapters.pydantic import PydanticAdapter
|
|
35
|
+
from key_value.aio.protocols import AsyncKeyValue
|
|
36
|
+
from key_value.aio.stores.memory import MemoryStore
|
|
37
|
+
from mcp.server.auth.handlers.token import TokenErrorResponse, TokenSuccessResponse
|
|
38
|
+
from mcp.server.auth.handlers.token import TokenHandler as _SDKTokenHandler
|
|
39
|
+
from mcp.server.auth.json_response import PydanticJSONResponse
|
|
40
|
+
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
|
|
31
41
|
from mcp.server.auth.provider import (
|
|
32
42
|
AccessToken,
|
|
33
43
|
AuthorizationCode,
|
|
@@ -35,21 +45,36 @@ from mcp.server.auth.provider import (
|
|
|
35
45
|
RefreshToken,
|
|
36
46
|
TokenError,
|
|
37
47
|
)
|
|
48
|
+
from mcp.server.auth.routes import cors_middleware
|
|
38
49
|
from mcp.server.auth.settings import (
|
|
39
50
|
ClientRegistrationOptions,
|
|
40
51
|
RevocationOptions,
|
|
41
52
|
)
|
|
42
53
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
|
43
|
-
from pydantic import AnyHttpUrl, AnyUrl, SecretStr
|
|
54
|
+
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, SecretStr
|
|
44
55
|
from starlette.requests import Request
|
|
45
|
-
from starlette.responses import RedirectResponse
|
|
56
|
+
from starlette.responses import HTMLResponse, RedirectResponse
|
|
46
57
|
from starlette.routing import Route
|
|
47
58
|
|
|
48
|
-
import fastmcp
|
|
49
59
|
from fastmcp.server.auth.auth import OAuthProvider, TokenVerifier
|
|
50
|
-
from fastmcp.server.auth.
|
|
60
|
+
from fastmcp.server.auth.jwt_issuer import (
|
|
61
|
+
JWTIssuer,
|
|
62
|
+
TokenEncryption,
|
|
63
|
+
)
|
|
64
|
+
from fastmcp.server.auth.redirect_validation import (
|
|
65
|
+
validate_redirect_uri,
|
|
66
|
+
)
|
|
51
67
|
from fastmcp.utilities.logging import get_logger
|
|
52
|
-
from fastmcp.utilities.
|
|
68
|
+
from fastmcp.utilities.ui import (
|
|
69
|
+
BUTTON_STYLES,
|
|
70
|
+
DETAIL_BOX_STYLES,
|
|
71
|
+
INFO_BOX_STYLES,
|
|
72
|
+
TOOLTIP_STYLES,
|
|
73
|
+
create_detail_box,
|
|
74
|
+
create_logo,
|
|
75
|
+
create_page,
|
|
76
|
+
create_secure_html_response,
|
|
77
|
+
)
|
|
53
78
|
|
|
54
79
|
if TYPE_CHECKING:
|
|
55
80
|
pass
|
|
@@ -57,6 +82,95 @@ if TYPE_CHECKING:
|
|
|
57
82
|
logger = get_logger(__name__)
|
|
58
83
|
|
|
59
84
|
|
|
85
|
+
# -------------------------------------------------------------------------
|
|
86
|
+
# Constants
|
|
87
|
+
# -------------------------------------------------------------------------
|
|
88
|
+
|
|
89
|
+
# Default token expiration times
|
|
90
|
+
DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS: Final[int] = 60 * 60 # 1 hour
|
|
91
|
+
DEFAULT_AUTH_CODE_EXPIRY_SECONDS: Final[int] = 5 * 60 # 5 minutes
|
|
92
|
+
|
|
93
|
+
# HTTP client timeout
|
|
94
|
+
HTTP_TIMEOUT_SECONDS: Final[int] = 30
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# -------------------------------------------------------------------------
|
|
98
|
+
# Pydantic Models
|
|
99
|
+
# -------------------------------------------------------------------------
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class OAuthTransaction(BaseModel):
|
|
103
|
+
"""OAuth transaction state for consent flow.
|
|
104
|
+
|
|
105
|
+
Stored server-side to track active authorization flows with client context.
|
|
106
|
+
Includes CSRF tokens for consent protection per MCP security best practices.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
txn_id: str
|
|
110
|
+
client_id: str
|
|
111
|
+
client_redirect_uri: str
|
|
112
|
+
client_state: str
|
|
113
|
+
code_challenge: str | None
|
|
114
|
+
code_challenge_method: str
|
|
115
|
+
scopes: list[str]
|
|
116
|
+
created_at: float
|
|
117
|
+
resource: str | None = None
|
|
118
|
+
proxy_code_verifier: str | None = None
|
|
119
|
+
csrf_token: str | None = None
|
|
120
|
+
csrf_expires_at: float | None = None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class ClientCode(BaseModel):
|
|
124
|
+
"""Client authorization code with PKCE and upstream tokens.
|
|
125
|
+
|
|
126
|
+
Stored server-side after upstream IdP callback. Contains the upstream
|
|
127
|
+
tokens bound to the client's PKCE challenge for secure token exchange.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
code: str
|
|
131
|
+
client_id: str
|
|
132
|
+
redirect_uri: str
|
|
133
|
+
code_challenge: str | None
|
|
134
|
+
code_challenge_method: str
|
|
135
|
+
scopes: list[str]
|
|
136
|
+
idp_tokens: dict[str, Any]
|
|
137
|
+
expires_at: float
|
|
138
|
+
created_at: float
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class UpstreamTokenSet(BaseModel):
|
|
142
|
+
"""Stored upstream OAuth tokens from identity provider.
|
|
143
|
+
|
|
144
|
+
These tokens are obtained from the upstream provider (Google, GitHub, etc.)
|
|
145
|
+
and are stored encrypted at rest. They are never exposed to MCP clients.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
upstream_token_id: str # Unique ID for this token set
|
|
149
|
+
access_token: bytes # Encrypted upstream access token
|
|
150
|
+
refresh_token: bytes | None # Encrypted upstream refresh token
|
|
151
|
+
refresh_token_expires_at: (
|
|
152
|
+
float | None
|
|
153
|
+
) # Unix timestamp when refresh token expires (if known)
|
|
154
|
+
expires_at: float # Unix timestamp when access token expires
|
|
155
|
+
token_type: str # Usually "Bearer"
|
|
156
|
+
scope: str # Space-separated scopes
|
|
157
|
+
client_id: str # MCP client this is bound to
|
|
158
|
+
created_at: float # Unix timestamp
|
|
159
|
+
raw_token_data: dict[str, Any] = Field(default_factory=dict) # Full token response
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class JTIMapping(BaseModel):
|
|
163
|
+
"""Maps FastMCP token JTI to upstream token ID.
|
|
164
|
+
|
|
165
|
+
This allows stateless JWT validation while still being able to look up
|
|
166
|
+
the corresponding upstream token when tools need to access upstream APIs.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
jti: str # JWT ID from FastMCP-issued token
|
|
170
|
+
upstream_token_id: str # References UpstreamTokenSet
|
|
171
|
+
created_at: float # Unix timestamp
|
|
172
|
+
|
|
173
|
+
|
|
60
174
|
class ProxyDCRClient(OAuthClientInformationFull):
|
|
61
175
|
"""Client for DCR proxy with configurable redirect URI validation.
|
|
62
176
|
|
|
@@ -83,18 +197,8 @@ class ProxyDCRClient(OAuthClientInformationFull):
|
|
|
83
197
|
arise from accepting arbitrary redirect URIs.
|
|
84
198
|
"""
|
|
85
199
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
):
|
|
89
|
-
"""Initialize with allowed redirect URI patterns.
|
|
90
|
-
|
|
91
|
-
Args:
|
|
92
|
-
allowed_redirect_uri_patterns: List of allowed redirect URI patterns with wildcard support.
|
|
93
|
-
If None, defaults to localhost-only patterns.
|
|
94
|
-
If empty list, allows all redirect URIs.
|
|
95
|
-
"""
|
|
96
|
-
super().__init__(*args, **kwargs)
|
|
97
|
-
self._allowed_redirect_uri_patterns = allowed_redirect_uri_patterns
|
|
200
|
+
allowed_redirect_uri_patterns: list[str] | None = Field(default=None)
|
|
201
|
+
client_name: str | None = Field(default=None)
|
|
98
202
|
|
|
99
203
|
def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
|
|
100
204
|
"""Validate redirect URI against allowed patterns.
|
|
@@ -106,7 +210,10 @@ class ProxyDCRClient(OAuthClientInformationFull):
|
|
|
106
210
|
"""
|
|
107
211
|
if redirect_uri is not None:
|
|
108
212
|
# Validate against allowed patterns
|
|
109
|
-
if validate_redirect_uri(
|
|
213
|
+
if validate_redirect_uri(
|
|
214
|
+
redirect_uri=redirect_uri,
|
|
215
|
+
allowed_patterns=self.allowed_redirect_uri_patterns,
|
|
216
|
+
):
|
|
110
217
|
return redirect_uri
|
|
111
218
|
# Fall back to normal validation if not in allowed patterns
|
|
112
219
|
return super().validate_redirect_uri(redirect_uri)
|
|
@@ -114,12 +221,173 @@ class ProxyDCRClient(OAuthClientInformationFull):
|
|
|
114
221
|
return super().validate_redirect_uri(redirect_uri)
|
|
115
222
|
|
|
116
223
|
|
|
117
|
-
#
|
|
118
|
-
|
|
119
|
-
|
|
224
|
+
# -------------------------------------------------------------------------
|
|
225
|
+
# Helper Functions
|
|
226
|
+
# -------------------------------------------------------------------------
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def create_consent_html(
|
|
230
|
+
client_id: str,
|
|
231
|
+
redirect_uri: str,
|
|
232
|
+
scopes: list[str],
|
|
233
|
+
txn_id: str,
|
|
234
|
+
csrf_token: str,
|
|
235
|
+
client_name: str | None = None,
|
|
236
|
+
title: str = "Authorization Consent",
|
|
237
|
+
server_name: str | None = None,
|
|
238
|
+
server_icon_url: str | None = None,
|
|
239
|
+
server_website_url: str | None = None,
|
|
240
|
+
) -> str:
|
|
241
|
+
"""Create a styled HTML consent page for OAuth authorization requests."""
|
|
242
|
+
# Format scopes for display
|
|
243
|
+
scopes_display = ", ".join(scopes) if scopes else "None"
|
|
244
|
+
|
|
245
|
+
# Build warning box with client name if available
|
|
246
|
+
import html as html_module
|
|
247
|
+
|
|
248
|
+
client_display = html_module.escape(client_name or client_id)
|
|
249
|
+
server_name_escaped = html_module.escape(server_name or "FastMCP")
|
|
250
|
+
|
|
251
|
+
# Make server name a hyperlink if website URL is available
|
|
252
|
+
if server_website_url:
|
|
253
|
+
website_url_escaped = html_module.escape(server_website_url)
|
|
254
|
+
server_display = f'<a href="{website_url_escaped}" target="_blank" rel="noopener noreferrer">{server_name_escaped}</a>'
|
|
255
|
+
else:
|
|
256
|
+
server_display = server_name_escaped
|
|
257
|
+
|
|
258
|
+
warning_box = f"""
|
|
259
|
+
<div class="warning-box">
|
|
260
|
+
<p><strong>{client_display}</strong> is requesting access to <strong>{server_display}</strong>.</p>
|
|
261
|
+
<p>Review the details below before approving.</p>
|
|
262
|
+
</div>
|
|
263
|
+
"""
|
|
120
264
|
|
|
121
|
-
#
|
|
122
|
-
|
|
265
|
+
# Build detail box with client information
|
|
266
|
+
detail_rows = []
|
|
267
|
+
if client_name:
|
|
268
|
+
detail_rows.append(("Client Name", client_name))
|
|
269
|
+
detail_rows.extend(
|
|
270
|
+
[
|
|
271
|
+
("Client ID", client_id),
|
|
272
|
+
("Redirect URI", redirect_uri),
|
|
273
|
+
("Requested Scopes", scopes_display),
|
|
274
|
+
]
|
|
275
|
+
)
|
|
276
|
+
detail_box = create_detail_box(detail_rows)
|
|
277
|
+
|
|
278
|
+
# Build form with buttons
|
|
279
|
+
form = f"""
|
|
280
|
+
<form id="consentForm" method="POST" action="/consent/submit">
|
|
281
|
+
<input type="hidden" name="txn_id" value="{txn_id}" />
|
|
282
|
+
<input type="hidden" name="csrf_token" value="{csrf_token}" />
|
|
283
|
+
<div class="button-group">
|
|
284
|
+
<button type="submit" name="action" value="approve" class="btn-approve">Approve</button>
|
|
285
|
+
<button type="submit" name="action" value="deny" class="btn-deny">Deny</button>
|
|
286
|
+
</div>
|
|
287
|
+
</form>
|
|
288
|
+
"""
|
|
289
|
+
|
|
290
|
+
# Build help link with tooltip
|
|
291
|
+
help_link = """
|
|
292
|
+
<div class="help-link-container">
|
|
293
|
+
<span class="help-link">
|
|
294
|
+
Why am I seeing this?
|
|
295
|
+
<span class="tooltip">
|
|
296
|
+
This FastMCP server requires your consent to allow a new client
|
|
297
|
+
to connect. This protects you from <a
|
|
298
|
+
href="https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#confused-deputy-problem"
|
|
299
|
+
target="_blank" class="tooltip-link">confused deputy
|
|
300
|
+
attacks</a>, where malicious clients could impersonate you
|
|
301
|
+
and steal access.<br><br>
|
|
302
|
+
<a
|
|
303
|
+
href="https://gofastmcp.com/servers/auth/oauth-proxy#confused-deputy-attacks"
|
|
304
|
+
target="_blank" class="tooltip-link">Learn more about
|
|
305
|
+
FastMCP security →</a>
|
|
306
|
+
</span>
|
|
307
|
+
</span>
|
|
308
|
+
</div>
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
# Build the page content
|
|
312
|
+
content = f"""
|
|
313
|
+
<div class="container">
|
|
314
|
+
{create_logo(icon_url=server_icon_url, alt_text=server_name or "FastMCP")}
|
|
315
|
+
<h1>Authorization Consent</h1>
|
|
316
|
+
{warning_box}
|
|
317
|
+
{detail_box}
|
|
318
|
+
{form}
|
|
319
|
+
</div>
|
|
320
|
+
{help_link}
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
# Additional styles needed for this page
|
|
324
|
+
additional_styles = (
|
|
325
|
+
INFO_BOX_STYLES + DETAIL_BOX_STYLES + BUTTON_STYLES + TOOLTIP_STYLES
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# Need to allow form-action for form submission
|
|
329
|
+
csp_policy = "default-src 'none'; style-src 'unsafe-inline'; img-src https:; base-uri 'none'; form-action *"
|
|
330
|
+
|
|
331
|
+
return create_page(
|
|
332
|
+
content=content,
|
|
333
|
+
title=title,
|
|
334
|
+
additional_styles=additional_styles,
|
|
335
|
+
csp_policy=csp_policy,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
# -------------------------------------------------------------------------
|
|
340
|
+
# Handler Classes
|
|
341
|
+
# -------------------------------------------------------------------------
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
class TokenHandler(_SDKTokenHandler):
|
|
345
|
+
"""TokenHandler that returns OAuth 2.1 compliant error responses.
|
|
346
|
+
|
|
347
|
+
The MCP SDK always returns HTTP 400 for all client authentication issues.
|
|
348
|
+
However, OAuth 2.1 Section 5.3 and the MCP specification require that
|
|
349
|
+
invalid or expired tokens MUST receive a HTTP 401 response.
|
|
350
|
+
|
|
351
|
+
This handler extends the base MCP SDK TokenHandler to transform client
|
|
352
|
+
authentication failures into OAuth 2.1 compliant responses:
|
|
353
|
+
- Changes 'unauthorized_client' to 'invalid_client' error code
|
|
354
|
+
- Returns HTTP 401 status code instead of 400 for client auth failures
|
|
355
|
+
|
|
356
|
+
Per OAuth 2.1 Section 5.3: "The authorization server MAY return an HTTP 401
|
|
357
|
+
(Unauthorized) status code to indicate which HTTP authentication schemes
|
|
358
|
+
are supported."
|
|
359
|
+
|
|
360
|
+
Per MCP spec: "Invalid or expired tokens MUST receive a HTTP 401 response."
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
|
|
364
|
+
"""Override response method to provide OAuth 2.1 compliant error handling."""
|
|
365
|
+
# Check if this is a client authentication failure (not just unauthorized for grant type)
|
|
366
|
+
# unauthorized_client can mean two things:
|
|
367
|
+
# 1. Client authentication failed (client_id not found or wrong credentials) -> invalid_client 401
|
|
368
|
+
# 2. Client not authorized for this grant type -> unauthorized_client 400 (correct per spec)
|
|
369
|
+
if (
|
|
370
|
+
isinstance(obj, TokenErrorResponse)
|
|
371
|
+
and obj.error == "unauthorized_client"
|
|
372
|
+
and obj.error_description
|
|
373
|
+
and "Invalid client_id" in obj.error_description
|
|
374
|
+
):
|
|
375
|
+
# Transform client auth failure to OAuth 2.1 compliant response
|
|
376
|
+
return PydanticJSONResponse(
|
|
377
|
+
content=TokenErrorResponse(
|
|
378
|
+
error="invalid_client",
|
|
379
|
+
error_description=obj.error_description,
|
|
380
|
+
error_uri=obj.error_uri,
|
|
381
|
+
),
|
|
382
|
+
status_code=401,
|
|
383
|
+
headers={
|
|
384
|
+
"Cache-Control": "no-store",
|
|
385
|
+
"Pragma": "no-cache",
|
|
386
|
+
},
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# Otherwise use default behavior from parent class
|
|
390
|
+
return super().response(obj)
|
|
123
391
|
|
|
124
392
|
|
|
125
393
|
class OAuthProxy(OAuthProvider):
|
|
@@ -201,7 +469,6 @@ class OAuthProxy(OAuthProvider):
|
|
|
201
469
|
State Management
|
|
202
470
|
---------------
|
|
203
471
|
The proxy maintains minimal but crucial state:
|
|
204
|
-
- _clients: DCR registrations (all use ProxyDCRClient for flexibility)
|
|
205
472
|
- _oauth_transactions: Active authorization flows with client context
|
|
206
473
|
- _client_codes: Authorization codes with PKCE challenges and upstream tokens
|
|
207
474
|
- _access_tokens, _refresh_tokens: Token storage for revocation
|
|
@@ -257,7 +524,11 @@ class OAuthProxy(OAuthProvider):
|
|
|
257
524
|
# Extra parameters to forward to token endpoint
|
|
258
525
|
extra_token_params: dict[str, str] | None = None,
|
|
259
526
|
# Client storage
|
|
260
|
-
client_storage:
|
|
527
|
+
client_storage: AsyncKeyValue | None = None,
|
|
528
|
+
# JWT signing key (optional, ephemeral if not provided)
|
|
529
|
+
jwt_signing_key: str | bytes | None = None,
|
|
530
|
+
# Token encryption key (optional, ephemeral if not provided)
|
|
531
|
+
token_encryption_key: str | bytes | None = None,
|
|
261
532
|
):
|
|
262
533
|
"""Initialize the OAuth proxy provider.
|
|
263
534
|
|
|
@@ -291,9 +562,13 @@ class OAuthProxy(OAuthProvider):
|
|
|
291
562
|
Example: {"audience": "https://api.example.com"}
|
|
292
563
|
extra_token_params: Additional parameters to forward to the upstream token endpoint.
|
|
293
564
|
Useful for provider-specific parameters during token exchange.
|
|
294
|
-
client_storage:
|
|
295
|
-
|
|
296
|
-
|
|
565
|
+
client_storage: An AsyncKeyValue-compatible store for client registrations, registrations are stored in memory if not provided
|
|
566
|
+
jwt_signing_key: Optional secret for signing FastMCP JWT tokens (accepts any string or bytes).
|
|
567
|
+
Default: ephemeral (random salt at startup, won't survive restart).
|
|
568
|
+
Production: provide explicit key from environment variable.
|
|
569
|
+
token_encryption_key: Optional secret for encrypting upstream tokens at rest (accepts any string or bytes).
|
|
570
|
+
Default: ephemeral (random salt at startup, won't survive restart).
|
|
571
|
+
Production: provide explicit key from environment variable.
|
|
297
572
|
"""
|
|
298
573
|
# Always enable DCR since we implement it locally for MCP clients
|
|
299
574
|
client_registration_options = ClientRegistrationOptions(
|
|
@@ -330,6 +605,15 @@ class OAuthProxy(OAuthProvider):
|
|
|
330
605
|
self._redirect_path = (
|
|
331
606
|
redirect_path if redirect_path.startswith("/") else f"/{redirect_path}"
|
|
332
607
|
)
|
|
608
|
+
|
|
609
|
+
if (
|
|
610
|
+
isinstance(allowed_client_redirect_uris, list)
|
|
611
|
+
and not allowed_client_redirect_uris
|
|
612
|
+
):
|
|
613
|
+
logger.warning(
|
|
614
|
+
"allowed_client_redirect_uris is empty list; no redirect URIs will be accepted. "
|
|
615
|
+
"This will block all OAuth clients."
|
|
616
|
+
)
|
|
333
617
|
self._allowed_client_redirect_uris = allowed_client_redirect_uris
|
|
334
618
|
|
|
335
619
|
# PKCE configuration
|
|
@@ -342,11 +626,69 @@ class OAuthProxy(OAuthProvider):
|
|
|
342
626
|
self._extra_authorize_params = extra_authorize_params or {}
|
|
343
627
|
self._extra_token_params = extra_token_params or {}
|
|
344
628
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
629
|
+
self._client_storage: AsyncKeyValue = client_storage or MemoryStore()
|
|
630
|
+
|
|
631
|
+
# Warn if using MemoryStore in production
|
|
632
|
+
if isinstance(client_storage, MemoryStore):
|
|
633
|
+
logger.warning(
|
|
634
|
+
"Using in-memory storage - all OAuth state (clients, tokens) will be lost on restart. "
|
|
635
|
+
"Additionally, without explicit jwt_signing_key and token_encryption_key, "
|
|
636
|
+
"keys are ephemeral and tokens won't survive restart even with persistent storage. "
|
|
637
|
+
"For production, configure persistent storage AND explicit keys."
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
# Cache HTTPS check to avoid repeated logging
|
|
641
|
+
self._is_https = str(self.base_url).startswith("https://")
|
|
642
|
+
if not self._is_https:
|
|
643
|
+
logger.warning(
|
|
644
|
+
"Using non-secure cookies for development; deploy with HTTPS for production."
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
self._client_store = PydanticAdapter[ProxyDCRClient](
|
|
648
|
+
key_value=self._client_storage,
|
|
649
|
+
pydantic_model=ProxyDCRClient,
|
|
650
|
+
default_collection="mcp-oauth-proxy-clients",
|
|
651
|
+
raise_on_validation_error=True,
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
# OAuth transaction storage for IdP callback forwarding
|
|
655
|
+
# Reuse client_storage with different collections for state management
|
|
656
|
+
self._transaction_store = PydanticAdapter[OAuthTransaction](
|
|
657
|
+
key_value=self._client_storage,
|
|
658
|
+
pydantic_model=OAuthTransaction,
|
|
659
|
+
default_collection="mcp-oauth-transactions",
|
|
660
|
+
raise_on_validation_error=True,
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
self._code_store = PydanticAdapter[ClientCode](
|
|
664
|
+
key_value=self._client_storage,
|
|
665
|
+
pydantic_model=ClientCode,
|
|
666
|
+
default_collection="mcp-authorization-codes",
|
|
667
|
+
raise_on_validation_error=True,
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
# Storage for upstream tokens (encrypted at rest)
|
|
671
|
+
self._upstream_token_store = PydanticAdapter[UpstreamTokenSet](
|
|
672
|
+
key_value=self._client_storage,
|
|
673
|
+
pydantic_model=UpstreamTokenSet,
|
|
674
|
+
default_collection="mcp-upstream-tokens",
|
|
675
|
+
raise_on_validation_error=True,
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
# Storage for JTI mappings (FastMCP token -> upstream token)
|
|
679
|
+
self._jti_mapping_store = PydanticAdapter[JTIMapping](
|
|
680
|
+
key_value=self._client_storage,
|
|
681
|
+
pydantic_model=JTIMapping,
|
|
682
|
+
default_collection="mcp-jti-mappings",
|
|
683
|
+
raise_on_validation_error=True,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
# JWT issuer and encryption (initialized lazily on first use)
|
|
687
|
+
self._custom_jwt_key = jwt_signing_key
|
|
688
|
+
self._custom_encryption_key = token_encryption_key
|
|
689
|
+
self._jwt_issuer: JWTIssuer | None = None
|
|
690
|
+
self._token_encryption: TokenEncryption | None = None
|
|
691
|
+
self._jwt_initialized = False
|
|
350
692
|
|
|
351
693
|
# Local state for token bookkeeping only (no client caching)
|
|
352
694
|
self._access_tokens: dict[str, AccessToken] = {}
|
|
@@ -356,12 +698,6 @@ class OAuthProxy(OAuthProvider):
|
|
|
356
698
|
self._access_to_refresh: dict[str, str] = {}
|
|
357
699
|
self._refresh_to_access: dict[str, str] = {}
|
|
358
700
|
|
|
359
|
-
# OAuth transaction storage for IdP callback forwarding
|
|
360
|
-
self._oauth_transactions: dict[
|
|
361
|
-
str, dict[str, Any]
|
|
362
|
-
] = {} # txn_id -> transaction_data
|
|
363
|
-
self._client_codes: dict[str, dict[str, Any]] = {} # client_code -> code_data
|
|
364
|
-
|
|
365
701
|
# Use the provided token validator
|
|
366
702
|
self._token_validator = token_verifier
|
|
367
703
|
|
|
@@ -389,6 +725,87 @@ class OAuthProxy(OAuthProvider):
|
|
|
389
725
|
|
|
390
726
|
return code_verifier, code_challenge
|
|
391
727
|
|
|
728
|
+
# -------------------------------------------------------------------------
|
|
729
|
+
# JWT Token Factory Initialization
|
|
730
|
+
# -------------------------------------------------------------------------
|
|
731
|
+
|
|
732
|
+
async def _ensure_jwt_initialized(self) -> None:
|
|
733
|
+
"""Initialize JWT issuer and token encryption (lazy initialization).
|
|
734
|
+
|
|
735
|
+
Key derivation strategy:
|
|
736
|
+
- Default: Generate random salt at startup, derive ephemeral keys
|
|
737
|
+
→ Keys change on restart, all tokens become invalid
|
|
738
|
+
→ Perfect for development/testing where re-auth is acceptable
|
|
739
|
+
|
|
740
|
+
- Production: User provides explicit keys via parameters
|
|
741
|
+
→ Keys stable across restarts when combined with persistent storage
|
|
742
|
+
→ Tokens survive restart, seamless client reconnection
|
|
743
|
+
"""
|
|
744
|
+
if self._jwt_initialized:
|
|
745
|
+
return
|
|
746
|
+
|
|
747
|
+
# Generate random salt for this server instance (NOT persisted)
|
|
748
|
+
server_salt = secrets.token_urlsafe(32)
|
|
749
|
+
|
|
750
|
+
# Derive or use custom JWT signing key
|
|
751
|
+
from fastmcp.server.auth.jwt_issuer import derive_key_from_secret
|
|
752
|
+
|
|
753
|
+
if self._custom_jwt_key:
|
|
754
|
+
jwt_key = derive_key_from_secret(
|
|
755
|
+
secret=self._custom_jwt_key,
|
|
756
|
+
salt="fastmcp-jwt-signing-v1",
|
|
757
|
+
info=b"HS256",
|
|
758
|
+
)
|
|
759
|
+
logger.info("Using explicit JWT signing key (will survive restarts)")
|
|
760
|
+
else:
|
|
761
|
+
# Ephemeral key from random salt + upstream secret
|
|
762
|
+
upstream_secret = self._upstream_client_secret.get_secret_value()
|
|
763
|
+
jwt_key = derive_key_from_secret(
|
|
764
|
+
secret=upstream_secret,
|
|
765
|
+
salt=f"fastmcp-jwt-signing-v1-{server_salt}",
|
|
766
|
+
info=b"HS256",
|
|
767
|
+
)
|
|
768
|
+
logger.info(
|
|
769
|
+
"Using ephemeral JWT signing key - tokens will NOT survive server restart. "
|
|
770
|
+
"For production, provide explicit jwt_signing_key parameter and use persistent storage."
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
# Initialize JWT issuer
|
|
774
|
+
issuer = str(self.base_url)
|
|
775
|
+
audience = f"{str(self.base_url).rstrip('/')}/mcp"
|
|
776
|
+
self._jwt_issuer = JWTIssuer(
|
|
777
|
+
issuer=issuer,
|
|
778
|
+
audience=audience,
|
|
779
|
+
signing_key=jwt_key,
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
# Derive or use custom encryption key
|
|
783
|
+
if self._custom_encryption_key:
|
|
784
|
+
encryption_key = derive_key_from_secret(
|
|
785
|
+
secret=self._custom_encryption_key,
|
|
786
|
+
salt="fastmcp-token-encryption-v1",
|
|
787
|
+
info=b"Fernet",
|
|
788
|
+
)
|
|
789
|
+
# Fernet needs base64url-encoded key
|
|
790
|
+
encryption_key = base64.urlsafe_b64encode(encryption_key)
|
|
791
|
+
logger.info("Using explicit token encryption key (will survive restarts)")
|
|
792
|
+
else:
|
|
793
|
+
# Ephemeral key from random salt + upstream secret
|
|
794
|
+
upstream_secret = self._upstream_client_secret.get_secret_value()
|
|
795
|
+
key_material = derive_key_from_secret(
|
|
796
|
+
secret=upstream_secret,
|
|
797
|
+
salt=f"fastmcp-token-encryption-v1-{server_salt}",
|
|
798
|
+
info=b"Fernet",
|
|
799
|
+
)
|
|
800
|
+
encryption_key = base64.urlsafe_b64encode(key_material)
|
|
801
|
+
logger.info(
|
|
802
|
+
"Using ephemeral token encryption key - encrypted tokens will NOT survive server restart. "
|
|
803
|
+
"For production, provide explicit token_encryption_key parameter and use persistent storage."
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
self._token_encryption = TokenEncryption(encryption_key)
|
|
807
|
+
self._jwt_initialized = True
|
|
808
|
+
|
|
392
809
|
# -------------------------------------------------------------------------
|
|
393
810
|
# Client Registration (Local Implementation)
|
|
394
811
|
# -------------------------------------------------------------------------
|
|
@@ -400,19 +817,13 @@ class OAuthProxy(OAuthProvider):
|
|
|
400
817
|
For unregistered clients, returns None (which will raise an error in the SDK).
|
|
401
818
|
"""
|
|
402
819
|
# Load from storage
|
|
403
|
-
|
|
404
|
-
if not data:
|
|
820
|
+
if not (client := await self._client_store.get(key=client_id)):
|
|
405
821
|
return None
|
|
406
822
|
|
|
407
|
-
if
|
|
408
|
-
|
|
409
|
-
allowed_redirect_uri_patterns=data.get(
|
|
410
|
-
"allowed_redirect_uri_patterns", self._allowed_client_redirect_uris
|
|
411
|
-
),
|
|
412
|
-
**client_data,
|
|
413
|
-
)
|
|
823
|
+
if client.allowed_redirect_uri_patterns is None:
|
|
824
|
+
client.allowed_redirect_uri_patterns = self._allowed_client_redirect_uris
|
|
414
825
|
|
|
415
|
-
return
|
|
826
|
+
return client
|
|
416
827
|
|
|
417
828
|
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
|
|
418
829
|
"""Register a client locally
|
|
@@ -424,7 +835,7 @@ class OAuthProxy(OAuthProvider):
|
|
|
424
835
|
"""
|
|
425
836
|
|
|
426
837
|
# Create a ProxyDCRClient with configured redirect URI validation
|
|
427
|
-
proxy_client = ProxyDCRClient(
|
|
838
|
+
proxy_client: ProxyDCRClient = ProxyDCRClient(
|
|
428
839
|
client_id=client_info.client_id,
|
|
429
840
|
client_secret=client_info.client_secret,
|
|
430
841
|
redirect_uris=client_info.redirect_uris or [AnyUrl("http://localhost")],
|
|
@@ -433,14 +844,13 @@ class OAuthProxy(OAuthProvider):
|
|
|
433
844
|
scope=client_info.scope or self._default_scope_str,
|
|
434
845
|
token_endpoint_auth_method="none",
|
|
435
846
|
allowed_redirect_uri_patterns=self._allowed_client_redirect_uris,
|
|
847
|
+
client_name=getattr(client_info, "client_name", None),
|
|
436
848
|
)
|
|
437
849
|
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
}
|
|
443
|
-
await self._client_storage.set(client_info.client_id, storage_data)
|
|
850
|
+
await self._client_store.put(
|
|
851
|
+
key=client_info.client_id,
|
|
852
|
+
value=proxy_client,
|
|
853
|
+
)
|
|
444
854
|
|
|
445
855
|
# Log redirect URIs to help users discover what patterns they might need
|
|
446
856
|
if client_info.redirect_uris:
|
|
@@ -466,13 +876,12 @@ class OAuthProxy(OAuthProvider):
|
|
|
466
876
|
client: OAuthClientInformationFull,
|
|
467
877
|
params: AuthorizationParams,
|
|
468
878
|
) -> str:
|
|
469
|
-
"""Start OAuth transaction and
|
|
879
|
+
"""Start OAuth transaction and route through consent interstitial.
|
|
470
880
|
|
|
471
|
-
|
|
472
|
-
1. Store transaction with client details and PKCE
|
|
473
|
-
2.
|
|
474
|
-
3.
|
|
475
|
-
4. Redirect to IdP with our fixed callback URL and proxy's PKCE
|
|
881
|
+
Flow:
|
|
882
|
+
1. Store transaction with client details and PKCE (if forwarding)
|
|
883
|
+
2. Return local /consent URL; browser visits consent first
|
|
884
|
+
3. Consent handler redirects to upstream IdP if approved/already approved
|
|
476
885
|
"""
|
|
477
886
|
# Generate transaction ID for this authorization request
|
|
478
887
|
txn_id = secrets.token_urlsafe(32)
|
|
@@ -488,75 +897,32 @@ class OAuthProxy(OAuthProvider):
|
|
|
488
897
|
)
|
|
489
898
|
|
|
490
899
|
# Store transaction data for IdP callback processing
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
# Build query parameters for upstream IdP authorization request
|
|
508
|
-
# Use our fixed IdP callback and transaction ID as state
|
|
509
|
-
query_params: dict[str, Any] = {
|
|
510
|
-
"response_type": "code",
|
|
511
|
-
"client_id": self._upstream_client_id,
|
|
512
|
-
"redirect_uri": f"{str(self.base_url).rstrip('/')}{self._redirect_path}",
|
|
513
|
-
"state": txn_id, # Use txn_id as IdP state
|
|
514
|
-
}
|
|
515
|
-
|
|
516
|
-
# Add scopes - use client scopes or fallback to required scopes
|
|
517
|
-
scopes_to_use = params.scopes or self.required_scopes or []
|
|
518
|
-
|
|
519
|
-
if scopes_to_use:
|
|
520
|
-
query_params["scope"] = " ".join(scopes_to_use)
|
|
521
|
-
|
|
522
|
-
# Forward proxy's PKCE challenge to upstream if enabled
|
|
523
|
-
if proxy_code_challenge:
|
|
524
|
-
query_params["code_challenge"] = proxy_code_challenge
|
|
525
|
-
query_params["code_challenge_method"] = "S256"
|
|
526
|
-
logger.debug(
|
|
527
|
-
"Forwarding proxy PKCE challenge to upstream for transaction %s",
|
|
528
|
-
txn_id,
|
|
529
|
-
)
|
|
530
|
-
|
|
531
|
-
# Forward resource parameter if provided (RFC 8707)
|
|
532
|
-
if params.resource:
|
|
533
|
-
query_params["resource"] = params.resource
|
|
534
|
-
logger.debug(
|
|
535
|
-
"Forwarding resource indicator '%s' to upstream for transaction %s",
|
|
536
|
-
params.resource,
|
|
537
|
-
txn_id,
|
|
538
|
-
)
|
|
539
|
-
|
|
540
|
-
# Add any extra authorization parameters configured for this proxy
|
|
541
|
-
if self._extra_authorize_params:
|
|
542
|
-
query_params.update(self._extra_authorize_params)
|
|
543
|
-
logger.debug(
|
|
544
|
-
"Adding extra authorization parameters for transaction %s: %s",
|
|
545
|
-
txn_id,
|
|
546
|
-
list(self._extra_authorize_params.keys()),
|
|
547
|
-
)
|
|
900
|
+
await self._transaction_store.put(
|
|
901
|
+
key=txn_id,
|
|
902
|
+
value=OAuthTransaction(
|
|
903
|
+
txn_id=txn_id,
|
|
904
|
+
client_id=client.client_id,
|
|
905
|
+
client_redirect_uri=str(params.redirect_uri),
|
|
906
|
+
client_state=params.state or "",
|
|
907
|
+
code_challenge=params.code_challenge,
|
|
908
|
+
code_challenge_method=getattr(params, "code_challenge_method", "S256"),
|
|
909
|
+
scopes=params.scopes or [],
|
|
910
|
+
created_at=time.time(),
|
|
911
|
+
resource=getattr(params, "resource", None),
|
|
912
|
+
proxy_code_verifier=proxy_code_verifier,
|
|
913
|
+
),
|
|
914
|
+
ttl=15 * 60, # Auto-expire after 15 minutes
|
|
915
|
+
)
|
|
548
916
|
|
|
549
|
-
|
|
550
|
-
separator = "&" if "?" in self._upstream_authorization_endpoint else "?"
|
|
551
|
-
upstream_url = f"{self._upstream_authorization_endpoint}{separator}{urlencode(query_params)}"
|
|
917
|
+
consent_url = f"{str(self.base_url).rstrip('/')}/consent?txn_id={txn_id}"
|
|
552
918
|
|
|
553
919
|
logger.debug(
|
|
554
|
-
"Starting OAuth transaction %s for client %s, redirecting to
|
|
920
|
+
"Starting OAuth transaction %s for client %s, redirecting to consent page (PKCE forwarding: %s)",
|
|
555
921
|
txn_id,
|
|
556
922
|
client.client_id,
|
|
557
923
|
"enabled" if proxy_code_challenge else "disabled",
|
|
558
924
|
)
|
|
559
|
-
return
|
|
925
|
+
return consent_url
|
|
560
926
|
|
|
561
927
|
# -------------------------------------------------------------------------
|
|
562
928
|
# Authorization Code Handling
|
|
@@ -573,22 +939,22 @@ class OAuthProxy(OAuthProvider):
|
|
|
573
939
|
with PKCE challenge for validation.
|
|
574
940
|
"""
|
|
575
941
|
# Look up client code data
|
|
576
|
-
|
|
577
|
-
if not
|
|
942
|
+
code_model = await self._code_store.get(key=authorization_code)
|
|
943
|
+
if not code_model:
|
|
578
944
|
logger.debug("Authorization code not found: %s", authorization_code)
|
|
579
945
|
return None
|
|
580
946
|
|
|
581
947
|
# Check if code expired
|
|
582
|
-
if time.time() >
|
|
948
|
+
if time.time() > code_model.expires_at:
|
|
583
949
|
logger.debug("Authorization code expired: %s", authorization_code)
|
|
584
|
-
self.
|
|
950
|
+
await self._code_store.delete(key=authorization_code)
|
|
585
951
|
return None
|
|
586
952
|
|
|
587
953
|
# Verify client ID matches
|
|
588
|
-
if
|
|
954
|
+
if code_model.client_id != client.client_id:
|
|
589
955
|
logger.debug(
|
|
590
956
|
"Authorization code client ID mismatch: %s vs %s",
|
|
591
|
-
|
|
957
|
+
code_model.client_id,
|
|
592
958
|
client.client_id,
|
|
593
959
|
)
|
|
594
960
|
return None
|
|
@@ -597,11 +963,11 @@ class OAuthProxy(OAuthProvider):
|
|
|
597
963
|
return AuthorizationCode(
|
|
598
964
|
code=authorization_code,
|
|
599
965
|
client_id=client.client_id,
|
|
600
|
-
redirect_uri=
|
|
966
|
+
redirect_uri=code_model.redirect_uri,
|
|
601
967
|
redirect_uri_provided_explicitly=True,
|
|
602
|
-
scopes=
|
|
603
|
-
expires_at=
|
|
604
|
-
code_challenge=
|
|
968
|
+
scopes=code_model.scopes,
|
|
969
|
+
expires_at=code_model.expires_at,
|
|
970
|
+
code_challenge=code_model.code_challenge or "",
|
|
605
971
|
)
|
|
606
972
|
|
|
607
973
|
async def exchange_authorization_code(
|
|
@@ -609,63 +975,166 @@ class OAuthProxy(OAuthProvider):
|
|
|
609
975
|
client: OAuthClientInformationFull,
|
|
610
976
|
authorization_code: AuthorizationCode,
|
|
611
977
|
) -> OAuthToken:
|
|
612
|
-
"""Exchange authorization code for
|
|
978
|
+
"""Exchange authorization code for FastMCP-issued tokens.
|
|
613
979
|
|
|
614
|
-
|
|
615
|
-
|
|
980
|
+
Implements the token factory pattern:
|
|
981
|
+
1. Retrieves upstream tokens from stored authorization code
|
|
982
|
+
2. Extracts user identity from upstream token
|
|
983
|
+
3. Encrypts and stores upstream tokens
|
|
984
|
+
4. Issues FastMCP-signed JWT tokens
|
|
985
|
+
5. Returns FastMCP tokens (NOT upstream tokens)
|
|
986
|
+
|
|
987
|
+
PKCE validation is handled by the MCP framework before this method is called.
|
|
616
988
|
"""
|
|
989
|
+
# Ensure JWT issuer is initialized
|
|
990
|
+
await self._ensure_jwt_initialized()
|
|
991
|
+
assert self._jwt_issuer is not None
|
|
992
|
+
assert self._token_encryption is not None
|
|
993
|
+
|
|
617
994
|
# Look up stored code data
|
|
618
|
-
|
|
619
|
-
if not
|
|
995
|
+
code_model = await self._code_store.get(key=authorization_code.code)
|
|
996
|
+
if not code_model:
|
|
620
997
|
logger.error(
|
|
621
998
|
"Authorization code not found in client codes: %s",
|
|
622
999
|
authorization_code.code,
|
|
623
1000
|
)
|
|
624
1001
|
raise TokenError("invalid_grant", "Authorization code not found")
|
|
625
1002
|
|
|
626
|
-
# Get stored
|
|
627
|
-
idp_tokens =
|
|
1003
|
+
# Get stored upstream tokens
|
|
1004
|
+
idp_tokens = code_model.idp_tokens
|
|
628
1005
|
|
|
629
1006
|
# Clean up client code (one-time use)
|
|
630
|
-
self.
|
|
1007
|
+
await self._code_store.delete(key=authorization_code.code)
|
|
631
1008
|
|
|
632
|
-
#
|
|
633
|
-
|
|
634
|
-
|
|
1009
|
+
# Generate IDs for token storage
|
|
1010
|
+
upstream_token_id = secrets.token_urlsafe(32)
|
|
1011
|
+
access_jti = secrets.token_urlsafe(32)
|
|
1012
|
+
refresh_jti = (
|
|
1013
|
+
secrets.token_urlsafe(32) if idp_tokens.get("refresh_token") else None
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1016
|
+
# Calculate token expiry times
|
|
635
1017
|
expires_in = int(
|
|
636
1018
|
idp_tokens.get("expires_in", DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS)
|
|
637
1019
|
)
|
|
638
|
-
expires_at = int(time.time() + expires_in)
|
|
639
1020
|
|
|
640
|
-
#
|
|
641
|
-
|
|
642
|
-
|
|
1021
|
+
# Calculate refresh token expiry if provided by upstream
|
|
1022
|
+
# Some providers include refresh_expires_in, some don't
|
|
1023
|
+
refresh_expires_in = None
|
|
1024
|
+
refresh_token_expires_at = None
|
|
1025
|
+
if idp_tokens.get("refresh_token"):
|
|
1026
|
+
if "refresh_expires_in" in idp_tokens:
|
|
1027
|
+
refresh_expires_in = int(idp_tokens["refresh_expires_in"])
|
|
1028
|
+
refresh_token_expires_at = time.time() + refresh_expires_in
|
|
1029
|
+
logger.debug(
|
|
1030
|
+
"Upstream refresh token expires in %d seconds", refresh_expires_in
|
|
1031
|
+
)
|
|
1032
|
+
else:
|
|
1033
|
+
# Default to 30 days if upstream doesn't specify
|
|
1034
|
+
# This is conservative - most providers use longer expiry
|
|
1035
|
+
refresh_expires_in = 60 * 60 * 24 * 30 # 30 days
|
|
1036
|
+
refresh_token_expires_at = time.time() + refresh_expires_in
|
|
1037
|
+
logger.debug(
|
|
1038
|
+
"Upstream refresh token expiry unknown, using 30-day default"
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
# Encrypt and store upstream tokens
|
|
1042
|
+
upstream_token_set = UpstreamTokenSet(
|
|
1043
|
+
upstream_token_id=upstream_token_id,
|
|
1044
|
+
access_token=self._token_encryption.encrypt(idp_tokens["access_token"]),
|
|
1045
|
+
refresh_token=self._token_encryption.encrypt(idp_tokens["refresh_token"])
|
|
1046
|
+
if idp_tokens.get("refresh_token")
|
|
1047
|
+
else None,
|
|
1048
|
+
refresh_token_expires_at=refresh_token_expires_at,
|
|
1049
|
+
expires_at=time.time() + expires_in,
|
|
1050
|
+
token_type=idp_tokens.get("token_type", "Bearer"),
|
|
1051
|
+
scope=" ".join(authorization_code.scopes),
|
|
1052
|
+
client_id=client.client_id,
|
|
1053
|
+
created_at=time.time(),
|
|
1054
|
+
raw_token_data=idp_tokens,
|
|
1055
|
+
)
|
|
1056
|
+
await self._upstream_token_store.put(
|
|
1057
|
+
key=upstream_token_id,
|
|
1058
|
+
value=upstream_token_set,
|
|
1059
|
+
ttl=expires_in, # Auto-expire when access token expires
|
|
1060
|
+
)
|
|
1061
|
+
logger.debug("Stored encrypted upstream tokens (jti=%s)", access_jti[:8])
|
|
1062
|
+
|
|
1063
|
+
# Issue minimal FastMCP access token (just a reference via JTI)
|
|
1064
|
+
fastmcp_access_token = self._jwt_issuer.issue_access_token(
|
|
643
1065
|
client_id=client.client_id,
|
|
644
1066
|
scopes=authorization_code.scopes,
|
|
645
|
-
|
|
1067
|
+
jti=access_jti,
|
|
1068
|
+
expires_in=expires_in,
|
|
646
1069
|
)
|
|
647
|
-
self._access_tokens[access_token_value] = access_token
|
|
648
1070
|
|
|
649
|
-
#
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
1071
|
+
# Issue minimal FastMCP refresh token if upstream provided one
|
|
1072
|
+
# Use upstream refresh token expiry to align lifetimes
|
|
1073
|
+
fastmcp_refresh_token = None
|
|
1074
|
+
if refresh_jti and refresh_expires_in:
|
|
1075
|
+
fastmcp_refresh_token = self._jwt_issuer.issue_refresh_token(
|
|
653
1076
|
client_id=client.client_id,
|
|
654
1077
|
scopes=authorization_code.scopes,
|
|
655
|
-
|
|
1078
|
+
jti=refresh_jti,
|
|
1079
|
+
expires_in=refresh_expires_in,
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
# Store JTI mappings
|
|
1083
|
+
await self._jti_mapping_store.put(
|
|
1084
|
+
key=access_jti,
|
|
1085
|
+
value=JTIMapping(
|
|
1086
|
+
jti=access_jti,
|
|
1087
|
+
upstream_token_id=upstream_token_id,
|
|
1088
|
+
created_at=time.time(),
|
|
1089
|
+
),
|
|
1090
|
+
ttl=expires_in, # Auto-expire with access token
|
|
1091
|
+
)
|
|
1092
|
+
if refresh_jti:
|
|
1093
|
+
await self._jti_mapping_store.put(
|
|
1094
|
+
key=refresh_jti,
|
|
1095
|
+
value=JTIMapping(
|
|
1096
|
+
jti=refresh_jti,
|
|
1097
|
+
upstream_token_id=upstream_token_id,
|
|
1098
|
+
created_at=time.time(),
|
|
1099
|
+
),
|
|
1100
|
+
ttl=60 * 60 * 24 * 30, # Auto-expire with refresh token (30 days)
|
|
656
1101
|
)
|
|
657
|
-
self._refresh_tokens[refresh_token_value] = refresh_token
|
|
658
1102
|
|
|
1103
|
+
# Store FastMCP access token for MCP framework validation
|
|
1104
|
+
self._access_tokens[fastmcp_access_token] = AccessToken(
|
|
1105
|
+
token=fastmcp_access_token,
|
|
1106
|
+
client_id=client.client_id,
|
|
1107
|
+
scopes=authorization_code.scopes,
|
|
1108
|
+
expires_at=int(time.time() + expires_in),
|
|
1109
|
+
)
|
|
1110
|
+
|
|
1111
|
+
# Store FastMCP refresh token if provided
|
|
1112
|
+
if fastmcp_refresh_token:
|
|
1113
|
+
self._refresh_tokens[fastmcp_refresh_token] = RefreshToken(
|
|
1114
|
+
token=fastmcp_refresh_token,
|
|
1115
|
+
client_id=client.client_id,
|
|
1116
|
+
scopes=authorization_code.scopes,
|
|
1117
|
+
expires_at=None,
|
|
1118
|
+
)
|
|
659
1119
|
# Maintain token relationships for cleanup
|
|
660
|
-
self._access_to_refresh[
|
|
661
|
-
self._refresh_to_access[
|
|
1120
|
+
self._access_to_refresh[fastmcp_access_token] = fastmcp_refresh_token
|
|
1121
|
+
self._refresh_to_access[fastmcp_refresh_token] = fastmcp_access_token
|
|
662
1122
|
|
|
663
1123
|
logger.debug(
|
|
664
|
-
"
|
|
1124
|
+
"Issued FastMCP tokens for client=%s (access_jti=%s, refresh_jti=%s)",
|
|
665
1125
|
client.client_id,
|
|
1126
|
+
access_jti[:8],
|
|
1127
|
+
refresh_jti[:8] if refresh_jti else "none",
|
|
666
1128
|
)
|
|
667
1129
|
|
|
668
|
-
|
|
1130
|
+
# Return FastMCP-issued tokens (NOT upstream tokens!)
|
|
1131
|
+
return OAuthToken(
|
|
1132
|
+
access_token=fastmcp_access_token,
|
|
1133
|
+
token_type="Bearer",
|
|
1134
|
+
expires_in=expires_in,
|
|
1135
|
+
refresh_token=fastmcp_refresh_token,
|
|
1136
|
+
scope=" ".join(authorization_code.scopes),
|
|
1137
|
+
)
|
|
669
1138
|
|
|
670
1139
|
# -------------------------------------------------------------------------
|
|
671
1140
|
# Refresh Token Flow
|
|
@@ -685,9 +1154,54 @@ class OAuthProxy(OAuthProvider):
|
|
|
685
1154
|
refresh_token: RefreshToken,
|
|
686
1155
|
scopes: list[str],
|
|
687
1156
|
) -> OAuthToken:
|
|
688
|
-
"""Exchange refresh token for new access token
|
|
1157
|
+
"""Exchange FastMCP refresh token for new FastMCP access token.
|
|
1158
|
+
|
|
1159
|
+
Implements two-tier refresh:
|
|
1160
|
+
1. Verify FastMCP refresh token
|
|
1161
|
+
2. Look up upstream token via JTI mapping
|
|
1162
|
+
3. Refresh upstream token with upstream provider
|
|
1163
|
+
4. Update stored upstream token
|
|
1164
|
+
5. Issue new FastMCP access token
|
|
1165
|
+
6. Keep same FastMCP refresh token (unless upstream rotates)
|
|
1166
|
+
"""
|
|
1167
|
+
# Ensure JWT issuer is initialized
|
|
1168
|
+
await self._ensure_jwt_initialized()
|
|
1169
|
+
assert self._jwt_issuer is not None
|
|
1170
|
+
assert self._token_encryption is not None
|
|
1171
|
+
|
|
1172
|
+
# Verify FastMCP refresh token
|
|
1173
|
+
try:
|
|
1174
|
+
refresh_payload = self._jwt_issuer.verify_token(refresh_token.token)
|
|
1175
|
+
refresh_jti = refresh_payload["jti"]
|
|
1176
|
+
except Exception as e:
|
|
1177
|
+
logger.debug("FastMCP refresh token validation failed: %s", e)
|
|
1178
|
+
raise TokenError("invalid_grant", "Invalid refresh token") from e
|
|
1179
|
+
|
|
1180
|
+
# Look up upstream token via JTI mapping
|
|
1181
|
+
jti_mapping = await self._jti_mapping_store.get(key=refresh_jti)
|
|
1182
|
+
if not jti_mapping:
|
|
1183
|
+
logger.error("JTI mapping not found for refresh token: %s", refresh_jti[:8])
|
|
1184
|
+
raise TokenError("invalid_grant", "Refresh token mapping not found")
|
|
1185
|
+
|
|
1186
|
+
upstream_token_set = await self._upstream_token_store.get(
|
|
1187
|
+
key=jti_mapping.upstream_token_id
|
|
1188
|
+
)
|
|
1189
|
+
if not upstream_token_set:
|
|
1190
|
+
logger.error(
|
|
1191
|
+
"Upstream token set not found: %s", jti_mapping.upstream_token_id[:8]
|
|
1192
|
+
)
|
|
1193
|
+
raise TokenError("invalid_grant", "Upstream token not found")
|
|
1194
|
+
|
|
1195
|
+
# Decrypt upstream refresh token
|
|
1196
|
+
if not upstream_token_set.refresh_token:
|
|
1197
|
+
logger.error("No upstream refresh token available")
|
|
1198
|
+
raise TokenError("invalid_grant", "Refresh not supported for this token")
|
|
689
1199
|
|
|
690
|
-
|
|
1200
|
+
upstream_refresh_token = self._token_encryption.decrypt(
|
|
1201
|
+
upstream_token_set.refresh_token
|
|
1202
|
+
)
|
|
1203
|
+
|
|
1204
|
+
# Refresh upstream token using authlib
|
|
691
1205
|
oauth_client = AsyncOAuth2Client(
|
|
692
1206
|
client_id=self._upstream_client_id,
|
|
693
1207
|
client_secret=self._upstream_client_secret.get_secret_value(),
|
|
@@ -696,77 +1210,217 @@ class OAuthProxy(OAuthProvider):
|
|
|
696
1210
|
)
|
|
697
1211
|
|
|
698
1212
|
try:
|
|
699
|
-
logger.debug("
|
|
700
|
-
|
|
701
|
-
# Let authlib handle the refresh token exchange
|
|
1213
|
+
logger.debug("Refreshing upstream token (jti=%s)", refresh_jti[:8])
|
|
702
1214
|
token_response: dict[str, Any] = await oauth_client.refresh_token( # type: ignore[misc]
|
|
703
1215
|
url=self._upstream_token_endpoint,
|
|
704
|
-
refresh_token=
|
|
1216
|
+
refresh_token=upstream_refresh_token,
|
|
705
1217
|
scope=" ".join(scopes) if scopes else None,
|
|
706
1218
|
)
|
|
707
|
-
|
|
708
|
-
logger.debug(
|
|
709
|
-
"Successfully refreshed access token via authlib (client: %s)",
|
|
710
|
-
client.client_id,
|
|
711
|
-
)
|
|
712
|
-
|
|
1219
|
+
logger.debug("Successfully refreshed upstream token")
|
|
713
1220
|
except Exception as e:
|
|
714
|
-
logger.error("
|
|
715
|
-
raise TokenError(
|
|
716
|
-
"invalid_grant", f"Upstream refresh token exchange failed: {e}"
|
|
717
|
-
) from e
|
|
1221
|
+
logger.error("Upstream token refresh failed: %s", e)
|
|
1222
|
+
raise TokenError("invalid_grant", f"Upstream refresh failed: {e}") from e
|
|
718
1223
|
|
|
719
|
-
# Update
|
|
720
|
-
|
|
721
|
-
expires_in = int(
|
|
1224
|
+
# Update stored upstream token
|
|
1225
|
+
new_expires_in = int(
|
|
722
1226
|
token_response.get("expires_in", DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS)
|
|
723
1227
|
)
|
|
1228
|
+
upstream_token_set.access_token = self._token_encryption.encrypt(
|
|
1229
|
+
token_response["access_token"]
|
|
1230
|
+
)
|
|
1231
|
+
upstream_token_set.expires_at = time.time() + new_expires_in
|
|
1232
|
+
|
|
1233
|
+
# Handle upstream refresh token rotation and expiry
|
|
1234
|
+
new_refresh_expires_in = None
|
|
1235
|
+
if new_upstream_refresh := token_response.get("refresh_token"):
|
|
1236
|
+
if new_upstream_refresh != upstream_refresh_token:
|
|
1237
|
+
upstream_token_set.refresh_token = self._token_encryption.encrypt(
|
|
1238
|
+
new_upstream_refresh
|
|
1239
|
+
)
|
|
1240
|
+
logger.debug("Upstream refresh token rotated")
|
|
724
1241
|
|
|
725
|
-
|
|
726
|
-
|
|
1242
|
+
# Update refresh token expiry if provided
|
|
1243
|
+
if "refresh_expires_in" in token_response:
|
|
1244
|
+
new_refresh_expires_in = int(token_response["refresh_expires_in"])
|
|
1245
|
+
upstream_token_set.refresh_token_expires_at = (
|
|
1246
|
+
time.time() + new_refresh_expires_in
|
|
1247
|
+
)
|
|
1248
|
+
logger.debug(
|
|
1249
|
+
"Upstream refresh token expires in %d seconds",
|
|
1250
|
+
new_refresh_expires_in,
|
|
1251
|
+
)
|
|
1252
|
+
elif upstream_token_set.refresh_token_expires_at:
|
|
1253
|
+
# Keep existing expiry if upstream doesn't provide new one
|
|
1254
|
+
new_refresh_expires_in = int(
|
|
1255
|
+
upstream_token_set.refresh_token_expires_at - time.time()
|
|
1256
|
+
)
|
|
1257
|
+
else:
|
|
1258
|
+
# Default to 30 days if unknown
|
|
1259
|
+
new_refresh_expires_in = 60 * 60 * 24 * 30
|
|
1260
|
+
upstream_token_set.refresh_token_expires_at = (
|
|
1261
|
+
time.time() + new_refresh_expires_in
|
|
1262
|
+
)
|
|
1263
|
+
|
|
1264
|
+
upstream_token_set.raw_token_data = token_response
|
|
1265
|
+
await self._upstream_token_store.put(
|
|
1266
|
+
key=upstream_token_set.upstream_token_id,
|
|
1267
|
+
value=upstream_token_set,
|
|
1268
|
+
ttl=new_expires_in, # Auto-expire when refreshed access token expires
|
|
1269
|
+
)
|
|
1270
|
+
|
|
1271
|
+
# Issue new minimal FastMCP access token (just a reference via JTI)
|
|
1272
|
+
new_access_jti = secrets.token_urlsafe(32)
|
|
1273
|
+
new_fastmcp_access = self._jwt_issuer.issue_access_token(
|
|
727
1274
|
client_id=client.client_id,
|
|
728
1275
|
scopes=scopes,
|
|
729
|
-
|
|
1276
|
+
jti=new_access_jti,
|
|
1277
|
+
expires_in=new_expires_in,
|
|
730
1278
|
)
|
|
731
1279
|
|
|
732
|
-
#
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
1280
|
+
# Store new access token JTI mapping
|
|
1281
|
+
await self._jti_mapping_store.put(
|
|
1282
|
+
key=new_access_jti,
|
|
1283
|
+
value=JTIMapping(
|
|
1284
|
+
jti=new_access_jti,
|
|
1285
|
+
upstream_token_id=upstream_token_set.upstream_token_id,
|
|
1286
|
+
created_at=time.time(),
|
|
1287
|
+
),
|
|
1288
|
+
ttl=new_expires_in, # Auto-expire with refreshed access token
|
|
1289
|
+
)
|
|
1290
|
+
|
|
1291
|
+
# Issue NEW minimal FastMCP refresh token (rotation for security)
|
|
1292
|
+
# Use upstream refresh token expiry to align lifetimes
|
|
1293
|
+
new_refresh_jti = secrets.token_urlsafe(32)
|
|
1294
|
+
new_fastmcp_refresh = self._jwt_issuer.issue_refresh_token(
|
|
1295
|
+
client_id=client.client_id,
|
|
1296
|
+
scopes=scopes,
|
|
1297
|
+
jti=new_refresh_jti,
|
|
1298
|
+
expires_in=new_refresh_expires_in
|
|
1299
|
+
or 60 * 60 * 24 * 30, # Fallback to 30 days
|
|
1300
|
+
)
|
|
751
1301
|
|
|
752
|
-
|
|
1302
|
+
# Store new refresh token JTI mapping with aligned expiry
|
|
1303
|
+
refresh_ttl = new_refresh_expires_in or 60 * 60 * 24 * 30
|
|
1304
|
+
await self._jti_mapping_store.put(
|
|
1305
|
+
key=new_refresh_jti,
|
|
1306
|
+
value=JTIMapping(
|
|
1307
|
+
jti=new_refresh_jti,
|
|
1308
|
+
upstream_token_id=upstream_token_set.upstream_token_id,
|
|
1309
|
+
created_at=time.time(),
|
|
1310
|
+
),
|
|
1311
|
+
ttl=refresh_ttl, # Align with upstream refresh token expiry
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1314
|
+
# Invalidate old refresh token (refresh token rotation - enforces one-time use)
|
|
1315
|
+
await self._jti_mapping_store.delete(key=refresh_jti)
|
|
1316
|
+
logger.debug(
|
|
1317
|
+
"Rotated refresh token (old JTI invalidated - one-time use enforced)"
|
|
1318
|
+
)
|
|
1319
|
+
|
|
1320
|
+
# Update local token tracking
|
|
1321
|
+
self._access_tokens[new_fastmcp_access] = AccessToken(
|
|
1322
|
+
token=new_fastmcp_access,
|
|
1323
|
+
client_id=client.client_id,
|
|
1324
|
+
scopes=scopes,
|
|
1325
|
+
expires_at=int(time.time() + new_expires_in),
|
|
1326
|
+
)
|
|
1327
|
+
self._refresh_tokens[new_fastmcp_refresh] = RefreshToken(
|
|
1328
|
+
token=new_fastmcp_refresh,
|
|
1329
|
+
client_id=client.client_id,
|
|
1330
|
+
scopes=scopes,
|
|
1331
|
+
expires_at=None,
|
|
1332
|
+
)
|
|
1333
|
+
|
|
1334
|
+
# Update token relationship mappings
|
|
1335
|
+
self._access_to_refresh[new_fastmcp_access] = new_fastmcp_refresh
|
|
1336
|
+
self._refresh_to_access[new_fastmcp_refresh] = new_fastmcp_access
|
|
1337
|
+
|
|
1338
|
+
# Clean up old token from in-memory tracking
|
|
1339
|
+
self._refresh_tokens.pop(refresh_token.token, None)
|
|
1340
|
+
old_access = self._refresh_to_access.pop(refresh_token.token, None)
|
|
1341
|
+
if old_access:
|
|
1342
|
+
self._access_tokens.pop(old_access, None)
|
|
1343
|
+
self._access_to_refresh.pop(old_access, None)
|
|
1344
|
+
|
|
1345
|
+
logger.info(
|
|
1346
|
+
"Issued new FastMCP tokens (rotated refresh) for client=%s (access_jti=%s, refresh_jti=%s)",
|
|
1347
|
+
client.client_id,
|
|
1348
|
+
new_access_jti[:8],
|
|
1349
|
+
new_refresh_jti[:8],
|
|
1350
|
+
)
|
|
1351
|
+
|
|
1352
|
+
# Return new FastMCP tokens (both access AND refresh are new)
|
|
1353
|
+
return OAuthToken(
|
|
1354
|
+
access_token=new_fastmcp_access,
|
|
1355
|
+
token_type="Bearer",
|
|
1356
|
+
expires_in=new_expires_in,
|
|
1357
|
+
refresh_token=new_fastmcp_refresh, # NEW refresh token (rotated)
|
|
1358
|
+
scope=" ".join(scopes),
|
|
1359
|
+
)
|
|
753
1360
|
|
|
754
1361
|
# -------------------------------------------------------------------------
|
|
755
1362
|
# Token Validation
|
|
756
1363
|
# -------------------------------------------------------------------------
|
|
757
1364
|
|
|
758
1365
|
async def load_access_token(self, token: str) -> AccessToken | None:
|
|
759
|
-
"""Validate
|
|
1366
|
+
"""Validate FastMCP JWT by swapping for upstream token.
|
|
1367
|
+
|
|
1368
|
+
This implements the token swap pattern:
|
|
1369
|
+
1. Verify FastMCP JWT signature (proves it's our token)
|
|
1370
|
+
2. Look up upstream token via JTI mapping
|
|
1371
|
+
3. Decrypt upstream token
|
|
1372
|
+
4. Validate upstream token with provider (GitHub API, JWT validation, etc.)
|
|
1373
|
+
5. Return upstream validation result
|
|
760
1374
|
|
|
761
|
-
|
|
762
|
-
|
|
1375
|
+
The FastMCP JWT is a reference token - all authorization data comes
|
|
1376
|
+
from validating the upstream token via the TokenVerifier.
|
|
763
1377
|
"""
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
1378
|
+
# Ensure JWT issuer and encryption are initialized
|
|
1379
|
+
await self._ensure_jwt_initialized()
|
|
1380
|
+
assert self._jwt_issuer is not None
|
|
1381
|
+
assert self._token_encryption is not None
|
|
1382
|
+
|
|
1383
|
+
try:
|
|
1384
|
+
# 1. Verify FastMCP JWT signature and claims
|
|
1385
|
+
payload = self._jwt_issuer.verify_token(token)
|
|
1386
|
+
jti = payload["jti"]
|
|
1387
|
+
|
|
1388
|
+
# 2. Look up upstream token via JTI mapping
|
|
1389
|
+
jti_mapping = await self._jti_mapping_store.get(key=jti)
|
|
1390
|
+
if not jti_mapping:
|
|
1391
|
+
logger.debug("JTI mapping not found: %s", jti)
|
|
1392
|
+
return None
|
|
1393
|
+
|
|
1394
|
+
upstream_token_set = await self._upstream_token_store.get(
|
|
1395
|
+
key=jti_mapping.upstream_token_id
|
|
1396
|
+
)
|
|
1397
|
+
if not upstream_token_set:
|
|
1398
|
+
logger.debug(
|
|
1399
|
+
"Upstream token not found: %s", jti_mapping.upstream_token_id
|
|
1400
|
+
)
|
|
1401
|
+
return None
|
|
1402
|
+
|
|
1403
|
+
# 3. Decrypt upstream token
|
|
1404
|
+
upstream_token = self._token_encryption.decrypt(
|
|
1405
|
+
upstream_token_set.access_token
|
|
1406
|
+
)
|
|
1407
|
+
|
|
1408
|
+
# 4. Validate with upstream provider (delegated to TokenVerifier)
|
|
1409
|
+
# This calls the real token validator (GitHub API, JWKS, etc.)
|
|
1410
|
+
validated = await self._token_validator.verify_token(upstream_token)
|
|
1411
|
+
|
|
1412
|
+
if not validated:
|
|
1413
|
+
logger.debug("Upstream token validation failed")
|
|
1414
|
+
return None
|
|
1415
|
+
|
|
1416
|
+
logger.debug(
|
|
1417
|
+
"Token swap successful for JTI=%s (upstream validated)", jti[:8]
|
|
1418
|
+
)
|
|
1419
|
+
return validated
|
|
1420
|
+
|
|
1421
|
+
except Exception as e:
|
|
1422
|
+
logger.debug("Token swap validation failed: %s", e)
|
|
1423
|
+
return None
|
|
770
1424
|
|
|
771
1425
|
# -------------------------------------------------------------------------
|
|
772
1426
|
# Token Revocation
|
|
@@ -819,7 +1473,6 @@ class OAuthProxy(OAuthProvider):
|
|
|
819
1473
|
def get_routes(
|
|
820
1474
|
self,
|
|
821
1475
|
mcp_path: str | None = None,
|
|
822
|
-
mcp_endpoint: Any | None = None,
|
|
823
1476
|
) -> list[Route]:
|
|
824
1477
|
"""Get OAuth routes with custom proxy token handler.
|
|
825
1478
|
|
|
@@ -828,10 +1481,10 @@ class OAuthProxy(OAuthProvider):
|
|
|
828
1481
|
|
|
829
1482
|
Args:
|
|
830
1483
|
mcp_path: The path where the MCP endpoint is mounted (e.g., "/mcp")
|
|
831
|
-
|
|
1484
|
+
This is used to advertise the resource URL in metadata.
|
|
832
1485
|
"""
|
|
833
1486
|
# Get standard OAuth routes from parent class
|
|
834
|
-
routes = super().get_routes(mcp_path
|
|
1487
|
+
routes = super().get_routes(mcp_path)
|
|
835
1488
|
custom_routes = []
|
|
836
1489
|
token_route_found = False
|
|
837
1490
|
|
|
@@ -844,9 +1497,7 @@ class OAuthProxy(OAuthProvider):
|
|
|
844
1497
|
f"Route {i}: {route} - path: {getattr(route, 'path', 'N/A')}, methods: {getattr(route, 'methods', 'N/A')}"
|
|
845
1498
|
)
|
|
846
1499
|
|
|
847
|
-
#
|
|
848
|
-
custom_routes.append(route)
|
|
849
|
-
|
|
1500
|
+
# Replace the token endpoint with our custom handler that returns proper OAuth 2.1 error codes
|
|
850
1501
|
if (
|
|
851
1502
|
isinstance(route, Route)
|
|
852
1503
|
and route.path == "/token"
|
|
@@ -854,6 +1505,22 @@ class OAuthProxy(OAuthProvider):
|
|
|
854
1505
|
and "POST" in route.methods
|
|
855
1506
|
):
|
|
856
1507
|
token_route_found = True
|
|
1508
|
+
# Replace with our OAuth 2.1 compliant token handler
|
|
1509
|
+
token_handler = TokenHandler(
|
|
1510
|
+
provider=self, client_authenticator=ClientAuthenticator(self)
|
|
1511
|
+
)
|
|
1512
|
+
custom_routes.append(
|
|
1513
|
+
Route(
|
|
1514
|
+
path="/token",
|
|
1515
|
+
endpoint=cors_middleware(
|
|
1516
|
+
token_handler.handle, ["POST", "OPTIONS"]
|
|
1517
|
+
),
|
|
1518
|
+
methods=["POST", "OPTIONS"],
|
|
1519
|
+
)
|
|
1520
|
+
)
|
|
1521
|
+
else:
|
|
1522
|
+
# Keep all other standard OAuth routes unchanged
|
|
1523
|
+
custom_routes.append(route)
|
|
857
1524
|
|
|
858
1525
|
# Add OAuth callback endpoint for forwarding to client callbacks
|
|
859
1526
|
custom_routes.append(
|
|
@@ -864,8 +1531,18 @@ class OAuthProxy(OAuthProvider):
|
|
|
864
1531
|
)
|
|
865
1532
|
)
|
|
866
1533
|
|
|
1534
|
+
# Add consent endpoints
|
|
1535
|
+
custom_routes.append(
|
|
1536
|
+
Route(path="/consent", endpoint=self._show_consent_page, methods=["GET"])
|
|
1537
|
+
)
|
|
1538
|
+
custom_routes.append(
|
|
1539
|
+
Route(
|
|
1540
|
+
path="/consent/submit", endpoint=self._submit_consent, methods=["POST"]
|
|
1541
|
+
)
|
|
1542
|
+
)
|
|
1543
|
+
|
|
867
1544
|
logger.debug(
|
|
868
|
-
f"✅ OAuth routes configured: token_endpoint={token_route_found}, total routes={len(custom_routes)} (includes OAuth callback)"
|
|
1545
|
+
f"✅ OAuth routes configured: token_endpoint={token_route_found}, total routes={len(custom_routes)} (includes OAuth callback + consent)"
|
|
869
1546
|
)
|
|
870
1547
|
return custom_routes
|
|
871
1548
|
|
|
@@ -907,13 +1584,14 @@ class OAuthProxy(OAuthProvider):
|
|
|
907
1584
|
)
|
|
908
1585
|
|
|
909
1586
|
# Look up transaction data
|
|
910
|
-
|
|
911
|
-
if not
|
|
1587
|
+
transaction_model = await self._transaction_store.get(key=txn_id)
|
|
1588
|
+
if not transaction_model:
|
|
912
1589
|
logger.error("IdP callback with invalid transaction ID: %s", txn_id)
|
|
913
1590
|
return RedirectResponse(
|
|
914
1591
|
url="data:text/html,<h1>OAuth Error</h1><p>Invalid or expired transaction</p>",
|
|
915
1592
|
status_code=302,
|
|
916
1593
|
)
|
|
1594
|
+
transaction = transaction_model.model_dump()
|
|
917
1595
|
|
|
918
1596
|
# Exchange IdP code for tokens (server-side)
|
|
919
1597
|
oauth_client = AsyncOAuth2Client(
|
|
@@ -977,19 +1655,24 @@ class OAuthProxy(OAuthProvider):
|
|
|
977
1655
|
code_expires_at = int(time.time() + DEFAULT_AUTH_CODE_EXPIRY_SECONDS)
|
|
978
1656
|
|
|
979
1657
|
# Store client code with PKCE challenge and IdP tokens
|
|
980
|
-
self.
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
1658
|
+
await self._code_store.put(
|
|
1659
|
+
key=client_code,
|
|
1660
|
+
value=ClientCode(
|
|
1661
|
+
code=client_code,
|
|
1662
|
+
client_id=transaction["client_id"],
|
|
1663
|
+
redirect_uri=transaction["client_redirect_uri"],
|
|
1664
|
+
code_challenge=transaction["code_challenge"],
|
|
1665
|
+
code_challenge_method=transaction["code_challenge_method"],
|
|
1666
|
+
scopes=transaction["scopes"],
|
|
1667
|
+
idp_tokens=idp_tokens,
|
|
1668
|
+
expires_at=code_expires_at,
|
|
1669
|
+
created_at=time.time(),
|
|
1670
|
+
),
|
|
1671
|
+
ttl=DEFAULT_AUTH_CODE_EXPIRY_SECONDS, # Auto-expire after 5 minutes
|
|
1672
|
+
)
|
|
990
1673
|
|
|
991
1674
|
# Clean up transaction
|
|
992
|
-
self.
|
|
1675
|
+
await self._transaction_store.delete(key=txn_id)
|
|
993
1676
|
|
|
994
1677
|
# Build client callback URL with our code and original state
|
|
995
1678
|
client_redirect_uri = transaction["client_redirect_uri"]
|
|
@@ -1016,3 +1699,315 @@ class OAuthProxy(OAuthProvider):
|
|
|
1016
1699
|
url="data:text/html,<h1>OAuth Error</h1><p>Internal server error during IdP callback</p>",
|
|
1017
1700
|
status_code=302,
|
|
1018
1701
|
)
|
|
1702
|
+
|
|
1703
|
+
# -------------------------------------------------------------------------
|
|
1704
|
+
# Consent Interstitial
|
|
1705
|
+
# -------------------------------------------------------------------------
|
|
1706
|
+
|
|
1707
|
+
def _normalize_uri(self, uri: str) -> str:
|
|
1708
|
+
"""Normalize a URI to a canonical form for consent tracking."""
|
|
1709
|
+
parsed = urlparse(uri)
|
|
1710
|
+
path = parsed.path or ""
|
|
1711
|
+
normalized = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}{path}"
|
|
1712
|
+
if normalized.endswith("/") and len(path) > 1:
|
|
1713
|
+
normalized = normalized[:-1]
|
|
1714
|
+
return normalized
|
|
1715
|
+
|
|
1716
|
+
def _make_client_key(self, client_id: str, redirect_uri: str | AnyUrl) -> str:
|
|
1717
|
+
"""Create a stable key for consent tracking from client_id and redirect_uri."""
|
|
1718
|
+
normalized = self._normalize_uri(str(redirect_uri))
|
|
1719
|
+
return f"{client_id}:{normalized}"
|
|
1720
|
+
|
|
1721
|
+
def _cookie_name(self, base_name: str) -> str:
|
|
1722
|
+
"""Return secure cookie name for HTTPS, fallback for HTTP development."""
|
|
1723
|
+
if self._is_https:
|
|
1724
|
+
return f"__Host-{base_name}"
|
|
1725
|
+
return f"__{base_name}"
|
|
1726
|
+
|
|
1727
|
+
def _sign_cookie(self, payload: str) -> str:
|
|
1728
|
+
"""Sign a cookie payload with HMAC-SHA256.
|
|
1729
|
+
|
|
1730
|
+
Returns: base64(payload).base64(signature)
|
|
1731
|
+
"""
|
|
1732
|
+
# Use upstream client secret as signing key
|
|
1733
|
+
key = self._upstream_client_secret.get_secret_value().encode()
|
|
1734
|
+
signature = hmac.new(key, payload.encode(), hashlib.sha256).digest()
|
|
1735
|
+
signature_b64 = base64.b64encode(signature).decode()
|
|
1736
|
+
return f"{payload}.{signature_b64}"
|
|
1737
|
+
|
|
1738
|
+
def _verify_cookie(self, signed_value: str) -> str | None:
|
|
1739
|
+
"""Verify and extract payload from signed cookie.
|
|
1740
|
+
|
|
1741
|
+
Returns: payload if signature valid, None otherwise
|
|
1742
|
+
"""
|
|
1743
|
+
try:
|
|
1744
|
+
if "." not in signed_value:
|
|
1745
|
+
return None
|
|
1746
|
+
payload, signature_b64 = signed_value.rsplit(".", 1)
|
|
1747
|
+
|
|
1748
|
+
# Verify signature
|
|
1749
|
+
key = self._upstream_client_secret.get_secret_value().encode()
|
|
1750
|
+
expected_sig = hmac.new(key, payload.encode(), hashlib.sha256).digest()
|
|
1751
|
+
provided_sig = base64.b64decode(signature_b64.encode())
|
|
1752
|
+
|
|
1753
|
+
# Constant-time comparison
|
|
1754
|
+
if not hmac.compare_digest(expected_sig, provided_sig):
|
|
1755
|
+
return None
|
|
1756
|
+
|
|
1757
|
+
return payload
|
|
1758
|
+
except Exception:
|
|
1759
|
+
return None
|
|
1760
|
+
|
|
1761
|
+
def _decode_list_cookie(self, request: Request, base_name: str) -> list[str]:
|
|
1762
|
+
"""Decode and verify a signed base64-encoded JSON list from cookie. Returns [] if missing/invalid."""
|
|
1763
|
+
# Prefer secure name, but also check non-secure variant for dev
|
|
1764
|
+
secure_name = self._cookie_name(base_name)
|
|
1765
|
+
raw = request.cookies.get(secure_name) or request.cookies.get(f"__{base_name}")
|
|
1766
|
+
if not raw:
|
|
1767
|
+
return []
|
|
1768
|
+
try:
|
|
1769
|
+
# Verify signature
|
|
1770
|
+
payload = self._verify_cookie(raw)
|
|
1771
|
+
if not payload:
|
|
1772
|
+
logger.debug("Cookie signature verification failed for %s", secure_name)
|
|
1773
|
+
return []
|
|
1774
|
+
|
|
1775
|
+
# Decode payload
|
|
1776
|
+
data = base64.b64decode(payload.encode())
|
|
1777
|
+
value = json.loads(data.decode())
|
|
1778
|
+
if isinstance(value, list):
|
|
1779
|
+
return [str(x) for x in value]
|
|
1780
|
+
except Exception:
|
|
1781
|
+
logger.debug("Failed to decode cookie %s; treating as empty", secure_name)
|
|
1782
|
+
return []
|
|
1783
|
+
|
|
1784
|
+
def _encode_list_cookie(self, values: list[str]) -> str:
|
|
1785
|
+
"""Encode values to base64 and sign with HMAC.
|
|
1786
|
+
|
|
1787
|
+
Returns: signed cookie value (payload.signature)
|
|
1788
|
+
"""
|
|
1789
|
+
payload = json.dumps(values, separators=(",", ":")).encode()
|
|
1790
|
+
payload_b64 = base64.b64encode(payload).decode()
|
|
1791
|
+
return self._sign_cookie(payload_b64)
|
|
1792
|
+
|
|
1793
|
+
def _set_list_cookie(
|
|
1794
|
+
self,
|
|
1795
|
+
response: HTMLResponse | RedirectResponse,
|
|
1796
|
+
base_name: str,
|
|
1797
|
+
value_b64: str,
|
|
1798
|
+
max_age: int,
|
|
1799
|
+
) -> None:
|
|
1800
|
+
name = self._cookie_name(base_name)
|
|
1801
|
+
response.set_cookie(
|
|
1802
|
+
name,
|
|
1803
|
+
value_b64,
|
|
1804
|
+
max_age=max_age,
|
|
1805
|
+
secure=self._is_https,
|
|
1806
|
+
httponly=True,
|
|
1807
|
+
samesite="lax",
|
|
1808
|
+
path="/",
|
|
1809
|
+
)
|
|
1810
|
+
|
|
1811
|
+
def _build_upstream_authorize_url(
|
|
1812
|
+
self, txn_id: str, transaction: dict[str, Any]
|
|
1813
|
+
) -> str:
|
|
1814
|
+
"""Construct the upstream IdP authorization URL using stored transaction data."""
|
|
1815
|
+
query_params: dict[str, Any] = {
|
|
1816
|
+
"response_type": "code",
|
|
1817
|
+
"client_id": self._upstream_client_id,
|
|
1818
|
+
"redirect_uri": f"{str(self.base_url).rstrip('/')}{self._redirect_path}",
|
|
1819
|
+
"state": txn_id,
|
|
1820
|
+
}
|
|
1821
|
+
|
|
1822
|
+
scopes_to_use = transaction.get("scopes") or self.required_scopes or []
|
|
1823
|
+
if scopes_to_use:
|
|
1824
|
+
query_params["scope"] = " ".join(scopes_to_use)
|
|
1825
|
+
|
|
1826
|
+
# If PKCE forwarding was enabled, include the proxy challenge
|
|
1827
|
+
proxy_code_verifier = transaction.get("proxy_code_verifier")
|
|
1828
|
+
if proxy_code_verifier:
|
|
1829
|
+
challenge_bytes = hashlib.sha256(proxy_code_verifier.encode()).digest()
|
|
1830
|
+
proxy_code_challenge = (
|
|
1831
|
+
urlsafe_b64encode(challenge_bytes).decode().rstrip("=")
|
|
1832
|
+
)
|
|
1833
|
+
query_params["code_challenge"] = proxy_code_challenge
|
|
1834
|
+
query_params["code_challenge_method"] = "S256"
|
|
1835
|
+
|
|
1836
|
+
# Forward resource indicator if present in transaction
|
|
1837
|
+
if resource := transaction.get("resource"):
|
|
1838
|
+
query_params["resource"] = resource
|
|
1839
|
+
|
|
1840
|
+
# Extra configured parameters
|
|
1841
|
+
if self._extra_authorize_params:
|
|
1842
|
+
query_params.update(self._extra_authorize_params)
|
|
1843
|
+
|
|
1844
|
+
separator = "&" if "?" in self._upstream_authorization_endpoint else "?"
|
|
1845
|
+
return f"{self._upstream_authorization_endpoint}{separator}{urlencode(query_params)}"
|
|
1846
|
+
|
|
1847
|
+
async def _show_consent_page(
|
|
1848
|
+
self, request: Request
|
|
1849
|
+
) -> HTMLResponse | RedirectResponse:
|
|
1850
|
+
"""Display consent page or auto-approve/deny based on cookies."""
|
|
1851
|
+
from fastmcp.server.server import FastMCP
|
|
1852
|
+
|
|
1853
|
+
txn_id = request.query_params.get("txn_id")
|
|
1854
|
+
if not txn_id:
|
|
1855
|
+
return create_secure_html_response(
|
|
1856
|
+
"<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
|
|
1857
|
+
)
|
|
1858
|
+
|
|
1859
|
+
txn_model = await self._transaction_store.get(key=txn_id)
|
|
1860
|
+
if not txn_model:
|
|
1861
|
+
return create_secure_html_response(
|
|
1862
|
+
"<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
|
|
1863
|
+
)
|
|
1864
|
+
|
|
1865
|
+
txn = txn_model.model_dump()
|
|
1866
|
+
client_key = self._make_client_key(txn["client_id"], txn["client_redirect_uri"])
|
|
1867
|
+
|
|
1868
|
+
approved = set(self._decode_list_cookie(request, "MCP_APPROVED_CLIENTS"))
|
|
1869
|
+
denied = set(self._decode_list_cookie(request, "MCP_DENIED_CLIENTS"))
|
|
1870
|
+
|
|
1871
|
+
if client_key in approved:
|
|
1872
|
+
upstream_url = self._build_upstream_authorize_url(txn_id, txn)
|
|
1873
|
+
return RedirectResponse(url=upstream_url, status_code=302)
|
|
1874
|
+
|
|
1875
|
+
if client_key in denied:
|
|
1876
|
+
callback_params = {
|
|
1877
|
+
"error": "access_denied",
|
|
1878
|
+
"state": txn.get("client_state") or "",
|
|
1879
|
+
}
|
|
1880
|
+
sep = "&" if "?" in txn["client_redirect_uri"] else "?"
|
|
1881
|
+
return RedirectResponse(
|
|
1882
|
+
url=f"{txn['client_redirect_uri']}{sep}{urlencode(callback_params)}",
|
|
1883
|
+
status_code=302,
|
|
1884
|
+
)
|
|
1885
|
+
|
|
1886
|
+
# Need consent: issue CSRF token and show HTML
|
|
1887
|
+
csrf_token = secrets.token_urlsafe(32)
|
|
1888
|
+
csrf_expires_at = time.time() + 15 * 60
|
|
1889
|
+
|
|
1890
|
+
# Update transaction with CSRF token
|
|
1891
|
+
txn_model.csrf_token = csrf_token
|
|
1892
|
+
txn_model.csrf_expires_at = csrf_expires_at
|
|
1893
|
+
await self._transaction_store.put(
|
|
1894
|
+
key=txn_id, value=txn_model, ttl=15 * 60
|
|
1895
|
+
) # Auto-expire after 15 minutes
|
|
1896
|
+
|
|
1897
|
+
# Update dict for use in HTML generation
|
|
1898
|
+
txn["csrf_token"] = csrf_token
|
|
1899
|
+
txn["csrf_expires_at"] = csrf_expires_at
|
|
1900
|
+
|
|
1901
|
+
# Load client to get client_name if available
|
|
1902
|
+
client = await self.get_client(txn["client_id"])
|
|
1903
|
+
client_name = getattr(client, "client_name", None) if client else None
|
|
1904
|
+
|
|
1905
|
+
# Extract server metadata from app state
|
|
1906
|
+
fastmcp = getattr(request.app.state, "fastmcp_server", None)
|
|
1907
|
+
|
|
1908
|
+
if isinstance(fastmcp, FastMCP):
|
|
1909
|
+
server_name = fastmcp.name
|
|
1910
|
+
icons = fastmcp.icons
|
|
1911
|
+
server_icon_url = icons[0].src if icons else None
|
|
1912
|
+
server_website_url = fastmcp.website_url
|
|
1913
|
+
else:
|
|
1914
|
+
server_name = None
|
|
1915
|
+
server_icon_url = None
|
|
1916
|
+
server_website_url = None
|
|
1917
|
+
|
|
1918
|
+
html = create_consent_html(
|
|
1919
|
+
client_id=txn["client_id"],
|
|
1920
|
+
redirect_uri=txn["client_redirect_uri"],
|
|
1921
|
+
scopes=txn.get("scopes") or [],
|
|
1922
|
+
txn_id=txn_id,
|
|
1923
|
+
csrf_token=csrf_token,
|
|
1924
|
+
client_name=client_name,
|
|
1925
|
+
server_name=server_name,
|
|
1926
|
+
server_icon_url=server_icon_url,
|
|
1927
|
+
server_website_url=server_website_url,
|
|
1928
|
+
)
|
|
1929
|
+
response = create_secure_html_response(html)
|
|
1930
|
+
# Store CSRF in cookie with short lifetime
|
|
1931
|
+
self._set_list_cookie(
|
|
1932
|
+
response,
|
|
1933
|
+
"MCP_CONSENT_STATE",
|
|
1934
|
+
self._encode_list_cookie([csrf_token]),
|
|
1935
|
+
max_age=15 * 60,
|
|
1936
|
+
)
|
|
1937
|
+
return response
|
|
1938
|
+
|
|
1939
|
+
async def _submit_consent(
|
|
1940
|
+
self, request: Request
|
|
1941
|
+
) -> RedirectResponse | HTMLResponse:
|
|
1942
|
+
"""Handle consent approval/denial, set cookies, and redirect appropriately."""
|
|
1943
|
+
form = await request.form()
|
|
1944
|
+
txn_id = str(form.get("txn_id", ""))
|
|
1945
|
+
action = str(form.get("action", ""))
|
|
1946
|
+
csrf_token = str(form.get("csrf_token", ""))
|
|
1947
|
+
|
|
1948
|
+
if not txn_id:
|
|
1949
|
+
return create_secure_html_response(
|
|
1950
|
+
"<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
|
|
1951
|
+
)
|
|
1952
|
+
|
|
1953
|
+
txn_model = await self._transaction_store.get(key=txn_id)
|
|
1954
|
+
if not txn_model:
|
|
1955
|
+
return create_secure_html_response(
|
|
1956
|
+
"<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
|
|
1957
|
+
)
|
|
1958
|
+
|
|
1959
|
+
txn = txn_model.model_dump()
|
|
1960
|
+
expected_csrf = txn.get("csrf_token")
|
|
1961
|
+
expires_at = float(txn.get("csrf_expires_at") or 0)
|
|
1962
|
+
|
|
1963
|
+
if not expected_csrf or csrf_token != expected_csrf or time.time() > expires_at:
|
|
1964
|
+
return create_secure_html_response(
|
|
1965
|
+
"<h1>Error</h1><p>Invalid or expired consent token</p>", status_code=400
|
|
1966
|
+
)
|
|
1967
|
+
|
|
1968
|
+
client_key = self._make_client_key(txn["client_id"], txn["client_redirect_uri"])
|
|
1969
|
+
|
|
1970
|
+
if action == "approve":
|
|
1971
|
+
approved = set(self._decode_list_cookie(request, "MCP_APPROVED_CLIENTS"))
|
|
1972
|
+
if client_key not in approved:
|
|
1973
|
+
approved.add(client_key)
|
|
1974
|
+
approved_b64 = self._encode_list_cookie(sorted(approved))
|
|
1975
|
+
|
|
1976
|
+
upstream_url = self._build_upstream_authorize_url(txn_id, txn)
|
|
1977
|
+
response = RedirectResponse(url=upstream_url, status_code=302)
|
|
1978
|
+
self._set_list_cookie(
|
|
1979
|
+
response, "MCP_APPROVED_CLIENTS", approved_b64, max_age=365 * 24 * 3600
|
|
1980
|
+
)
|
|
1981
|
+
# Clear CSRF cookie by setting empty short-lived value
|
|
1982
|
+
self._set_list_cookie(
|
|
1983
|
+
response, "MCP_CONSENT_STATE", self._encode_list_cookie([]), max_age=60
|
|
1984
|
+
)
|
|
1985
|
+
return response
|
|
1986
|
+
|
|
1987
|
+
elif action == "deny":
|
|
1988
|
+
denied = set(self._decode_list_cookie(request, "MCP_DENIED_CLIENTS"))
|
|
1989
|
+
if client_key not in denied:
|
|
1990
|
+
denied.add(client_key)
|
|
1991
|
+
denied_b64 = self._encode_list_cookie(sorted(denied))
|
|
1992
|
+
|
|
1993
|
+
callback_params = {
|
|
1994
|
+
"error": "access_denied",
|
|
1995
|
+
"state": txn.get("client_state") or "",
|
|
1996
|
+
}
|
|
1997
|
+
sep = "&" if "?" in txn["client_redirect_uri"] else "?"
|
|
1998
|
+
client_callback_url = (
|
|
1999
|
+
f"{txn['client_redirect_uri']}{sep}{urlencode(callback_params)}"
|
|
2000
|
+
)
|
|
2001
|
+
response = RedirectResponse(url=client_callback_url, status_code=302)
|
|
2002
|
+
self._set_list_cookie(
|
|
2003
|
+
response, "MCP_DENIED_CLIENTS", denied_b64, max_age=365 * 24 * 3600
|
|
2004
|
+
)
|
|
2005
|
+
self._set_list_cookie(
|
|
2006
|
+
response, "MCP_CONSENT_STATE", self._encode_list_cookie([]), max_age=60
|
|
2007
|
+
)
|
|
2008
|
+
return response
|
|
2009
|
+
|
|
2010
|
+
else:
|
|
2011
|
+
return create_secure_html_response(
|
|
2012
|
+
"<h1>Error</h1><p>Invalid action</p>", status_code=400
|
|
2013
|
+
)
|