fastmcp 2.12.4__py3-none-any.whl → 2.13.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastmcp/cli/cli.py +6 -6
- fastmcp/cli/install/claude_code.py +3 -3
- fastmcp/cli/install/claude_desktop.py +3 -3
- fastmcp/cli/install/cursor.py +7 -7
- fastmcp/cli/install/gemini_cli.py +3 -3
- fastmcp/cli/install/mcp_json.py +3 -3
- fastmcp/cli/run.py +13 -8
- fastmcp/client/auth/oauth.py +100 -208
- fastmcp/client/client.py +11 -11
- fastmcp/client/logging.py +18 -14
- fastmcp/client/oauth_callback.py +81 -171
- fastmcp/client/transports.py +76 -22
- fastmcp/contrib/component_manager/component_service.py +6 -6
- fastmcp/contrib/mcp_mixin/README.md +32 -1
- fastmcp/contrib/mcp_mixin/mcp_mixin.py +14 -2
- fastmcp/experimental/utilities/openapi/json_schema_converter.py +4 -0
- fastmcp/experimental/utilities/openapi/parser.py +23 -3
- fastmcp/prompts/prompt.py +13 -6
- fastmcp/prompts/prompt_manager.py +16 -101
- fastmcp/resources/resource.py +13 -6
- fastmcp/resources/resource_manager.py +5 -164
- fastmcp/resources/template.py +107 -17
- fastmcp/server/auth/auth.py +40 -32
- fastmcp/server/auth/jwt_issuer.py +289 -0
- fastmcp/server/auth/oauth_proxy.py +1238 -234
- fastmcp/server/auth/oidc_proxy.py +8 -6
- fastmcp/server/auth/providers/auth0.py +12 -6
- fastmcp/server/auth/providers/aws.py +13 -2
- fastmcp/server/auth/providers/azure.py +137 -124
- fastmcp/server/auth/providers/descope.py +4 -6
- fastmcp/server/auth/providers/github.py +13 -7
- fastmcp/server/auth/providers/google.py +13 -7
- fastmcp/server/auth/providers/introspection.py +281 -0
- fastmcp/server/auth/providers/jwt.py +8 -2
- fastmcp/server/auth/providers/scalekit.py +179 -0
- fastmcp/server/auth/providers/supabase.py +172 -0
- fastmcp/server/auth/providers/workos.py +16 -13
- fastmcp/server/context.py +89 -34
- fastmcp/server/http.py +53 -16
- fastmcp/server/low_level.py +121 -2
- fastmcp/server/middleware/caching.py +469 -0
- fastmcp/server/middleware/error_handling.py +6 -2
- fastmcp/server/middleware/logging.py +48 -37
- fastmcp/server/middleware/middleware.py +28 -15
- fastmcp/server/middleware/rate_limiting.py +3 -3
- fastmcp/server/proxy.py +6 -6
- fastmcp/server/server.py +638 -183
- fastmcp/settings.py +22 -9
- fastmcp/tools/tool.py +7 -3
- fastmcp/tools/tool_manager.py +22 -108
- fastmcp/tools/tool_transform.py +3 -3
- fastmcp/utilities/cli.py +2 -2
- fastmcp/utilities/components.py +5 -0
- fastmcp/utilities/inspect.py +77 -21
- fastmcp/utilities/logging.py +118 -8
- fastmcp/utilities/mcp_server_config/v1/environments/uv.py +6 -6
- fastmcp/utilities/mcp_server_config/v1/mcp_server_config.py +3 -3
- fastmcp/utilities/mcp_server_config/v1/schema.json +3 -0
- fastmcp/utilities/tests.py +87 -4
- fastmcp/utilities/types.py +1 -1
- fastmcp/utilities/ui.py +497 -0
- {fastmcp-2.12.4.dist-info → fastmcp-2.13.0rc1.dist-info}/METADATA +8 -4
- {fastmcp-2.12.4.dist-info → fastmcp-2.13.0rc1.dist-info}/RECORD +66 -62
- fastmcp/cli/claude.py +0 -135
- fastmcp/utilities/storage.py +0 -204
- {fastmcp-2.12.4.dist-info → fastmcp-2.13.0rc1.dist-info}/WHEEL +0 -0
- {fastmcp-2.12.4.dist-info → fastmcp-2.13.0rc1.dist-info}/entry_points.txt +0 -0
- {fastmcp-2.12.4.dist-info → fastmcp-2.13.0rc1.dist-info}/licenses/LICENSE +0 -0
|
@@ -18,16 +18,26 @@ production use with enterprise identity providers.
|
|
|
18
18
|
|
|
19
19
|
from __future__ import annotations
|
|
20
20
|
|
|
21
|
+
import base64
|
|
21
22
|
import hashlib
|
|
23
|
+
import hmac
|
|
24
|
+
import json
|
|
22
25
|
import secrets
|
|
23
26
|
import time
|
|
24
27
|
from base64 import urlsafe_b64encode
|
|
25
28
|
from typing import TYPE_CHECKING, Any, Final
|
|
26
|
-
from urllib.parse import urlencode
|
|
29
|
+
from urllib.parse import urlencode, urlparse
|
|
27
30
|
|
|
28
31
|
import httpx
|
|
29
32
|
from authlib.common.security import generate_token
|
|
30
33
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
34
|
+
from key_value.aio.adapters.pydantic import PydanticAdapter
|
|
35
|
+
from key_value.aio.protocols import AsyncKeyValue
|
|
36
|
+
from key_value.aio.stores.memory import MemoryStore
|
|
37
|
+
from mcp.server.auth.handlers.token import TokenErrorResponse, TokenSuccessResponse
|
|
38
|
+
from mcp.server.auth.handlers.token import TokenHandler as _SDKTokenHandler
|
|
39
|
+
from mcp.server.auth.json_response import PydanticJSONResponse
|
|
40
|
+
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
|
|
31
41
|
from mcp.server.auth.provider import (
|
|
32
42
|
AccessToken,
|
|
33
43
|
AuthorizationCode,
|
|
@@ -35,21 +45,36 @@ from mcp.server.auth.provider import (
|
|
|
35
45
|
RefreshToken,
|
|
36
46
|
TokenError,
|
|
37
47
|
)
|
|
48
|
+
from mcp.server.auth.routes import cors_middleware
|
|
38
49
|
from mcp.server.auth.settings import (
|
|
39
50
|
ClientRegistrationOptions,
|
|
40
51
|
RevocationOptions,
|
|
41
52
|
)
|
|
42
53
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
|
43
|
-
from pydantic import AnyHttpUrl, AnyUrl, SecretStr
|
|
54
|
+
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, SecretStr
|
|
44
55
|
from starlette.requests import Request
|
|
45
|
-
from starlette.responses import RedirectResponse
|
|
56
|
+
from starlette.responses import HTMLResponse, RedirectResponse
|
|
46
57
|
from starlette.routing import Route
|
|
47
58
|
|
|
48
|
-
import fastmcp
|
|
49
59
|
from fastmcp.server.auth.auth import OAuthProvider, TokenVerifier
|
|
50
|
-
from fastmcp.server.auth.
|
|
60
|
+
from fastmcp.server.auth.jwt_issuer import (
|
|
61
|
+
JWTIssuer,
|
|
62
|
+
TokenEncryption,
|
|
63
|
+
)
|
|
64
|
+
from fastmcp.server.auth.redirect_validation import (
|
|
65
|
+
validate_redirect_uri,
|
|
66
|
+
)
|
|
51
67
|
from fastmcp.utilities.logging import get_logger
|
|
52
|
-
from fastmcp.utilities.
|
|
68
|
+
from fastmcp.utilities.ui import (
|
|
69
|
+
BUTTON_STYLES,
|
|
70
|
+
DETAIL_BOX_STYLES,
|
|
71
|
+
INFO_BOX_STYLES,
|
|
72
|
+
TOOLTIP_STYLES,
|
|
73
|
+
create_detail_box,
|
|
74
|
+
create_logo,
|
|
75
|
+
create_page,
|
|
76
|
+
create_secure_html_response,
|
|
77
|
+
)
|
|
53
78
|
|
|
54
79
|
if TYPE_CHECKING:
|
|
55
80
|
pass
|
|
@@ -57,6 +82,95 @@ if TYPE_CHECKING:
|
|
|
57
82
|
logger = get_logger(__name__)
|
|
58
83
|
|
|
59
84
|
|
|
85
|
+
# -------------------------------------------------------------------------
|
|
86
|
+
# Constants
|
|
87
|
+
# -------------------------------------------------------------------------
|
|
88
|
+
|
|
89
|
+
# Default token expiration times
|
|
90
|
+
DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS: Final[int] = 60 * 60 # 1 hour
|
|
91
|
+
DEFAULT_AUTH_CODE_EXPIRY_SECONDS: Final[int] = 5 * 60 # 5 minutes
|
|
92
|
+
|
|
93
|
+
# HTTP client timeout
|
|
94
|
+
HTTP_TIMEOUT_SECONDS: Final[int] = 30
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# -------------------------------------------------------------------------
|
|
98
|
+
# Pydantic Models
|
|
99
|
+
# -------------------------------------------------------------------------
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class OAuthTransaction(BaseModel):
|
|
103
|
+
"""OAuth transaction state for consent flow.
|
|
104
|
+
|
|
105
|
+
Stored server-side to track active authorization flows with client context.
|
|
106
|
+
Includes CSRF tokens for consent protection per MCP security best practices.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
txn_id: str
|
|
110
|
+
client_id: str
|
|
111
|
+
client_redirect_uri: str
|
|
112
|
+
client_state: str
|
|
113
|
+
code_challenge: str | None
|
|
114
|
+
code_challenge_method: str
|
|
115
|
+
scopes: list[str]
|
|
116
|
+
created_at: float
|
|
117
|
+
resource: str | None = None
|
|
118
|
+
proxy_code_verifier: str | None = None
|
|
119
|
+
csrf_token: str | None = None
|
|
120
|
+
csrf_expires_at: float | None = None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class ClientCode(BaseModel):
|
|
124
|
+
"""Client authorization code with PKCE and upstream tokens.
|
|
125
|
+
|
|
126
|
+
Stored server-side after upstream IdP callback. Contains the upstream
|
|
127
|
+
tokens bound to the client's PKCE challenge for secure token exchange.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
code: str
|
|
131
|
+
client_id: str
|
|
132
|
+
redirect_uri: str
|
|
133
|
+
code_challenge: str | None
|
|
134
|
+
code_challenge_method: str
|
|
135
|
+
scopes: list[str]
|
|
136
|
+
idp_tokens: dict[str, Any]
|
|
137
|
+
expires_at: float
|
|
138
|
+
created_at: float
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class UpstreamTokenSet(BaseModel):
|
|
142
|
+
"""Stored upstream OAuth tokens from identity provider.
|
|
143
|
+
|
|
144
|
+
These tokens are obtained from the upstream provider (Google, GitHub, etc.)
|
|
145
|
+
and are stored encrypted at rest. They are never exposed to MCP clients.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
upstream_token_id: str # Unique ID for this token set
|
|
149
|
+
access_token: bytes # Encrypted upstream access token
|
|
150
|
+
refresh_token: bytes | None # Encrypted upstream refresh token
|
|
151
|
+
refresh_token_expires_at: (
|
|
152
|
+
float | None
|
|
153
|
+
) # Unix timestamp when refresh token expires (if known)
|
|
154
|
+
expires_at: float # Unix timestamp when access token expires
|
|
155
|
+
token_type: str # Usually "Bearer"
|
|
156
|
+
scope: str # Space-separated scopes
|
|
157
|
+
client_id: str # MCP client this is bound to
|
|
158
|
+
created_at: float # Unix timestamp
|
|
159
|
+
raw_token_data: dict[str, Any] = Field(default_factory=dict) # Full token response
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class JTIMapping(BaseModel):
|
|
163
|
+
"""Maps FastMCP token JTI to upstream token ID.
|
|
164
|
+
|
|
165
|
+
This allows stateless JWT validation while still being able to look up
|
|
166
|
+
the corresponding upstream token when tools need to access upstream APIs.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
jti: str # JWT ID from FastMCP-issued token
|
|
170
|
+
upstream_token_id: str # References UpstreamTokenSet
|
|
171
|
+
created_at: float # Unix timestamp
|
|
172
|
+
|
|
173
|
+
|
|
60
174
|
class ProxyDCRClient(OAuthClientInformationFull):
|
|
61
175
|
"""Client for DCR proxy with configurable redirect URI validation.
|
|
62
176
|
|
|
@@ -83,18 +197,8 @@ class ProxyDCRClient(OAuthClientInformationFull):
|
|
|
83
197
|
arise from accepting arbitrary redirect URIs.
|
|
84
198
|
"""
|
|
85
199
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
):
|
|
89
|
-
"""Initialize with allowed redirect URI patterns.
|
|
90
|
-
|
|
91
|
-
Args:
|
|
92
|
-
allowed_redirect_uri_patterns: List of allowed redirect URI patterns with wildcard support.
|
|
93
|
-
If None, defaults to localhost-only patterns.
|
|
94
|
-
If empty list, allows all redirect URIs.
|
|
95
|
-
"""
|
|
96
|
-
super().__init__(*args, **kwargs)
|
|
97
|
-
self._allowed_redirect_uri_patterns = allowed_redirect_uri_patterns
|
|
200
|
+
allowed_redirect_uri_patterns: list[str] | None = Field(default=None)
|
|
201
|
+
client_name: str | None = Field(default=None)
|
|
98
202
|
|
|
99
203
|
def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
|
|
100
204
|
"""Validate redirect URI against allowed patterns.
|
|
@@ -106,7 +210,10 @@ class ProxyDCRClient(OAuthClientInformationFull):
|
|
|
106
210
|
"""
|
|
107
211
|
if redirect_uri is not None:
|
|
108
212
|
# Validate against allowed patterns
|
|
109
|
-
if validate_redirect_uri(
|
|
213
|
+
if validate_redirect_uri(
|
|
214
|
+
redirect_uri=redirect_uri,
|
|
215
|
+
allowed_patterns=self.allowed_redirect_uri_patterns,
|
|
216
|
+
):
|
|
110
217
|
return redirect_uri
|
|
111
218
|
# Fall back to normal validation if not in allowed patterns
|
|
112
219
|
return super().validate_redirect_uri(redirect_uri)
|
|
@@ -114,12 +221,173 @@ class ProxyDCRClient(OAuthClientInformationFull):
|
|
|
114
221
|
return super().validate_redirect_uri(redirect_uri)
|
|
115
222
|
|
|
116
223
|
|
|
117
|
-
#
|
|
118
|
-
|
|
119
|
-
|
|
224
|
+
# -------------------------------------------------------------------------
|
|
225
|
+
# Helper Functions
|
|
226
|
+
# -------------------------------------------------------------------------
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def create_consent_html(
|
|
230
|
+
client_id: str,
|
|
231
|
+
redirect_uri: str,
|
|
232
|
+
scopes: list[str],
|
|
233
|
+
txn_id: str,
|
|
234
|
+
csrf_token: str,
|
|
235
|
+
client_name: str | None = None,
|
|
236
|
+
title: str = "Authorization Consent",
|
|
237
|
+
server_name: str | None = None,
|
|
238
|
+
server_icon_url: str | None = None,
|
|
239
|
+
server_website_url: str | None = None,
|
|
240
|
+
) -> str:
|
|
241
|
+
"""Create a styled HTML consent page for OAuth authorization requests."""
|
|
242
|
+
# Format scopes for display
|
|
243
|
+
scopes_display = ", ".join(scopes) if scopes else "None"
|
|
244
|
+
|
|
245
|
+
# Build warning box with client name if available
|
|
246
|
+
import html as html_module
|
|
247
|
+
|
|
248
|
+
client_display = html_module.escape(client_name or client_id)
|
|
249
|
+
server_name_escaped = html_module.escape(server_name or "FastMCP")
|
|
250
|
+
|
|
251
|
+
# Make server name a hyperlink if website URL is available
|
|
252
|
+
if server_website_url:
|
|
253
|
+
website_url_escaped = html_module.escape(server_website_url)
|
|
254
|
+
server_display = f'<a href="{website_url_escaped}" target="_blank" rel="noopener noreferrer">{server_name_escaped}</a>'
|
|
255
|
+
else:
|
|
256
|
+
server_display = server_name_escaped
|
|
257
|
+
|
|
258
|
+
warning_box = f"""
|
|
259
|
+
<div class="warning-box">
|
|
260
|
+
<p><strong>{client_display}</strong> is requesting access to <strong>{server_display}</strong>.</p>
|
|
261
|
+
<p>Review the details below before approving.</p>
|
|
262
|
+
</div>
|
|
263
|
+
"""
|
|
120
264
|
|
|
121
|
-
#
|
|
122
|
-
|
|
265
|
+
# Build detail box with client information
|
|
266
|
+
detail_rows = []
|
|
267
|
+
if client_name:
|
|
268
|
+
detail_rows.append(("Client Name", client_name))
|
|
269
|
+
detail_rows.extend(
|
|
270
|
+
[
|
|
271
|
+
("Client ID", client_id),
|
|
272
|
+
("Redirect URI", redirect_uri),
|
|
273
|
+
("Requested Scopes", scopes_display),
|
|
274
|
+
]
|
|
275
|
+
)
|
|
276
|
+
detail_box = create_detail_box(detail_rows)
|
|
277
|
+
|
|
278
|
+
# Build form with buttons
|
|
279
|
+
form = f"""
|
|
280
|
+
<form id="consentForm" method="POST" action="/consent/submit">
|
|
281
|
+
<input type="hidden" name="txn_id" value="{txn_id}" />
|
|
282
|
+
<input type="hidden" name="csrf_token" value="{csrf_token}" />
|
|
283
|
+
<div class="button-group">
|
|
284
|
+
<button type="submit" name="action" value="approve" class="btn-approve">Approve</button>
|
|
285
|
+
<button type="submit" name="action" value="deny" class="btn-deny">Deny</button>
|
|
286
|
+
</div>
|
|
287
|
+
</form>
|
|
288
|
+
"""
|
|
289
|
+
|
|
290
|
+
# Build help link with tooltip
|
|
291
|
+
help_link = """
|
|
292
|
+
<div class="help-link-container">
|
|
293
|
+
<span class="help-link">
|
|
294
|
+
Why am I seeing this?
|
|
295
|
+
<span class="tooltip">
|
|
296
|
+
This FastMCP server requires your consent to allow a new client
|
|
297
|
+
to connect. This protects you from <a
|
|
298
|
+
href="https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#confused-deputy-problem"
|
|
299
|
+
target="_blank" class="tooltip-link">confused deputy
|
|
300
|
+
attacks</a>, where malicious clients could impersonate you
|
|
301
|
+
and steal access.<br><br>
|
|
302
|
+
<a
|
|
303
|
+
href="https://gofastmcp.com/servers/auth/oauth-proxy#confused-deputy-attacks"
|
|
304
|
+
target="_blank" class="tooltip-link">Learn more about
|
|
305
|
+
FastMCP security →</a>
|
|
306
|
+
</span>
|
|
307
|
+
</span>
|
|
308
|
+
</div>
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
# Build the page content
|
|
312
|
+
content = f"""
|
|
313
|
+
<div class="container">
|
|
314
|
+
{create_logo(icon_url=server_icon_url, alt_text=server_name or "FastMCP")}
|
|
315
|
+
<h1>Authorization Consent</h1>
|
|
316
|
+
{warning_box}
|
|
317
|
+
{detail_box}
|
|
318
|
+
{form}
|
|
319
|
+
</div>
|
|
320
|
+
{help_link}
|
|
321
|
+
"""
|
|
322
|
+
|
|
323
|
+
# Additional styles needed for this page
|
|
324
|
+
additional_styles = (
|
|
325
|
+
INFO_BOX_STYLES + DETAIL_BOX_STYLES + BUTTON_STYLES + TOOLTIP_STYLES
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# Need to allow form-action for form submission
|
|
329
|
+
csp_policy = "default-src 'none'; style-src 'unsafe-inline'; img-src https:; base-uri 'none'; form-action *"
|
|
330
|
+
|
|
331
|
+
return create_page(
|
|
332
|
+
content=content,
|
|
333
|
+
title=title,
|
|
334
|
+
additional_styles=additional_styles,
|
|
335
|
+
csp_policy=csp_policy,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
# -------------------------------------------------------------------------
|
|
340
|
+
# Handler Classes
|
|
341
|
+
# -------------------------------------------------------------------------
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
class TokenHandler(_SDKTokenHandler):
|
|
345
|
+
"""TokenHandler that returns OAuth 2.1 compliant error responses.
|
|
346
|
+
|
|
347
|
+
The MCP SDK always returns HTTP 400 for all client authentication issues.
|
|
348
|
+
However, OAuth 2.1 Section 5.3 and the MCP specification require that
|
|
349
|
+
invalid or expired tokens MUST receive a HTTP 401 response.
|
|
350
|
+
|
|
351
|
+
This handler extends the base MCP SDK TokenHandler to transform client
|
|
352
|
+
authentication failures into OAuth 2.1 compliant responses:
|
|
353
|
+
- Changes 'unauthorized_client' to 'invalid_client' error code
|
|
354
|
+
- Returns HTTP 401 status code instead of 400 for client auth failures
|
|
355
|
+
|
|
356
|
+
Per OAuth 2.1 Section 5.3: "The authorization server MAY return an HTTP 401
|
|
357
|
+
(Unauthorized) status code to indicate which HTTP authentication schemes
|
|
358
|
+
are supported."
|
|
359
|
+
|
|
360
|
+
Per MCP spec: "Invalid or expired tokens MUST receive a HTTP 401 response."
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
|
|
364
|
+
"""Override response method to provide OAuth 2.1 compliant error handling."""
|
|
365
|
+
# Check if this is a client authentication failure (not just unauthorized for grant type)
|
|
366
|
+
# unauthorized_client can mean two things:
|
|
367
|
+
# 1. Client authentication failed (client_id not found or wrong credentials) -> invalid_client 401
|
|
368
|
+
# 2. Client not authorized for this grant type -> unauthorized_client 400 (correct per spec)
|
|
369
|
+
if (
|
|
370
|
+
isinstance(obj, TokenErrorResponse)
|
|
371
|
+
and obj.error == "unauthorized_client"
|
|
372
|
+
and obj.error_description
|
|
373
|
+
and "Invalid client_id" in obj.error_description
|
|
374
|
+
):
|
|
375
|
+
# Transform client auth failure to OAuth 2.1 compliant response
|
|
376
|
+
return PydanticJSONResponse(
|
|
377
|
+
content=TokenErrorResponse(
|
|
378
|
+
error="invalid_client",
|
|
379
|
+
error_description=obj.error_description,
|
|
380
|
+
error_uri=obj.error_uri,
|
|
381
|
+
),
|
|
382
|
+
status_code=401,
|
|
383
|
+
headers={
|
|
384
|
+
"Cache-Control": "no-store",
|
|
385
|
+
"Pragma": "no-cache",
|
|
386
|
+
},
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# Otherwise use default behavior from parent class
|
|
390
|
+
return super().response(obj)
|
|
123
391
|
|
|
124
392
|
|
|
125
393
|
class OAuthProxy(OAuthProvider):
|
|
@@ -201,7 +469,6 @@ class OAuthProxy(OAuthProvider):
|
|
|
201
469
|
State Management
|
|
202
470
|
---------------
|
|
203
471
|
The proxy maintains minimal but crucial state:
|
|
204
|
-
- _clients: DCR registrations (all use ProxyDCRClient for flexibility)
|
|
205
472
|
- _oauth_transactions: Active authorization flows with client context
|
|
206
473
|
- _client_codes: Authorization codes with PKCE challenges and upstream tokens
|
|
207
474
|
- _access_tokens, _refresh_tokens: Token storage for revocation
|
|
@@ -257,7 +524,11 @@ class OAuthProxy(OAuthProvider):
|
|
|
257
524
|
# Extra parameters to forward to token endpoint
|
|
258
525
|
extra_token_params: dict[str, str] | None = None,
|
|
259
526
|
# Client storage
|
|
260
|
-
client_storage:
|
|
527
|
+
client_storage: AsyncKeyValue | None = None,
|
|
528
|
+
# JWT signing key (optional, ephemeral if not provided)
|
|
529
|
+
jwt_signing_key: str | bytes | None = None,
|
|
530
|
+
# Token encryption key (optional, ephemeral if not provided)
|
|
531
|
+
token_encryption_key: str | bytes | None = None,
|
|
261
532
|
):
|
|
262
533
|
"""Initialize the OAuth proxy provider.
|
|
263
534
|
|
|
@@ -291,9 +562,13 @@ class OAuthProxy(OAuthProvider):
|
|
|
291
562
|
Example: {"audience": "https://api.example.com"}
|
|
292
563
|
extra_token_params: Additional parameters to forward to the upstream token endpoint.
|
|
293
564
|
Useful for provider-specific parameters during token exchange.
|
|
294
|
-
client_storage:
|
|
295
|
-
|
|
296
|
-
|
|
565
|
+
client_storage: An AsyncKeyValue-compatible store for client registrations, registrations are stored in memory if not provided
|
|
566
|
+
jwt_signing_key: Optional secret for signing FastMCP JWT tokens (accepts any string or bytes).
|
|
567
|
+
Default: ephemeral (random salt at startup, won't survive restart).
|
|
568
|
+
Production: provide explicit key from environment variable.
|
|
569
|
+
token_encryption_key: Optional secret for encrypting upstream tokens at rest (accepts any string or bytes).
|
|
570
|
+
Default: ephemeral (random salt at startup, won't survive restart).
|
|
571
|
+
Production: provide explicit key from environment variable.
|
|
297
572
|
"""
|
|
298
573
|
# Always enable DCR since we implement it locally for MCP clients
|
|
299
574
|
client_registration_options = ClientRegistrationOptions(
|
|
@@ -330,7 +605,25 @@ class OAuthProxy(OAuthProvider):
|
|
|
330
605
|
self._redirect_path = (
|
|
331
606
|
redirect_path if redirect_path.startswith("/") else f"/{redirect_path}"
|
|
332
607
|
)
|
|
333
|
-
|
|
608
|
+
# Redirect URI validation (consent flow provides primary protection)
|
|
609
|
+
if allowed_client_redirect_uris is None:
|
|
610
|
+
logger.info(
|
|
611
|
+
"allowed_client_redirect_uris not specified; accepting all redirect URIs. "
|
|
612
|
+
"Consent flow provides protection against confused deputy attacks. "
|
|
613
|
+
"Configure allowed patterns for defense-in-depth."
|
|
614
|
+
)
|
|
615
|
+
self._allowed_client_redirect_uris = None
|
|
616
|
+
elif (
|
|
617
|
+
isinstance(allowed_client_redirect_uris, list)
|
|
618
|
+
and not allowed_client_redirect_uris
|
|
619
|
+
):
|
|
620
|
+
logger.warning(
|
|
621
|
+
"allowed_client_redirect_uris is empty list; no redirect URIs will be accepted. "
|
|
622
|
+
"This will block all OAuth clients."
|
|
623
|
+
)
|
|
624
|
+
self._allowed_client_redirect_uris = []
|
|
625
|
+
else:
|
|
626
|
+
self._allowed_client_redirect_uris = allowed_client_redirect_uris
|
|
334
627
|
|
|
335
628
|
# PKCE configuration
|
|
336
629
|
self._forward_pkce = forward_pkce
|
|
@@ -342,11 +635,69 @@ class OAuthProxy(OAuthProvider):
|
|
|
342
635
|
self._extra_authorize_params = extra_authorize_params or {}
|
|
343
636
|
self._extra_token_params = extra_token_params or {}
|
|
344
637
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
638
|
+
self._client_storage: AsyncKeyValue = client_storage or MemoryStore()
|
|
639
|
+
|
|
640
|
+
# Warn if using MemoryStore in production
|
|
641
|
+
if isinstance(client_storage, MemoryStore):
|
|
642
|
+
logger.warning(
|
|
643
|
+
"Using in-memory storage - all OAuth state (clients, tokens) will be lost on restart. "
|
|
644
|
+
"Additionally, without explicit jwt_signing_key and token_encryption_key, "
|
|
645
|
+
"keys are ephemeral and tokens won't survive restart even with persistent storage. "
|
|
646
|
+
"For production, configure persistent storage AND explicit keys."
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
# Cache HTTPS check to avoid repeated logging
|
|
650
|
+
self._is_https = str(self.base_url).startswith("https://")
|
|
651
|
+
if not self._is_https:
|
|
652
|
+
logger.warning(
|
|
653
|
+
"Using non-secure cookies for development; deploy with HTTPS for production."
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
self._client_store = PydanticAdapter[ProxyDCRClient](
|
|
657
|
+
key_value=self._client_storage,
|
|
658
|
+
pydantic_model=ProxyDCRClient,
|
|
659
|
+
default_collection="mcp-oauth-proxy-clients",
|
|
660
|
+
raise_on_validation_error=True,
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
# OAuth transaction storage for IdP callback forwarding
|
|
664
|
+
# Reuse client_storage with different collections for state management
|
|
665
|
+
self._transaction_store = PydanticAdapter[OAuthTransaction](
|
|
666
|
+
key_value=self._client_storage,
|
|
667
|
+
pydantic_model=OAuthTransaction,
|
|
668
|
+
default_collection="mcp-oauth-transactions",
|
|
669
|
+
raise_on_validation_error=True,
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
self._code_store = PydanticAdapter[ClientCode](
|
|
673
|
+
key_value=self._client_storage,
|
|
674
|
+
pydantic_model=ClientCode,
|
|
675
|
+
default_collection="mcp-authorization-codes",
|
|
676
|
+
raise_on_validation_error=True,
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
# Storage for upstream tokens (encrypted at rest)
|
|
680
|
+
self._upstream_token_store = PydanticAdapter[UpstreamTokenSet](
|
|
681
|
+
key_value=self._client_storage,
|
|
682
|
+
pydantic_model=UpstreamTokenSet,
|
|
683
|
+
default_collection="mcp-upstream-tokens",
|
|
684
|
+
raise_on_validation_error=True,
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
# Storage for JTI mappings (FastMCP token -> upstream token)
|
|
688
|
+
self._jti_mapping_store = PydanticAdapter[JTIMapping](
|
|
689
|
+
key_value=self._client_storage,
|
|
690
|
+
pydantic_model=JTIMapping,
|
|
691
|
+
default_collection="mcp-jti-mappings",
|
|
692
|
+
raise_on_validation_error=True,
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
# JWT issuer and encryption (initialized lazily on first use)
|
|
696
|
+
self._custom_jwt_key = jwt_signing_key
|
|
697
|
+
self._custom_encryption_key = token_encryption_key
|
|
698
|
+
self._jwt_issuer: JWTIssuer | None = None
|
|
699
|
+
self._token_encryption: TokenEncryption | None = None
|
|
700
|
+
self._jwt_initialized = False
|
|
350
701
|
|
|
351
702
|
# Local state for token bookkeeping only (no client caching)
|
|
352
703
|
self._access_tokens: dict[str, AccessToken] = {}
|
|
@@ -356,12 +707,6 @@ class OAuthProxy(OAuthProvider):
|
|
|
356
707
|
self._access_to_refresh: dict[str, str] = {}
|
|
357
708
|
self._refresh_to_access: dict[str, str] = {}
|
|
358
709
|
|
|
359
|
-
# OAuth transaction storage for IdP callback forwarding
|
|
360
|
-
self._oauth_transactions: dict[
|
|
361
|
-
str, dict[str, Any]
|
|
362
|
-
] = {} # txn_id -> transaction_data
|
|
363
|
-
self._client_codes: dict[str, dict[str, Any]] = {} # client_code -> code_data
|
|
364
|
-
|
|
365
710
|
# Use the provided token validator
|
|
366
711
|
self._token_validator = token_verifier
|
|
367
712
|
|
|
@@ -389,6 +734,87 @@ class OAuthProxy(OAuthProvider):
|
|
|
389
734
|
|
|
390
735
|
return code_verifier, code_challenge
|
|
391
736
|
|
|
737
|
+
# -------------------------------------------------------------------------
|
|
738
|
+
# JWT Token Factory Initialization
|
|
739
|
+
# -------------------------------------------------------------------------
|
|
740
|
+
|
|
741
|
+
async def _ensure_jwt_initialized(self) -> None:
|
|
742
|
+
"""Initialize JWT issuer and token encryption (lazy initialization).
|
|
743
|
+
|
|
744
|
+
Key derivation strategy:
|
|
745
|
+
- Default: Generate random salt at startup, derive ephemeral keys
|
|
746
|
+
→ Keys change on restart, all tokens become invalid
|
|
747
|
+
→ Perfect for development/testing where re-auth is acceptable
|
|
748
|
+
|
|
749
|
+
- Production: User provides explicit keys via parameters
|
|
750
|
+
→ Keys stable across restarts when combined with persistent storage
|
|
751
|
+
→ Tokens survive restart, seamless client reconnection
|
|
752
|
+
"""
|
|
753
|
+
if self._jwt_initialized:
|
|
754
|
+
return
|
|
755
|
+
|
|
756
|
+
# Generate random salt for this server instance (NOT persisted)
|
|
757
|
+
server_salt = secrets.token_urlsafe(32)
|
|
758
|
+
|
|
759
|
+
# Derive or use custom JWT signing key
|
|
760
|
+
from fastmcp.server.auth.jwt_issuer import derive_key_from_secret
|
|
761
|
+
|
|
762
|
+
if self._custom_jwt_key:
|
|
763
|
+
jwt_key = derive_key_from_secret(
|
|
764
|
+
secret=self._custom_jwt_key,
|
|
765
|
+
salt="fastmcp-jwt-signing-v1",
|
|
766
|
+
info=b"HS256",
|
|
767
|
+
)
|
|
768
|
+
logger.info("Using explicit JWT signing key (will survive restarts)")
|
|
769
|
+
else:
|
|
770
|
+
# Ephemeral key from random salt + upstream secret
|
|
771
|
+
upstream_secret = self._upstream_client_secret.get_secret_value()
|
|
772
|
+
jwt_key = derive_key_from_secret(
|
|
773
|
+
secret=upstream_secret,
|
|
774
|
+
salt=f"fastmcp-jwt-signing-v1-{server_salt}",
|
|
775
|
+
info=b"HS256",
|
|
776
|
+
)
|
|
777
|
+
logger.info(
|
|
778
|
+
"Using ephemeral JWT signing key - tokens will NOT survive server restart. "
|
|
779
|
+
"For production, provide explicit jwt_signing_key parameter."
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
# Initialize JWT issuer
|
|
783
|
+
issuer = str(self.base_url)
|
|
784
|
+
audience = f"{str(self.base_url).rstrip('/')}/mcp"
|
|
785
|
+
self._jwt_issuer = JWTIssuer(
|
|
786
|
+
issuer=issuer,
|
|
787
|
+
audience=audience,
|
|
788
|
+
signing_key=jwt_key,
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
# Derive or use custom encryption key
|
|
792
|
+
if self._custom_encryption_key:
|
|
793
|
+
encryption_key = derive_key_from_secret(
|
|
794
|
+
secret=self._custom_encryption_key,
|
|
795
|
+
salt="fastmcp-token-encryption-v1",
|
|
796
|
+
info=b"Fernet",
|
|
797
|
+
)
|
|
798
|
+
# Fernet needs base64url-encoded key
|
|
799
|
+
encryption_key = base64.urlsafe_b64encode(encryption_key)
|
|
800
|
+
logger.info("Using explicit token encryption key (will survive restarts)")
|
|
801
|
+
else:
|
|
802
|
+
# Ephemeral key from random salt + upstream secret
|
|
803
|
+
upstream_secret = self._upstream_client_secret.get_secret_value()
|
|
804
|
+
key_material = derive_key_from_secret(
|
|
805
|
+
secret=upstream_secret,
|
|
806
|
+
salt=f"fastmcp-token-encryption-v1-{server_salt}",
|
|
807
|
+
info=b"Fernet",
|
|
808
|
+
)
|
|
809
|
+
encryption_key = base64.urlsafe_b64encode(key_material)
|
|
810
|
+
logger.info(
|
|
811
|
+
"Using ephemeral token encryption key - encrypted tokens will NOT survive server restart. "
|
|
812
|
+
"For production, provide explicit token_encryption_key parameter."
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
self._token_encryption = TokenEncryption(encryption_key)
|
|
816
|
+
self._jwt_initialized = True
|
|
817
|
+
|
|
392
818
|
# -------------------------------------------------------------------------
|
|
393
819
|
# Client Registration (Local Implementation)
|
|
394
820
|
# -------------------------------------------------------------------------
|
|
@@ -400,19 +826,13 @@ class OAuthProxy(OAuthProvider):
|
|
|
400
826
|
For unregistered clients, returns None (which will raise an error in the SDK).
|
|
401
827
|
"""
|
|
402
828
|
# Load from storage
|
|
403
|
-
|
|
404
|
-
if not data:
|
|
829
|
+
if not (client := await self._client_store.get(key=client_id)):
|
|
405
830
|
return None
|
|
406
831
|
|
|
407
|
-
if
|
|
408
|
-
|
|
409
|
-
allowed_redirect_uri_patterns=data.get(
|
|
410
|
-
"allowed_redirect_uri_patterns", self._allowed_client_redirect_uris
|
|
411
|
-
),
|
|
412
|
-
**client_data,
|
|
413
|
-
)
|
|
832
|
+
if client.allowed_redirect_uri_patterns is None:
|
|
833
|
+
client.allowed_redirect_uri_patterns = self._allowed_client_redirect_uris
|
|
414
834
|
|
|
415
|
-
return
|
|
835
|
+
return client
|
|
416
836
|
|
|
417
837
|
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
|
|
418
838
|
"""Register a client locally
|
|
@@ -424,7 +844,7 @@ class OAuthProxy(OAuthProvider):
|
|
|
424
844
|
"""
|
|
425
845
|
|
|
426
846
|
# Create a ProxyDCRClient with configured redirect URI validation
|
|
427
|
-
proxy_client = ProxyDCRClient(
|
|
847
|
+
proxy_client: ProxyDCRClient = ProxyDCRClient(
|
|
428
848
|
client_id=client_info.client_id,
|
|
429
849
|
client_secret=client_info.client_secret,
|
|
430
850
|
redirect_uris=client_info.redirect_uris or [AnyUrl("http://localhost")],
|
|
@@ -433,14 +853,13 @@ class OAuthProxy(OAuthProvider):
|
|
|
433
853
|
scope=client_info.scope or self._default_scope_str,
|
|
434
854
|
token_endpoint_auth_method="none",
|
|
435
855
|
allowed_redirect_uri_patterns=self._allowed_client_redirect_uris,
|
|
856
|
+
client_name=getattr(client_info, "client_name", None),
|
|
436
857
|
)
|
|
437
858
|
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
}
|
|
443
|
-
await self._client_storage.set(client_info.client_id, storage_data)
|
|
859
|
+
await self._client_store.put(
|
|
860
|
+
key=client_info.client_id,
|
|
861
|
+
value=proxy_client,
|
|
862
|
+
)
|
|
444
863
|
|
|
445
864
|
# Log redirect URIs to help users discover what patterns they might need
|
|
446
865
|
if client_info.redirect_uris:
|
|
@@ -466,13 +885,12 @@ class OAuthProxy(OAuthProvider):
|
|
|
466
885
|
client: OAuthClientInformationFull,
|
|
467
886
|
params: AuthorizationParams,
|
|
468
887
|
) -> str:
|
|
469
|
-
"""Start OAuth transaction and
|
|
888
|
+
"""Start OAuth transaction and route through consent interstitial.
|
|
470
889
|
|
|
471
|
-
|
|
472
|
-
1. Store transaction with client details and PKCE
|
|
473
|
-
2.
|
|
474
|
-
3.
|
|
475
|
-
4. Redirect to IdP with our fixed callback URL and proxy's PKCE
|
|
890
|
+
Flow:
|
|
891
|
+
1. Store transaction with client details and PKCE (if forwarding)
|
|
892
|
+
2. Return local /consent URL; browser visits consent first
|
|
893
|
+
3. Consent handler redirects to upstream IdP if approved/already approved
|
|
476
894
|
"""
|
|
477
895
|
# Generate transaction ID for this authorization request
|
|
478
896
|
txn_id = secrets.token_urlsafe(32)
|
|
@@ -488,75 +906,32 @@ class OAuthProxy(OAuthProvider):
|
|
|
488
906
|
)
|
|
489
907
|
|
|
490
908
|
# Store transaction data for IdP callback processing
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
# Build query parameters for upstream IdP authorization request
|
|
508
|
-
# Use our fixed IdP callback and transaction ID as state
|
|
509
|
-
query_params: dict[str, Any] = {
|
|
510
|
-
"response_type": "code",
|
|
511
|
-
"client_id": self._upstream_client_id,
|
|
512
|
-
"redirect_uri": f"{str(self.base_url).rstrip('/')}{self._redirect_path}",
|
|
513
|
-
"state": txn_id, # Use txn_id as IdP state
|
|
514
|
-
}
|
|
515
|
-
|
|
516
|
-
# Add scopes - use client scopes or fallback to required scopes
|
|
517
|
-
scopes_to_use = params.scopes or self.required_scopes or []
|
|
518
|
-
|
|
519
|
-
if scopes_to_use:
|
|
520
|
-
query_params["scope"] = " ".join(scopes_to_use)
|
|
521
|
-
|
|
522
|
-
# Forward proxy's PKCE challenge to upstream if enabled
|
|
523
|
-
if proxy_code_challenge:
|
|
524
|
-
query_params["code_challenge"] = proxy_code_challenge
|
|
525
|
-
query_params["code_challenge_method"] = "S256"
|
|
526
|
-
logger.debug(
|
|
527
|
-
"Forwarding proxy PKCE challenge to upstream for transaction %s",
|
|
528
|
-
txn_id,
|
|
529
|
-
)
|
|
530
|
-
|
|
531
|
-
# Forward resource parameter if provided (RFC 8707)
|
|
532
|
-
if params.resource:
|
|
533
|
-
query_params["resource"] = params.resource
|
|
534
|
-
logger.debug(
|
|
535
|
-
"Forwarding resource indicator '%s' to upstream for transaction %s",
|
|
536
|
-
params.resource,
|
|
537
|
-
txn_id,
|
|
538
|
-
)
|
|
539
|
-
|
|
540
|
-
# Add any extra authorization parameters configured for this proxy
|
|
541
|
-
if self._extra_authorize_params:
|
|
542
|
-
query_params.update(self._extra_authorize_params)
|
|
543
|
-
logger.debug(
|
|
544
|
-
"Adding extra authorization parameters for transaction %s: %s",
|
|
545
|
-
txn_id,
|
|
546
|
-
list(self._extra_authorize_params.keys()),
|
|
547
|
-
)
|
|
909
|
+
await self._transaction_store.put(
|
|
910
|
+
key=txn_id,
|
|
911
|
+
value=OAuthTransaction(
|
|
912
|
+
txn_id=txn_id,
|
|
913
|
+
client_id=client.client_id,
|
|
914
|
+
client_redirect_uri=str(params.redirect_uri),
|
|
915
|
+
client_state=params.state or "",
|
|
916
|
+
code_challenge=params.code_challenge,
|
|
917
|
+
code_challenge_method=getattr(params, "code_challenge_method", "S256"),
|
|
918
|
+
scopes=params.scopes or [],
|
|
919
|
+
created_at=time.time(),
|
|
920
|
+
resource=getattr(params, "resource", None),
|
|
921
|
+
proxy_code_verifier=proxy_code_verifier,
|
|
922
|
+
),
|
|
923
|
+
ttl=15 * 60, # Auto-expire after 15 minutes
|
|
924
|
+
)
|
|
548
925
|
|
|
549
|
-
|
|
550
|
-
separator = "&" if "?" in self._upstream_authorization_endpoint else "?"
|
|
551
|
-
upstream_url = f"{self._upstream_authorization_endpoint}{separator}{urlencode(query_params)}"
|
|
926
|
+
consent_url = f"{str(self.base_url).rstrip('/')}/consent?txn_id={txn_id}"
|
|
552
927
|
|
|
553
928
|
logger.debug(
|
|
554
|
-
"Starting OAuth transaction %s for client %s, redirecting to
|
|
929
|
+
"Starting OAuth transaction %s for client %s, redirecting to consent page (PKCE forwarding: %s)",
|
|
555
930
|
txn_id,
|
|
556
931
|
client.client_id,
|
|
557
932
|
"enabled" if proxy_code_challenge else "disabled",
|
|
558
933
|
)
|
|
559
|
-
return
|
|
934
|
+
return consent_url
|
|
560
935
|
|
|
561
936
|
# -------------------------------------------------------------------------
|
|
562
937
|
# Authorization Code Handling
|
|
@@ -573,22 +948,22 @@ class OAuthProxy(OAuthProvider):
|
|
|
573
948
|
with PKCE challenge for validation.
|
|
574
949
|
"""
|
|
575
950
|
# Look up client code data
|
|
576
|
-
|
|
577
|
-
if not
|
|
951
|
+
code_model = await self._code_store.get(key=authorization_code)
|
|
952
|
+
if not code_model:
|
|
578
953
|
logger.debug("Authorization code not found: %s", authorization_code)
|
|
579
954
|
return None
|
|
580
955
|
|
|
581
956
|
# Check if code expired
|
|
582
|
-
if time.time() >
|
|
957
|
+
if time.time() > code_model.expires_at:
|
|
583
958
|
logger.debug("Authorization code expired: %s", authorization_code)
|
|
584
|
-
self.
|
|
959
|
+
await self._code_store.delete(key=authorization_code)
|
|
585
960
|
return None
|
|
586
961
|
|
|
587
962
|
# Verify client ID matches
|
|
588
|
-
if
|
|
963
|
+
if code_model.client_id != client.client_id:
|
|
589
964
|
logger.debug(
|
|
590
965
|
"Authorization code client ID mismatch: %s vs %s",
|
|
591
|
-
|
|
966
|
+
code_model.client_id,
|
|
592
967
|
client.client_id,
|
|
593
968
|
)
|
|
594
969
|
return None
|
|
@@ -597,11 +972,11 @@ class OAuthProxy(OAuthProvider):
|
|
|
597
972
|
return AuthorizationCode(
|
|
598
973
|
code=authorization_code,
|
|
599
974
|
client_id=client.client_id,
|
|
600
|
-
redirect_uri=
|
|
975
|
+
redirect_uri=code_model.redirect_uri,
|
|
601
976
|
redirect_uri_provided_explicitly=True,
|
|
602
|
-
scopes=
|
|
603
|
-
expires_at=
|
|
604
|
-
code_challenge=
|
|
977
|
+
scopes=code_model.scopes,
|
|
978
|
+
expires_at=code_model.expires_at,
|
|
979
|
+
code_challenge=code_model.code_challenge or "",
|
|
605
980
|
)
|
|
606
981
|
|
|
607
982
|
async def exchange_authorization_code(
|
|
@@ -609,63 +984,166 @@ class OAuthProxy(OAuthProvider):
|
|
|
609
984
|
client: OAuthClientInformationFull,
|
|
610
985
|
authorization_code: AuthorizationCode,
|
|
611
986
|
) -> OAuthToken:
|
|
612
|
-
"""Exchange authorization code for
|
|
987
|
+
"""Exchange authorization code for FastMCP-issued tokens.
|
|
613
988
|
|
|
614
|
-
|
|
615
|
-
|
|
989
|
+
Implements the token factory pattern:
|
|
990
|
+
1. Retrieves upstream tokens from stored authorization code
|
|
991
|
+
2. Extracts user identity from upstream token
|
|
992
|
+
3. Encrypts and stores upstream tokens
|
|
993
|
+
4. Issues FastMCP-signed JWT tokens
|
|
994
|
+
5. Returns FastMCP tokens (NOT upstream tokens)
|
|
995
|
+
|
|
996
|
+
PKCE validation is handled by the MCP framework before this method is called.
|
|
616
997
|
"""
|
|
998
|
+
# Ensure JWT issuer is initialized
|
|
999
|
+
await self._ensure_jwt_initialized()
|
|
1000
|
+
assert self._jwt_issuer is not None
|
|
1001
|
+
assert self._token_encryption is not None
|
|
1002
|
+
|
|
617
1003
|
# Look up stored code data
|
|
618
|
-
|
|
619
|
-
if not
|
|
1004
|
+
code_model = await self._code_store.get(key=authorization_code.code)
|
|
1005
|
+
if not code_model:
|
|
620
1006
|
logger.error(
|
|
621
1007
|
"Authorization code not found in client codes: %s",
|
|
622
1008
|
authorization_code.code,
|
|
623
1009
|
)
|
|
624
1010
|
raise TokenError("invalid_grant", "Authorization code not found")
|
|
625
1011
|
|
|
626
|
-
# Get stored
|
|
627
|
-
idp_tokens =
|
|
1012
|
+
# Get stored upstream tokens
|
|
1013
|
+
idp_tokens = code_model.idp_tokens
|
|
628
1014
|
|
|
629
1015
|
# Clean up client code (one-time use)
|
|
630
|
-
self.
|
|
1016
|
+
await self._code_store.delete(key=authorization_code.code)
|
|
631
1017
|
|
|
632
|
-
#
|
|
633
|
-
|
|
634
|
-
|
|
1018
|
+
# Generate IDs for token storage
|
|
1019
|
+
upstream_token_id = secrets.token_urlsafe(32)
|
|
1020
|
+
access_jti = secrets.token_urlsafe(32)
|
|
1021
|
+
refresh_jti = (
|
|
1022
|
+
secrets.token_urlsafe(32) if idp_tokens.get("refresh_token") else None
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
# Calculate token expiry times
|
|
635
1026
|
expires_in = int(
|
|
636
1027
|
idp_tokens.get("expires_in", DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS)
|
|
637
1028
|
)
|
|
638
|
-
expires_at = int(time.time() + expires_in)
|
|
639
1029
|
|
|
640
|
-
#
|
|
641
|
-
|
|
642
|
-
|
|
1030
|
+
# Calculate refresh token expiry if provided by upstream
|
|
1031
|
+
# Some providers include refresh_expires_in, some don't
|
|
1032
|
+
refresh_expires_in = None
|
|
1033
|
+
refresh_token_expires_at = None
|
|
1034
|
+
if idp_tokens.get("refresh_token"):
|
|
1035
|
+
if "refresh_expires_in" in idp_tokens:
|
|
1036
|
+
refresh_expires_in = int(idp_tokens["refresh_expires_in"])
|
|
1037
|
+
refresh_token_expires_at = time.time() + refresh_expires_in
|
|
1038
|
+
logger.debug(
|
|
1039
|
+
"Upstream refresh token expires in %d seconds", refresh_expires_in
|
|
1040
|
+
)
|
|
1041
|
+
else:
|
|
1042
|
+
# Default to 30 days if upstream doesn't specify
|
|
1043
|
+
# This is conservative - most providers use longer expiry
|
|
1044
|
+
refresh_expires_in = 60 * 60 * 24 * 30 # 30 days
|
|
1045
|
+
refresh_token_expires_at = time.time() + refresh_expires_in
|
|
1046
|
+
logger.debug(
|
|
1047
|
+
"Upstream refresh token expiry unknown, using 30-day default"
|
|
1048
|
+
)
|
|
1049
|
+
|
|
1050
|
+
# Encrypt and store upstream tokens
|
|
1051
|
+
upstream_token_set = UpstreamTokenSet(
|
|
1052
|
+
upstream_token_id=upstream_token_id,
|
|
1053
|
+
access_token=self._token_encryption.encrypt(idp_tokens["access_token"]),
|
|
1054
|
+
refresh_token=self._token_encryption.encrypt(idp_tokens["refresh_token"])
|
|
1055
|
+
if idp_tokens.get("refresh_token")
|
|
1056
|
+
else None,
|
|
1057
|
+
refresh_token_expires_at=refresh_token_expires_at,
|
|
1058
|
+
expires_at=time.time() + expires_in,
|
|
1059
|
+
token_type=idp_tokens.get("token_type", "Bearer"),
|
|
1060
|
+
scope=" ".join(authorization_code.scopes),
|
|
1061
|
+
client_id=client.client_id,
|
|
1062
|
+
created_at=time.time(),
|
|
1063
|
+
raw_token_data=idp_tokens,
|
|
1064
|
+
)
|
|
1065
|
+
await self._upstream_token_store.put(
|
|
1066
|
+
key=upstream_token_id,
|
|
1067
|
+
value=upstream_token_set,
|
|
1068
|
+
ttl=expires_in, # Auto-expire when access token expires
|
|
1069
|
+
)
|
|
1070
|
+
logger.debug("Stored encrypted upstream tokens (jti=%s)", access_jti[:8])
|
|
1071
|
+
|
|
1072
|
+
# Issue minimal FastMCP access token (just a reference via JTI)
|
|
1073
|
+
fastmcp_access_token = self._jwt_issuer.issue_access_token(
|
|
643
1074
|
client_id=client.client_id,
|
|
644
1075
|
scopes=authorization_code.scopes,
|
|
645
|
-
|
|
1076
|
+
jti=access_jti,
|
|
1077
|
+
expires_in=expires_in,
|
|
646
1078
|
)
|
|
647
|
-
self._access_tokens[access_token_value] = access_token
|
|
648
1079
|
|
|
649
|
-
#
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
1080
|
+
# Issue minimal FastMCP refresh token if upstream provided one
|
|
1081
|
+
# Use upstream refresh token expiry to align lifetimes
|
|
1082
|
+
fastmcp_refresh_token = None
|
|
1083
|
+
if refresh_jti and refresh_expires_in:
|
|
1084
|
+
fastmcp_refresh_token = self._jwt_issuer.issue_refresh_token(
|
|
653
1085
|
client_id=client.client_id,
|
|
654
1086
|
scopes=authorization_code.scopes,
|
|
655
|
-
|
|
1087
|
+
jti=refresh_jti,
|
|
1088
|
+
expires_in=refresh_expires_in,
|
|
1089
|
+
)
|
|
1090
|
+
|
|
1091
|
+
# Store JTI mappings
|
|
1092
|
+
await self._jti_mapping_store.put(
|
|
1093
|
+
key=access_jti,
|
|
1094
|
+
value=JTIMapping(
|
|
1095
|
+
jti=access_jti,
|
|
1096
|
+
upstream_token_id=upstream_token_id,
|
|
1097
|
+
created_at=time.time(),
|
|
1098
|
+
),
|
|
1099
|
+
ttl=expires_in, # Auto-expire with access token
|
|
1100
|
+
)
|
|
1101
|
+
if refresh_jti:
|
|
1102
|
+
await self._jti_mapping_store.put(
|
|
1103
|
+
key=refresh_jti,
|
|
1104
|
+
value=JTIMapping(
|
|
1105
|
+
jti=refresh_jti,
|
|
1106
|
+
upstream_token_id=upstream_token_id,
|
|
1107
|
+
created_at=time.time(),
|
|
1108
|
+
),
|
|
1109
|
+
ttl=60 * 60 * 24 * 30, # Auto-expire with refresh token (30 days)
|
|
656
1110
|
)
|
|
657
|
-
self._refresh_tokens[refresh_token_value] = refresh_token
|
|
658
1111
|
|
|
1112
|
+
# Store FastMCP access token for MCP framework validation
|
|
1113
|
+
self._access_tokens[fastmcp_access_token] = AccessToken(
|
|
1114
|
+
token=fastmcp_access_token,
|
|
1115
|
+
client_id=client.client_id,
|
|
1116
|
+
scopes=authorization_code.scopes,
|
|
1117
|
+
expires_at=int(time.time() + expires_in),
|
|
1118
|
+
)
|
|
1119
|
+
|
|
1120
|
+
# Store FastMCP refresh token if provided
|
|
1121
|
+
if fastmcp_refresh_token:
|
|
1122
|
+
self._refresh_tokens[fastmcp_refresh_token] = RefreshToken(
|
|
1123
|
+
token=fastmcp_refresh_token,
|
|
1124
|
+
client_id=client.client_id,
|
|
1125
|
+
scopes=authorization_code.scopes,
|
|
1126
|
+
expires_at=None,
|
|
1127
|
+
)
|
|
659
1128
|
# Maintain token relationships for cleanup
|
|
660
|
-
self._access_to_refresh[
|
|
661
|
-
self._refresh_to_access[
|
|
1129
|
+
self._access_to_refresh[fastmcp_access_token] = fastmcp_refresh_token
|
|
1130
|
+
self._refresh_to_access[fastmcp_refresh_token] = fastmcp_access_token
|
|
662
1131
|
|
|
663
1132
|
logger.debug(
|
|
664
|
-
"
|
|
1133
|
+
"Issued FastMCP tokens for client=%s (access_jti=%s, refresh_jti=%s)",
|
|
665
1134
|
client.client_id,
|
|
1135
|
+
access_jti[:8],
|
|
1136
|
+
refresh_jti[:8] if refresh_jti else "none",
|
|
666
1137
|
)
|
|
667
1138
|
|
|
668
|
-
|
|
1139
|
+
# Return FastMCP-issued tokens (NOT upstream tokens!)
|
|
1140
|
+
return OAuthToken(
|
|
1141
|
+
access_token=fastmcp_access_token,
|
|
1142
|
+
token_type="Bearer",
|
|
1143
|
+
expires_in=expires_in,
|
|
1144
|
+
refresh_token=fastmcp_refresh_token,
|
|
1145
|
+
scope=" ".join(authorization_code.scopes),
|
|
1146
|
+
)
|
|
669
1147
|
|
|
670
1148
|
# -------------------------------------------------------------------------
|
|
671
1149
|
# Refresh Token Flow
|
|
@@ -685,9 +1163,54 @@ class OAuthProxy(OAuthProvider):
|
|
|
685
1163
|
refresh_token: RefreshToken,
|
|
686
1164
|
scopes: list[str],
|
|
687
1165
|
) -> OAuthToken:
|
|
688
|
-
"""Exchange refresh token for new access token
|
|
1166
|
+
"""Exchange FastMCP refresh token for new FastMCP access token.
|
|
1167
|
+
|
|
1168
|
+
Implements two-tier refresh:
|
|
1169
|
+
1. Verify FastMCP refresh token
|
|
1170
|
+
2. Look up upstream token via JTI mapping
|
|
1171
|
+
3. Refresh upstream token with upstream provider
|
|
1172
|
+
4. Update stored upstream token
|
|
1173
|
+
5. Issue new FastMCP access token
|
|
1174
|
+
6. Keep same FastMCP refresh token (unless upstream rotates)
|
|
1175
|
+
"""
|
|
1176
|
+
# Ensure JWT issuer is initialized
|
|
1177
|
+
await self._ensure_jwt_initialized()
|
|
1178
|
+
assert self._jwt_issuer is not None
|
|
1179
|
+
assert self._token_encryption is not None
|
|
1180
|
+
|
|
1181
|
+
# Verify FastMCP refresh token
|
|
1182
|
+
try:
|
|
1183
|
+
refresh_payload = self._jwt_issuer.verify_token(refresh_token.token)
|
|
1184
|
+
refresh_jti = refresh_payload["jti"]
|
|
1185
|
+
except Exception as e:
|
|
1186
|
+
logger.debug("FastMCP refresh token validation failed: %s", e)
|
|
1187
|
+
raise TokenError("invalid_grant", "Invalid refresh token") from e
|
|
1188
|
+
|
|
1189
|
+
# Look up upstream token via JTI mapping
|
|
1190
|
+
jti_mapping = await self._jti_mapping_store.get(key=refresh_jti)
|
|
1191
|
+
if not jti_mapping:
|
|
1192
|
+
logger.error("JTI mapping not found for refresh token: %s", refresh_jti[:8])
|
|
1193
|
+
raise TokenError("invalid_grant", "Refresh token mapping not found")
|
|
1194
|
+
|
|
1195
|
+
upstream_token_set = await self._upstream_token_store.get(
|
|
1196
|
+
key=jti_mapping.upstream_token_id
|
|
1197
|
+
)
|
|
1198
|
+
if not upstream_token_set:
|
|
1199
|
+
logger.error(
|
|
1200
|
+
"Upstream token set not found: %s", jti_mapping.upstream_token_id[:8]
|
|
1201
|
+
)
|
|
1202
|
+
raise TokenError("invalid_grant", "Upstream token not found")
|
|
1203
|
+
|
|
1204
|
+
# Decrypt upstream refresh token
|
|
1205
|
+
if not upstream_token_set.refresh_token:
|
|
1206
|
+
logger.error("No upstream refresh token available")
|
|
1207
|
+
raise TokenError("invalid_grant", "Refresh not supported for this token")
|
|
689
1208
|
|
|
690
|
-
|
|
1209
|
+
upstream_refresh_token = self._token_encryption.decrypt(
|
|
1210
|
+
upstream_token_set.refresh_token
|
|
1211
|
+
)
|
|
1212
|
+
|
|
1213
|
+
# Refresh upstream token using authlib
|
|
691
1214
|
oauth_client = AsyncOAuth2Client(
|
|
692
1215
|
client_id=self._upstream_client_id,
|
|
693
1216
|
client_secret=self._upstream_client_secret.get_secret_value(),
|
|
@@ -696,77 +1219,217 @@ class OAuthProxy(OAuthProvider):
|
|
|
696
1219
|
)
|
|
697
1220
|
|
|
698
1221
|
try:
|
|
699
|
-
logger.debug("
|
|
700
|
-
|
|
701
|
-
# Let authlib handle the refresh token exchange
|
|
1222
|
+
logger.debug("Refreshing upstream token (jti=%s)", refresh_jti[:8])
|
|
702
1223
|
token_response: dict[str, Any] = await oauth_client.refresh_token( # type: ignore[misc]
|
|
703
1224
|
url=self._upstream_token_endpoint,
|
|
704
|
-
refresh_token=
|
|
1225
|
+
refresh_token=upstream_refresh_token,
|
|
705
1226
|
scope=" ".join(scopes) if scopes else None,
|
|
706
1227
|
)
|
|
707
|
-
|
|
708
|
-
logger.debug(
|
|
709
|
-
"Successfully refreshed access token via authlib (client: %s)",
|
|
710
|
-
client.client_id,
|
|
711
|
-
)
|
|
712
|
-
|
|
1228
|
+
logger.debug("Successfully refreshed upstream token")
|
|
713
1229
|
except Exception as e:
|
|
714
|
-
logger.error("
|
|
715
|
-
raise TokenError(
|
|
716
|
-
"invalid_grant", f"Upstream refresh token exchange failed: {e}"
|
|
717
|
-
) from e
|
|
1230
|
+
logger.error("Upstream token refresh failed: %s", e)
|
|
1231
|
+
raise TokenError("invalid_grant", f"Upstream refresh failed: {e}") from e
|
|
718
1232
|
|
|
719
|
-
# Update
|
|
720
|
-
|
|
721
|
-
expires_in = int(
|
|
1233
|
+
# Update stored upstream token
|
|
1234
|
+
new_expires_in = int(
|
|
722
1235
|
token_response.get("expires_in", DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS)
|
|
723
1236
|
)
|
|
1237
|
+
upstream_token_set.access_token = self._token_encryption.encrypt(
|
|
1238
|
+
token_response["access_token"]
|
|
1239
|
+
)
|
|
1240
|
+
upstream_token_set.expires_at = time.time() + new_expires_in
|
|
1241
|
+
|
|
1242
|
+
# Handle upstream refresh token rotation and expiry
|
|
1243
|
+
new_refresh_expires_in = None
|
|
1244
|
+
if new_upstream_refresh := token_response.get("refresh_token"):
|
|
1245
|
+
if new_upstream_refresh != upstream_refresh_token:
|
|
1246
|
+
upstream_token_set.refresh_token = self._token_encryption.encrypt(
|
|
1247
|
+
new_upstream_refresh
|
|
1248
|
+
)
|
|
1249
|
+
logger.debug("Upstream refresh token rotated")
|
|
724
1250
|
|
|
725
|
-
|
|
726
|
-
|
|
1251
|
+
# Update refresh token expiry if provided
|
|
1252
|
+
if "refresh_expires_in" in token_response:
|
|
1253
|
+
new_refresh_expires_in = int(token_response["refresh_expires_in"])
|
|
1254
|
+
upstream_token_set.refresh_token_expires_at = (
|
|
1255
|
+
time.time() + new_refresh_expires_in
|
|
1256
|
+
)
|
|
1257
|
+
logger.debug(
|
|
1258
|
+
"Upstream refresh token expires in %d seconds",
|
|
1259
|
+
new_refresh_expires_in,
|
|
1260
|
+
)
|
|
1261
|
+
elif upstream_token_set.refresh_token_expires_at:
|
|
1262
|
+
# Keep existing expiry if upstream doesn't provide new one
|
|
1263
|
+
new_refresh_expires_in = int(
|
|
1264
|
+
upstream_token_set.refresh_token_expires_at - time.time()
|
|
1265
|
+
)
|
|
1266
|
+
else:
|
|
1267
|
+
# Default to 30 days if unknown
|
|
1268
|
+
new_refresh_expires_in = 60 * 60 * 24 * 30
|
|
1269
|
+
upstream_token_set.refresh_token_expires_at = (
|
|
1270
|
+
time.time() + new_refresh_expires_in
|
|
1271
|
+
)
|
|
1272
|
+
|
|
1273
|
+
upstream_token_set.raw_token_data = token_response
|
|
1274
|
+
await self._upstream_token_store.put(
|
|
1275
|
+
key=upstream_token_set.upstream_token_id,
|
|
1276
|
+
value=upstream_token_set,
|
|
1277
|
+
ttl=new_expires_in, # Auto-expire when refreshed access token expires
|
|
1278
|
+
)
|
|
1279
|
+
|
|
1280
|
+
# Issue new minimal FastMCP access token (just a reference via JTI)
|
|
1281
|
+
new_access_jti = secrets.token_urlsafe(32)
|
|
1282
|
+
new_fastmcp_access = self._jwt_issuer.issue_access_token(
|
|
727
1283
|
client_id=client.client_id,
|
|
728
1284
|
scopes=scopes,
|
|
729
|
-
|
|
1285
|
+
jti=new_access_jti,
|
|
1286
|
+
expires_in=new_expires_in,
|
|
730
1287
|
)
|
|
731
1288
|
|
|
732
|
-
#
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
1289
|
+
# Store new access token JTI mapping
|
|
1290
|
+
await self._jti_mapping_store.put(
|
|
1291
|
+
key=new_access_jti,
|
|
1292
|
+
value=JTIMapping(
|
|
1293
|
+
jti=new_access_jti,
|
|
1294
|
+
upstream_token_id=upstream_token_set.upstream_token_id,
|
|
1295
|
+
created_at=time.time(),
|
|
1296
|
+
),
|
|
1297
|
+
ttl=new_expires_in, # Auto-expire with refreshed access token
|
|
1298
|
+
)
|
|
1299
|
+
|
|
1300
|
+
# Issue NEW minimal FastMCP refresh token (rotation for security)
|
|
1301
|
+
# Use upstream refresh token expiry to align lifetimes
|
|
1302
|
+
new_refresh_jti = secrets.token_urlsafe(32)
|
|
1303
|
+
new_fastmcp_refresh = self._jwt_issuer.issue_refresh_token(
|
|
1304
|
+
client_id=client.client_id,
|
|
1305
|
+
scopes=scopes,
|
|
1306
|
+
jti=new_refresh_jti,
|
|
1307
|
+
expires_in=new_refresh_expires_in
|
|
1308
|
+
or 60 * 60 * 24 * 30, # Fallback to 30 days
|
|
1309
|
+
)
|
|
751
1310
|
|
|
752
|
-
|
|
1311
|
+
# Store new refresh token JTI mapping with aligned expiry
|
|
1312
|
+
refresh_ttl = new_refresh_expires_in or 60 * 60 * 24 * 30
|
|
1313
|
+
await self._jti_mapping_store.put(
|
|
1314
|
+
key=new_refresh_jti,
|
|
1315
|
+
value=JTIMapping(
|
|
1316
|
+
jti=new_refresh_jti,
|
|
1317
|
+
upstream_token_id=upstream_token_set.upstream_token_id,
|
|
1318
|
+
created_at=time.time(),
|
|
1319
|
+
),
|
|
1320
|
+
ttl=refresh_ttl, # Align with upstream refresh token expiry
|
|
1321
|
+
)
|
|
1322
|
+
|
|
1323
|
+
# Invalidate old refresh token (refresh token rotation - enforces one-time use)
|
|
1324
|
+
await self._jti_mapping_store.delete(key=refresh_jti)
|
|
1325
|
+
logger.debug(
|
|
1326
|
+
"Rotated refresh token (old JTI invalidated - one-time use enforced)"
|
|
1327
|
+
)
|
|
1328
|
+
|
|
1329
|
+
# Update local token tracking
|
|
1330
|
+
self._access_tokens[new_fastmcp_access] = AccessToken(
|
|
1331
|
+
token=new_fastmcp_access,
|
|
1332
|
+
client_id=client.client_id,
|
|
1333
|
+
scopes=scopes,
|
|
1334
|
+
expires_at=int(time.time() + new_expires_in),
|
|
1335
|
+
)
|
|
1336
|
+
self._refresh_tokens[new_fastmcp_refresh] = RefreshToken(
|
|
1337
|
+
token=new_fastmcp_refresh,
|
|
1338
|
+
client_id=client.client_id,
|
|
1339
|
+
scopes=scopes,
|
|
1340
|
+
expires_at=None,
|
|
1341
|
+
)
|
|
1342
|
+
|
|
1343
|
+
# Update token relationship mappings
|
|
1344
|
+
self._access_to_refresh[new_fastmcp_access] = new_fastmcp_refresh
|
|
1345
|
+
self._refresh_to_access[new_fastmcp_refresh] = new_fastmcp_access
|
|
1346
|
+
|
|
1347
|
+
# Clean up old token from in-memory tracking
|
|
1348
|
+
self._refresh_tokens.pop(refresh_token.token, None)
|
|
1349
|
+
old_access = self._refresh_to_access.pop(refresh_token.token, None)
|
|
1350
|
+
if old_access:
|
|
1351
|
+
self._access_tokens.pop(old_access, None)
|
|
1352
|
+
self._access_to_refresh.pop(old_access, None)
|
|
1353
|
+
|
|
1354
|
+
logger.info(
|
|
1355
|
+
"Issued new FastMCP tokens (rotated refresh) for client=%s (access_jti=%s, refresh_jti=%s)",
|
|
1356
|
+
client.client_id,
|
|
1357
|
+
new_access_jti[:8],
|
|
1358
|
+
new_refresh_jti[:8],
|
|
1359
|
+
)
|
|
1360
|
+
|
|
1361
|
+
# Return new FastMCP tokens (both access AND refresh are new)
|
|
1362
|
+
return OAuthToken(
|
|
1363
|
+
access_token=new_fastmcp_access,
|
|
1364
|
+
token_type="Bearer",
|
|
1365
|
+
expires_in=new_expires_in,
|
|
1366
|
+
refresh_token=new_fastmcp_refresh, # NEW refresh token (rotated)
|
|
1367
|
+
scope=" ".join(scopes),
|
|
1368
|
+
)
|
|
753
1369
|
|
|
754
1370
|
# -------------------------------------------------------------------------
|
|
755
1371
|
# Token Validation
|
|
756
1372
|
# -------------------------------------------------------------------------
|
|
757
1373
|
|
|
758
1374
|
async def load_access_token(self, token: str) -> AccessToken | None:
|
|
759
|
-
"""Validate
|
|
1375
|
+
"""Validate FastMCP JWT by swapping for upstream token.
|
|
1376
|
+
|
|
1377
|
+
This implements the token swap pattern:
|
|
1378
|
+
1. Verify FastMCP JWT signature (proves it's our token)
|
|
1379
|
+
2. Look up upstream token via JTI mapping
|
|
1380
|
+
3. Decrypt upstream token
|
|
1381
|
+
4. Validate upstream token with provider (GitHub API, JWT validation, etc.)
|
|
1382
|
+
5. Return upstream validation result
|
|
760
1383
|
|
|
761
|
-
|
|
762
|
-
|
|
1384
|
+
The FastMCP JWT is a reference token - all authorization data comes
|
|
1385
|
+
from validating the upstream token via the TokenVerifier.
|
|
763
1386
|
"""
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
1387
|
+
# Ensure JWT issuer and encryption are initialized
|
|
1388
|
+
await self._ensure_jwt_initialized()
|
|
1389
|
+
assert self._jwt_issuer is not None
|
|
1390
|
+
assert self._token_encryption is not None
|
|
1391
|
+
|
|
1392
|
+
try:
|
|
1393
|
+
# 1. Verify FastMCP JWT signature and claims
|
|
1394
|
+
payload = self._jwt_issuer.verify_token(token)
|
|
1395
|
+
jti = payload["jti"]
|
|
1396
|
+
|
|
1397
|
+
# 2. Look up upstream token via JTI mapping
|
|
1398
|
+
jti_mapping = await self._jti_mapping_store.get(key=jti)
|
|
1399
|
+
if not jti_mapping:
|
|
1400
|
+
logger.debug("JTI mapping not found: %s", jti)
|
|
1401
|
+
return None
|
|
1402
|
+
|
|
1403
|
+
upstream_token_set = await self._upstream_token_store.get(
|
|
1404
|
+
key=jti_mapping.upstream_token_id
|
|
1405
|
+
)
|
|
1406
|
+
if not upstream_token_set:
|
|
1407
|
+
logger.debug(
|
|
1408
|
+
"Upstream token not found: %s", jti_mapping.upstream_token_id
|
|
1409
|
+
)
|
|
1410
|
+
return None
|
|
1411
|
+
|
|
1412
|
+
# 3. Decrypt upstream token
|
|
1413
|
+
upstream_token = self._token_encryption.decrypt(
|
|
1414
|
+
upstream_token_set.access_token
|
|
1415
|
+
)
|
|
1416
|
+
|
|
1417
|
+
# 4. Validate with upstream provider (delegated to TokenVerifier)
|
|
1418
|
+
# This calls the real token validator (GitHub API, JWKS, etc.)
|
|
1419
|
+
validated = await self._token_validator.verify_token(upstream_token)
|
|
1420
|
+
|
|
1421
|
+
if not validated:
|
|
1422
|
+
logger.debug("Upstream token validation failed")
|
|
1423
|
+
return None
|
|
1424
|
+
|
|
1425
|
+
logger.debug(
|
|
1426
|
+
"Token swap successful for JTI=%s (upstream validated)", jti[:8]
|
|
1427
|
+
)
|
|
1428
|
+
return validated
|
|
1429
|
+
|
|
1430
|
+
except Exception as e:
|
|
1431
|
+
logger.debug("Token swap validation failed: %s", e)
|
|
1432
|
+
return None
|
|
770
1433
|
|
|
771
1434
|
# -------------------------------------------------------------------------
|
|
772
1435
|
# Token Revocation
|
|
@@ -819,7 +1482,6 @@ class OAuthProxy(OAuthProvider):
|
|
|
819
1482
|
def get_routes(
|
|
820
1483
|
self,
|
|
821
1484
|
mcp_path: str | None = None,
|
|
822
|
-
mcp_endpoint: Any | None = None,
|
|
823
1485
|
) -> list[Route]:
|
|
824
1486
|
"""Get OAuth routes with custom proxy token handler.
|
|
825
1487
|
|
|
@@ -828,10 +1490,10 @@ class OAuthProxy(OAuthProvider):
|
|
|
828
1490
|
|
|
829
1491
|
Args:
|
|
830
1492
|
mcp_path: The path where the MCP endpoint is mounted (e.g., "/mcp")
|
|
831
|
-
|
|
1493
|
+
This is used to advertise the resource URL in metadata.
|
|
832
1494
|
"""
|
|
833
1495
|
# Get standard OAuth routes from parent class
|
|
834
|
-
routes = super().get_routes(mcp_path
|
|
1496
|
+
routes = super().get_routes(mcp_path)
|
|
835
1497
|
custom_routes = []
|
|
836
1498
|
token_route_found = False
|
|
837
1499
|
|
|
@@ -844,9 +1506,7 @@ class OAuthProxy(OAuthProvider):
|
|
|
844
1506
|
f"Route {i}: {route} - path: {getattr(route, 'path', 'N/A')}, methods: {getattr(route, 'methods', 'N/A')}"
|
|
845
1507
|
)
|
|
846
1508
|
|
|
847
|
-
#
|
|
848
|
-
custom_routes.append(route)
|
|
849
|
-
|
|
1509
|
+
# Replace the token endpoint with our custom handler that returns proper OAuth 2.1 error codes
|
|
850
1510
|
if (
|
|
851
1511
|
isinstance(route, Route)
|
|
852
1512
|
and route.path == "/token"
|
|
@@ -854,6 +1514,22 @@ class OAuthProxy(OAuthProvider):
|
|
|
854
1514
|
and "POST" in route.methods
|
|
855
1515
|
):
|
|
856
1516
|
token_route_found = True
|
|
1517
|
+
# Replace with our OAuth 2.1 compliant token handler
|
|
1518
|
+
token_handler = TokenHandler(
|
|
1519
|
+
provider=self, client_authenticator=ClientAuthenticator(self)
|
|
1520
|
+
)
|
|
1521
|
+
custom_routes.append(
|
|
1522
|
+
Route(
|
|
1523
|
+
path="/token",
|
|
1524
|
+
endpoint=cors_middleware(
|
|
1525
|
+
token_handler.handle, ["POST", "OPTIONS"]
|
|
1526
|
+
),
|
|
1527
|
+
methods=["POST", "OPTIONS"],
|
|
1528
|
+
)
|
|
1529
|
+
)
|
|
1530
|
+
else:
|
|
1531
|
+
# Keep all other standard OAuth routes unchanged
|
|
1532
|
+
custom_routes.append(route)
|
|
857
1533
|
|
|
858
1534
|
# Add OAuth callback endpoint for forwarding to client callbacks
|
|
859
1535
|
custom_routes.append(
|
|
@@ -864,8 +1540,18 @@ class OAuthProxy(OAuthProvider):
|
|
|
864
1540
|
)
|
|
865
1541
|
)
|
|
866
1542
|
|
|
1543
|
+
# Add consent endpoints
|
|
1544
|
+
custom_routes.append(
|
|
1545
|
+
Route(path="/consent", endpoint=self._show_consent_page, methods=["GET"])
|
|
1546
|
+
)
|
|
1547
|
+
custom_routes.append(
|
|
1548
|
+
Route(
|
|
1549
|
+
path="/consent/submit", endpoint=self._submit_consent, methods=["POST"]
|
|
1550
|
+
)
|
|
1551
|
+
)
|
|
1552
|
+
|
|
867
1553
|
logger.debug(
|
|
868
|
-
f"✅ OAuth routes configured: token_endpoint={token_route_found}, total routes={len(custom_routes)} (includes OAuth callback)"
|
|
1554
|
+
f"✅ OAuth routes configured: token_endpoint={token_route_found}, total routes={len(custom_routes)} (includes OAuth callback + consent)"
|
|
869
1555
|
)
|
|
870
1556
|
return custom_routes
|
|
871
1557
|
|
|
@@ -907,13 +1593,14 @@ class OAuthProxy(OAuthProvider):
|
|
|
907
1593
|
)
|
|
908
1594
|
|
|
909
1595
|
# Look up transaction data
|
|
910
|
-
|
|
911
|
-
if not
|
|
1596
|
+
transaction_model = await self._transaction_store.get(key=txn_id)
|
|
1597
|
+
if not transaction_model:
|
|
912
1598
|
logger.error("IdP callback with invalid transaction ID: %s", txn_id)
|
|
913
1599
|
return RedirectResponse(
|
|
914
1600
|
url="data:text/html,<h1>OAuth Error</h1><p>Invalid or expired transaction</p>",
|
|
915
1601
|
status_code=302,
|
|
916
1602
|
)
|
|
1603
|
+
transaction = transaction_model.model_dump()
|
|
917
1604
|
|
|
918
1605
|
# Exchange IdP code for tokens (server-side)
|
|
919
1606
|
oauth_client = AsyncOAuth2Client(
|
|
@@ -977,19 +1664,24 @@ class OAuthProxy(OAuthProvider):
|
|
|
977
1664
|
code_expires_at = int(time.time() + DEFAULT_AUTH_CODE_EXPIRY_SECONDS)
|
|
978
1665
|
|
|
979
1666
|
# Store client code with PKCE challenge and IdP tokens
|
|
980
|
-
self.
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
1667
|
+
await self._code_store.put(
|
|
1668
|
+
key=client_code,
|
|
1669
|
+
value=ClientCode(
|
|
1670
|
+
code=client_code,
|
|
1671
|
+
client_id=transaction["client_id"],
|
|
1672
|
+
redirect_uri=transaction["client_redirect_uri"],
|
|
1673
|
+
code_challenge=transaction["code_challenge"],
|
|
1674
|
+
code_challenge_method=transaction["code_challenge_method"],
|
|
1675
|
+
scopes=transaction["scopes"],
|
|
1676
|
+
idp_tokens=idp_tokens,
|
|
1677
|
+
expires_at=code_expires_at,
|
|
1678
|
+
created_at=time.time(),
|
|
1679
|
+
),
|
|
1680
|
+
ttl=DEFAULT_AUTH_CODE_EXPIRY_SECONDS, # Auto-expire after 5 minutes
|
|
1681
|
+
)
|
|
990
1682
|
|
|
991
1683
|
# Clean up transaction
|
|
992
|
-
self.
|
|
1684
|
+
await self._transaction_store.delete(key=txn_id)
|
|
993
1685
|
|
|
994
1686
|
# Build client callback URL with our code and original state
|
|
995
1687
|
client_redirect_uri = transaction["client_redirect_uri"]
|
|
@@ -1016,3 +1708,315 @@ class OAuthProxy(OAuthProvider):
|
|
|
1016
1708
|
url="data:text/html,<h1>OAuth Error</h1><p>Internal server error during IdP callback</p>",
|
|
1017
1709
|
status_code=302,
|
|
1018
1710
|
)
|
|
1711
|
+
|
|
1712
|
+
# -------------------------------------------------------------------------
|
|
1713
|
+
# Consent Interstitial
|
|
1714
|
+
# -------------------------------------------------------------------------
|
|
1715
|
+
|
|
1716
|
+
def _normalize_uri(self, uri: str) -> str:
|
|
1717
|
+
"""Normalize a URI to a canonical form for consent tracking."""
|
|
1718
|
+
parsed = urlparse(uri)
|
|
1719
|
+
path = parsed.path or ""
|
|
1720
|
+
normalized = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}{path}"
|
|
1721
|
+
if normalized.endswith("/") and len(path) > 1:
|
|
1722
|
+
normalized = normalized[:-1]
|
|
1723
|
+
return normalized
|
|
1724
|
+
|
|
1725
|
+
def _make_client_key(self, client_id: str, redirect_uri: str | AnyUrl) -> str:
|
|
1726
|
+
"""Create a stable key for consent tracking from client_id and redirect_uri."""
|
|
1727
|
+
normalized = self._normalize_uri(str(redirect_uri))
|
|
1728
|
+
return f"{client_id}:{normalized}"
|
|
1729
|
+
|
|
1730
|
+
def _cookie_name(self, base_name: str) -> str:
|
|
1731
|
+
"""Return secure cookie name for HTTPS, fallback for HTTP development."""
|
|
1732
|
+
if self._is_https:
|
|
1733
|
+
return f"__Host-{base_name}"
|
|
1734
|
+
return f"__{base_name}"
|
|
1735
|
+
|
|
1736
|
+
def _sign_cookie(self, payload: str) -> str:
|
|
1737
|
+
"""Sign a cookie payload with HMAC-SHA256.
|
|
1738
|
+
|
|
1739
|
+
Returns: base64(payload).base64(signature)
|
|
1740
|
+
"""
|
|
1741
|
+
# Use upstream client secret as signing key
|
|
1742
|
+
key = self._upstream_client_secret.get_secret_value().encode()
|
|
1743
|
+
signature = hmac.new(key, payload.encode(), hashlib.sha256).digest()
|
|
1744
|
+
signature_b64 = base64.b64encode(signature).decode()
|
|
1745
|
+
return f"{payload}.{signature_b64}"
|
|
1746
|
+
|
|
1747
|
+
def _verify_cookie(self, signed_value: str) -> str | None:
|
|
1748
|
+
"""Verify and extract payload from signed cookie.
|
|
1749
|
+
|
|
1750
|
+
Returns: payload if signature valid, None otherwise
|
|
1751
|
+
"""
|
|
1752
|
+
try:
|
|
1753
|
+
if "." not in signed_value:
|
|
1754
|
+
return None
|
|
1755
|
+
payload, signature_b64 = signed_value.rsplit(".", 1)
|
|
1756
|
+
|
|
1757
|
+
# Verify signature
|
|
1758
|
+
key = self._upstream_client_secret.get_secret_value().encode()
|
|
1759
|
+
expected_sig = hmac.new(key, payload.encode(), hashlib.sha256).digest()
|
|
1760
|
+
provided_sig = base64.b64decode(signature_b64.encode())
|
|
1761
|
+
|
|
1762
|
+
# Constant-time comparison
|
|
1763
|
+
if not hmac.compare_digest(expected_sig, provided_sig):
|
|
1764
|
+
return None
|
|
1765
|
+
|
|
1766
|
+
return payload
|
|
1767
|
+
except Exception:
|
|
1768
|
+
return None
|
|
1769
|
+
|
|
1770
|
+
def _decode_list_cookie(self, request: Request, base_name: str) -> list[str]:
|
|
1771
|
+
"""Decode and verify a signed base64-encoded JSON list from cookie. Returns [] if missing/invalid."""
|
|
1772
|
+
# Prefer secure name, but also check non-secure variant for dev
|
|
1773
|
+
secure_name = self._cookie_name(base_name)
|
|
1774
|
+
raw = request.cookies.get(secure_name) or request.cookies.get(f"__{base_name}")
|
|
1775
|
+
if not raw:
|
|
1776
|
+
return []
|
|
1777
|
+
try:
|
|
1778
|
+
# Verify signature
|
|
1779
|
+
payload = self._verify_cookie(raw)
|
|
1780
|
+
if not payload:
|
|
1781
|
+
logger.debug("Cookie signature verification failed for %s", secure_name)
|
|
1782
|
+
return []
|
|
1783
|
+
|
|
1784
|
+
# Decode payload
|
|
1785
|
+
data = base64.b64decode(payload.encode())
|
|
1786
|
+
value = json.loads(data.decode())
|
|
1787
|
+
if isinstance(value, list):
|
|
1788
|
+
return [str(x) for x in value]
|
|
1789
|
+
except Exception:
|
|
1790
|
+
logger.debug("Failed to decode cookie %s; treating as empty", secure_name)
|
|
1791
|
+
return []
|
|
1792
|
+
|
|
1793
|
+
def _encode_list_cookie(self, values: list[str]) -> str:
|
|
1794
|
+
"""Encode values to base64 and sign with HMAC.
|
|
1795
|
+
|
|
1796
|
+
Returns: signed cookie value (payload.signature)
|
|
1797
|
+
"""
|
|
1798
|
+
payload = json.dumps(values, separators=(",", ":")).encode()
|
|
1799
|
+
payload_b64 = base64.b64encode(payload).decode()
|
|
1800
|
+
return self._sign_cookie(payload_b64)
|
|
1801
|
+
|
|
1802
|
+
def _set_list_cookie(
|
|
1803
|
+
self,
|
|
1804
|
+
response: HTMLResponse | RedirectResponse,
|
|
1805
|
+
base_name: str,
|
|
1806
|
+
value_b64: str,
|
|
1807
|
+
max_age: int,
|
|
1808
|
+
) -> None:
|
|
1809
|
+
name = self._cookie_name(base_name)
|
|
1810
|
+
response.set_cookie(
|
|
1811
|
+
name,
|
|
1812
|
+
value_b64,
|
|
1813
|
+
max_age=max_age,
|
|
1814
|
+
secure=self._is_https,
|
|
1815
|
+
httponly=True,
|
|
1816
|
+
samesite="lax",
|
|
1817
|
+
path="/",
|
|
1818
|
+
)
|
|
1819
|
+
|
|
1820
|
+
def _build_upstream_authorize_url(
|
|
1821
|
+
self, txn_id: str, transaction: dict[str, Any]
|
|
1822
|
+
) -> str:
|
|
1823
|
+
"""Construct the upstream IdP authorization URL using stored transaction data."""
|
|
1824
|
+
query_params: dict[str, Any] = {
|
|
1825
|
+
"response_type": "code",
|
|
1826
|
+
"client_id": self._upstream_client_id,
|
|
1827
|
+
"redirect_uri": f"{str(self.base_url).rstrip('/')}{self._redirect_path}",
|
|
1828
|
+
"state": txn_id,
|
|
1829
|
+
}
|
|
1830
|
+
|
|
1831
|
+
scopes_to_use = transaction.get("scopes") or self.required_scopes or []
|
|
1832
|
+
if scopes_to_use:
|
|
1833
|
+
query_params["scope"] = " ".join(scopes_to_use)
|
|
1834
|
+
|
|
1835
|
+
# If PKCE forwarding was enabled, include the proxy challenge
|
|
1836
|
+
proxy_code_verifier = transaction.get("proxy_code_verifier")
|
|
1837
|
+
if proxy_code_verifier:
|
|
1838
|
+
challenge_bytes = hashlib.sha256(proxy_code_verifier.encode()).digest()
|
|
1839
|
+
proxy_code_challenge = (
|
|
1840
|
+
urlsafe_b64encode(challenge_bytes).decode().rstrip("=")
|
|
1841
|
+
)
|
|
1842
|
+
query_params["code_challenge"] = proxy_code_challenge
|
|
1843
|
+
query_params["code_challenge_method"] = "S256"
|
|
1844
|
+
|
|
1845
|
+
# Forward resource indicator if present in transaction
|
|
1846
|
+
if resource := transaction.get("resource"):
|
|
1847
|
+
query_params["resource"] = resource
|
|
1848
|
+
|
|
1849
|
+
# Extra configured parameters
|
|
1850
|
+
if self._extra_authorize_params:
|
|
1851
|
+
query_params.update(self._extra_authorize_params)
|
|
1852
|
+
|
|
1853
|
+
separator = "&" if "?" in self._upstream_authorization_endpoint else "?"
|
|
1854
|
+
return f"{self._upstream_authorization_endpoint}{separator}{urlencode(query_params)}"
|
|
1855
|
+
|
|
1856
|
+
async def _show_consent_page(
|
|
1857
|
+
self, request: Request
|
|
1858
|
+
) -> HTMLResponse | RedirectResponse:
|
|
1859
|
+
"""Display consent page or auto-approve/deny based on cookies."""
|
|
1860
|
+
from fastmcp.server.server import FastMCP
|
|
1861
|
+
|
|
1862
|
+
txn_id = request.query_params.get("txn_id")
|
|
1863
|
+
if not txn_id:
|
|
1864
|
+
return create_secure_html_response(
|
|
1865
|
+
"<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
|
|
1866
|
+
)
|
|
1867
|
+
|
|
1868
|
+
txn_model = await self._transaction_store.get(key=txn_id)
|
|
1869
|
+
if not txn_model:
|
|
1870
|
+
return create_secure_html_response(
|
|
1871
|
+
"<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
|
|
1872
|
+
)
|
|
1873
|
+
|
|
1874
|
+
txn = txn_model.model_dump()
|
|
1875
|
+
client_key = self._make_client_key(txn["client_id"], txn["client_redirect_uri"])
|
|
1876
|
+
|
|
1877
|
+
approved = set(self._decode_list_cookie(request, "MCP_APPROVED_CLIENTS"))
|
|
1878
|
+
denied = set(self._decode_list_cookie(request, "MCP_DENIED_CLIENTS"))
|
|
1879
|
+
|
|
1880
|
+
if client_key in approved:
|
|
1881
|
+
upstream_url = self._build_upstream_authorize_url(txn_id, txn)
|
|
1882
|
+
return RedirectResponse(url=upstream_url, status_code=302)
|
|
1883
|
+
|
|
1884
|
+
if client_key in denied:
|
|
1885
|
+
callback_params = {
|
|
1886
|
+
"error": "access_denied",
|
|
1887
|
+
"state": txn.get("client_state") or "",
|
|
1888
|
+
}
|
|
1889
|
+
sep = "&" if "?" in txn["client_redirect_uri"] else "?"
|
|
1890
|
+
return RedirectResponse(
|
|
1891
|
+
url=f"{txn['client_redirect_uri']}{sep}{urlencode(callback_params)}",
|
|
1892
|
+
status_code=302,
|
|
1893
|
+
)
|
|
1894
|
+
|
|
1895
|
+
# Need consent: issue CSRF token and show HTML
|
|
1896
|
+
csrf_token = secrets.token_urlsafe(32)
|
|
1897
|
+
csrf_expires_at = time.time() + 15 * 60
|
|
1898
|
+
|
|
1899
|
+
# Update transaction with CSRF token
|
|
1900
|
+
txn_model.csrf_token = csrf_token
|
|
1901
|
+
txn_model.csrf_expires_at = csrf_expires_at
|
|
1902
|
+
await self._transaction_store.put(
|
|
1903
|
+
key=txn_id, value=txn_model, ttl=15 * 60
|
|
1904
|
+
) # Auto-expire after 15 minutes
|
|
1905
|
+
|
|
1906
|
+
# Update dict for use in HTML generation
|
|
1907
|
+
txn["csrf_token"] = csrf_token
|
|
1908
|
+
txn["csrf_expires_at"] = csrf_expires_at
|
|
1909
|
+
|
|
1910
|
+
# Load client to get client_name if available
|
|
1911
|
+
client = await self.get_client(txn["client_id"])
|
|
1912
|
+
client_name = getattr(client, "client_name", None) if client else None
|
|
1913
|
+
|
|
1914
|
+
# Extract server metadata from app state
|
|
1915
|
+
fastmcp = getattr(request.app.state, "fastmcp_server", None)
|
|
1916
|
+
|
|
1917
|
+
if isinstance(fastmcp, FastMCP):
|
|
1918
|
+
server_name = fastmcp.name
|
|
1919
|
+
icons = fastmcp.icons
|
|
1920
|
+
server_icon_url = icons[0].src if icons else None
|
|
1921
|
+
server_website_url = fastmcp.website_url
|
|
1922
|
+
else:
|
|
1923
|
+
server_name = None
|
|
1924
|
+
server_icon_url = None
|
|
1925
|
+
server_website_url = None
|
|
1926
|
+
|
|
1927
|
+
html = create_consent_html(
|
|
1928
|
+
client_id=txn["client_id"],
|
|
1929
|
+
redirect_uri=txn["client_redirect_uri"],
|
|
1930
|
+
scopes=txn.get("scopes") or [],
|
|
1931
|
+
txn_id=txn_id,
|
|
1932
|
+
csrf_token=csrf_token,
|
|
1933
|
+
client_name=client_name,
|
|
1934
|
+
server_name=server_name,
|
|
1935
|
+
server_icon_url=server_icon_url,
|
|
1936
|
+
server_website_url=server_website_url,
|
|
1937
|
+
)
|
|
1938
|
+
response = create_secure_html_response(html)
|
|
1939
|
+
# Store CSRF in cookie with short lifetime
|
|
1940
|
+
self._set_list_cookie(
|
|
1941
|
+
response,
|
|
1942
|
+
"MCP_CONSENT_STATE",
|
|
1943
|
+
self._encode_list_cookie([csrf_token]),
|
|
1944
|
+
max_age=15 * 60,
|
|
1945
|
+
)
|
|
1946
|
+
return response
|
|
1947
|
+
|
|
1948
|
+
async def _submit_consent(
|
|
1949
|
+
self, request: Request
|
|
1950
|
+
) -> RedirectResponse | HTMLResponse:
|
|
1951
|
+
"""Handle consent approval/denial, set cookies, and redirect appropriately."""
|
|
1952
|
+
form = await request.form()
|
|
1953
|
+
txn_id = str(form.get("txn_id", ""))
|
|
1954
|
+
action = str(form.get("action", ""))
|
|
1955
|
+
csrf_token = str(form.get("csrf_token", ""))
|
|
1956
|
+
|
|
1957
|
+
if not txn_id:
|
|
1958
|
+
return create_secure_html_response(
|
|
1959
|
+
"<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
|
|
1960
|
+
)
|
|
1961
|
+
|
|
1962
|
+
txn_model = await self._transaction_store.get(key=txn_id)
|
|
1963
|
+
if not txn_model:
|
|
1964
|
+
return create_secure_html_response(
|
|
1965
|
+
"<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
|
|
1966
|
+
)
|
|
1967
|
+
|
|
1968
|
+
txn = txn_model.model_dump()
|
|
1969
|
+
expected_csrf = txn.get("csrf_token")
|
|
1970
|
+
expires_at = float(txn.get("csrf_expires_at") or 0)
|
|
1971
|
+
|
|
1972
|
+
if not expected_csrf or csrf_token != expected_csrf or time.time() > expires_at:
|
|
1973
|
+
return create_secure_html_response(
|
|
1974
|
+
"<h1>Error</h1><p>Invalid or expired consent token</p>", status_code=400
|
|
1975
|
+
)
|
|
1976
|
+
|
|
1977
|
+
client_key = self._make_client_key(txn["client_id"], txn["client_redirect_uri"])
|
|
1978
|
+
|
|
1979
|
+
if action == "approve":
|
|
1980
|
+
approved = set(self._decode_list_cookie(request, "MCP_APPROVED_CLIENTS"))
|
|
1981
|
+
if client_key not in approved:
|
|
1982
|
+
approved.add(client_key)
|
|
1983
|
+
approved_b64 = self._encode_list_cookie(sorted(approved))
|
|
1984
|
+
|
|
1985
|
+
upstream_url = self._build_upstream_authorize_url(txn_id, txn)
|
|
1986
|
+
response = RedirectResponse(url=upstream_url, status_code=302)
|
|
1987
|
+
self._set_list_cookie(
|
|
1988
|
+
response, "MCP_APPROVED_CLIENTS", approved_b64, max_age=365 * 24 * 3600
|
|
1989
|
+
)
|
|
1990
|
+
# Clear CSRF cookie by setting empty short-lived value
|
|
1991
|
+
self._set_list_cookie(
|
|
1992
|
+
response, "MCP_CONSENT_STATE", self._encode_list_cookie([]), max_age=60
|
|
1993
|
+
)
|
|
1994
|
+
return response
|
|
1995
|
+
|
|
1996
|
+
elif action == "deny":
|
|
1997
|
+
denied = set(self._decode_list_cookie(request, "MCP_DENIED_CLIENTS"))
|
|
1998
|
+
if client_key not in denied:
|
|
1999
|
+
denied.add(client_key)
|
|
2000
|
+
denied_b64 = self._encode_list_cookie(sorted(denied))
|
|
2001
|
+
|
|
2002
|
+
callback_params = {
|
|
2003
|
+
"error": "access_denied",
|
|
2004
|
+
"state": txn.get("client_state") or "",
|
|
2005
|
+
}
|
|
2006
|
+
sep = "&" if "?" in txn["client_redirect_uri"] else "?"
|
|
2007
|
+
client_callback_url = (
|
|
2008
|
+
f"{txn['client_redirect_uri']}{sep}{urlencode(callback_params)}"
|
|
2009
|
+
)
|
|
2010
|
+
response = RedirectResponse(url=client_callback_url, status_code=302)
|
|
2011
|
+
self._set_list_cookie(
|
|
2012
|
+
response, "MCP_DENIED_CLIENTS", denied_b64, max_age=365 * 24 * 3600
|
|
2013
|
+
)
|
|
2014
|
+
self._set_list_cookie(
|
|
2015
|
+
response, "MCP_CONSENT_STATE", self._encode_list_cookie([]), max_age=60
|
|
2016
|
+
)
|
|
2017
|
+
return response
|
|
2018
|
+
|
|
2019
|
+
else:
|
|
2020
|
+
return create_secure_html_response(
|
|
2021
|
+
"<h1>Error</h1><p>Invalid action</p>", status_code=400
|
|
2022
|
+
)
|