fastmcp 2.12.1__py3-none-any.whl → 2.13.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastmcp/__init__.py +2 -2
- fastmcp/cli/cli.py +56 -36
- fastmcp/cli/install/__init__.py +2 -0
- fastmcp/cli/install/claude_code.py +7 -16
- fastmcp/cli/install/claude_desktop.py +4 -12
- fastmcp/cli/install/cursor.py +20 -30
- fastmcp/cli/install/gemini_cli.py +241 -0
- fastmcp/cli/install/mcp_json.py +4 -12
- fastmcp/cli/run.py +15 -94
- fastmcp/client/__init__.py +9 -9
- fastmcp/client/auth/oauth.py +117 -206
- fastmcp/client/client.py +123 -47
- fastmcp/client/elicitation.py +6 -1
- fastmcp/client/logging.py +18 -14
- fastmcp/client/oauth_callback.py +85 -171
- fastmcp/client/sampling.py +1 -1
- fastmcp/client/transports.py +81 -26
- fastmcp/contrib/component_manager/__init__.py +1 -1
- fastmcp/contrib/component_manager/component_manager.py +2 -2
- fastmcp/contrib/component_manager/component_service.py +7 -7
- fastmcp/contrib/mcp_mixin/README.md +35 -4
- fastmcp/contrib/mcp_mixin/__init__.py +2 -2
- fastmcp/contrib/mcp_mixin/mcp_mixin.py +54 -7
- fastmcp/experimental/sampling/handlers/openai.py +2 -2
- fastmcp/experimental/server/openapi/__init__.py +5 -8
- fastmcp/experimental/server/openapi/components.py +11 -7
- fastmcp/experimental/server/openapi/routing.py +2 -2
- fastmcp/experimental/utilities/openapi/__init__.py +10 -15
- fastmcp/experimental/utilities/openapi/director.py +16 -10
- fastmcp/experimental/utilities/openapi/json_schema_converter.py +6 -2
- fastmcp/experimental/utilities/openapi/models.py +3 -3
- fastmcp/experimental/utilities/openapi/parser.py +37 -16
- fastmcp/experimental/utilities/openapi/schemas.py +33 -7
- fastmcp/mcp_config.py +3 -4
- fastmcp/prompts/__init__.py +1 -1
- fastmcp/prompts/prompt.py +32 -27
- fastmcp/prompts/prompt_manager.py +16 -101
- fastmcp/resources/__init__.py +5 -5
- fastmcp/resources/resource.py +28 -20
- fastmcp/resources/resource_manager.py +9 -168
- fastmcp/resources/template.py +119 -27
- fastmcp/resources/types.py +30 -24
- fastmcp/server/__init__.py +1 -1
- fastmcp/server/auth/__init__.py +9 -5
- fastmcp/server/auth/auth.py +80 -47
- fastmcp/server/auth/handlers/authorize.py +326 -0
- fastmcp/server/auth/jwt_issuer.py +236 -0
- fastmcp/server/auth/middleware.py +96 -0
- fastmcp/server/auth/oauth_proxy.py +1556 -265
- fastmcp/server/auth/oidc_proxy.py +412 -0
- fastmcp/server/auth/providers/auth0.py +193 -0
- fastmcp/server/auth/providers/aws.py +263 -0
- fastmcp/server/auth/providers/azure.py +314 -129
- fastmcp/server/auth/providers/bearer.py +1 -1
- fastmcp/server/auth/providers/debug.py +114 -0
- fastmcp/server/auth/providers/descope.py +229 -0
- fastmcp/server/auth/providers/discord.py +308 -0
- fastmcp/server/auth/providers/github.py +31 -6
- fastmcp/server/auth/providers/google.py +50 -7
- fastmcp/server/auth/providers/in_memory.py +27 -3
- fastmcp/server/auth/providers/introspection.py +281 -0
- fastmcp/server/auth/providers/jwt.py +48 -31
- fastmcp/server/auth/providers/oci.py +233 -0
- fastmcp/server/auth/providers/scalekit.py +238 -0
- fastmcp/server/auth/providers/supabase.py +188 -0
- fastmcp/server/auth/providers/workos.py +37 -15
- fastmcp/server/context.py +194 -67
- fastmcp/server/dependencies.py +56 -16
- fastmcp/server/elicitation.py +1 -1
- fastmcp/server/http.py +57 -18
- fastmcp/server/low_level.py +121 -2
- fastmcp/server/middleware/__init__.py +1 -1
- fastmcp/server/middleware/caching.py +476 -0
- fastmcp/server/middleware/error_handling.py +14 -10
- fastmcp/server/middleware/logging.py +158 -116
- fastmcp/server/middleware/middleware.py +30 -16
- fastmcp/server/middleware/rate_limiting.py +3 -3
- fastmcp/server/middleware/tool_injection.py +116 -0
- fastmcp/server/openapi.py +15 -7
- fastmcp/server/proxy.py +22 -11
- fastmcp/server/server.py +744 -254
- fastmcp/settings.py +65 -15
- fastmcp/tools/__init__.py +1 -1
- fastmcp/tools/tool.py +173 -108
- fastmcp/tools/tool_manager.py +30 -112
- fastmcp/tools/tool_transform.py +13 -11
- fastmcp/utilities/cli.py +67 -28
- fastmcp/utilities/components.py +7 -2
- fastmcp/utilities/inspect.py +79 -23
- fastmcp/utilities/json_schema.py +21 -4
- fastmcp/utilities/json_schema_type.py +4 -4
- fastmcp/utilities/logging.py +182 -10
- fastmcp/utilities/mcp_server_config/__init__.py +3 -3
- fastmcp/utilities/mcp_server_config/v1/environments/base.py +1 -2
- fastmcp/utilities/mcp_server_config/v1/environments/uv.py +10 -45
- fastmcp/utilities/mcp_server_config/v1/mcp_server_config.py +8 -7
- fastmcp/utilities/mcp_server_config/v1/schema.json +5 -1
- fastmcp/utilities/mcp_server_config/v1/sources/base.py +0 -1
- fastmcp/utilities/openapi.py +11 -11
- fastmcp/utilities/tests.py +93 -10
- fastmcp/utilities/types.py +87 -21
- fastmcp/utilities/ui.py +626 -0
- {fastmcp-2.12.1.dist-info → fastmcp-2.13.2.dist-info}/METADATA +141 -60
- fastmcp-2.13.2.dist-info/RECORD +144 -0
- {fastmcp-2.12.1.dist-info → fastmcp-2.13.2.dist-info}/WHEEL +1 -1
- fastmcp/cli/claude.py +0 -144
- fastmcp-2.12.1.dist-info/RECORD +0 -128
- {fastmcp-2.12.1.dist-info → fastmcp-2.13.2.dist-info}/entry_points.txt +0 -0
- {fastmcp-2.12.1.dist-info → fastmcp-2.13.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -18,36 +18,70 @@ production use with enterprise identity providers.
|
|
|
18
18
|
|
|
19
19
|
from __future__ import annotations
|
|
20
20
|
|
|
21
|
+
import base64
|
|
21
22
|
import hashlib
|
|
23
|
+
import hmac
|
|
24
|
+
import json
|
|
22
25
|
import secrets
|
|
23
26
|
import time
|
|
24
27
|
from base64 import urlsafe_b64encode
|
|
25
28
|
from typing import TYPE_CHECKING, Any, Final
|
|
26
|
-
from urllib.parse import urlencode
|
|
29
|
+
from urllib.parse import urlencode, urlparse
|
|
27
30
|
|
|
28
31
|
import httpx
|
|
29
32
|
from authlib.common.security import generate_token
|
|
30
33
|
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
34
|
+
from cryptography.fernet import Fernet
|
|
35
|
+
from key_value.aio.adapters.pydantic import PydanticAdapter
|
|
36
|
+
from key_value.aio.protocols import AsyncKeyValue
|
|
37
|
+
from key_value.aio.stores.disk import DiskStore
|
|
38
|
+
from key_value.aio.wrappers.encryption import FernetEncryptionWrapper
|
|
39
|
+
from mcp.server.auth.handlers.token import TokenErrorResponse, TokenSuccessResponse
|
|
40
|
+
from mcp.server.auth.handlers.token import TokenHandler as _SDKTokenHandler
|
|
41
|
+
from mcp.server.auth.json_response import PydanticJSONResponse
|
|
42
|
+
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
|
|
31
43
|
from mcp.server.auth.provider import (
|
|
32
44
|
AccessToken,
|
|
33
45
|
AuthorizationCode,
|
|
34
46
|
AuthorizationParams,
|
|
47
|
+
AuthorizeError,
|
|
35
48
|
RefreshToken,
|
|
36
49
|
TokenError,
|
|
37
50
|
)
|
|
51
|
+
from mcp.server.auth.routes import cors_middleware
|
|
38
52
|
from mcp.server.auth.settings import (
|
|
39
53
|
ClientRegistrationOptions,
|
|
40
54
|
RevocationOptions,
|
|
41
55
|
)
|
|
42
56
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
|
43
|
-
from pydantic import AnyHttpUrl, AnyUrl, SecretStr
|
|
57
|
+
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, SecretStr
|
|
44
58
|
from starlette.requests import Request
|
|
45
|
-
from starlette.responses import RedirectResponse
|
|
59
|
+
from starlette.responses import HTMLResponse, RedirectResponse
|
|
46
60
|
from starlette.routing import Route
|
|
61
|
+
from typing_extensions import override
|
|
47
62
|
|
|
63
|
+
from fastmcp import settings
|
|
48
64
|
from fastmcp.server.auth.auth import OAuthProvider, TokenVerifier
|
|
49
|
-
from fastmcp.server.auth.
|
|
65
|
+
from fastmcp.server.auth.handlers.authorize import AuthorizationHandler
|
|
66
|
+
from fastmcp.server.auth.jwt_issuer import (
|
|
67
|
+
JWTIssuer,
|
|
68
|
+
derive_jwt_key,
|
|
69
|
+
)
|
|
70
|
+
from fastmcp.server.auth.redirect_validation import (
|
|
71
|
+
validate_redirect_uri,
|
|
72
|
+
)
|
|
50
73
|
from fastmcp.utilities.logging import get_logger
|
|
74
|
+
from fastmcp.utilities.ui import (
|
|
75
|
+
BUTTON_STYLES,
|
|
76
|
+
DETAIL_BOX_STYLES,
|
|
77
|
+
DETAILS_STYLES,
|
|
78
|
+
INFO_BOX_STYLES,
|
|
79
|
+
REDIRECT_SECTION_STYLES,
|
|
80
|
+
TOOLTIP_STYLES,
|
|
81
|
+
create_logo,
|
|
82
|
+
create_page,
|
|
83
|
+
create_secure_html_response,
|
|
84
|
+
)
|
|
51
85
|
|
|
52
86
|
if TYPE_CHECKING:
|
|
53
87
|
pass
|
|
@@ -55,6 +89,118 @@ if TYPE_CHECKING:
|
|
|
55
89
|
logger = get_logger(__name__)
|
|
56
90
|
|
|
57
91
|
|
|
92
|
+
# -------------------------------------------------------------------------
|
|
93
|
+
# Constants
|
|
94
|
+
# -------------------------------------------------------------------------
|
|
95
|
+
|
|
96
|
+
# Default token expiration times
|
|
97
|
+
DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS: Final[int] = 60 * 60 # 1 hour
|
|
98
|
+
DEFAULT_AUTH_CODE_EXPIRY_SECONDS: Final[int] = 5 * 60 # 5 minutes
|
|
99
|
+
|
|
100
|
+
# HTTP client timeout
|
|
101
|
+
HTTP_TIMEOUT_SECONDS: Final[int] = 30
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# -------------------------------------------------------------------------
|
|
105
|
+
# Pydantic Models
|
|
106
|
+
# -------------------------------------------------------------------------
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class OAuthTransaction(BaseModel):
|
|
110
|
+
"""OAuth transaction state for consent flow.
|
|
111
|
+
|
|
112
|
+
Stored server-side to track active authorization flows with client context.
|
|
113
|
+
Includes CSRF tokens for consent protection per MCP security best practices.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
txn_id: str
|
|
117
|
+
client_id: str
|
|
118
|
+
client_redirect_uri: str
|
|
119
|
+
client_state: str
|
|
120
|
+
code_challenge: str | None
|
|
121
|
+
code_challenge_method: str
|
|
122
|
+
scopes: list[str]
|
|
123
|
+
created_at: float
|
|
124
|
+
resource: str | None = None
|
|
125
|
+
proxy_code_verifier: str | None = None
|
|
126
|
+
csrf_token: str | None = None
|
|
127
|
+
csrf_expires_at: float | None = None
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class ClientCode(BaseModel):
|
|
131
|
+
"""Client authorization code with PKCE and upstream tokens.
|
|
132
|
+
|
|
133
|
+
Stored server-side after upstream IdP callback. Contains the upstream
|
|
134
|
+
tokens bound to the client's PKCE challenge for secure token exchange.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
code: str
|
|
138
|
+
client_id: str
|
|
139
|
+
redirect_uri: str
|
|
140
|
+
code_challenge: str | None
|
|
141
|
+
code_challenge_method: str
|
|
142
|
+
scopes: list[str]
|
|
143
|
+
idp_tokens: dict[str, Any]
|
|
144
|
+
expires_at: float
|
|
145
|
+
created_at: float
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class UpstreamTokenSet(BaseModel):
|
|
149
|
+
"""Stored upstream OAuth tokens from identity provider.
|
|
150
|
+
|
|
151
|
+
These tokens are obtained from the upstream provider (Google, GitHub, etc.)
|
|
152
|
+
and stored in plaintext within this model. Encryption is handled transparently
|
|
153
|
+
at the storage layer via FernetEncryptionWrapper. Tokens are never exposed to MCP clients.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
upstream_token_id: str # Unique ID for this token set
|
|
157
|
+
access_token: str # Upstream access token
|
|
158
|
+
refresh_token: str | None # Upstream refresh token
|
|
159
|
+
refresh_token_expires_at: (
|
|
160
|
+
float | None
|
|
161
|
+
) # Unix timestamp when refresh token expires (if known)
|
|
162
|
+
expires_at: float # Unix timestamp when access token expires
|
|
163
|
+
token_type: str # Usually "Bearer"
|
|
164
|
+
scope: str # Space-separated scopes
|
|
165
|
+
client_id: str # MCP client this is bound to
|
|
166
|
+
created_at: float # Unix timestamp
|
|
167
|
+
raw_token_data: dict[str, Any] = Field(default_factory=dict) # Full token response
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class JTIMapping(BaseModel):
|
|
171
|
+
"""Maps FastMCP token JTI to upstream token ID.
|
|
172
|
+
|
|
173
|
+
This allows stateless JWT validation while still being able to look up
|
|
174
|
+
the corresponding upstream token when tools need to access upstream APIs.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
jti: str # JWT ID from FastMCP-issued token
|
|
178
|
+
upstream_token_id: str # References UpstreamTokenSet
|
|
179
|
+
created_at: float # Unix timestamp
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class RefreshTokenMetadata(BaseModel):
|
|
183
|
+
"""Metadata for a refresh token, stored keyed by token hash.
|
|
184
|
+
|
|
185
|
+
We store only metadata (not the token itself) for security - if storage
|
|
186
|
+
is compromised, attackers get hashes they can't reverse into usable tokens.
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
client_id: str
|
|
190
|
+
scopes: list[str]
|
|
191
|
+
expires_at: int | None = None
|
|
192
|
+
created_at: float
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _hash_token(token: str) -> str:
|
|
196
|
+
"""Hash a token for secure storage lookup.
|
|
197
|
+
|
|
198
|
+
Uses SHA-256 to create a one-way hash. The original token cannot be
|
|
199
|
+
recovered from the hash, providing defense in depth if storage is compromised.
|
|
200
|
+
"""
|
|
201
|
+
return hashlib.sha256(token.encode()).hexdigest()
|
|
202
|
+
|
|
203
|
+
|
|
58
204
|
class ProxyDCRClient(OAuthClientInformationFull):
|
|
59
205
|
"""Client for DCR proxy with configurable redirect URI validation.
|
|
60
206
|
|
|
@@ -81,18 +227,8 @@ class ProxyDCRClient(OAuthClientInformationFull):
|
|
|
81
227
|
arise from accepting arbitrary redirect URIs.
|
|
82
228
|
"""
|
|
83
229
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
):
|
|
87
|
-
"""Initialize with allowed redirect URI patterns.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
allowed_redirect_uri_patterns: List of allowed redirect URI patterns with wildcard support.
|
|
91
|
-
If None, defaults to localhost-only patterns.
|
|
92
|
-
If empty list, allows all redirect URIs.
|
|
93
|
-
"""
|
|
94
|
-
super().__init__(*args, **kwargs)
|
|
95
|
-
self._allowed_redirect_uri_patterns = allowed_redirect_uri_patterns
|
|
230
|
+
allowed_redirect_uri_patterns: list[str] | None = Field(default=None)
|
|
231
|
+
client_name: str | None = Field(default=None)
|
|
96
232
|
|
|
97
233
|
def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
|
|
98
234
|
"""Validate redirect URI against allowed patterns.
|
|
@@ -104,7 +240,10 @@ class ProxyDCRClient(OAuthClientInformationFull):
|
|
|
104
240
|
"""
|
|
105
241
|
if redirect_uri is not None:
|
|
106
242
|
# Validate against allowed patterns
|
|
107
|
-
if validate_redirect_uri(
|
|
243
|
+
if validate_redirect_uri(
|
|
244
|
+
redirect_uri=redirect_uri,
|
|
245
|
+
allowed_patterns=self.allowed_redirect_uri_patterns,
|
|
246
|
+
):
|
|
108
247
|
return redirect_uri
|
|
109
248
|
# Fall back to normal validation if not in allowed patterns
|
|
110
249
|
return super().validate_redirect_uri(redirect_uri)
|
|
@@ -112,12 +251,321 @@ class ProxyDCRClient(OAuthClientInformationFull):
|
|
|
112
251
|
return super().validate_redirect_uri(redirect_uri)
|
|
113
252
|
|
|
114
253
|
|
|
115
|
-
#
|
|
116
|
-
|
|
117
|
-
|
|
254
|
+
# -------------------------------------------------------------------------
|
|
255
|
+
# Helper Functions
|
|
256
|
+
# -------------------------------------------------------------------------
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def create_consent_html(
|
|
260
|
+
client_id: str,
|
|
261
|
+
redirect_uri: str,
|
|
262
|
+
scopes: list[str],
|
|
263
|
+
txn_id: str,
|
|
264
|
+
csrf_token: str,
|
|
265
|
+
client_name: str | None = None,
|
|
266
|
+
title: str = "Application Access Request",
|
|
267
|
+
server_name: str | None = None,
|
|
268
|
+
server_icon_url: str | None = None,
|
|
269
|
+
server_website_url: str | None = None,
|
|
270
|
+
client_website_url: str | None = None,
|
|
271
|
+
csp_policy: str | None = None,
|
|
272
|
+
) -> str:
|
|
273
|
+
"""Create a styled HTML consent page for OAuth authorization requests.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
csp_policy: Content Security Policy override.
|
|
277
|
+
If None, uses the built-in CSP policy with appropriate directives.
|
|
278
|
+
If empty string "", disables CSP entirely (no meta tag is rendered).
|
|
279
|
+
If a non-empty string, uses that as the CSP policy value.
|
|
280
|
+
"""
|
|
281
|
+
import html as html_module
|
|
282
|
+
|
|
283
|
+
client_display = html_module.escape(client_name or client_id)
|
|
284
|
+
server_name_escaped = html_module.escape(server_name or "FastMCP")
|
|
285
|
+
|
|
286
|
+
# Make server name a hyperlink if website URL is available
|
|
287
|
+
if server_website_url:
|
|
288
|
+
website_url_escaped = html_module.escape(server_website_url)
|
|
289
|
+
server_display = f'<a href="{website_url_escaped}" target="_blank" rel="noopener noreferrer" class="server-name-link">{server_name_escaped}</a>'
|
|
290
|
+
else:
|
|
291
|
+
server_display = server_name_escaped
|
|
292
|
+
|
|
293
|
+
# Build intro box with call-to-action
|
|
294
|
+
intro_box = f"""
|
|
295
|
+
<div class="info-box">
|
|
296
|
+
<p>The application <strong>{client_display}</strong> wants to access the MCP server <strong>{server_display}</strong>. Please ensure you recognize the callback address below.</p>
|
|
297
|
+
</div>
|
|
298
|
+
"""
|
|
118
299
|
|
|
119
|
-
#
|
|
120
|
-
|
|
300
|
+
# Build redirect URI section (yellow box, centered)
|
|
301
|
+
redirect_uri_escaped = html_module.escape(redirect_uri)
|
|
302
|
+
redirect_section = f"""
|
|
303
|
+
<div class="redirect-section">
|
|
304
|
+
<span class="label">Credentials will be sent to:</span>
|
|
305
|
+
<div class="value">{redirect_uri_escaped}</div>
|
|
306
|
+
</div>
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
# Build advanced details with collapsible section
|
|
310
|
+
detail_rows = [
|
|
311
|
+
("Application Name", html_module.escape(client_name or client_id)),
|
|
312
|
+
("Application Website", html_module.escape(client_website_url or "N/A")),
|
|
313
|
+
("Application ID", client_id),
|
|
314
|
+
("Redirect URI", redirect_uri_escaped),
|
|
315
|
+
(
|
|
316
|
+
"Requested Scopes",
|
|
317
|
+
", ".join(html_module.escape(s) for s in scopes) if scopes else "None",
|
|
318
|
+
),
|
|
319
|
+
]
|
|
320
|
+
|
|
321
|
+
detail_rows_html = "\n".join(
|
|
322
|
+
[
|
|
323
|
+
f"""
|
|
324
|
+
<div class="detail-row">
|
|
325
|
+
<div class="detail-label">{label}:</div>
|
|
326
|
+
<div class="detail-value">{value}</div>
|
|
327
|
+
</div>
|
|
328
|
+
"""
|
|
329
|
+
for label, value in detail_rows
|
|
330
|
+
]
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
advanced_details = f"""
|
|
334
|
+
<details>
|
|
335
|
+
<summary>Advanced Details</summary>
|
|
336
|
+
<div class="detail-box">
|
|
337
|
+
{detail_rows_html}
|
|
338
|
+
</div>
|
|
339
|
+
</details>
|
|
340
|
+
"""
|
|
341
|
+
|
|
342
|
+
# Build form with buttons
|
|
343
|
+
# Use empty action to submit to current URL (/consent or /mcp/consent)
|
|
344
|
+
# The POST handler is registered at the same path as GET
|
|
345
|
+
form = f"""
|
|
346
|
+
<form id="consentForm" method="POST" action="">
|
|
347
|
+
<input type="hidden" name="txn_id" value="{txn_id}" />
|
|
348
|
+
<input type="hidden" name="csrf_token" value="{csrf_token}" />
|
|
349
|
+
<input type="hidden" name="submit" value="true" />
|
|
350
|
+
<div class="button-group">
|
|
351
|
+
<button type="submit" name="action" value="approve" class="btn-approve">Allow Access</button>
|
|
352
|
+
<button type="submit" name="action" value="deny" class="btn-deny">Deny</button>
|
|
353
|
+
</div>
|
|
354
|
+
</form>
|
|
355
|
+
"""
|
|
356
|
+
|
|
357
|
+
# Build help link with tooltip (identical to current implementation)
|
|
358
|
+
help_link = """
|
|
359
|
+
<div class="help-link-container">
|
|
360
|
+
<span class="help-link">
|
|
361
|
+
Why am I seeing this?
|
|
362
|
+
<span class="tooltip">
|
|
363
|
+
This FastMCP server requires your consent to allow a new client
|
|
364
|
+
to connect. This protects you from <a
|
|
365
|
+
href="https://modelcontextprotocol.io/specification/2025-06-18/basic/security_best_practices#confused-deputy-problem"
|
|
366
|
+
target="_blank" class="tooltip-link">confused deputy
|
|
367
|
+
attacks</a>, where malicious clients could impersonate you
|
|
368
|
+
and steal access.<br><br>
|
|
369
|
+
<a
|
|
370
|
+
href="https://gofastmcp.com/servers/auth/oauth-proxy#confused-deputy-attacks"
|
|
371
|
+
target="_blank" class="tooltip-link">Learn more about
|
|
372
|
+
FastMCP security →</a>
|
|
373
|
+
</span>
|
|
374
|
+
</span>
|
|
375
|
+
</div>
|
|
376
|
+
"""
|
|
377
|
+
|
|
378
|
+
# Build the page content
|
|
379
|
+
content = f"""
|
|
380
|
+
<div class="container">
|
|
381
|
+
{create_logo(icon_url=server_icon_url, alt_text=server_name or "FastMCP")}
|
|
382
|
+
<h1>Application Access Request</h1>
|
|
383
|
+
{intro_box}
|
|
384
|
+
{redirect_section}
|
|
385
|
+
{advanced_details}
|
|
386
|
+
{form}
|
|
387
|
+
</div>
|
|
388
|
+
{help_link}
|
|
389
|
+
"""
|
|
390
|
+
|
|
391
|
+
# Additional styles needed for this page
|
|
392
|
+
additional_styles = (
|
|
393
|
+
INFO_BOX_STYLES
|
|
394
|
+
+ REDIRECT_SECTION_STYLES
|
|
395
|
+
+ DETAILS_STYLES
|
|
396
|
+
+ DETAIL_BOX_STYLES
|
|
397
|
+
+ BUTTON_STYLES
|
|
398
|
+
+ TOOLTIP_STYLES
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# Determine CSP policy to use
|
|
402
|
+
# If csp_policy is None, build the default CSP policy
|
|
403
|
+
# If csp_policy is empty string, CSP will be disabled entirely in create_page
|
|
404
|
+
# If csp_policy is a non-empty string, use it as-is
|
|
405
|
+
if csp_policy is None:
|
|
406
|
+
# Need to allow form-action for form submission
|
|
407
|
+
# Chrome requires explicit scheme declarations in CSP form-action when redirect chains
|
|
408
|
+
# end in custom protocol schemes (e.g., cursor://). Parse redirect_uri to include its scheme.
|
|
409
|
+
parsed_redirect = urlparse(redirect_uri)
|
|
410
|
+
redirect_scheme = parsed_redirect.scheme.lower()
|
|
411
|
+
|
|
412
|
+
# Build form-action directive with standard schemes plus custom protocol if present
|
|
413
|
+
form_action_schemes = ["https:", "http:"]
|
|
414
|
+
if redirect_scheme and redirect_scheme not in ("http", "https"):
|
|
415
|
+
# Custom protocol scheme (e.g., cursor:, vscode:, etc.)
|
|
416
|
+
form_action_schemes.append(f"{redirect_scheme}:")
|
|
417
|
+
|
|
418
|
+
form_action_directive = " ".join(form_action_schemes)
|
|
419
|
+
csp_policy = f"default-src 'none'; style-src 'unsafe-inline'; img-src https: data:; base-uri 'none'; form-action {form_action_directive}"
|
|
420
|
+
|
|
421
|
+
return create_page(
|
|
422
|
+
content=content,
|
|
423
|
+
title=title,
|
|
424
|
+
additional_styles=additional_styles,
|
|
425
|
+
csp_policy=csp_policy,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def create_error_html(
|
|
430
|
+
error_title: str,
|
|
431
|
+
error_message: str,
|
|
432
|
+
error_details: dict[str, str] | None = None,
|
|
433
|
+
server_name: str | None = None,
|
|
434
|
+
server_icon_url: str | None = None,
|
|
435
|
+
) -> str:
|
|
436
|
+
"""Create a styled HTML error page for OAuth errors.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
error_title: The error title (e.g., "OAuth Error", "Authorization Failed")
|
|
440
|
+
error_message: The main error message to display
|
|
441
|
+
error_details: Optional dictionary of error details to show (e.g., {"Error Code": "invalid_client"})
|
|
442
|
+
server_name: Optional server name to display
|
|
443
|
+
server_icon_url: Optional URL to server icon/logo
|
|
444
|
+
|
|
445
|
+
Returns:
|
|
446
|
+
Complete HTML page as a string
|
|
447
|
+
"""
|
|
448
|
+
import html as html_module
|
|
449
|
+
|
|
450
|
+
error_message_escaped = html_module.escape(error_message)
|
|
451
|
+
|
|
452
|
+
# Build error message box
|
|
453
|
+
error_box = f"""
|
|
454
|
+
<div class="info-box error">
|
|
455
|
+
<p>{error_message_escaped}</p>
|
|
456
|
+
</div>
|
|
457
|
+
"""
|
|
458
|
+
|
|
459
|
+
# Build error details section if provided
|
|
460
|
+
details_section = ""
|
|
461
|
+
if error_details:
|
|
462
|
+
detail_rows_html = "\n".join(
|
|
463
|
+
[
|
|
464
|
+
f"""
|
|
465
|
+
<div class="detail-row">
|
|
466
|
+
<div class="detail-label">{html_module.escape(label)}:</div>
|
|
467
|
+
<div class="detail-value">{html_module.escape(value)}</div>
|
|
468
|
+
</div>
|
|
469
|
+
"""
|
|
470
|
+
for label, value in error_details.items()
|
|
471
|
+
]
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
details_section = f"""
|
|
475
|
+
<details>
|
|
476
|
+
<summary>Error Details</summary>
|
|
477
|
+
<div class="detail-box">
|
|
478
|
+
{detail_rows_html}
|
|
479
|
+
</div>
|
|
480
|
+
</details>
|
|
481
|
+
"""
|
|
482
|
+
|
|
483
|
+
# Build the page content
|
|
484
|
+
content = f"""
|
|
485
|
+
<div class="container">
|
|
486
|
+
{create_logo(icon_url=server_icon_url, alt_text=server_name or "FastMCP")}
|
|
487
|
+
<h1>{html_module.escape(error_title)}</h1>
|
|
488
|
+
{error_box}
|
|
489
|
+
{details_section}
|
|
490
|
+
</div>
|
|
491
|
+
"""
|
|
492
|
+
|
|
493
|
+
# Additional styles needed for this page
|
|
494
|
+
# Override .info-box.error to use normal text color instead of red
|
|
495
|
+
additional_styles = (
|
|
496
|
+
INFO_BOX_STYLES
|
|
497
|
+
+ DETAILS_STYLES
|
|
498
|
+
+ DETAIL_BOX_STYLES
|
|
499
|
+
+ """
|
|
500
|
+
.info-box.error {
|
|
501
|
+
color: #111827;
|
|
502
|
+
}
|
|
503
|
+
"""
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
# Simple CSP policy for error pages (no forms needed)
|
|
507
|
+
csp_policy = "default-src 'none'; style-src 'unsafe-inline'; img-src https: data:; base-uri 'none'"
|
|
508
|
+
|
|
509
|
+
return create_page(
|
|
510
|
+
content=content,
|
|
511
|
+
title=error_title,
|
|
512
|
+
additional_styles=additional_styles,
|
|
513
|
+
csp_policy=csp_policy,
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
# -------------------------------------------------------------------------
|
|
518
|
+
# Handler Classes
|
|
519
|
+
# -------------------------------------------------------------------------
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
class TokenHandler(_SDKTokenHandler):
|
|
523
|
+
"""TokenHandler that returns OAuth 2.1 compliant error responses.
|
|
524
|
+
|
|
525
|
+
The MCP SDK always returns HTTP 400 for all client authentication issues.
|
|
526
|
+
However, OAuth 2.1 Section 5.3 and the MCP specification require that
|
|
527
|
+
invalid or expired tokens MUST receive a HTTP 401 response.
|
|
528
|
+
|
|
529
|
+
This handler extends the base MCP SDK TokenHandler to transform client
|
|
530
|
+
authentication failures into OAuth 2.1 compliant responses:
|
|
531
|
+
- Changes 'unauthorized_client' to 'invalid_client' error code
|
|
532
|
+
- Returns HTTP 401 status code instead of 400 for client auth failures
|
|
533
|
+
|
|
534
|
+
Per OAuth 2.1 Section 5.3: "The authorization server MAY return an HTTP 401
|
|
535
|
+
(Unauthorized) status code to indicate which HTTP authentication schemes
|
|
536
|
+
are supported."
|
|
537
|
+
|
|
538
|
+
Per MCP spec: "Invalid or expired tokens MUST receive a HTTP 401 response."
|
|
539
|
+
"""
|
|
540
|
+
|
|
541
|
+
def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
|
|
542
|
+
"""Override response method to provide OAuth 2.1 compliant error handling."""
|
|
543
|
+
# Check if this is a client authentication failure (not just unauthorized for grant type)
|
|
544
|
+
# unauthorized_client can mean two things:
|
|
545
|
+
# 1. Client authentication failed (client_id not found or wrong credentials) -> invalid_client 401
|
|
546
|
+
# 2. Client not authorized for this grant type -> unauthorized_client 400 (correct per spec)
|
|
547
|
+
if (
|
|
548
|
+
isinstance(obj, TokenErrorResponse)
|
|
549
|
+
and obj.error == "unauthorized_client"
|
|
550
|
+
and obj.error_description
|
|
551
|
+
and "Invalid client_id" in obj.error_description
|
|
552
|
+
):
|
|
553
|
+
# Transform client auth failure to OAuth 2.1 compliant response
|
|
554
|
+
return PydanticJSONResponse(
|
|
555
|
+
content=TokenErrorResponse(
|
|
556
|
+
error="invalid_client",
|
|
557
|
+
error_description=obj.error_description,
|
|
558
|
+
error_uri=obj.error_uri,
|
|
559
|
+
),
|
|
560
|
+
status_code=401,
|
|
561
|
+
headers={
|
|
562
|
+
"Cache-Control": "no-store",
|
|
563
|
+
"Pragma": "no-cache",
|
|
564
|
+
},
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# Otherwise use default behavior from parent class
|
|
568
|
+
return super().response(obj)
|
|
121
569
|
|
|
122
570
|
|
|
123
571
|
class OAuthProxy(OAuthProvider):
|
|
@@ -198,15 +646,18 @@ class OAuthProxy(OAuthProvider):
|
|
|
198
646
|
|
|
199
647
|
State Management
|
|
200
648
|
---------------
|
|
201
|
-
The proxy maintains minimal but crucial state:
|
|
202
|
-
- _clients: DCR registrations (all use ProxyDCRClient for flexibility)
|
|
649
|
+
The proxy maintains minimal but crucial state via pluggable storage (client_storage):
|
|
203
650
|
- _oauth_transactions: Active authorization flows with client context
|
|
204
651
|
- _client_codes: Authorization codes with PKCE challenges and upstream tokens
|
|
205
|
-
-
|
|
206
|
-
-
|
|
652
|
+
- _jti_mapping_store: Maps FastMCP token JTIs to upstream token IDs
|
|
653
|
+
- _refresh_token_store: Refresh token metadata (keyed by token hash)
|
|
654
|
+
|
|
655
|
+
All state is stored in the configured client_storage backend (Redis, disk, etc.)
|
|
656
|
+
enabling horizontal scaling across multiple instances.
|
|
207
657
|
|
|
208
658
|
Security Considerations
|
|
209
659
|
----------------------
|
|
660
|
+
- Refresh tokens stored by hash only (defense in depth if storage compromised)
|
|
210
661
|
- PKCE enforced end-to-end (client to proxy, proxy to upstream)
|
|
211
662
|
- Authorization codes are single-use with short expiry
|
|
212
663
|
- Transaction IDs are cryptographically random
|
|
@@ -240,7 +691,7 @@ class OAuthProxy(OAuthProvider):
|
|
|
240
691
|
token_verifier: TokenVerifier,
|
|
241
692
|
# FastMCP server configuration
|
|
242
693
|
base_url: AnyHttpUrl | str,
|
|
243
|
-
redirect_path: str =
|
|
694
|
+
redirect_path: str | None = None,
|
|
244
695
|
issuer_url: AnyHttpUrl | str | None = None,
|
|
245
696
|
service_documentation_url: AnyHttpUrl | str | None = None,
|
|
246
697
|
# Client redirect URI validation
|
|
@@ -250,6 +701,17 @@ class OAuthProxy(OAuthProvider):
|
|
|
250
701
|
forward_pkce: bool = True,
|
|
251
702
|
# Token endpoint authentication
|
|
252
703
|
token_endpoint_auth_method: str | None = None,
|
|
704
|
+
# Extra parameters to forward to authorization endpoint
|
|
705
|
+
extra_authorize_params: dict[str, str] | None = None,
|
|
706
|
+
# Extra parameters to forward to token endpoint
|
|
707
|
+
extra_token_params: dict[str, str] | None = None,
|
|
708
|
+
# Client storage
|
|
709
|
+
client_storage: AsyncKeyValue | None = None,
|
|
710
|
+
# JWT signing key
|
|
711
|
+
jwt_signing_key: str | bytes | None = None,
|
|
712
|
+
# Consent screen configuration
|
|
713
|
+
require_authorization_consent: bool = True,
|
|
714
|
+
consent_csp_policy: str | None = None,
|
|
253
715
|
):
|
|
254
716
|
"""Initialize the OAuth proxy provider.
|
|
255
717
|
|
|
@@ -273,12 +735,34 @@ class OAuthProxy(OAuthProvider):
|
|
|
273
735
|
valid_scopes: List of all the possible valid scopes for a client.
|
|
274
736
|
These are advertised to clients through the `/.well-known` endpoints. Defaults to `required_scopes` if not provided.
|
|
275
737
|
forward_pkce: Whether to forward PKCE to upstream server (default True).
|
|
276
|
-
Enable for providers that support/require PKCE (Google, Azure, etc.).
|
|
738
|
+
Enable for providers that support/require PKCE (Google, Azure, AWS, etc.).
|
|
277
739
|
Disable only if upstream provider doesn't support PKCE.
|
|
278
740
|
token_endpoint_auth_method: Token endpoint authentication method for upstream server.
|
|
279
741
|
Common values: "client_secret_basic", "client_secret_post", "none".
|
|
280
742
|
If None, authlib will use its default (typically "client_secret_basic").
|
|
743
|
+
extra_authorize_params: Additional parameters to forward to the upstream authorization endpoint.
|
|
744
|
+
Useful for provider-specific parameters like Auth0's "audience".
|
|
745
|
+
Example: {"audience": "https://api.example.com"}
|
|
746
|
+
extra_token_params: Additional parameters to forward to the upstream token endpoint.
|
|
747
|
+
Useful for provider-specific parameters during token exchange.
|
|
748
|
+
client_storage: Storage backend for OAuth state (client registrations, tokens).
|
|
749
|
+
If None, an encrypted DiskStore will be created in the data directory.
|
|
750
|
+
jwt_signing_key: Secret for signing FastMCP JWT tokens (any string or bytes).
|
|
751
|
+
If bytes are provided, they will be used as-is.
|
|
752
|
+
If a string is provided, it will be derived into a 32-byte key using PBKDF2 (1.2M iterations).
|
|
753
|
+
If not provided, it will be derived from the upstream client secret using HKDF.
|
|
754
|
+
require_authorization_consent: Whether to require user consent before authorizing clients (default True).
|
|
755
|
+
When True, users see a consent screen before being redirected to the upstream IdP.
|
|
756
|
+
When False, authorization proceeds directly without user confirmation.
|
|
757
|
+
SECURITY WARNING: Only disable for local development or testing environments.
|
|
758
|
+
consent_csp_policy: Content Security Policy for the consent page.
|
|
759
|
+
If None (default), uses the built-in CSP policy with appropriate directives.
|
|
760
|
+
If empty string "", disables CSP entirely (no meta tag is rendered).
|
|
761
|
+
If a non-empty string, uses that as the CSP policy value.
|
|
762
|
+
This allows organizations with their own CSP policies to override or disable
|
|
763
|
+
the built-in CSP directives.
|
|
281
764
|
"""
|
|
765
|
+
|
|
282
766
|
# Always enable DCR since we implement it locally for MCP clients
|
|
283
767
|
client_registration_options = ClientRegistrationOptions(
|
|
284
768
|
enabled=True,
|
|
@@ -300,42 +784,157 @@ class OAuthProxy(OAuthProvider):
|
|
|
300
784
|
)
|
|
301
785
|
|
|
302
786
|
# Store upstream configuration
|
|
303
|
-
self._upstream_authorization_endpoint = upstream_authorization_endpoint
|
|
304
|
-
self._upstream_token_endpoint = upstream_token_endpoint
|
|
305
|
-
self._upstream_client_id = upstream_client_id
|
|
306
|
-
self._upstream_client_secret = SecretStr(
|
|
307
|
-
|
|
308
|
-
|
|
787
|
+
self._upstream_authorization_endpoint: str = upstream_authorization_endpoint
|
|
788
|
+
self._upstream_token_endpoint: str = upstream_token_endpoint
|
|
789
|
+
self._upstream_client_id: str = upstream_client_id
|
|
790
|
+
self._upstream_client_secret: SecretStr = SecretStr(
|
|
791
|
+
secret_value=upstream_client_secret
|
|
792
|
+
)
|
|
793
|
+
self._upstream_revocation_endpoint: str | None = upstream_revocation_endpoint
|
|
794
|
+
self._default_scope_str: str = " ".join(self.required_scopes or [])
|
|
309
795
|
|
|
310
796
|
# Store redirect configuration
|
|
311
|
-
|
|
312
|
-
|
|
797
|
+
if not redirect_path:
|
|
798
|
+
self._redirect_path = "/auth/callback"
|
|
799
|
+
else:
|
|
800
|
+
self._redirect_path = (
|
|
801
|
+
redirect_path if redirect_path.startswith("/") else f"/{redirect_path}"
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
if (
|
|
805
|
+
isinstance(allowed_client_redirect_uris, list)
|
|
806
|
+
and not allowed_client_redirect_uris
|
|
807
|
+
):
|
|
808
|
+
logger.warning(
|
|
809
|
+
"allowed_client_redirect_uris is empty list; no redirect URIs will be accepted. "
|
|
810
|
+
+ "This will block all OAuth clients."
|
|
811
|
+
)
|
|
812
|
+
self._allowed_client_redirect_uris: list[str] | None = (
|
|
813
|
+
allowed_client_redirect_uris
|
|
313
814
|
)
|
|
314
|
-
self._allowed_client_redirect_uris = allowed_client_redirect_uris
|
|
315
815
|
|
|
316
816
|
# PKCE configuration
|
|
317
|
-
self._forward_pkce = forward_pkce
|
|
817
|
+
self._forward_pkce: bool = forward_pkce
|
|
318
818
|
|
|
319
819
|
# Token endpoint authentication
|
|
320
|
-
self._token_endpoint_auth_method = token_endpoint_auth_method
|
|
820
|
+
self._token_endpoint_auth_method: str | None = token_endpoint_auth_method
|
|
821
|
+
|
|
822
|
+
# Consent screen configuration
|
|
823
|
+
self._require_authorization_consent: bool = require_authorization_consent
|
|
824
|
+
self._consent_csp_policy: str | None = consent_csp_policy
|
|
825
|
+
if not require_authorization_consent:
|
|
826
|
+
logger.warning(
|
|
827
|
+
"Authorization consent screen disabled - only use for local development or testing. "
|
|
828
|
+
+ "In production, this screen protects against confused deputy attacks."
|
|
829
|
+
)
|
|
321
830
|
|
|
322
|
-
#
|
|
323
|
-
self.
|
|
324
|
-
self.
|
|
325
|
-
self._refresh_tokens: dict[str, RefreshToken] = {}
|
|
831
|
+
# Extra parameters for authorization and token endpoints
|
|
832
|
+
self._extra_authorize_params: dict[str, str] = extra_authorize_params or {}
|
|
833
|
+
self._extra_token_params: dict[str, str] = extra_token_params or {}
|
|
326
834
|
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
835
|
+
if jwt_signing_key is None:
|
|
836
|
+
jwt_signing_key = derive_jwt_key(
|
|
837
|
+
high_entropy_material=upstream_client_secret,
|
|
838
|
+
salt="fastmcp-jwt-signing-key",
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
if isinstance(jwt_signing_key, str):
|
|
842
|
+
if len(jwt_signing_key) < 12:
|
|
843
|
+
logger.warning(
|
|
844
|
+
"jwt_signing_key is less than 12 characters; it is recommended to use a longer. "
|
|
845
|
+
+ "string for the key derivation."
|
|
846
|
+
)
|
|
847
|
+
jwt_signing_key = derive_jwt_key(
|
|
848
|
+
low_entropy_material=jwt_signing_key,
|
|
849
|
+
salt="fastmcp-jwt-signing-key",
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
self._jwt_issuer: JWTIssuer = JWTIssuer(
|
|
853
|
+
issuer=str(self.base_url),
|
|
854
|
+
audience=f"{str(self.base_url).rstrip('/')}/mcp",
|
|
855
|
+
signing_key=jwt_signing_key,
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
# If the user does not provide a store, we will provide an encrypted disk store
|
|
859
|
+
if client_storage is None:
|
|
860
|
+
storage_encryption_key = derive_jwt_key(
|
|
861
|
+
high_entropy_material=jwt_signing_key.decode(),
|
|
862
|
+
salt="fastmcp-storage-encryption-key",
|
|
863
|
+
)
|
|
864
|
+
client_storage = FernetEncryptionWrapper(
|
|
865
|
+
key_value=DiskStore(directory=settings.home / "oauth-proxy"),
|
|
866
|
+
fernet=Fernet(key=storage_encryption_key),
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
self._client_storage: AsyncKeyValue = client_storage
|
|
870
|
+
|
|
871
|
+
# Cache HTTPS check to avoid repeated logging
|
|
872
|
+
self._is_https: bool = str(self.base_url).startswith("https://")
|
|
873
|
+
if not self._is_https:
|
|
874
|
+
logger.warning(
|
|
875
|
+
"Using non-secure cookies for development; deploy with HTTPS for production."
|
|
876
|
+
)
|
|
877
|
+
|
|
878
|
+
self._upstream_token_store: PydanticAdapter[UpstreamTokenSet] = PydanticAdapter[
|
|
879
|
+
UpstreamTokenSet
|
|
880
|
+
](
|
|
881
|
+
key_value=self._client_storage,
|
|
882
|
+
pydantic_model=UpstreamTokenSet,
|
|
883
|
+
default_collection="mcp-upstream-tokens",
|
|
884
|
+
raise_on_validation_error=True,
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
self._client_store: PydanticAdapter[ProxyDCRClient] = PydanticAdapter[
|
|
888
|
+
ProxyDCRClient
|
|
889
|
+
](
|
|
890
|
+
key_value=self._client_storage,
|
|
891
|
+
pydantic_model=ProxyDCRClient,
|
|
892
|
+
default_collection="mcp-oauth-proxy-clients",
|
|
893
|
+
raise_on_validation_error=True,
|
|
894
|
+
)
|
|
330
895
|
|
|
331
896
|
# OAuth transaction storage for IdP callback forwarding
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
897
|
+
# Reuse client_storage with different collections for state management
|
|
898
|
+
self._transaction_store: PydanticAdapter[OAuthTransaction] = PydanticAdapter[
|
|
899
|
+
OAuthTransaction
|
|
900
|
+
](
|
|
901
|
+
key_value=self._client_storage,
|
|
902
|
+
pydantic_model=OAuthTransaction,
|
|
903
|
+
default_collection="mcp-oauth-transactions",
|
|
904
|
+
raise_on_validation_error=True,
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
self._code_store: PydanticAdapter[ClientCode] = PydanticAdapter[ClientCode](
|
|
908
|
+
key_value=self._client_storage,
|
|
909
|
+
pydantic_model=ClientCode,
|
|
910
|
+
default_collection="mcp-authorization-codes",
|
|
911
|
+
raise_on_validation_error=True,
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
# Storage for JTI mappings (FastMCP token -> upstream token)
|
|
915
|
+
self._jti_mapping_store: PydanticAdapter[JTIMapping] = PydanticAdapter[
|
|
916
|
+
JTIMapping
|
|
917
|
+
](
|
|
918
|
+
key_value=self._client_storage,
|
|
919
|
+
pydantic_model=JTIMapping,
|
|
920
|
+
default_collection="mcp-jti-mappings",
|
|
921
|
+
raise_on_validation_error=True,
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
# Refresh token metadata storage, keyed by token hash for security.
|
|
925
|
+
# We only store metadata (not the token itself) - if storage is compromised,
|
|
926
|
+
# attackers get hashes they can't reverse into usable tokens.
|
|
927
|
+
self._refresh_token_store: PydanticAdapter[RefreshTokenMetadata] = (
|
|
928
|
+
PydanticAdapter[RefreshTokenMetadata](
|
|
929
|
+
key_value=self._client_storage,
|
|
930
|
+
pydantic_model=RefreshTokenMetadata,
|
|
931
|
+
default_collection="mcp-refresh-tokens",
|
|
932
|
+
raise_on_validation_error=True,
|
|
933
|
+
)
|
|
934
|
+
)
|
|
336
935
|
|
|
337
936
|
# Use the provided token validator
|
|
338
|
-
self._token_validator = token_verifier
|
|
937
|
+
self._token_validator: TokenVerifier = token_verifier
|
|
339
938
|
|
|
340
939
|
logger.debug(
|
|
341
940
|
"Initialized OAuth proxy provider with upstream server %s",
|
|
@@ -365,16 +964,23 @@ class OAuthProxy(OAuthProvider):
|
|
|
365
964
|
# Client Registration (Local Implementation)
|
|
366
965
|
# -------------------------------------------------------------------------
|
|
367
966
|
|
|
967
|
+
@override
|
|
368
968
|
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
|
369
969
|
"""Get client information by ID. This is generally the random ID
|
|
370
970
|
provided to the DCR client during registration, not the upstream client ID.
|
|
371
971
|
|
|
372
972
|
For unregistered clients, returns None (which will raise an error in the SDK).
|
|
373
973
|
"""
|
|
374
|
-
|
|
974
|
+
# Load from storage
|
|
975
|
+
if not (client := await self._client_store.get(key=client_id)):
|
|
976
|
+
return None
|
|
977
|
+
|
|
978
|
+
if client.allowed_redirect_uri_patterns is None:
|
|
979
|
+
client.allowed_redirect_uri_patterns = self._allowed_client_redirect_uris
|
|
375
980
|
|
|
376
981
|
return client
|
|
377
982
|
|
|
983
|
+
@override
|
|
378
984
|
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
|
|
379
985
|
"""Register a client locally
|
|
380
986
|
|
|
@@ -385,19 +991,24 @@ class OAuthProxy(OAuthProvider):
|
|
|
385
991
|
"""
|
|
386
992
|
|
|
387
993
|
# Create a ProxyDCRClient with configured redirect URI validation
|
|
388
|
-
|
|
994
|
+
if client_info.client_id is None:
|
|
995
|
+
raise ValueError("client_id is required for client registration")
|
|
996
|
+
proxy_client: ProxyDCRClient = ProxyDCRClient(
|
|
389
997
|
client_id=client_info.client_id,
|
|
390
998
|
client_secret=client_info.client_secret,
|
|
391
999
|
redirect_uris=client_info.redirect_uris or [AnyUrl("http://localhost")],
|
|
392
1000
|
grant_types=client_info.grant_types
|
|
393
1001
|
or ["authorization_code", "refresh_token"],
|
|
394
|
-
scope=self._default_scope_str,
|
|
1002
|
+
scope=client_info.scope or self._default_scope_str,
|
|
395
1003
|
token_endpoint_auth_method="none",
|
|
396
1004
|
allowed_redirect_uri_patterns=self._allowed_client_redirect_uris,
|
|
1005
|
+
client_name=getattr(client_info, "client_name", None),
|
|
397
1006
|
)
|
|
398
1007
|
|
|
399
|
-
|
|
400
|
-
|
|
1008
|
+
await self._client_store.put(
|
|
1009
|
+
key=client_info.client_id,
|
|
1010
|
+
value=proxy_client,
|
|
1011
|
+
)
|
|
401
1012
|
|
|
402
1013
|
# Log redirect URIs to help users discover what patterns they might need
|
|
403
1014
|
if client_info.redirect_uris:
|
|
@@ -411,25 +1022,28 @@ class OAuthProxy(OAuthProvider):
|
|
|
411
1022
|
logger.debug(
|
|
412
1023
|
"Registered client %s with %d redirect URIs",
|
|
413
1024
|
client_info.client_id,
|
|
414
|
-
len(proxy_client.redirect_uris),
|
|
1025
|
+
len(proxy_client.redirect_uris) if proxy_client.redirect_uris else 0,
|
|
415
1026
|
)
|
|
416
1027
|
|
|
417
1028
|
# -------------------------------------------------------------------------
|
|
418
1029
|
# Authorization Flow (Proxy to Upstream)
|
|
419
1030
|
# -------------------------------------------------------------------------
|
|
420
1031
|
|
|
1032
|
+
@override
|
|
421
1033
|
async def authorize(
|
|
422
1034
|
self,
|
|
423
1035
|
client: OAuthClientInformationFull,
|
|
424
1036
|
params: AuthorizationParams,
|
|
425
1037
|
) -> str:
|
|
426
|
-
"""Start OAuth transaction and
|
|
1038
|
+
"""Start OAuth transaction and route through consent interstitial.
|
|
427
1039
|
|
|
428
|
-
|
|
429
|
-
1. Store transaction with client details and PKCE
|
|
430
|
-
2.
|
|
431
|
-
3.
|
|
432
|
-
|
|
1040
|
+
Flow:
|
|
1041
|
+
1. Store transaction with client details and PKCE (if forwarding)
|
|
1042
|
+
2. Return local /consent URL; browser visits consent first
|
|
1043
|
+
3. Consent handler redirects to upstream IdP if approved/already approved
|
|
1044
|
+
|
|
1045
|
+
If consent is disabled (require_authorization_consent=False), skip the consent screen
|
|
1046
|
+
and redirect directly to the upstream IdP.
|
|
433
1047
|
"""
|
|
434
1048
|
# Generate transaction ID for this authorization request
|
|
435
1049
|
txn_id = secrets.token_urlsafe(32)
|
|
@@ -445,62 +1059,56 @@ class OAuthProxy(OAuthProvider):
|
|
|
445
1059
|
)
|
|
446
1060
|
|
|
447
1061
|
# Store transaction data for IdP callback processing
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
"redirect_uri": f"{str(self.base_url).rstrip('/')}{self._redirect_path}",
|
|
470
|
-
"state": txn_id, # Use txn_id as IdP state
|
|
471
|
-
}
|
|
472
|
-
|
|
473
|
-
# Add scopes - use client scopes or fallback to required scopes
|
|
474
|
-
scopes_to_use = params.scopes or self.required_scopes or []
|
|
475
|
-
|
|
476
|
-
if scopes_to_use:
|
|
477
|
-
query_params["scope"] = " ".join(scopes_to_use)
|
|
1062
|
+
if client.client_id is None:
|
|
1063
|
+
raise AuthorizeError(
|
|
1064
|
+
error="invalid_client", error_description="Client ID is required"
|
|
1065
|
+
)
|
|
1066
|
+
transaction = OAuthTransaction(
|
|
1067
|
+
txn_id=txn_id,
|
|
1068
|
+
client_id=client.client_id,
|
|
1069
|
+
client_redirect_uri=str(params.redirect_uri),
|
|
1070
|
+
client_state=params.state or "",
|
|
1071
|
+
code_challenge=params.code_challenge,
|
|
1072
|
+
code_challenge_method=getattr(params, "code_challenge_method", "S256"),
|
|
1073
|
+
scopes=params.scopes or [],
|
|
1074
|
+
created_at=time.time(),
|
|
1075
|
+
resource=getattr(params, "resource", None),
|
|
1076
|
+
proxy_code_verifier=proxy_code_verifier,
|
|
1077
|
+
)
|
|
1078
|
+
await self._transaction_store.put(
|
|
1079
|
+
key=txn_id,
|
|
1080
|
+
value=transaction,
|
|
1081
|
+
ttl=15 * 60, # Auto-expire after 15 minutes
|
|
1082
|
+
)
|
|
478
1083
|
|
|
479
|
-
#
|
|
480
|
-
if
|
|
481
|
-
|
|
482
|
-
|
|
1084
|
+
# If consent is disabled, skip consent screen and go directly to upstream IdP
|
|
1085
|
+
if not self._require_authorization_consent:
|
|
1086
|
+
upstream_url = self._build_upstream_authorize_url(
|
|
1087
|
+
txn_id, transaction.model_dump()
|
|
1088
|
+
)
|
|
483
1089
|
logger.debug(
|
|
484
|
-
"
|
|
1090
|
+
"Starting OAuth transaction %s for client %s, redirecting directly to upstream IdP (consent disabled, PKCE forwarding: %s)",
|
|
485
1091
|
txn_id,
|
|
1092
|
+
client.client_id,
|
|
1093
|
+
"enabled" if proxy_code_challenge else "disabled",
|
|
486
1094
|
)
|
|
1095
|
+
return upstream_url
|
|
487
1096
|
|
|
488
|
-
|
|
489
|
-
separator = "&" if "?" in self._upstream_authorization_endpoint else "?"
|
|
490
|
-
upstream_url = f"{self._upstream_authorization_endpoint}{separator}{urlencode(query_params)}"
|
|
1097
|
+
consent_url = f"{str(self.base_url).rstrip('/')}/consent?txn_id={txn_id}"
|
|
491
1098
|
|
|
492
1099
|
logger.debug(
|
|
493
|
-
"Starting OAuth transaction %s for client %s, redirecting to
|
|
1100
|
+
"Starting OAuth transaction %s for client %s, redirecting to consent page (PKCE forwarding: %s)",
|
|
494
1101
|
txn_id,
|
|
495
1102
|
client.client_id,
|
|
496
1103
|
"enabled" if proxy_code_challenge else "disabled",
|
|
497
1104
|
)
|
|
498
|
-
return
|
|
1105
|
+
return consent_url
|
|
499
1106
|
|
|
500
1107
|
# -------------------------------------------------------------------------
|
|
501
1108
|
# Authorization Code Handling
|
|
502
1109
|
# -------------------------------------------------------------------------
|
|
503
1110
|
|
|
1111
|
+
@override
|
|
504
1112
|
async def load_authorization_code(
|
|
505
1113
|
self,
|
|
506
1114
|
client: OAuthClientInformationFull,
|
|
@@ -512,111 +1120,248 @@ class OAuthProxy(OAuthProvider):
|
|
|
512
1120
|
with PKCE challenge for validation.
|
|
513
1121
|
"""
|
|
514
1122
|
# Look up client code data
|
|
515
|
-
|
|
516
|
-
if not
|
|
1123
|
+
code_model = await self._code_store.get(key=authorization_code)
|
|
1124
|
+
if not code_model:
|
|
517
1125
|
logger.debug("Authorization code not found: %s", authorization_code)
|
|
518
1126
|
return None
|
|
519
1127
|
|
|
520
1128
|
# Check if code expired
|
|
521
|
-
if time.time() >
|
|
1129
|
+
if time.time() > code_model.expires_at:
|
|
522
1130
|
logger.debug("Authorization code expired: %s", authorization_code)
|
|
523
|
-
self.
|
|
1131
|
+
_ = await self._code_store.delete(key=authorization_code)
|
|
524
1132
|
return None
|
|
525
1133
|
|
|
526
1134
|
# Verify client ID matches
|
|
527
|
-
if
|
|
1135
|
+
if code_model.client_id != client.client_id:
|
|
528
1136
|
logger.debug(
|
|
529
1137
|
"Authorization code client ID mismatch: %s vs %s",
|
|
530
|
-
|
|
1138
|
+
code_model.client_id,
|
|
531
1139
|
client.client_id,
|
|
532
1140
|
)
|
|
533
1141
|
return None
|
|
534
1142
|
|
|
535
1143
|
# Create authorization code object with PKCE challenge
|
|
1144
|
+
if client.client_id is None:
|
|
1145
|
+
raise AuthorizeError(
|
|
1146
|
+
error="invalid_client", error_description="Client ID is required"
|
|
1147
|
+
)
|
|
536
1148
|
return AuthorizationCode(
|
|
537
1149
|
code=authorization_code,
|
|
538
1150
|
client_id=client.client_id,
|
|
539
|
-
redirect_uri=
|
|
1151
|
+
redirect_uri=AnyUrl(url=code_model.redirect_uri),
|
|
540
1152
|
redirect_uri_provided_explicitly=True,
|
|
541
|
-
scopes=
|
|
542
|
-
expires_at=
|
|
543
|
-
code_challenge=
|
|
1153
|
+
scopes=code_model.scopes,
|
|
1154
|
+
expires_at=code_model.expires_at,
|
|
1155
|
+
code_challenge=code_model.code_challenge or "",
|
|
544
1156
|
)
|
|
545
1157
|
|
|
1158
|
+
@override
|
|
546
1159
|
async def exchange_authorization_code(
|
|
547
1160
|
self,
|
|
548
1161
|
client: OAuthClientInformationFull,
|
|
549
1162
|
authorization_code: AuthorizationCode,
|
|
550
1163
|
) -> OAuthToken:
|
|
551
|
-
"""Exchange authorization code for
|
|
1164
|
+
"""Exchange authorization code for FastMCP-issued tokens.
|
|
1165
|
+
|
|
1166
|
+
Implements the token factory pattern:
|
|
1167
|
+
1. Retrieves upstream tokens from stored authorization code
|
|
1168
|
+
2. Extracts user identity from upstream token
|
|
1169
|
+
3. Encrypts and stores upstream tokens
|
|
1170
|
+
4. Issues FastMCP-signed JWT tokens
|
|
1171
|
+
5. Returns FastMCP tokens (NOT upstream tokens)
|
|
552
1172
|
|
|
553
|
-
|
|
554
|
-
during the IdP callback exchange. PKCE validation is handled by the MCP framework.
|
|
1173
|
+
PKCE validation is handled by the MCP framework before this method is called.
|
|
555
1174
|
"""
|
|
556
1175
|
# Look up stored code data
|
|
557
|
-
|
|
558
|
-
if not
|
|
1176
|
+
code_model = await self._code_store.get(key=authorization_code.code)
|
|
1177
|
+
if not code_model:
|
|
559
1178
|
logger.error(
|
|
560
1179
|
"Authorization code not found in client codes: %s",
|
|
561
1180
|
authorization_code.code,
|
|
562
1181
|
)
|
|
563
1182
|
raise TokenError("invalid_grant", "Authorization code not found")
|
|
564
1183
|
|
|
565
|
-
# Get stored
|
|
566
|
-
idp_tokens =
|
|
1184
|
+
# Get stored upstream tokens
|
|
1185
|
+
idp_tokens = code_model.idp_tokens
|
|
567
1186
|
|
|
568
1187
|
# Clean up client code (one-time use)
|
|
569
|
-
self.
|
|
1188
|
+
await self._code_store.delete(key=authorization_code.code)
|
|
570
1189
|
|
|
571
|
-
#
|
|
572
|
-
|
|
573
|
-
|
|
1190
|
+
# Generate IDs for token storage
|
|
1191
|
+
upstream_token_id = secrets.token_urlsafe(32)
|
|
1192
|
+
access_jti = secrets.token_urlsafe(32)
|
|
1193
|
+
refresh_jti = (
|
|
1194
|
+
secrets.token_urlsafe(32) if idp_tokens.get("refresh_token") else None
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
# Calculate token expiry times
|
|
574
1198
|
expires_in = int(
|
|
575
1199
|
idp_tokens.get("expires_in", DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS)
|
|
576
1200
|
)
|
|
577
|
-
expires_at = int(time.time() + expires_in)
|
|
578
1201
|
|
|
579
|
-
#
|
|
580
|
-
|
|
581
|
-
|
|
1202
|
+
# Calculate refresh token expiry if provided by upstream
|
|
1203
|
+
# Some providers include refresh_expires_in, some don't
|
|
1204
|
+
refresh_expires_in = None
|
|
1205
|
+
refresh_token_expires_at = None
|
|
1206
|
+
if idp_tokens.get("refresh_token"):
|
|
1207
|
+
if "refresh_expires_in" in idp_tokens:
|
|
1208
|
+
refresh_expires_in = int(idp_tokens["refresh_expires_in"])
|
|
1209
|
+
refresh_token_expires_at = time.time() + refresh_expires_in
|
|
1210
|
+
logger.debug(
|
|
1211
|
+
"Upstream refresh token expires in %d seconds", refresh_expires_in
|
|
1212
|
+
)
|
|
1213
|
+
else:
|
|
1214
|
+
# Default to 30 days if upstream doesn't specify
|
|
1215
|
+
# This is conservative - most providers use longer expiry
|
|
1216
|
+
refresh_expires_in = 60 * 60 * 24 * 30 # 30 days
|
|
1217
|
+
refresh_token_expires_at = time.time() + refresh_expires_in
|
|
1218
|
+
logger.debug(
|
|
1219
|
+
"Upstream refresh token expiry unknown, using 30-day default"
|
|
1220
|
+
)
|
|
1221
|
+
|
|
1222
|
+
# Encrypt and store upstream tokens
|
|
1223
|
+
upstream_token_set = UpstreamTokenSet(
|
|
1224
|
+
upstream_token_id=upstream_token_id,
|
|
1225
|
+
access_token=idp_tokens["access_token"],
|
|
1226
|
+
refresh_token=idp_tokens["refresh_token"]
|
|
1227
|
+
if idp_tokens.get("refresh_token")
|
|
1228
|
+
else None,
|
|
1229
|
+
refresh_token_expires_at=refresh_token_expires_at,
|
|
1230
|
+
expires_at=time.time() + expires_in,
|
|
1231
|
+
token_type=idp_tokens.get("token_type", "Bearer"),
|
|
1232
|
+
scope=" ".join(authorization_code.scopes),
|
|
1233
|
+
client_id=client.client_id or "",
|
|
1234
|
+
created_at=time.time(),
|
|
1235
|
+
raw_token_data=idp_tokens,
|
|
1236
|
+
)
|
|
1237
|
+
await self._upstream_token_store.put(
|
|
1238
|
+
key=upstream_token_id,
|
|
1239
|
+
value=upstream_token_set,
|
|
1240
|
+
ttl=refresh_expires_in
|
|
1241
|
+
or expires_in, # Auto-expire when refresh token, or access token expires
|
|
1242
|
+
)
|
|
1243
|
+
logger.debug("Stored encrypted upstream tokens (jti=%s)", access_jti[:8])
|
|
1244
|
+
|
|
1245
|
+
# Issue minimal FastMCP access token (just a reference via JTI)
|
|
1246
|
+
if client.client_id is None:
|
|
1247
|
+
raise TokenError("invalid_client", "Client ID is required")
|
|
1248
|
+
fastmcp_access_token = self._jwt_issuer.issue_access_token(
|
|
582
1249
|
client_id=client.client_id,
|
|
583
1250
|
scopes=authorization_code.scopes,
|
|
584
|
-
|
|
1251
|
+
jti=access_jti,
|
|
1252
|
+
expires_in=expires_in,
|
|
585
1253
|
)
|
|
586
|
-
self._access_tokens[access_token_value] = access_token
|
|
587
1254
|
|
|
588
|
-
#
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
1255
|
+
# Issue minimal FastMCP refresh token if upstream provided one
|
|
1256
|
+
# Use upstream refresh token expiry to align lifetimes
|
|
1257
|
+
fastmcp_refresh_token = None
|
|
1258
|
+
if refresh_jti and refresh_expires_in:
|
|
1259
|
+
fastmcp_refresh_token = self._jwt_issuer.issue_refresh_token(
|
|
592
1260
|
client_id=client.client_id,
|
|
593
1261
|
scopes=authorization_code.scopes,
|
|
594
|
-
|
|
1262
|
+
jti=refresh_jti,
|
|
1263
|
+
expires_in=refresh_expires_in,
|
|
595
1264
|
)
|
|
596
|
-
self._refresh_tokens[refresh_token_value] = refresh_token
|
|
597
1265
|
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
1266
|
+
# Store JTI mappings
|
|
1267
|
+
await self._jti_mapping_store.put(
|
|
1268
|
+
key=access_jti,
|
|
1269
|
+
value=JTIMapping(
|
|
1270
|
+
jti=access_jti,
|
|
1271
|
+
upstream_token_id=upstream_token_id,
|
|
1272
|
+
created_at=time.time(),
|
|
1273
|
+
),
|
|
1274
|
+
ttl=expires_in, # Auto-expire with access token
|
|
1275
|
+
)
|
|
1276
|
+
if refresh_jti:
|
|
1277
|
+
await self._jti_mapping_store.put(
|
|
1278
|
+
key=refresh_jti,
|
|
1279
|
+
value=JTIMapping(
|
|
1280
|
+
jti=refresh_jti,
|
|
1281
|
+
upstream_token_id=upstream_token_id,
|
|
1282
|
+
created_at=time.time(),
|
|
1283
|
+
),
|
|
1284
|
+
ttl=60 * 60 * 24 * 30, # Auto-expire with refresh token (30 days)
|
|
1285
|
+
)
|
|
1286
|
+
|
|
1287
|
+
# Store refresh token metadata (keyed by hash for security)
|
|
1288
|
+
if fastmcp_refresh_token and refresh_expires_in:
|
|
1289
|
+
await self._refresh_token_store.put(
|
|
1290
|
+
key=_hash_token(fastmcp_refresh_token),
|
|
1291
|
+
value=RefreshTokenMetadata(
|
|
1292
|
+
client_id=client.client_id,
|
|
1293
|
+
scopes=authorization_code.scopes,
|
|
1294
|
+
expires_at=int(time.time()) + refresh_expires_in,
|
|
1295
|
+
created_at=time.time(),
|
|
1296
|
+
),
|
|
1297
|
+
ttl=refresh_expires_in,
|
|
1298
|
+
)
|
|
601
1299
|
|
|
602
1300
|
logger.debug(
|
|
603
|
-
"
|
|
1301
|
+
"Issued FastMCP tokens for client=%s (access_jti=%s, refresh_jti=%s)",
|
|
604
1302
|
client.client_id,
|
|
1303
|
+
access_jti[:8],
|
|
1304
|
+
refresh_jti[:8] if refresh_jti else "none",
|
|
605
1305
|
)
|
|
606
1306
|
|
|
607
|
-
|
|
1307
|
+
# Return FastMCP-issued tokens (NOT upstream tokens!)
|
|
1308
|
+
return OAuthToken(
|
|
1309
|
+
access_token=fastmcp_access_token,
|
|
1310
|
+
token_type="Bearer",
|
|
1311
|
+
expires_in=expires_in,
|
|
1312
|
+
refresh_token=fastmcp_refresh_token,
|
|
1313
|
+
scope=" ".join(authorization_code.scopes),
|
|
1314
|
+
)
|
|
608
1315
|
|
|
609
1316
|
# -------------------------------------------------------------------------
|
|
610
1317
|
# Refresh Token Flow
|
|
611
1318
|
# -------------------------------------------------------------------------
|
|
612
1319
|
|
|
1320
|
+
def _prepare_scopes_for_upstream_refresh(self, scopes: list[str]) -> list[str]:
|
|
1321
|
+
"""Prepare scopes for upstream token refresh request.
|
|
1322
|
+
|
|
1323
|
+
Override this method to transform scopes before sending to upstream provider.
|
|
1324
|
+
For example, Azure needs to prefix scopes and add additional Graph scopes.
|
|
1325
|
+
|
|
1326
|
+
The scopes parameter represents what should be stored in the RefreshToken.
|
|
1327
|
+
This method returns what should be sent to the upstream provider.
|
|
1328
|
+
|
|
1329
|
+
Args:
|
|
1330
|
+
scopes: Base scopes that will be stored in RefreshToken
|
|
1331
|
+
|
|
1332
|
+
Returns:
|
|
1333
|
+
Scopes to send to upstream provider (may be transformed/augmented)
|
|
1334
|
+
"""
|
|
1335
|
+
return scopes
|
|
1336
|
+
|
|
613
1337
|
async def load_refresh_token(
|
|
614
1338
|
self,
|
|
615
1339
|
client: OAuthClientInformationFull,
|
|
616
1340
|
refresh_token: str,
|
|
617
1341
|
) -> RefreshToken | None:
|
|
618
|
-
"""Load refresh token from
|
|
619
|
-
|
|
1342
|
+
"""Load refresh token metadata from distributed storage.
|
|
1343
|
+
|
|
1344
|
+
Looks up by token hash and reconstructs the RefreshToken object.
|
|
1345
|
+
Validates that the token belongs to the requesting client.
|
|
1346
|
+
"""
|
|
1347
|
+
token_hash = _hash_token(refresh_token)
|
|
1348
|
+
metadata = await self._refresh_token_store.get(key=token_hash)
|
|
1349
|
+
if not metadata:
|
|
1350
|
+
return None
|
|
1351
|
+
# Verify token belongs to this client (prevents cross-client token usage)
|
|
1352
|
+
if metadata.client_id != client.client_id:
|
|
1353
|
+
logger.warning(
|
|
1354
|
+
"Refresh token client_id mismatch: expected %s, got %s",
|
|
1355
|
+
client.client_id,
|
|
1356
|
+
metadata.client_id,
|
|
1357
|
+
)
|
|
1358
|
+
return None
|
|
1359
|
+
return RefreshToken(
|
|
1360
|
+
token=refresh_token,
|
|
1361
|
+
client_id=metadata.client_id,
|
|
1362
|
+
scopes=metadata.scopes,
|
|
1363
|
+
expires_at=metadata.expires_at,
|
|
1364
|
+
)
|
|
620
1365
|
|
|
621
1366
|
async def exchange_refresh_token(
|
|
622
1367
|
self,
|
|
@@ -624,9 +1369,45 @@ class OAuthProxy(OAuthProvider):
|
|
|
624
1369
|
refresh_token: RefreshToken,
|
|
625
1370
|
scopes: list[str],
|
|
626
1371
|
) -> OAuthToken:
|
|
627
|
-
"""Exchange refresh token for new access token
|
|
1372
|
+
"""Exchange FastMCP refresh token for new FastMCP access token.
|
|
1373
|
+
|
|
1374
|
+
Implements two-tier refresh:
|
|
1375
|
+
1. Verify FastMCP refresh token
|
|
1376
|
+
2. Look up upstream token via JTI mapping
|
|
1377
|
+
3. Refresh upstream token with upstream provider
|
|
1378
|
+
4. Update stored upstream token
|
|
1379
|
+
5. Issue new FastMCP access token
|
|
1380
|
+
6. Keep same FastMCP refresh token (unless upstream rotates)
|
|
1381
|
+
"""
|
|
1382
|
+
# Verify FastMCP refresh token
|
|
1383
|
+
try:
|
|
1384
|
+
refresh_payload = self._jwt_issuer.verify_token(refresh_token.token)
|
|
1385
|
+
refresh_jti = refresh_payload["jti"]
|
|
1386
|
+
except Exception as e:
|
|
1387
|
+
logger.debug("FastMCP refresh token validation failed: %s", e)
|
|
1388
|
+
raise TokenError("invalid_grant", "Invalid refresh token") from e
|
|
1389
|
+
|
|
1390
|
+
# Look up upstream token via JTI mapping
|
|
1391
|
+
jti_mapping = await self._jti_mapping_store.get(key=refresh_jti)
|
|
1392
|
+
if not jti_mapping:
|
|
1393
|
+
logger.error("JTI mapping not found for refresh token: %s", refresh_jti[:8])
|
|
1394
|
+
raise TokenError("invalid_grant", "Refresh token mapping not found")
|
|
628
1395
|
|
|
629
|
-
|
|
1396
|
+
upstream_token_set = await self._upstream_token_store.get(
|
|
1397
|
+
key=jti_mapping.upstream_token_id
|
|
1398
|
+
)
|
|
1399
|
+
if not upstream_token_set:
|
|
1400
|
+
logger.error(
|
|
1401
|
+
"Upstream token set not found: %s", jti_mapping.upstream_token_id[:8]
|
|
1402
|
+
)
|
|
1403
|
+
raise TokenError("invalid_grant", "Upstream token not found")
|
|
1404
|
+
|
|
1405
|
+
# Decrypt upstream refresh token
|
|
1406
|
+
if not upstream_token_set.refresh_token:
|
|
1407
|
+
logger.error("No upstream refresh token available")
|
|
1408
|
+
raise TokenError("invalid_grant", "Refresh not supported for this token")
|
|
1409
|
+
|
|
1410
|
+
# Refresh upstream token using authlib
|
|
630
1411
|
oauth_client = AsyncOAuth2Client(
|
|
631
1412
|
client_id=self._upstream_client_id,
|
|
632
1413
|
client_secret=self._upstream_client_secret.get_secret_value(),
|
|
@@ -634,78 +1415,209 @@ class OAuthProxy(OAuthProvider):
|
|
|
634
1415
|
timeout=HTTP_TIMEOUT_SECONDS,
|
|
635
1416
|
)
|
|
636
1417
|
|
|
637
|
-
|
|
638
|
-
|
|
1418
|
+
# Allow child classes to transform scopes before sending to upstream
|
|
1419
|
+
# This enables provider-specific scope formatting (e.g., Azure prefixing)
|
|
1420
|
+
# while keeping original scopes in storage
|
|
1421
|
+
upstream_scopes = self._prepare_scopes_for_upstream_refresh(scopes)
|
|
639
1422
|
|
|
640
|
-
|
|
1423
|
+
try:
|
|
1424
|
+
logger.debug("Refreshing upstream token (jti=%s)", refresh_jti[:8])
|
|
641
1425
|
token_response: dict[str, Any] = await oauth_client.refresh_token( # type: ignore[misc]
|
|
642
1426
|
url=self._upstream_token_endpoint,
|
|
643
|
-
refresh_token=refresh_token
|
|
644
|
-
scope=" ".join(
|
|
1427
|
+
refresh_token=upstream_token_set.refresh_token,
|
|
1428
|
+
scope=" ".join(upstream_scopes) if upstream_scopes else None,
|
|
1429
|
+
**self._extra_token_params,
|
|
645
1430
|
)
|
|
646
|
-
|
|
647
|
-
logger.debug(
|
|
648
|
-
"Successfully refreshed access token via authlib (client: %s)",
|
|
649
|
-
client.client_id,
|
|
650
|
-
)
|
|
651
|
-
|
|
1431
|
+
logger.debug("Successfully refreshed upstream token")
|
|
652
1432
|
except Exception as e:
|
|
653
|
-
logger.error("
|
|
654
|
-
raise TokenError(
|
|
655
|
-
"invalid_grant", f"Upstream refresh token exchange failed: {e}"
|
|
656
|
-
) from e
|
|
1433
|
+
logger.error("Upstream token refresh failed: %s", e)
|
|
1434
|
+
raise TokenError("invalid_grant", f"Upstream refresh failed: {e}") from e
|
|
657
1435
|
|
|
658
|
-
# Update
|
|
659
|
-
|
|
660
|
-
expires_in = int(
|
|
1436
|
+
# Update stored upstream token
|
|
1437
|
+
new_expires_in = int(
|
|
661
1438
|
token_response.get("expires_in", DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS)
|
|
662
1439
|
)
|
|
1440
|
+
upstream_token_set.access_token = token_response["access_token"]
|
|
1441
|
+
upstream_token_set.expires_at = time.time() + new_expires_in
|
|
1442
|
+
|
|
1443
|
+
# Handle upstream refresh token rotation and expiry
|
|
1444
|
+
new_refresh_expires_in = None
|
|
1445
|
+
if new_upstream_refresh := token_response.get("refresh_token"):
|
|
1446
|
+
if new_upstream_refresh != upstream_token_set.refresh_token:
|
|
1447
|
+
upstream_token_set.refresh_token = new_upstream_refresh
|
|
1448
|
+
logger.debug("Upstream refresh token rotated")
|
|
1449
|
+
|
|
1450
|
+
# Update refresh token expiry if provided
|
|
1451
|
+
if "refresh_expires_in" in token_response:
|
|
1452
|
+
new_refresh_expires_in = int(token_response["refresh_expires_in"])
|
|
1453
|
+
upstream_token_set.refresh_token_expires_at = (
|
|
1454
|
+
time.time() + new_refresh_expires_in
|
|
1455
|
+
)
|
|
1456
|
+
logger.debug(
|
|
1457
|
+
"Upstream refresh token expires in %d seconds",
|
|
1458
|
+
new_refresh_expires_in,
|
|
1459
|
+
)
|
|
1460
|
+
elif upstream_token_set.refresh_token_expires_at:
|
|
1461
|
+
# Keep existing expiry if upstream doesn't provide new one
|
|
1462
|
+
new_refresh_expires_in = int(
|
|
1463
|
+
upstream_token_set.refresh_token_expires_at - time.time()
|
|
1464
|
+
)
|
|
1465
|
+
else:
|
|
1466
|
+
# Default to 30 days if unknown
|
|
1467
|
+
new_refresh_expires_in = 60 * 60 * 24 * 30
|
|
1468
|
+
upstream_token_set.refresh_token_expires_at = (
|
|
1469
|
+
time.time() + new_refresh_expires_in
|
|
1470
|
+
)
|
|
663
1471
|
|
|
664
|
-
|
|
665
|
-
|
|
1472
|
+
upstream_token_set.raw_token_data = token_response
|
|
1473
|
+
await self._upstream_token_store.put(
|
|
1474
|
+
key=upstream_token_set.upstream_token_id,
|
|
1475
|
+
value=upstream_token_set,
|
|
1476
|
+
ttl=new_refresh_expires_in
|
|
1477
|
+
or (
|
|
1478
|
+
int(upstream_token_set.refresh_token_expires_at - time.time())
|
|
1479
|
+
if upstream_token_set.refresh_token_expires_at
|
|
1480
|
+
else 60 * 60 * 24 * 30 # Default to 30 days if unknown
|
|
1481
|
+
), # Auto-expire when refresh token expires
|
|
1482
|
+
)
|
|
1483
|
+
|
|
1484
|
+
# Issue new minimal FastMCP access token (just a reference via JTI)
|
|
1485
|
+
if client.client_id is None:
|
|
1486
|
+
raise TokenError("invalid_client", "Client ID is required")
|
|
1487
|
+
new_access_jti = secrets.token_urlsafe(32)
|
|
1488
|
+
new_fastmcp_access = self._jwt_issuer.issue_access_token(
|
|
666
1489
|
client_id=client.client_id,
|
|
667
1490
|
scopes=scopes,
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
self._refresh_tokens[new_refresh_token] = RefreshToken(
|
|
683
|
-
token=new_refresh_token,
|
|
684
|
-
client_id=client.client_id,
|
|
685
|
-
scopes=scopes,
|
|
686
|
-
expires_at=None,
|
|
687
|
-
)
|
|
688
|
-
self._access_to_refresh[new_access_token] = new_refresh_token
|
|
689
|
-
self._refresh_to_access[new_refresh_token] = new_access_token
|
|
1491
|
+
jti=new_access_jti,
|
|
1492
|
+
expires_in=new_expires_in,
|
|
1493
|
+
)
|
|
1494
|
+
|
|
1495
|
+
# Store new access token JTI mapping
|
|
1496
|
+
await self._jti_mapping_store.put(
|
|
1497
|
+
key=new_access_jti,
|
|
1498
|
+
value=JTIMapping(
|
|
1499
|
+
jti=new_access_jti,
|
|
1500
|
+
upstream_token_id=upstream_token_set.upstream_token_id,
|
|
1501
|
+
created_at=time.time(),
|
|
1502
|
+
),
|
|
1503
|
+
ttl=new_expires_in, # Auto-expire with refreshed access token
|
|
1504
|
+
)
|
|
690
1505
|
|
|
691
|
-
|
|
1506
|
+
# Issue NEW minimal FastMCP refresh token (rotation for security)
|
|
1507
|
+
# Use upstream refresh token expiry to align lifetimes
|
|
1508
|
+
new_refresh_jti = secrets.token_urlsafe(32)
|
|
1509
|
+
new_fastmcp_refresh = self._jwt_issuer.issue_refresh_token(
|
|
1510
|
+
client_id=client.client_id,
|
|
1511
|
+
scopes=scopes,
|
|
1512
|
+
jti=new_refresh_jti,
|
|
1513
|
+
expires_in=new_refresh_expires_in
|
|
1514
|
+
or 60 * 60 * 24 * 30, # Fallback to 30 days
|
|
1515
|
+
)
|
|
1516
|
+
|
|
1517
|
+
# Store new refresh token JTI mapping with aligned expiry
|
|
1518
|
+
refresh_ttl = new_refresh_expires_in or 60 * 60 * 24 * 30
|
|
1519
|
+
await self._jti_mapping_store.put(
|
|
1520
|
+
key=new_refresh_jti,
|
|
1521
|
+
value=JTIMapping(
|
|
1522
|
+
jti=new_refresh_jti,
|
|
1523
|
+
upstream_token_id=upstream_token_set.upstream_token_id,
|
|
1524
|
+
created_at=time.time(),
|
|
1525
|
+
),
|
|
1526
|
+
ttl=refresh_ttl, # Align with upstream refresh token expiry
|
|
1527
|
+
)
|
|
1528
|
+
|
|
1529
|
+
# Invalidate old refresh token (refresh token rotation - enforces one-time use)
|
|
1530
|
+
await self._jti_mapping_store.delete(key=refresh_jti)
|
|
1531
|
+
logger.debug(
|
|
1532
|
+
"Rotated refresh token (old JTI invalidated - one-time use enforced)"
|
|
1533
|
+
)
|
|
1534
|
+
|
|
1535
|
+
# Store new refresh token metadata (keyed by hash)
|
|
1536
|
+
await self._refresh_token_store.put(
|
|
1537
|
+
key=_hash_token(new_fastmcp_refresh),
|
|
1538
|
+
value=RefreshTokenMetadata(
|
|
1539
|
+
client_id=client.client_id,
|
|
1540
|
+
scopes=scopes,
|
|
1541
|
+
expires_at=int(time.time()) + refresh_ttl,
|
|
1542
|
+
created_at=time.time(),
|
|
1543
|
+
),
|
|
1544
|
+
ttl=refresh_ttl,
|
|
1545
|
+
)
|
|
1546
|
+
|
|
1547
|
+
# Delete old refresh token (by hash)
|
|
1548
|
+
await self._refresh_token_store.delete(key=_hash_token(refresh_token.token))
|
|
1549
|
+
|
|
1550
|
+
logger.info(
|
|
1551
|
+
"Issued new FastMCP tokens (rotated refresh) for client=%s (access_jti=%s, refresh_jti=%s)",
|
|
1552
|
+
client.client_id,
|
|
1553
|
+
new_access_jti[:8],
|
|
1554
|
+
new_refresh_jti[:8],
|
|
1555
|
+
)
|
|
1556
|
+
|
|
1557
|
+
# Return new FastMCP tokens (both access AND refresh are new)
|
|
1558
|
+
return OAuthToken(
|
|
1559
|
+
access_token=new_fastmcp_access,
|
|
1560
|
+
token_type="Bearer",
|
|
1561
|
+
expires_in=new_expires_in,
|
|
1562
|
+
refresh_token=new_fastmcp_refresh, # NEW refresh token (rotated)
|
|
1563
|
+
scope=" ".join(scopes),
|
|
1564
|
+
)
|
|
692
1565
|
|
|
693
1566
|
# -------------------------------------------------------------------------
|
|
694
1567
|
# Token Validation
|
|
695
1568
|
# -------------------------------------------------------------------------
|
|
696
1569
|
|
|
697
1570
|
async def load_access_token(self, token: str) -> AccessToken | None:
|
|
698
|
-
"""Validate
|
|
1571
|
+
"""Validate FastMCP JWT by swapping for upstream token.
|
|
1572
|
+
|
|
1573
|
+
This implements the token swap pattern:
|
|
1574
|
+
1. Verify FastMCP JWT signature (proves it's our token)
|
|
1575
|
+
2. Look up upstream token via JTI mapping
|
|
1576
|
+
3. Decrypt upstream token
|
|
1577
|
+
4. Validate upstream token with provider (GitHub API, JWT validation, etc.)
|
|
1578
|
+
5. Return upstream validation result
|
|
699
1579
|
|
|
700
|
-
|
|
701
|
-
|
|
1580
|
+
The FastMCP JWT is a reference token - all authorization data comes
|
|
1581
|
+
from validating the upstream token via the TokenVerifier.
|
|
702
1582
|
"""
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
1583
|
+
try:
|
|
1584
|
+
# 1. Verify FastMCP JWT signature and claims
|
|
1585
|
+
payload = self._jwt_issuer.verify_token(token)
|
|
1586
|
+
jti = payload["jti"]
|
|
1587
|
+
|
|
1588
|
+
# 2. Look up upstream token via JTI mapping
|
|
1589
|
+
jti_mapping = await self._jti_mapping_store.get(key=jti)
|
|
1590
|
+
if not jti_mapping:
|
|
1591
|
+
logger.debug("JTI mapping not found: %s", jti)
|
|
1592
|
+
return None
|
|
1593
|
+
|
|
1594
|
+
upstream_token_set = await self._upstream_token_store.get(
|
|
1595
|
+
key=jti_mapping.upstream_token_id
|
|
1596
|
+
)
|
|
1597
|
+
if not upstream_token_set:
|
|
1598
|
+
logger.debug(
|
|
1599
|
+
"Upstream token not found: %s", jti_mapping.upstream_token_id
|
|
1600
|
+
)
|
|
1601
|
+
return None
|
|
1602
|
+
|
|
1603
|
+
# 3. Validate with upstream provider (delegated to TokenVerifier)
|
|
1604
|
+
# This calls the real token validator (GitHub API, JWKS, etc.)
|
|
1605
|
+
validated = await self._token_validator.verify_token(
|
|
1606
|
+
upstream_token_set.access_token
|
|
1607
|
+
)
|
|
1608
|
+
|
|
1609
|
+
if not validated:
|
|
1610
|
+
logger.debug("Upstream token validation failed")
|
|
1611
|
+
return None
|
|
1612
|
+
|
|
1613
|
+
logger.debug(
|
|
1614
|
+
"Token swap successful for JTI=%s (upstream validated)", jti[:8]
|
|
1615
|
+
)
|
|
1616
|
+
return validated
|
|
1617
|
+
|
|
1618
|
+
except Exception as e:
|
|
1619
|
+
logger.debug("Token swap validation failed: %s", e)
|
|
1620
|
+
return None
|
|
709
1621
|
|
|
710
1622
|
# -------------------------------------------------------------------------
|
|
711
1623
|
# Token Revocation
|
|
@@ -714,24 +1626,13 @@ class OAuthProxy(OAuthProvider):
|
|
|
714
1626
|
async def revoke_token(self, token: AccessToken | RefreshToken) -> None:
|
|
715
1627
|
"""Revoke token locally and with upstream server if supported.
|
|
716
1628
|
|
|
717
|
-
|
|
718
|
-
|
|
1629
|
+
For refresh tokens, removes from local storage by hash.
|
|
1630
|
+
For all tokens, attempts upstream revocation if endpoint is configured.
|
|
1631
|
+
Access token JTI mappings expire via TTL.
|
|
719
1632
|
"""
|
|
720
|
-
#
|
|
721
|
-
if isinstance(token,
|
|
722
|
-
self.
|
|
723
|
-
# Also remove associated refresh token
|
|
724
|
-
paired_refresh = self._access_to_refresh.pop(token.token, None)
|
|
725
|
-
if paired_refresh:
|
|
726
|
-
self._refresh_tokens.pop(paired_refresh, None)
|
|
727
|
-
self._refresh_to_access.pop(paired_refresh, None)
|
|
728
|
-
else: # RefreshToken
|
|
729
|
-
self._refresh_tokens.pop(token.token, None)
|
|
730
|
-
# Also remove associated access token
|
|
731
|
-
paired_access = self._refresh_to_access.pop(token.token, None)
|
|
732
|
-
if paired_access:
|
|
733
|
-
self._access_tokens.pop(paired_access, None)
|
|
734
|
-
self._access_to_refresh.pop(paired_access, None)
|
|
1633
|
+
# For refresh tokens, delete from local storage by hash
|
|
1634
|
+
if isinstance(token, RefreshToken):
|
|
1635
|
+
await self._refresh_token_store.delete(key=_hash_token(token.token))
|
|
735
1636
|
|
|
736
1637
|
# Attempt upstream revocation if endpoint is configured
|
|
737
1638
|
if self._upstream_revocation_endpoint:
|
|
@@ -758,21 +1659,22 @@ class OAuthProxy(OAuthProvider):
|
|
|
758
1659
|
def get_routes(
|
|
759
1660
|
self,
|
|
760
1661
|
mcp_path: str | None = None,
|
|
761
|
-
mcp_endpoint: Any | None = None,
|
|
762
1662
|
) -> list[Route]:
|
|
763
|
-
"""Get OAuth routes with custom
|
|
1663
|
+
"""Get OAuth routes with custom handlers for better error UX.
|
|
764
1664
|
|
|
765
|
-
This method creates standard OAuth routes and replaces
|
|
766
|
-
|
|
1665
|
+
This method creates standard OAuth routes and replaces:
|
|
1666
|
+
- /authorize endpoint: Enhanced error responses for unregistered clients
|
|
1667
|
+
- /token endpoint: OAuth 2.1 compliant error codes
|
|
767
1668
|
|
|
768
1669
|
Args:
|
|
769
1670
|
mcp_path: The path where the MCP endpoint is mounted (e.g., "/mcp")
|
|
770
|
-
|
|
1671
|
+
This is used to advertise the resource URL in metadata.
|
|
771
1672
|
"""
|
|
772
1673
|
# Get standard OAuth routes from parent class
|
|
773
|
-
routes = super().get_routes(mcp_path
|
|
1674
|
+
routes = super().get_routes(mcp_path)
|
|
774
1675
|
custom_routes = []
|
|
775
1676
|
token_route_found = False
|
|
1677
|
+
authorize_route_found = False
|
|
776
1678
|
|
|
777
1679
|
logger.debug(
|
|
778
1680
|
f"get_routes called - configuring OAuth routes in {len(routes)} routes"
|
|
@@ -783,16 +1685,53 @@ class OAuthProxy(OAuthProvider):
|
|
|
783
1685
|
f"Route {i}: {route} - path: {getattr(route, 'path', 'N/A')}, methods: {getattr(route, 'methods', 'N/A')}"
|
|
784
1686
|
)
|
|
785
1687
|
|
|
786
|
-
#
|
|
787
|
-
custom_routes.append(route)
|
|
788
|
-
|
|
1688
|
+
# Replace the authorize endpoint with our enhanced handler for better error UX
|
|
789
1689
|
if (
|
|
1690
|
+
isinstance(route, Route)
|
|
1691
|
+
and route.path == "/authorize"
|
|
1692
|
+
and route.methods is not None
|
|
1693
|
+
and ("GET" in route.methods or "POST" in route.methods)
|
|
1694
|
+
):
|
|
1695
|
+
authorize_route_found = True
|
|
1696
|
+
# Replace with our enhanced authorization handler
|
|
1697
|
+
# Note: self.base_url is guaranteed to be set in parent __init__
|
|
1698
|
+
authorize_handler = AuthorizationHandler(
|
|
1699
|
+
provider=self,
|
|
1700
|
+
base_url=self.base_url, # ty: ignore[invalid-argument-type]
|
|
1701
|
+
server_name=None, # Could be extended to pass server metadata
|
|
1702
|
+
server_icon_url=None,
|
|
1703
|
+
)
|
|
1704
|
+
custom_routes.append(
|
|
1705
|
+
Route(
|
|
1706
|
+
path="/authorize",
|
|
1707
|
+
endpoint=authorize_handler.handle,
|
|
1708
|
+
methods=["GET", "POST"],
|
|
1709
|
+
)
|
|
1710
|
+
)
|
|
1711
|
+
# Replace the token endpoint with our custom handler that returns proper OAuth 2.1 error codes
|
|
1712
|
+
elif (
|
|
790
1713
|
isinstance(route, Route)
|
|
791
1714
|
and route.path == "/token"
|
|
792
1715
|
and route.methods is not None
|
|
793
1716
|
and "POST" in route.methods
|
|
794
1717
|
):
|
|
795
1718
|
token_route_found = True
|
|
1719
|
+
# Replace with our OAuth 2.1 compliant token handler
|
|
1720
|
+
token_handler = TokenHandler(
|
|
1721
|
+
provider=self, client_authenticator=ClientAuthenticator(self)
|
|
1722
|
+
)
|
|
1723
|
+
custom_routes.append(
|
|
1724
|
+
Route(
|
|
1725
|
+
path="/token",
|
|
1726
|
+
endpoint=cors_middleware(
|
|
1727
|
+
token_handler.handle, ["POST", "OPTIONS"]
|
|
1728
|
+
),
|
|
1729
|
+
methods=["POST", "OPTIONS"],
|
|
1730
|
+
)
|
|
1731
|
+
)
|
|
1732
|
+
else:
|
|
1733
|
+
# Keep all other standard OAuth routes unchanged
|
|
1734
|
+
custom_routes.append(route)
|
|
796
1735
|
|
|
797
1736
|
# Add OAuth callback endpoint for forwarding to client callbacks
|
|
798
1737
|
custom_routes.append(
|
|
@@ -803,8 +1742,16 @@ class OAuthProxy(OAuthProvider):
|
|
|
803
1742
|
)
|
|
804
1743
|
)
|
|
805
1744
|
|
|
1745
|
+
# Add consent endpoints
|
|
1746
|
+
# Handle both GET (show page) and POST (submit) at /consent
|
|
1747
|
+
custom_routes.append(
|
|
1748
|
+
Route(
|
|
1749
|
+
path="/consent", endpoint=self._handle_consent, methods=["GET", "POST"]
|
|
1750
|
+
)
|
|
1751
|
+
)
|
|
1752
|
+
|
|
806
1753
|
logger.debug(
|
|
807
|
-
f"✅ OAuth routes configured: token_endpoint={token_route_found}, total routes={len(custom_routes)} (includes OAuth callback)"
|
|
1754
|
+
f"✅ OAuth routes configured: authorize_endpoint={authorize_route_found}, token_endpoint={token_route_found}, total routes={len(custom_routes)} (includes OAuth callback + consent)"
|
|
808
1755
|
)
|
|
809
1756
|
return custom_routes
|
|
810
1757
|
|
|
@@ -812,7 +1759,9 @@ class OAuthProxy(OAuthProvider):
|
|
|
812
1759
|
# IdP Callback Forwarding
|
|
813
1760
|
# -------------------------------------------------------------------------
|
|
814
1761
|
|
|
815
|
-
async def _handle_idp_callback(
|
|
1762
|
+
async def _handle_idp_callback(
|
|
1763
|
+
self, request: Request
|
|
1764
|
+
) -> HTMLResponse | RedirectResponse:
|
|
816
1765
|
"""Handle callback from upstream IdP and forward to client.
|
|
817
1766
|
|
|
818
1767
|
This implements the DCR-compliant callback forwarding:
|
|
@@ -827,32 +1776,38 @@ class OAuthProxy(OAuthProvider):
|
|
|
827
1776
|
error = request.query_params.get("error")
|
|
828
1777
|
|
|
829
1778
|
if error:
|
|
1779
|
+
error_description = request.query_params.get("error_description")
|
|
830
1780
|
logger.error(
|
|
831
1781
|
"IdP callback error: %s - %s",
|
|
832
1782
|
error,
|
|
833
|
-
|
|
1783
|
+
error_description,
|
|
834
1784
|
)
|
|
835
|
-
#
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
1785
|
+
# Show error page to user
|
|
1786
|
+
html_content = create_error_html(
|
|
1787
|
+
error_title="OAuth Error",
|
|
1788
|
+
error_message=f"Authentication failed: {error_description or 'Unknown error'}",
|
|
1789
|
+
error_details={"Error Code": error} if error else None,
|
|
839
1790
|
)
|
|
1791
|
+
return HTMLResponse(content=html_content, status_code=400)
|
|
840
1792
|
|
|
841
1793
|
if not idp_code or not txn_id:
|
|
842
1794
|
logger.error("IdP callback missing code or transaction ID")
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
1795
|
+
html_content = create_error_html(
|
|
1796
|
+
error_title="OAuth Error",
|
|
1797
|
+
error_message="Missing authorization code or transaction ID from the identity provider.",
|
|
846
1798
|
)
|
|
1799
|
+
return HTMLResponse(content=html_content, status_code=400)
|
|
847
1800
|
|
|
848
1801
|
# Look up transaction data
|
|
849
|
-
|
|
850
|
-
if not
|
|
1802
|
+
transaction_model = await self._transaction_store.get(key=txn_id)
|
|
1803
|
+
if not transaction_model:
|
|
851
1804
|
logger.error("IdP callback with invalid transaction ID: %s", txn_id)
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
1805
|
+
html_content = create_error_html(
|
|
1806
|
+
error_title="OAuth Error",
|
|
1807
|
+
error_message="Invalid or expired authorization transaction. Please try authenticating again.",
|
|
855
1808
|
)
|
|
1809
|
+
return HTMLResponse(content=html_content, status_code=400)
|
|
1810
|
+
transaction = transaction_model.model_dump()
|
|
856
1811
|
|
|
857
1812
|
# Exchange IdP code for tokens (server-side)
|
|
858
1813
|
oauth_client = AsyncOAuth2Client(
|
|
@@ -870,56 +1825,70 @@ class OAuthProxy(OAuthProvider):
|
|
|
870
1825
|
f"Exchanging IdP code for tokens with redirect_uri: {idp_redirect_uri}"
|
|
871
1826
|
)
|
|
872
1827
|
|
|
1828
|
+
# Build token exchange parameters
|
|
1829
|
+
token_params = {
|
|
1830
|
+
"url": self._upstream_token_endpoint,
|
|
1831
|
+
"code": idp_code,
|
|
1832
|
+
"redirect_uri": idp_redirect_uri,
|
|
1833
|
+
}
|
|
1834
|
+
|
|
873
1835
|
# Include proxy's code_verifier if we forwarded PKCE
|
|
874
1836
|
proxy_code_verifier = transaction.get("proxy_code_verifier")
|
|
875
1837
|
if proxy_code_verifier:
|
|
1838
|
+
token_params["code_verifier"] = proxy_code_verifier
|
|
876
1839
|
logger.debug(
|
|
877
1840
|
"Including proxy code_verifier in token exchange for transaction %s",
|
|
878
1841
|
txn_id,
|
|
879
1842
|
)
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
url=self._upstream_token_endpoint,
|
|
889
|
-
code=idp_code,
|
|
890
|
-
redirect_uri=idp_redirect_uri,
|
|
1843
|
+
|
|
1844
|
+
# Add any extra token parameters configured for this proxy
|
|
1845
|
+
if self._extra_token_params:
|
|
1846
|
+
token_params.update(self._extra_token_params)
|
|
1847
|
+
logger.debug(
|
|
1848
|
+
"Adding extra token parameters for transaction %s: %s",
|
|
1849
|
+
txn_id,
|
|
1850
|
+
list(self._extra_token_params.keys()),
|
|
891
1851
|
)
|
|
892
1852
|
|
|
1853
|
+
idp_tokens: dict[str, Any] = await oauth_client.fetch_token(
|
|
1854
|
+
**token_params
|
|
1855
|
+
) # type: ignore[misc]
|
|
1856
|
+
|
|
893
1857
|
logger.debug(
|
|
894
1858
|
f"Successfully exchanged IdP code for tokens (transaction: {txn_id}, PKCE: {bool(proxy_code_verifier)})"
|
|
895
1859
|
)
|
|
896
1860
|
|
|
897
1861
|
except Exception as e:
|
|
898
1862
|
logger.error("IdP token exchange failed: %s", e)
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
status_code=302,
|
|
1863
|
+
html_content = create_error_html(
|
|
1864
|
+
error_title="OAuth Error",
|
|
1865
|
+
error_message=f"Token exchange with identity provider failed: {e}",
|
|
903
1866
|
)
|
|
1867
|
+
return HTMLResponse(content=html_content, status_code=500)
|
|
904
1868
|
|
|
905
1869
|
# Generate our own authorization code for the client
|
|
906
1870
|
client_code = secrets.token_urlsafe(32)
|
|
907
1871
|
code_expires_at = int(time.time() + DEFAULT_AUTH_CODE_EXPIRY_SECONDS)
|
|
908
1872
|
|
|
909
1873
|
# Store client code with PKCE challenge and IdP tokens
|
|
910
|
-
self.
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
1874
|
+
await self._code_store.put(
|
|
1875
|
+
key=client_code,
|
|
1876
|
+
value=ClientCode(
|
|
1877
|
+
code=client_code,
|
|
1878
|
+
client_id=transaction["client_id"],
|
|
1879
|
+
redirect_uri=transaction["client_redirect_uri"],
|
|
1880
|
+
code_challenge=transaction["code_challenge"],
|
|
1881
|
+
code_challenge_method=transaction["code_challenge_method"],
|
|
1882
|
+
scopes=transaction["scopes"],
|
|
1883
|
+
idp_tokens=idp_tokens,
|
|
1884
|
+
expires_at=code_expires_at,
|
|
1885
|
+
created_at=time.time(),
|
|
1886
|
+
),
|
|
1887
|
+
ttl=DEFAULT_AUTH_CODE_EXPIRY_SECONDS, # Auto-expire after 5 minutes
|
|
1888
|
+
)
|
|
920
1889
|
|
|
921
1890
|
# Clean up transaction
|
|
922
|
-
self.
|
|
1891
|
+
await self._transaction_store.delete(key=txn_id)
|
|
923
1892
|
|
|
924
1893
|
# Build client callback URL with our code and original state
|
|
925
1894
|
client_redirect_uri = transaction["client_redirect_uri"]
|
|
@@ -942,7 +1911,329 @@ class OAuthProxy(OAuthProvider):
|
|
|
942
1911
|
|
|
943
1912
|
except Exception as e:
|
|
944
1913
|
logger.error("Error in IdP callback handler: %s", e, exc_info=True)
|
|
1914
|
+
html_content = create_error_html(
|
|
1915
|
+
error_title="OAuth Error",
|
|
1916
|
+
error_message="Internal server error during OAuth callback processing. Please try again.",
|
|
1917
|
+
)
|
|
1918
|
+
return HTMLResponse(content=html_content, status_code=500)
|
|
1919
|
+
|
|
1920
|
+
# -------------------------------------------------------------------------
|
|
1921
|
+
# Consent Interstitial
|
|
1922
|
+
# -------------------------------------------------------------------------
|
|
1923
|
+
|
|
1924
|
+
def _normalize_uri(self, uri: str) -> str:
|
|
1925
|
+
"""Normalize a URI to a canonical form for consent tracking."""
|
|
1926
|
+
parsed = urlparse(uri)
|
|
1927
|
+
path = parsed.path or ""
|
|
1928
|
+
normalized = f"{parsed.scheme.lower()}://{parsed.netloc.lower()}{path}"
|
|
1929
|
+
if normalized.endswith("/") and len(path) > 1:
|
|
1930
|
+
normalized = normalized[:-1]
|
|
1931
|
+
return normalized
|
|
1932
|
+
|
|
1933
|
+
def _make_client_key(self, client_id: str, redirect_uri: str | AnyUrl) -> str:
|
|
1934
|
+
"""Create a stable key for consent tracking from client_id and redirect_uri."""
|
|
1935
|
+
normalized = self._normalize_uri(str(redirect_uri))
|
|
1936
|
+
return f"{client_id}:{normalized}"
|
|
1937
|
+
|
|
1938
|
+
def _cookie_name(self, base_name: str) -> str:
|
|
1939
|
+
"""Return secure cookie name for HTTPS, fallback for HTTP development."""
|
|
1940
|
+
if self._is_https:
|
|
1941
|
+
return f"__Host-{base_name}"
|
|
1942
|
+
return f"__{base_name}"
|
|
1943
|
+
|
|
1944
|
+
def _sign_cookie(self, payload: str) -> str:
|
|
1945
|
+
"""Sign a cookie payload with HMAC-SHA256.
|
|
1946
|
+
|
|
1947
|
+
Returns: base64(payload).base64(signature)
|
|
1948
|
+
"""
|
|
1949
|
+
# Use upstream client secret as signing key
|
|
1950
|
+
key = self._upstream_client_secret.get_secret_value().encode()
|
|
1951
|
+
signature = hmac.new(key, payload.encode(), hashlib.sha256).digest()
|
|
1952
|
+
signature_b64 = base64.b64encode(signature).decode()
|
|
1953
|
+
return f"{payload}.{signature_b64}"
|
|
1954
|
+
|
|
1955
|
+
def _verify_cookie(self, signed_value: str) -> str | None:
|
|
1956
|
+
"""Verify and extract payload from signed cookie.
|
|
1957
|
+
|
|
1958
|
+
Returns: payload if signature valid, None otherwise
|
|
1959
|
+
"""
|
|
1960
|
+
try:
|
|
1961
|
+
if "." not in signed_value:
|
|
1962
|
+
return None
|
|
1963
|
+
payload, signature_b64 = signed_value.rsplit(".", 1)
|
|
1964
|
+
|
|
1965
|
+
# Verify signature
|
|
1966
|
+
key = self._upstream_client_secret.get_secret_value().encode()
|
|
1967
|
+
expected_sig = hmac.new(key, payload.encode(), hashlib.sha256).digest()
|
|
1968
|
+
provided_sig = base64.b64decode(signature_b64.encode())
|
|
1969
|
+
|
|
1970
|
+
# Constant-time comparison
|
|
1971
|
+
if not hmac.compare_digest(expected_sig, provided_sig):
|
|
1972
|
+
return None
|
|
1973
|
+
|
|
1974
|
+
return payload
|
|
1975
|
+
except Exception:
|
|
1976
|
+
return None
|
|
1977
|
+
|
|
1978
|
+
def _decode_list_cookie(self, request: Request, base_name: str) -> list[str]:
|
|
1979
|
+
"""Decode and verify a signed base64-encoded JSON list from cookie. Returns [] if missing/invalid."""
|
|
1980
|
+
# Prefer secure name, but also check non-secure variant for dev
|
|
1981
|
+
secure_name = self._cookie_name(base_name)
|
|
1982
|
+
raw = request.cookies.get(secure_name) or request.cookies.get(f"__{base_name}")
|
|
1983
|
+
if not raw:
|
|
1984
|
+
return []
|
|
1985
|
+
try:
|
|
1986
|
+
# Verify signature
|
|
1987
|
+
payload = self._verify_cookie(raw)
|
|
1988
|
+
if not payload:
|
|
1989
|
+
logger.debug("Cookie signature verification failed for %s", secure_name)
|
|
1990
|
+
return []
|
|
1991
|
+
|
|
1992
|
+
# Decode payload
|
|
1993
|
+
data = base64.b64decode(payload.encode())
|
|
1994
|
+
value = json.loads(data.decode())
|
|
1995
|
+
if isinstance(value, list):
|
|
1996
|
+
return [str(x) for x in value]
|
|
1997
|
+
except Exception:
|
|
1998
|
+
logger.debug("Failed to decode cookie %s; treating as empty", secure_name)
|
|
1999
|
+
return []
|
|
2000
|
+
|
|
2001
|
+
def _encode_list_cookie(self, values: list[str]) -> str:
|
|
2002
|
+
"""Encode values to base64 and sign with HMAC.
|
|
2003
|
+
|
|
2004
|
+
Returns: signed cookie value (payload.signature)
|
|
2005
|
+
"""
|
|
2006
|
+
payload = json.dumps(values, separators=(",", ":")).encode()
|
|
2007
|
+
payload_b64 = base64.b64encode(payload).decode()
|
|
2008
|
+
return self._sign_cookie(payload_b64)
|
|
2009
|
+
|
|
2010
|
+
def _set_list_cookie(
|
|
2011
|
+
self,
|
|
2012
|
+
response: HTMLResponse | RedirectResponse,
|
|
2013
|
+
base_name: str,
|
|
2014
|
+
value_b64: str,
|
|
2015
|
+
max_age: int,
|
|
2016
|
+
) -> None:
|
|
2017
|
+
name = self._cookie_name(base_name)
|
|
2018
|
+
response.set_cookie(
|
|
2019
|
+
name,
|
|
2020
|
+
value_b64,
|
|
2021
|
+
max_age=max_age,
|
|
2022
|
+
secure=self._is_https,
|
|
2023
|
+
httponly=True,
|
|
2024
|
+
samesite="lax",
|
|
2025
|
+
path="/",
|
|
2026
|
+
)
|
|
2027
|
+
|
|
2028
|
+
def _build_upstream_authorize_url(
|
|
2029
|
+
self, txn_id: str, transaction: dict[str, Any]
|
|
2030
|
+
) -> str:
|
|
2031
|
+
"""Construct the upstream IdP authorization URL using stored transaction data."""
|
|
2032
|
+
query_params: dict[str, Any] = {
|
|
2033
|
+
"response_type": "code",
|
|
2034
|
+
"client_id": self._upstream_client_id,
|
|
2035
|
+
"redirect_uri": f"{str(self.base_url).rstrip('/')}{self._redirect_path}",
|
|
2036
|
+
"state": txn_id,
|
|
2037
|
+
}
|
|
2038
|
+
|
|
2039
|
+
scopes_to_use = transaction.get("scopes") or self.required_scopes or []
|
|
2040
|
+
if scopes_to_use:
|
|
2041
|
+
query_params["scope"] = " ".join(scopes_to_use)
|
|
2042
|
+
|
|
2043
|
+
# If PKCE forwarding was enabled, include the proxy challenge
|
|
2044
|
+
proxy_code_verifier = transaction.get("proxy_code_verifier")
|
|
2045
|
+
if proxy_code_verifier:
|
|
2046
|
+
challenge_bytes = hashlib.sha256(proxy_code_verifier.encode()).digest()
|
|
2047
|
+
proxy_code_challenge = (
|
|
2048
|
+
urlsafe_b64encode(challenge_bytes).decode().rstrip("=")
|
|
2049
|
+
)
|
|
2050
|
+
query_params["code_challenge"] = proxy_code_challenge
|
|
2051
|
+
query_params["code_challenge_method"] = "S256"
|
|
2052
|
+
|
|
2053
|
+
# Forward resource indicator if present in transaction
|
|
2054
|
+
if resource := transaction.get("resource"):
|
|
2055
|
+
query_params["resource"] = resource
|
|
2056
|
+
|
|
2057
|
+
# Extra configured parameters
|
|
2058
|
+
if self._extra_authorize_params:
|
|
2059
|
+
query_params.update(self._extra_authorize_params)
|
|
2060
|
+
|
|
2061
|
+
separator = "&" if "?" in self._upstream_authorization_endpoint else "?"
|
|
2062
|
+
return f"{self._upstream_authorization_endpoint}{separator}{urlencode(query_params)}"
|
|
2063
|
+
|
|
2064
|
+
async def _handle_consent(
|
|
2065
|
+
self, request: Request
|
|
2066
|
+
) -> HTMLResponse | RedirectResponse:
|
|
2067
|
+
"""Handle consent page - dispatch to GET or POST handler based on method."""
|
|
2068
|
+
if request.method == "POST":
|
|
2069
|
+
return await self._submit_consent(request)
|
|
2070
|
+
return await self._show_consent_page(request)
|
|
2071
|
+
|
|
2072
|
+
async def _show_consent_page(
|
|
2073
|
+
self, request: Request
|
|
2074
|
+
) -> HTMLResponse | RedirectResponse:
|
|
2075
|
+
"""Display consent page or auto-approve/deny based on cookies."""
|
|
2076
|
+
from fastmcp.server.server import FastMCP
|
|
2077
|
+
|
|
2078
|
+
txn_id = request.query_params.get("txn_id")
|
|
2079
|
+
if not txn_id:
|
|
2080
|
+
return create_secure_html_response(
|
|
2081
|
+
"<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
|
|
2082
|
+
)
|
|
2083
|
+
|
|
2084
|
+
txn_model = await self._transaction_store.get(key=txn_id)
|
|
2085
|
+
if not txn_model:
|
|
2086
|
+
return create_secure_html_response(
|
|
2087
|
+
"<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
|
|
2088
|
+
)
|
|
2089
|
+
|
|
2090
|
+
txn = txn_model.model_dump()
|
|
2091
|
+
client_key = self._make_client_key(txn["client_id"], txn["client_redirect_uri"])
|
|
2092
|
+
|
|
2093
|
+
approved = set(self._decode_list_cookie(request, "MCP_APPROVED_CLIENTS"))
|
|
2094
|
+
denied = set(self._decode_list_cookie(request, "MCP_DENIED_CLIENTS"))
|
|
2095
|
+
|
|
2096
|
+
if client_key in approved:
|
|
2097
|
+
upstream_url = self._build_upstream_authorize_url(txn_id, txn)
|
|
2098
|
+
return RedirectResponse(url=upstream_url, status_code=302)
|
|
2099
|
+
|
|
2100
|
+
if client_key in denied:
|
|
2101
|
+
callback_params = {
|
|
2102
|
+
"error": "access_denied",
|
|
2103
|
+
"state": txn.get("client_state") or "",
|
|
2104
|
+
}
|
|
2105
|
+
sep = "&" if "?" in txn["client_redirect_uri"] else "?"
|
|
945
2106
|
return RedirectResponse(
|
|
946
|
-
url="
|
|
2107
|
+
url=f"{txn['client_redirect_uri']}{sep}{urlencode(callback_params)}",
|
|
947
2108
|
status_code=302,
|
|
948
2109
|
)
|
|
2110
|
+
|
|
2111
|
+
# Need consent: issue CSRF token and show HTML
|
|
2112
|
+
csrf_token = secrets.token_urlsafe(32)
|
|
2113
|
+
csrf_expires_at = time.time() + 15 * 60
|
|
2114
|
+
|
|
2115
|
+
# Update transaction with CSRF token
|
|
2116
|
+
txn_model.csrf_token = csrf_token
|
|
2117
|
+
txn_model.csrf_expires_at = csrf_expires_at
|
|
2118
|
+
await self._transaction_store.put(
|
|
2119
|
+
key=txn_id, value=txn_model, ttl=15 * 60
|
|
2120
|
+
) # Auto-expire after 15 minutes
|
|
2121
|
+
|
|
2122
|
+
# Update dict for use in HTML generation
|
|
2123
|
+
txn["csrf_token"] = csrf_token
|
|
2124
|
+
txn["csrf_expires_at"] = csrf_expires_at
|
|
2125
|
+
|
|
2126
|
+
# Load client to get client_name if available
|
|
2127
|
+
client = await self.get_client(txn["client_id"])
|
|
2128
|
+
client_name = getattr(client, "client_name", None) if client else None
|
|
2129
|
+
|
|
2130
|
+
# Extract server metadata from app state
|
|
2131
|
+
fastmcp = getattr(request.app.state, "fastmcp_server", None)
|
|
2132
|
+
|
|
2133
|
+
if isinstance(fastmcp, FastMCP):
|
|
2134
|
+
server_name = fastmcp.name
|
|
2135
|
+
icons = fastmcp.icons
|
|
2136
|
+
server_icon_url = icons[0].src if icons else None
|
|
2137
|
+
server_website_url = fastmcp.website_url
|
|
2138
|
+
else:
|
|
2139
|
+
server_name = None
|
|
2140
|
+
server_icon_url = None
|
|
2141
|
+
server_website_url = None
|
|
2142
|
+
|
|
2143
|
+
html = create_consent_html(
|
|
2144
|
+
client_id=txn["client_id"],
|
|
2145
|
+
redirect_uri=txn["client_redirect_uri"],
|
|
2146
|
+
scopes=txn.get("scopes") or [],
|
|
2147
|
+
txn_id=txn_id,
|
|
2148
|
+
csrf_token=csrf_token,
|
|
2149
|
+
client_name=client_name,
|
|
2150
|
+
server_name=server_name,
|
|
2151
|
+
server_icon_url=server_icon_url,
|
|
2152
|
+
server_website_url=server_website_url,
|
|
2153
|
+
csp_policy=self._consent_csp_policy,
|
|
2154
|
+
)
|
|
2155
|
+
response = create_secure_html_response(html)
|
|
2156
|
+
# Store CSRF in cookie with short lifetime
|
|
2157
|
+
self._set_list_cookie(
|
|
2158
|
+
response,
|
|
2159
|
+
"MCP_CONSENT_STATE",
|
|
2160
|
+
self._encode_list_cookie([csrf_token]),
|
|
2161
|
+
max_age=15 * 60,
|
|
2162
|
+
)
|
|
2163
|
+
return response
|
|
2164
|
+
|
|
2165
|
+
async def _submit_consent(
|
|
2166
|
+
self, request: Request
|
|
2167
|
+
) -> RedirectResponse | HTMLResponse:
|
|
2168
|
+
"""Handle consent approval/denial, set cookies, and redirect appropriately."""
|
|
2169
|
+
form = await request.form()
|
|
2170
|
+
txn_id = str(form.get("txn_id", ""))
|
|
2171
|
+
action = str(form.get("action", ""))
|
|
2172
|
+
csrf_token = str(form.get("csrf_token", ""))
|
|
2173
|
+
|
|
2174
|
+
if not txn_id:
|
|
2175
|
+
return create_secure_html_response(
|
|
2176
|
+
"<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
|
|
2177
|
+
)
|
|
2178
|
+
|
|
2179
|
+
txn_model = await self._transaction_store.get(key=txn_id)
|
|
2180
|
+
if not txn_model:
|
|
2181
|
+
return create_secure_html_response(
|
|
2182
|
+
"<h1>Error</h1><p>Invalid or expired transaction</p>", status_code=400
|
|
2183
|
+
)
|
|
2184
|
+
|
|
2185
|
+
txn = txn_model.model_dump()
|
|
2186
|
+
expected_csrf = txn.get("csrf_token")
|
|
2187
|
+
expires_at = float(txn.get("csrf_expires_at") or 0)
|
|
2188
|
+
|
|
2189
|
+
if not expected_csrf or csrf_token != expected_csrf or time.time() > expires_at:
|
|
2190
|
+
return create_secure_html_response(
|
|
2191
|
+
"<h1>Error</h1><p>Invalid or expired consent token</p>", status_code=400
|
|
2192
|
+
)
|
|
2193
|
+
|
|
2194
|
+
client_key = self._make_client_key(txn["client_id"], txn["client_redirect_uri"])
|
|
2195
|
+
|
|
2196
|
+
if action == "approve":
|
|
2197
|
+
approved = set(self._decode_list_cookie(request, "MCP_APPROVED_CLIENTS"))
|
|
2198
|
+
if client_key not in approved:
|
|
2199
|
+
approved.add(client_key)
|
|
2200
|
+
approved_b64 = self._encode_list_cookie(sorted(approved))
|
|
2201
|
+
|
|
2202
|
+
upstream_url = self._build_upstream_authorize_url(txn_id, txn)
|
|
2203
|
+
response = RedirectResponse(url=upstream_url, status_code=302)
|
|
2204
|
+
self._set_list_cookie(
|
|
2205
|
+
response, "MCP_APPROVED_CLIENTS", approved_b64, max_age=365 * 24 * 3600
|
|
2206
|
+
)
|
|
2207
|
+
# Clear CSRF cookie by setting empty short-lived value
|
|
2208
|
+
self._set_list_cookie(
|
|
2209
|
+
response, "MCP_CONSENT_STATE", self._encode_list_cookie([]), max_age=60
|
|
2210
|
+
)
|
|
2211
|
+
return response
|
|
2212
|
+
|
|
2213
|
+
elif action == "deny":
|
|
2214
|
+
denied = set(self._decode_list_cookie(request, "MCP_DENIED_CLIENTS"))
|
|
2215
|
+
if client_key not in denied:
|
|
2216
|
+
denied.add(client_key)
|
|
2217
|
+
denied_b64 = self._encode_list_cookie(sorted(denied))
|
|
2218
|
+
|
|
2219
|
+
callback_params = {
|
|
2220
|
+
"error": "access_denied",
|
|
2221
|
+
"state": txn.get("client_state") or "",
|
|
2222
|
+
}
|
|
2223
|
+
sep = "&" if "?" in txn["client_redirect_uri"] else "?"
|
|
2224
|
+
client_callback_url = (
|
|
2225
|
+
f"{txn['client_redirect_uri']}{sep}{urlencode(callback_params)}"
|
|
2226
|
+
)
|
|
2227
|
+
response = RedirectResponse(url=client_callback_url, status_code=302)
|
|
2228
|
+
self._set_list_cookie(
|
|
2229
|
+
response, "MCP_DENIED_CLIENTS", denied_b64, max_age=365 * 24 * 3600
|
|
2230
|
+
)
|
|
2231
|
+
self._set_list_cookie(
|
|
2232
|
+
response, "MCP_CONSENT_STATE", self._encode_list_cookie([]), max_age=60
|
|
2233
|
+
)
|
|
2234
|
+
return response
|
|
2235
|
+
|
|
2236
|
+
else:
|
|
2237
|
+
return create_secure_html_response(
|
|
2238
|
+
"<h1>Error</h1><p>Invalid action</p>", status_code=400
|
|
2239
|
+
)
|