fastmcp 2.13.0rc2__py3-none-any.whl → 2.13.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -31,9 +31,11 @@ from urllib.parse import urlencode, urlparse
31
31
  import httpx
32
32
  from authlib.common.security import generate_token
33
33
  from authlib.integrations.httpx_client import AsyncOAuth2Client
34
+ from cryptography.fernet import Fernet
34
35
  from key_value.aio.adapters.pydantic import PydanticAdapter
35
36
  from key_value.aio.protocols import AsyncKeyValue
36
- from key_value.aio.stores.memory import MemoryStore
37
+ from key_value.aio.stores.disk import DiskStore
38
+ from key_value.aio.wrappers.encryption import FernetEncryptionWrapper
37
39
  from mcp.server.auth.handlers.token import TokenErrorResponse, TokenSuccessResponse
38
40
  from mcp.server.auth.handlers.token import TokenHandler as _SDKTokenHandler
39
41
  from mcp.server.auth.json_response import PydanticJSONResponse
@@ -55,11 +57,14 @@ from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, SecretStr
55
57
  from starlette.requests import Request
56
58
  from starlette.responses import HTMLResponse, RedirectResponse
57
59
  from starlette.routing import Route
60
+ from typing_extensions import override
58
61
 
62
+ from fastmcp import settings
59
63
  from fastmcp.server.auth.auth import OAuthProvider, TokenVerifier
64
+ from fastmcp.server.auth.handlers.authorize import AuthorizationHandler
60
65
  from fastmcp.server.auth.jwt_issuer import (
61
66
  JWTIssuer,
62
- TokenEncryption,
67
+ derive_jwt_key,
63
68
  )
64
69
  from fastmcp.server.auth.redirect_validation import (
65
70
  validate_redirect_uri,
@@ -68,9 +73,10 @@ from fastmcp.utilities.logging import get_logger
68
73
  from fastmcp.utilities.ui import (
69
74
  BUTTON_STYLES,
70
75
  DETAIL_BOX_STYLES,
76
+ DETAILS_STYLES,
71
77
  INFO_BOX_STYLES,
78
+ REDIRECT_SECTION_STYLES,
72
79
  TOOLTIP_STYLES,
73
- create_detail_box,
74
80
  create_logo,
75
81
  create_page,
76
82
  create_secure_html_response,
@@ -142,12 +148,13 @@ class UpstreamTokenSet(BaseModel):
142
148
  """Stored upstream OAuth tokens from identity provider.
143
149
 
144
150
  These tokens are obtained from the upstream provider (Google, GitHub, etc.)
145
- and are stored encrypted at rest. They are never exposed to MCP clients.
151
+ and stored in plaintext within this model. Encryption is handled transparently
152
+ at the storage layer via FernetEncryptionWrapper. Tokens are never exposed to MCP clients.
146
153
  """
147
154
 
148
155
  upstream_token_id: str # Unique ID for this token set
149
- access_token: bytes # Encrypted upstream access token
150
- refresh_token: bytes | None # Encrypted upstream refresh token
156
+ access_token: str # Upstream access token
157
+ refresh_token: str | None # Upstream refresh token
151
158
  refresh_token_expires_at: (
152
159
  float | None
153
160
  ) # Unix timestamp when refresh token expires (if known)
@@ -233,16 +240,13 @@ def create_consent_html(
233
240
  txn_id: str,
234
241
  csrf_token: str,
235
242
  client_name: str | None = None,
236
- title: str = "Authorization Consent",
243
+ title: str = "Application Access Request",
237
244
  server_name: str | None = None,
238
245
  server_icon_url: str | None = None,
239
246
  server_website_url: str | None = None,
247
+ client_website_url: str | None = None,
240
248
  ) -> str:
241
249
  """Create a styled HTML consent page for OAuth authorization requests."""
242
- # Format scopes for display
243
- scopes_display = ", ".join(scopes) if scopes else "None"
244
-
245
- # Build warning box with client name if available
246
250
  import html as html_module
247
251
 
248
252
  client_display = html_module.escape(client_name or client_id)
@@ -251,29 +255,58 @@ def create_consent_html(
251
255
  # Make server name a hyperlink if website URL is available
252
256
  if server_website_url:
253
257
  website_url_escaped = html_module.escape(server_website_url)
254
- server_display = f'<a href="{website_url_escaped}" target="_blank" rel="noopener noreferrer">{server_name_escaped}</a>'
258
+ server_display = f'<a href="{website_url_escaped}" target="_blank" rel="noopener noreferrer" class="server-name-link">{server_name_escaped}</a>'
255
259
  else:
256
260
  server_display = server_name_escaped
257
261
 
258
- warning_box = f"""
259
- <div class="warning-box">
260
- <p><strong>{client_display}</strong> is requesting access to <strong>{server_display}</strong>.</p>
261
- <p>Review the details below before approving.</p>
262
+ # Build intro box with call-to-action
263
+ intro_box = f"""
264
+ <div class="info-box">
265
+ <p>The application <strong>{client_display}</strong> wants to access the MCP server <strong>{server_display}</strong>. Please ensure you recognize the callback address below.</p>
266
+ </div>
267
+ """
268
+
269
+ # Build redirect URI section (yellow box, centered)
270
+ redirect_uri_escaped = html_module.escape(redirect_uri)
271
+ redirect_section = f"""
272
+ <div class="redirect-section">
273
+ <span class="label">Credentials will be sent to:</span>
274
+ <div class="value">{redirect_uri_escaped}</div>
262
275
  </div>
263
276
  """
264
277
 
265
- # Build detail box with client information
266
- detail_rows = []
267
- if client_name:
268
- detail_rows.append(("Client Name", client_name))
269
- detail_rows.extend(
278
+ # Build advanced details with collapsible section
279
+ detail_rows = [
280
+ ("Application Name", html_module.escape(client_name or client_id)),
281
+ ("Application Website", html_module.escape(client_website_url or "N/A")),
282
+ ("Application ID", client_id),
283
+ ("Redirect URI", redirect_uri_escaped),
284
+ (
285
+ "Requested Scopes",
286
+ ", ".join(html_module.escape(s) for s in scopes) if scopes else "None",
287
+ ),
288
+ ]
289
+
290
+ detail_rows_html = "\n".join(
270
291
  [
271
- ("Client ID", client_id),
272
- ("Redirect URI", redirect_uri),
273
- ("Requested Scopes", scopes_display),
292
+ f"""
293
+ <div class="detail-row">
294
+ <div class="detail-label">{label}:</div>
295
+ <div class="detail-value">{value}</div>
296
+ </div>
297
+ """
298
+ for label, value in detail_rows
274
299
  ]
275
300
  )
276
- detail_box = create_detail_box(detail_rows)
301
+
302
+ advanced_details = f"""
303
+ <details>
304
+ <summary>Advanced Details</summary>
305
+ <div class="detail-box">
306
+ {detail_rows_html}
307
+ </div>
308
+ </details>
309
+ """
277
310
 
278
311
  # Build form with buttons
279
312
  form = f"""
@@ -281,13 +314,13 @@ def create_consent_html(
281
314
  <input type="hidden" name="txn_id" value="{txn_id}" />
282
315
  <input type="hidden" name="csrf_token" value="{csrf_token}" />
283
316
  <div class="button-group">
284
- <button type="submit" name="action" value="approve" class="btn-approve">Approve</button>
317
+ <button type="submit" name="action" value="approve" class="btn-approve">Allow Access</button>
285
318
  <button type="submit" name="action" value="deny" class="btn-deny">Deny</button>
286
319
  </div>
287
320
  </form>
288
321
  """
289
322
 
290
- # Build help link with tooltip
323
+ # Build help link with tooltip (identical to current implementation)
291
324
  help_link = """
292
325
  <div class="help-link-container">
293
326
  <span class="help-link">
@@ -312,9 +345,10 @@ def create_consent_html(
312
345
  content = f"""
313
346
  <div class="container">
314
347
  {create_logo(icon_url=server_icon_url, alt_text=server_name or "FastMCP")}
315
- <h1>Authorization Consent</h1>
316
- {warning_box}
317
- {detail_box}
348
+ <h1>Application Access Request</h1>
349
+ {intro_box}
350
+ {redirect_section}
351
+ {advanced_details}
318
352
  {form}
319
353
  </div>
320
354
  {help_link}
@@ -322,7 +356,12 @@ def create_consent_html(
322
356
 
323
357
  # Additional styles needed for this page
324
358
  additional_styles = (
325
- INFO_BOX_STYLES + DETAIL_BOX_STYLES + BUTTON_STYLES + TOOLTIP_STYLES
359
+ INFO_BOX_STYLES
360
+ + REDIRECT_SECTION_STYLES
361
+ + DETAILS_STYLES
362
+ + DETAIL_BOX_STYLES
363
+ + BUTTON_STYLES
364
+ + TOOLTIP_STYLES
326
365
  )
327
366
 
328
367
  # Need to allow form-action for form submission
@@ -525,10 +564,10 @@ class OAuthProxy(OAuthProvider):
525
564
  extra_token_params: dict[str, str] | None = None,
526
565
  # Client storage
527
566
  client_storage: AsyncKeyValue | None = None,
528
- # JWT signing key (optional, ephemeral if not provided)
567
+ # JWT signing key
529
568
  jwt_signing_key: str | bytes | None = None,
530
- # Token encryption key (optional, ephemeral if not provided)
531
- token_encryption_key: str | bytes | None = None,
569
+ # Consent screen configuration
570
+ require_authorization_consent: bool = True,
532
571
  ):
533
572
  """Initialize the OAuth proxy provider.
534
573
 
@@ -562,14 +601,18 @@ class OAuthProxy(OAuthProvider):
562
601
  Example: {"audience": "https://api.example.com"}
563
602
  extra_token_params: Additional parameters to forward to the upstream token endpoint.
564
603
  Useful for provider-specific parameters during token exchange.
565
- client_storage: An AsyncKeyValue-compatible store for client registrations, registrations are stored in memory if not provided
566
- jwt_signing_key: Optional secret for signing FastMCP JWT tokens (accepts any string or bytes).
567
- Default: ephemeral (random salt at startup, won't survive restart).
568
- Production: provide explicit key from environment variable.
569
- token_encryption_key: Optional secret for encrypting upstream tokens at rest (accepts any string or bytes).
570
- Default: ephemeral (random salt at startup, won't survive restart).
571
- Production: provide explicit key from environment variable.
604
+ client_storage: Storage backend for OAuth state (client registrations, tokens).
605
+ If None, an encrypted DiskStore will be created in the data directory.
606
+ jwt_signing_key: Secret for signing FastMCP JWT tokens (any string or bytes).
607
+ If bytes are provided, they will be used as-is.
608
+ If a string is provided, it will be derived into a 32-byte key using PBKDF2 (1.2M iterations).
609
+ If not provided, it will be derived from the upstream client secret using HKDF.
610
+ require_authorization_consent: Whether to require user consent before authorizing clients (default True).
611
+ When True, users see a consent screen before being redirected to the upstream IdP.
612
+ When False, authorization proceeds directly without user confirmation.
613
+ SECURITY WARNING: Only disable for local development or testing environments.
572
614
  """
615
+
573
616
  # Always enable DCR since we implement it locally for MCP clients
574
617
  client_registration_options = ClientRegistrationOptions(
575
618
  enabled=True,
@@ -591,12 +634,14 @@ class OAuthProxy(OAuthProvider):
591
634
  )
592
635
 
593
636
  # Store upstream configuration
594
- self._upstream_authorization_endpoint = upstream_authorization_endpoint
595
- self._upstream_token_endpoint = upstream_token_endpoint
596
- self._upstream_client_id = upstream_client_id
597
- self._upstream_client_secret = SecretStr(upstream_client_secret)
598
- self._upstream_revocation_endpoint = upstream_revocation_endpoint
599
- self._default_scope_str = " ".join(self.required_scopes or [])
637
+ self._upstream_authorization_endpoint: str = upstream_authorization_endpoint
638
+ self._upstream_token_endpoint: str = upstream_token_endpoint
639
+ self._upstream_client_id: str = upstream_client_id
640
+ self._upstream_client_secret: SecretStr = SecretStr(
641
+ secret_value=upstream_client_secret
642
+ )
643
+ self._upstream_revocation_endpoint: str | None = upstream_revocation_endpoint
644
+ self._default_scope_str: str = " ".join(self.required_scopes or [])
600
645
 
601
646
  # Store redirect configuration
602
647
  if not redirect_path:
@@ -612,39 +657,85 @@ class OAuthProxy(OAuthProvider):
612
657
  ):
613
658
  logger.warning(
614
659
  "allowed_client_redirect_uris is empty list; no redirect URIs will be accepted. "
615
- "This will block all OAuth clients."
660
+ + "This will block all OAuth clients."
616
661
  )
617
- self._allowed_client_redirect_uris = allowed_client_redirect_uris
662
+ self._allowed_client_redirect_uris: list[str] | None = (
663
+ allowed_client_redirect_uris
664
+ )
618
665
 
619
666
  # PKCE configuration
620
- self._forward_pkce = forward_pkce
667
+ self._forward_pkce: bool = forward_pkce
621
668
 
622
669
  # Token endpoint authentication
623
- self._token_endpoint_auth_method = token_endpoint_auth_method
670
+ self._token_endpoint_auth_method: str | None = token_endpoint_auth_method
671
+
672
+ # Consent screen configuration
673
+ self._require_authorization_consent: bool = require_authorization_consent
674
+ if not require_authorization_consent:
675
+ logger.warning(
676
+ "Authorization consent screen disabled - only use for local development or testing. "
677
+ + "In production, this screen protects against confused deputy attacks."
678
+ )
624
679
 
625
680
  # Extra parameters for authorization and token endpoints
626
- self._extra_authorize_params = extra_authorize_params or {}
627
- self._extra_token_params = extra_token_params or {}
681
+ self._extra_authorize_params: dict[str, str] = extra_authorize_params or {}
682
+ self._extra_token_params: dict[str, str] = extra_token_params or {}
628
683
 
629
- self._client_storage: AsyncKeyValue = client_storage or MemoryStore()
684
+ if jwt_signing_key is None:
685
+ jwt_signing_key = derive_jwt_key(
686
+ high_entropy_material=upstream_client_secret,
687
+ salt="fastmcp-jwt-signing-key",
688
+ )
630
689
 
631
- # Warn if using MemoryStore in production
632
- if isinstance(client_storage, MemoryStore):
633
- logger.warning(
634
- "Using in-memory storage - all OAuth state (clients, tokens) will be lost on restart. "
635
- "Additionally, without explicit jwt_signing_key and token_encryption_key, "
636
- "keys are ephemeral and tokens won't survive restart even with persistent storage. "
637
- "For production, configure persistent storage AND explicit keys."
690
+ if isinstance(jwt_signing_key, str):
691
+ if len(jwt_signing_key) < 12:
692
+ logger.warning(
693
+ "jwt_signing_key is less than 12 characters; it is recommended to use a longer. "
694
+ + "string for the key derivation."
695
+ )
696
+ jwt_signing_key = derive_jwt_key(
697
+ low_entropy_material=jwt_signing_key,
698
+ salt="fastmcp-jwt-signing-key",
638
699
  )
639
700
 
701
+ self._jwt_issuer: JWTIssuer = JWTIssuer(
702
+ issuer=str(self.base_url),
703
+ audience=f"{str(self.base_url).rstrip('/')}/mcp",
704
+ signing_key=jwt_signing_key,
705
+ )
706
+
707
+ # If the user does not provide a store, we will provide an encrypted disk store
708
+ if client_storage is None:
709
+ storage_encryption_key = derive_jwt_key(
710
+ high_entropy_material=jwt_signing_key.decode(),
711
+ salt="fastmcp-storage-encryption-key",
712
+ )
713
+ client_storage = FernetEncryptionWrapper(
714
+ key_value=DiskStore(directory=settings.home / "oauth-proxy"),
715
+ fernet=Fernet(key=storage_encryption_key),
716
+ )
717
+
718
+ self._client_storage: AsyncKeyValue = client_storage
719
+
640
720
  # Cache HTTPS check to avoid repeated logging
641
- self._is_https = str(self.base_url).startswith("https://")
721
+ self._is_https: bool = str(self.base_url).startswith("https://")
642
722
  if not self._is_https:
643
723
  logger.warning(
644
724
  "Using non-secure cookies for development; deploy with HTTPS for production."
645
725
  )
646
726
 
647
- self._client_store = PydanticAdapter[ProxyDCRClient](
727
+ self._upstream_token_store: PydanticAdapter[UpstreamTokenSet] = PydanticAdapter[
728
+ UpstreamTokenSet
729
+ ](
730
+ key_value=self._client_storage,
731
+ pydantic_model=UpstreamTokenSet,
732
+ default_collection="mcp-upstream-tokens",
733
+ raise_on_validation_error=True,
734
+ )
735
+
736
+ self._client_store: PydanticAdapter[ProxyDCRClient] = PydanticAdapter[
737
+ ProxyDCRClient
738
+ ](
648
739
  key_value=self._client_storage,
649
740
  pydantic_model=ProxyDCRClient,
650
741
  default_collection="mcp-oauth-proxy-clients",
@@ -653,43 +744,32 @@ class OAuthProxy(OAuthProvider):
653
744
 
654
745
  # OAuth transaction storage for IdP callback forwarding
655
746
  # Reuse client_storage with different collections for state management
656
- self._transaction_store = PydanticAdapter[OAuthTransaction](
747
+ self._transaction_store: PydanticAdapter[OAuthTransaction] = PydanticAdapter[
748
+ OAuthTransaction
749
+ ](
657
750
  key_value=self._client_storage,
658
751
  pydantic_model=OAuthTransaction,
659
752
  default_collection="mcp-oauth-transactions",
660
753
  raise_on_validation_error=True,
661
754
  )
662
755
 
663
- self._code_store = PydanticAdapter[ClientCode](
756
+ self._code_store: PydanticAdapter[ClientCode] = PydanticAdapter[ClientCode](
664
757
  key_value=self._client_storage,
665
758
  pydantic_model=ClientCode,
666
759
  default_collection="mcp-authorization-codes",
667
760
  raise_on_validation_error=True,
668
761
  )
669
762
 
670
- # Storage for upstream tokens (encrypted at rest)
671
- self._upstream_token_store = PydanticAdapter[UpstreamTokenSet](
672
- key_value=self._client_storage,
673
- pydantic_model=UpstreamTokenSet,
674
- default_collection="mcp-upstream-tokens",
675
- raise_on_validation_error=True,
676
- )
677
-
678
763
  # Storage for JTI mappings (FastMCP token -> upstream token)
679
- self._jti_mapping_store = PydanticAdapter[JTIMapping](
764
+ self._jti_mapping_store: PydanticAdapter[JTIMapping] = PydanticAdapter[
765
+ JTIMapping
766
+ ](
680
767
  key_value=self._client_storage,
681
768
  pydantic_model=JTIMapping,
682
769
  default_collection="mcp-jti-mappings",
683
770
  raise_on_validation_error=True,
684
771
  )
685
772
 
686
- # JWT issuer and encryption (initialized lazily on first use)
687
- self._custom_jwt_key = jwt_signing_key
688
- self._custom_encryption_key = token_encryption_key
689
- self._jwt_issuer: JWTIssuer | None = None
690
- self._token_encryption: TokenEncryption | None = None
691
- self._jwt_initialized = False
692
-
693
773
  # Local state for token bookkeeping only (no client caching)
694
774
  self._access_tokens: dict[str, AccessToken] = {}
695
775
  self._refresh_tokens: dict[str, RefreshToken] = {}
@@ -699,7 +779,7 @@ class OAuthProxy(OAuthProvider):
699
779
  self._refresh_to_access: dict[str, str] = {}
700
780
 
701
781
  # Use the provided token validator
702
- self._token_validator = token_verifier
782
+ self._token_validator: TokenVerifier = token_verifier
703
783
 
704
784
  logger.debug(
705
785
  "Initialized OAuth proxy provider with upstream server %s",
@@ -725,91 +805,11 @@ class OAuthProxy(OAuthProvider):
725
805
 
726
806
  return code_verifier, code_challenge
727
807
 
728
- # -------------------------------------------------------------------------
729
- # JWT Token Factory Initialization
730
- # -------------------------------------------------------------------------
731
-
732
- async def _ensure_jwt_initialized(self) -> None:
733
- """Initialize JWT issuer and token encryption (lazy initialization).
734
-
735
- Key derivation strategy:
736
- - Default: Generate random salt at startup, derive ephemeral keys
737
- → Keys change on restart, all tokens become invalid
738
- → Perfect for development/testing where re-auth is acceptable
739
-
740
- - Production: User provides explicit keys via parameters
741
- → Keys stable across restarts when combined with persistent storage
742
- → Tokens survive restart, seamless client reconnection
743
- """
744
- if self._jwt_initialized:
745
- return
746
-
747
- # Generate random salt for this server instance (NOT persisted)
748
- server_salt = secrets.token_urlsafe(32)
749
-
750
- # Derive or use custom JWT signing key
751
- from fastmcp.server.auth.jwt_issuer import derive_key_from_secret
752
-
753
- if self._custom_jwt_key:
754
- jwt_key = derive_key_from_secret(
755
- secret=self._custom_jwt_key,
756
- salt="fastmcp-jwt-signing-v1",
757
- info=b"HS256",
758
- )
759
- logger.info("Using explicit JWT signing key (will survive restarts)")
760
- else:
761
- # Ephemeral key from random salt + upstream secret
762
- upstream_secret = self._upstream_client_secret.get_secret_value()
763
- jwt_key = derive_key_from_secret(
764
- secret=upstream_secret,
765
- salt=f"fastmcp-jwt-signing-v1-{server_salt}",
766
- info=b"HS256",
767
- )
768
- logger.info(
769
- "Using ephemeral JWT signing key - tokens will NOT survive server restart. "
770
- "For production, provide explicit jwt_signing_key parameter and use persistent storage."
771
- )
772
-
773
- # Initialize JWT issuer
774
- issuer = str(self.base_url)
775
- audience = f"{str(self.base_url).rstrip('/')}/mcp"
776
- self._jwt_issuer = JWTIssuer(
777
- issuer=issuer,
778
- audience=audience,
779
- signing_key=jwt_key,
780
- )
781
-
782
- # Derive or use custom encryption key
783
- if self._custom_encryption_key:
784
- encryption_key = derive_key_from_secret(
785
- secret=self._custom_encryption_key,
786
- salt="fastmcp-token-encryption-v1",
787
- info=b"Fernet",
788
- )
789
- # Fernet needs base64url-encoded key
790
- encryption_key = base64.urlsafe_b64encode(encryption_key)
791
- logger.info("Using explicit token encryption key (will survive restarts)")
792
- else:
793
- # Ephemeral key from random salt + upstream secret
794
- upstream_secret = self._upstream_client_secret.get_secret_value()
795
- key_material = derive_key_from_secret(
796
- secret=upstream_secret,
797
- salt=f"fastmcp-token-encryption-v1-{server_salt}",
798
- info=b"Fernet",
799
- )
800
- encryption_key = base64.urlsafe_b64encode(key_material)
801
- logger.info(
802
- "Using ephemeral token encryption key - encrypted tokens will NOT survive server restart. "
803
- "For production, provide explicit token_encryption_key parameter and use persistent storage."
804
- )
805
-
806
- self._token_encryption = TokenEncryption(encryption_key)
807
- self._jwt_initialized = True
808
-
809
808
  # -------------------------------------------------------------------------
810
809
  # Client Registration (Local Implementation)
811
810
  # -------------------------------------------------------------------------
812
811
 
812
+ @override
813
813
  async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
814
814
  """Get client information by ID. This is generally the random ID
815
815
  provided to the DCR client during registration, not the upstream client ID.
@@ -825,6 +825,7 @@ class OAuthProxy(OAuthProvider):
825
825
 
826
826
  return client
827
827
 
828
+ @override
828
829
  async def register_client(self, client_info: OAuthClientInformationFull) -> None:
829
830
  """Register a client locally
830
831
 
@@ -871,6 +872,7 @@ class OAuthProxy(OAuthProvider):
871
872
  # Authorization Flow (Proxy to Upstream)
872
873
  # -------------------------------------------------------------------------
873
874
 
875
+ @override
874
876
  async def authorize(
875
877
  self,
876
878
  client: OAuthClientInformationFull,
@@ -882,6 +884,9 @@ class OAuthProxy(OAuthProvider):
882
884
  1. Store transaction with client details and PKCE (if forwarding)
883
885
  2. Return local /consent URL; browser visits consent first
884
886
  3. Consent handler redirects to upstream IdP if approved/already approved
887
+
888
+ If consent is disabled (require_authorization_consent=False), skip the consent screen
889
+ and redirect directly to the upstream IdP.
885
890
  """
886
891
  # Generate transaction ID for this authorization request
887
892
  txn_id = secrets.token_urlsafe(32)
@@ -897,23 +902,37 @@ class OAuthProxy(OAuthProvider):
897
902
  )
898
903
 
899
904
  # Store transaction data for IdP callback processing
905
+ transaction = OAuthTransaction(
906
+ txn_id=txn_id,
907
+ client_id=client.client_id,
908
+ client_redirect_uri=str(params.redirect_uri),
909
+ client_state=params.state or "",
910
+ code_challenge=params.code_challenge,
911
+ code_challenge_method=getattr(params, "code_challenge_method", "S256"),
912
+ scopes=params.scopes or [],
913
+ created_at=time.time(),
914
+ resource=getattr(params, "resource", None),
915
+ proxy_code_verifier=proxy_code_verifier,
916
+ )
900
917
  await self._transaction_store.put(
901
918
  key=txn_id,
902
- value=OAuthTransaction(
903
- txn_id=txn_id,
904
- client_id=client.client_id,
905
- client_redirect_uri=str(params.redirect_uri),
906
- client_state=params.state or "",
907
- code_challenge=params.code_challenge,
908
- code_challenge_method=getattr(params, "code_challenge_method", "S256"),
909
- scopes=params.scopes or [],
910
- created_at=time.time(),
911
- resource=getattr(params, "resource", None),
912
- proxy_code_verifier=proxy_code_verifier,
913
- ),
919
+ value=transaction,
914
920
  ttl=15 * 60, # Auto-expire after 15 minutes
915
921
  )
916
922
 
923
+ # If consent is disabled, skip consent screen and go directly to upstream IdP
924
+ if not self._require_authorization_consent:
925
+ upstream_url = self._build_upstream_authorize_url(
926
+ txn_id, transaction.model_dump()
927
+ )
928
+ logger.debug(
929
+ "Starting OAuth transaction %s for client %s, redirecting directly to upstream IdP (consent disabled, PKCE forwarding: %s)",
930
+ txn_id,
931
+ client.client_id,
932
+ "enabled" if proxy_code_challenge else "disabled",
933
+ )
934
+ return upstream_url
935
+
917
936
  consent_url = f"{str(self.base_url).rstrip('/')}/consent?txn_id={txn_id}"
918
937
 
919
938
  logger.debug(
@@ -928,6 +947,7 @@ class OAuthProxy(OAuthProvider):
928
947
  # Authorization Code Handling
929
948
  # -------------------------------------------------------------------------
930
949
 
950
+ @override
931
951
  async def load_authorization_code(
932
952
  self,
933
953
  client: OAuthClientInformationFull,
@@ -947,7 +967,7 @@ class OAuthProxy(OAuthProvider):
947
967
  # Check if code expired
948
968
  if time.time() > code_model.expires_at:
949
969
  logger.debug("Authorization code expired: %s", authorization_code)
950
- await self._code_store.delete(key=authorization_code)
970
+ _ = await self._code_store.delete(key=authorization_code)
951
971
  return None
952
972
 
953
973
  # Verify client ID matches
@@ -963,13 +983,14 @@ class OAuthProxy(OAuthProvider):
963
983
  return AuthorizationCode(
964
984
  code=authorization_code,
965
985
  client_id=client.client_id,
966
- redirect_uri=code_model.redirect_uri,
986
+ redirect_uri=AnyUrl(url=code_model.redirect_uri),
967
987
  redirect_uri_provided_explicitly=True,
968
988
  scopes=code_model.scopes,
969
989
  expires_at=code_model.expires_at,
970
990
  code_challenge=code_model.code_challenge or "",
971
991
  )
972
992
 
993
+ @override
973
994
  async def exchange_authorization_code(
974
995
  self,
975
996
  client: OAuthClientInformationFull,
@@ -986,11 +1007,6 @@ class OAuthProxy(OAuthProvider):
986
1007
 
987
1008
  PKCE validation is handled by the MCP framework before this method is called.
988
1009
  """
989
- # Ensure JWT issuer is initialized
990
- await self._ensure_jwt_initialized()
991
- assert self._jwt_issuer is not None
992
- assert self._token_encryption is not None
993
-
994
1010
  # Look up stored code data
995
1011
  code_model = await self._code_store.get(key=authorization_code.code)
996
1012
  if not code_model:
@@ -1041,8 +1057,8 @@ class OAuthProxy(OAuthProvider):
1041
1057
  # Encrypt and store upstream tokens
1042
1058
  upstream_token_set = UpstreamTokenSet(
1043
1059
  upstream_token_id=upstream_token_id,
1044
- access_token=self._token_encryption.encrypt(idp_tokens["access_token"]),
1045
- refresh_token=self._token_encryption.encrypt(idp_tokens["refresh_token"])
1060
+ access_token=idp_tokens["access_token"],
1061
+ refresh_token=idp_tokens["refresh_token"]
1046
1062
  if idp_tokens.get("refresh_token")
1047
1063
  else None,
1048
1064
  refresh_token_expires_at=refresh_token_expires_at,
@@ -1164,11 +1180,6 @@ class OAuthProxy(OAuthProvider):
1164
1180
  5. Issue new FastMCP access token
1165
1181
  6. Keep same FastMCP refresh token (unless upstream rotates)
1166
1182
  """
1167
- # Ensure JWT issuer is initialized
1168
- await self._ensure_jwt_initialized()
1169
- assert self._jwt_issuer is not None
1170
- assert self._token_encryption is not None
1171
-
1172
1183
  # Verify FastMCP refresh token
1173
1184
  try:
1174
1185
  refresh_payload = self._jwt_issuer.verify_token(refresh_token.token)
@@ -1197,10 +1208,6 @@ class OAuthProxy(OAuthProvider):
1197
1208
  logger.error("No upstream refresh token available")
1198
1209
  raise TokenError("invalid_grant", "Refresh not supported for this token")
1199
1210
 
1200
- upstream_refresh_token = self._token_encryption.decrypt(
1201
- upstream_token_set.refresh_token
1202
- )
1203
-
1204
1211
  # Refresh upstream token using authlib
1205
1212
  oauth_client = AsyncOAuth2Client(
1206
1213
  client_id=self._upstream_client_id,
@@ -1213,7 +1220,7 @@ class OAuthProxy(OAuthProvider):
1213
1220
  logger.debug("Refreshing upstream token (jti=%s)", refresh_jti[:8])
1214
1221
  token_response: dict[str, Any] = await oauth_client.refresh_token( # type: ignore[misc]
1215
1222
  url=self._upstream_token_endpoint,
1216
- refresh_token=upstream_refresh_token,
1223
+ refresh_token=upstream_token_set.refresh_token,
1217
1224
  scope=" ".join(scopes) if scopes else None,
1218
1225
  )
1219
1226
  logger.debug("Successfully refreshed upstream token")
@@ -1225,18 +1232,14 @@ class OAuthProxy(OAuthProvider):
1225
1232
  new_expires_in = int(
1226
1233
  token_response.get("expires_in", DEFAULT_ACCESS_TOKEN_EXPIRY_SECONDS)
1227
1234
  )
1228
- upstream_token_set.access_token = self._token_encryption.encrypt(
1229
- token_response["access_token"]
1230
- )
1235
+ upstream_token_set.access_token = token_response["access_token"]
1231
1236
  upstream_token_set.expires_at = time.time() + new_expires_in
1232
1237
 
1233
1238
  # Handle upstream refresh token rotation and expiry
1234
1239
  new_refresh_expires_in = None
1235
1240
  if new_upstream_refresh := token_response.get("refresh_token"):
1236
- if new_upstream_refresh != upstream_refresh_token:
1237
- upstream_token_set.refresh_token = self._token_encryption.encrypt(
1238
- new_upstream_refresh
1239
- )
1241
+ if new_upstream_refresh != upstream_token_set.refresh_token:
1242
+ upstream_token_set.refresh_token = new_upstream_refresh
1240
1243
  logger.debug("Upstream refresh token rotated")
1241
1244
 
1242
1245
  # Update refresh token expiry if provided
@@ -1375,11 +1378,6 @@ class OAuthProxy(OAuthProvider):
1375
1378
  The FastMCP JWT is a reference token - all authorization data comes
1376
1379
  from validating the upstream token via the TokenVerifier.
1377
1380
  """
1378
- # Ensure JWT issuer and encryption are initialized
1379
- await self._ensure_jwt_initialized()
1380
- assert self._jwt_issuer is not None
1381
- assert self._token_encryption is not None
1382
-
1383
1381
  try:
1384
1382
  # 1. Verify FastMCP JWT signature and claims
1385
1383
  payload = self._jwt_issuer.verify_token(token)
@@ -1400,15 +1398,12 @@ class OAuthProxy(OAuthProvider):
1400
1398
  )
1401
1399
  return None
1402
1400
 
1403
- # 3. Decrypt upstream token
1404
- upstream_token = self._token_encryption.decrypt(
1401
+ # 3. Validate with upstream provider (delegated to TokenVerifier)
1402
+ # This calls the real token validator (GitHub API, JWKS, etc.)
1403
+ validated = await self._token_validator.verify_token(
1405
1404
  upstream_token_set.access_token
1406
1405
  )
1407
1406
 
1408
- # 4. Validate with upstream provider (delegated to TokenVerifier)
1409
- # This calls the real token validator (GitHub API, JWKS, etc.)
1410
- validated = await self._token_validator.verify_token(upstream_token)
1411
-
1412
1407
  if not validated:
1413
1408
  logger.debug("Upstream token validation failed")
1414
1409
  return None
@@ -1474,10 +1469,11 @@ class OAuthProxy(OAuthProvider):
1474
1469
  self,
1475
1470
  mcp_path: str | None = None,
1476
1471
  ) -> list[Route]:
1477
- """Get OAuth routes with custom proxy token handler.
1472
+ """Get OAuth routes with custom handlers for better error UX.
1478
1473
 
1479
- This method creates standard OAuth routes and replaces the token endpoint
1480
- with our proxy handler that forwards requests to the upstream OAuth server.
1474
+ This method creates standard OAuth routes and replaces:
1475
+ - /authorize endpoint: Enhanced error responses for unregistered clients
1476
+ - /token endpoint: OAuth 2.1 compliant error codes
1481
1477
 
1482
1478
  Args:
1483
1479
  mcp_path: The path where the MCP endpoint is mounted (e.g., "/mcp")
@@ -1487,6 +1483,7 @@ class OAuthProxy(OAuthProvider):
1487
1483
  routes = super().get_routes(mcp_path)
1488
1484
  custom_routes = []
1489
1485
  token_route_found = False
1486
+ authorize_route_found = False
1490
1487
 
1491
1488
  logger.debug(
1492
1489
  f"get_routes called - configuring OAuth routes in {len(routes)} routes"
@@ -1497,8 +1494,30 @@ class OAuthProxy(OAuthProvider):
1497
1494
  f"Route {i}: {route} - path: {getattr(route, 'path', 'N/A')}, methods: {getattr(route, 'methods', 'N/A')}"
1498
1495
  )
1499
1496
 
1500
- # Replace the token endpoint with our custom handler that returns proper OAuth 2.1 error codes
1497
+ # Replace the authorize endpoint with our enhanced handler for better error UX
1501
1498
  if (
1499
+ isinstance(route, Route)
1500
+ and route.path == "/authorize"
1501
+ and route.methods is not None
1502
+ and ("GET" in route.methods or "POST" in route.methods)
1503
+ ):
1504
+ authorize_route_found = True
1505
+ # Replace with our enhanced authorization handler
1506
+ authorize_handler = AuthorizationHandler(
1507
+ provider=self,
1508
+ base_url=self.base_url,
1509
+ server_name=None, # Could be extended to pass server metadata
1510
+ server_icon_url=None,
1511
+ )
1512
+ custom_routes.append(
1513
+ Route(
1514
+ path="/authorize",
1515
+ endpoint=authorize_handler.handle,
1516
+ methods=["GET", "POST"],
1517
+ )
1518
+ )
1519
+ # Replace the token endpoint with our custom handler that returns proper OAuth 2.1 error codes
1520
+ elif (
1502
1521
  isinstance(route, Route)
1503
1522
  and route.path == "/token"
1504
1523
  and route.methods is not None
@@ -1542,7 +1561,7 @@ class OAuthProxy(OAuthProvider):
1542
1561
  )
1543
1562
 
1544
1563
  logger.debug(
1545
- f"✅ OAuth routes configured: token_endpoint={token_route_found}, total routes={len(custom_routes)} (includes OAuth callback + consent)"
1564
+ f"✅ OAuth routes configured: authorize_endpoint={authorize_route_found}, token_endpoint={token_route_found}, total routes={len(custom_routes)} (includes OAuth callback + consent)"
1546
1565
  )
1547
1566
  return custom_routes
1548
1567